def forward
函数功能
根据不同的op_type设置,执行线性变换。
函数原型
def forward(self, input: Tensor) -> Tensor:
参数说明
参数名 |
输入/输出 |
类型 |
说明 |
---|---|---|---|
input |
输入 |
torch.Tensor |
输入张量,形状通常为(batch_size, sequence_length, hidden_size)或类似的多维张量。最后一维的大小必须等于in_features。 |
返回值说明
线性变换后的结果。
父主题: class Linear