LinearParam
属性  | 
类型  | 
默认值  | 
描述  | 
|---|---|---|---|
transpose_a  | 
bool  | 
False  | 
-  | 
transpose_b  | 
bool  | 
True  | 
-  | 
has_bias  | 
bool  | 
True  | 
-  | 
out_data_type  | 
torch_atb.AclDataType  | 
torch_atb.AclDataType.ACL_DT_UNDEFINED  | 
根据输入tensors自动推导输出tensors数据类型。  | 
en_accum  | 
bool  | 
False  | 
-  | 
matmul_type  | 
torch_atb.LinearParam.MatmulType  | 
torch_atb.LinearParam.MatmulType.MATMUL_UNDEFINED  | 
-  | 
LinearParam.MatmulType
枚举项:
- MATMUL_UNDEFINED
 - MATMUL_EIN_SUM
 
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23  | import torch import torch_atb def linear(): m, k, n = 3, 4, 5 input_tensor = torch.randn(m, k, dtype=torch.float16).npu() weight_tensor = torch.randn(k, n, dtype=torch.float16).npu() print(input_tensor) print(weight_tensor) linear_param = torch_atb.LinearParam(has_bias = False, transpose_b = False) linear = torch_atb.Operation(linear_param) def linear_run(): linear_outputs = linear.forward([input_tensor, weight_tensor]) return [linear_outputs[0].to(torch.float32)] outputs = linear_run() print("outputs: ", outputs) if __name__ == "__main__": linear()  | 
父主题: OpParam