昇腾社区首页
中文
注册

def forward

函数功能

根据参数if_fused选择不同的实现方式,对输入张量hidden_states进行RMS归一化操作。

函数原型

def forward(self, hidden_states, if_fused=True):

参数说明

参数名

输入/输出

类型

说明

hidden_states

输入

torch.Tensor

输入张量,形状通常为(batch_size, sequence_length, hidden_size)或类似的多维张量。最后一维的大小必须等于hidden_size。

if_fused

输入

bool

是否使用融合算子,默认为True。

返回值说明

归一化后的张量,形状与输入张量hidden_states相同。