昇腾社区首页
中文
注册

LinearParam

属性

类型

默认值

描述

transpose_a

bool

False

-

transpose_b

bool

True

-

has_bias

bool

True

-

out_data_type

torch_atb.AclDataType

torch_atb.AclDataType.ACL_DT_UNDEFINED

根据输入tensors自动推导输出tensors数据类型。

en_accum

bool

False

-

matmul_type

torch_atb.LinearParam.MatmulType

torch_atb.LinearParam.MatmulType.MATMUL_UNDEFINED

-

LinearParam.MatmulType

枚举项:

  • MATMUL_UNDEFINED
  • MATMUL_EIN_SUM

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch_atb

def linear():
    m, k, n = 3, 4, 5
    input_tensor = torch.randn(m, k, dtype=torch.float16).npu()
    weight_tensor = torch.randn(k, n, dtype=torch.float16).npu()
    print(input_tensor)
    print(weight_tensor)

    linear_param = torch_atb.LinearParam(has_bias = False, transpose_b = False)
    linear = torch_atb.Operation(linear_param)

    def linear_run():
        linear_outputs = linear.forward([input_tensor, weight_tensor])
        return [linear_outputs[0].to(torch.float32)]


    outputs = linear_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    linear()