def forward
函数功能
根据参数if_fused选择不同的实现方式,对输入张量hidden_states进行RMS归一化操作。
函数原型
def forward(self, hidden_states, if_fused=True):
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
hidden_states |
输入 |
torch.Tensor |
输入张量,形状通常为(batch_size, sequence_length, hidden_size)或类似的多维张量。最后一维的大小必须等于hidden_size。 |
if_fused |
输入 |
bool |
是否使用融合算子,默认为True。 |
返回值说明
归一化后的张量,形状与输入张量hidden_states相同。
父主题: class RMSNorm