昇腾社区首页
中文
注册

def forward

函数功能

根据不同的op_type设置,执行线性变换。

函数原型

def forward(self, input: Tensor) -> Tensor:

参数说明

参数名

输入/输出

类型

说明

input

输入

torch.Tensor

输入张量,形状通常为(batch_size, sequence_length, hidden_size)或类似的多维张量。最后一维的大小必须等于in_features。

返回值说明

线性变换后的结果。