LinearParam
属性 |
类型 |
默认值 |
描述 |
---|---|---|---|
transpose_a |
bool |
False |
- |
transpose_b |
bool |
True |
- |
has_bias |
bool |
True |
- |
out_data_type |
torch_atb.AclDataType |
torch_atb.AclDataType.ACL_DT_UNDEFINED |
根据输入tensors自动推导输出tensors数据类型。 |
en_accum |
bool |
False |
- |
matmul_type |
torch_atb.LinearParam.MatmulType |
torch_atb.LinearParam.MatmulType.MATMUL_UNDEFINED |
- |
LinearParam.MatmulType
枚举项:
- MATMUL_UNDEFINED
- MATMUL_EIN_SUM
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import torch import torch_atb def linear(): m, k, n = 3, 4, 5 input_tensor = torch.randn(m, k, dtype=torch.float16).npu() weight_tensor = torch.randn(k, n, dtype=torch.float16).npu() print(input_tensor) print(weight_tensor) linear_param = torch_atb.LinearParam(has_bias = False, transpose_b = False) linear = torch_atb.Operation(linear_param) def linear_run(): linear_outputs = linear.forward([input_tensor, weight_tensor]) return [linear_outputs[0].to(torch.float32)] outputs = linear_run() print("outputs: ", outputs) if __name__ == "__main__": linear() |
父主题: OpParam