昇腾社区首页
中文
注册

ElewiseParam

属性

类型

默认值

描述

elewise_type

torch_atb.ElewiseParam.ElewiseType

torch_atb.ElewiseParam.ElewiseType.ELEWISE_UNDEFINED

此默认类型不可用,用户需配置此项参数。

quant_param

torch_atb.ElewiseParam.QuantParam

-

-

muls_param

torch_atb.ElewiseParam. MulsParam

-

-

out_tensor_type

torch_atb.AclDataType

torch_atb.AclDataType.ACL_DT_UNDEFINED

根据输入tensors自动推导输出tensors数据类型。

ElewiseParam.ElewiseType

枚举项:

  • ELEWISE_UNDEFINED
  • ELEWISE_CAST
  • ELEWISE_MULS
  • ELEWISE_COS
  • ELEWISE_SIN
  • ELEWISE_SUB
  • ELEWISE_EQUAL
  • ELEWISE_NEG
  • ELEWISE_QUANT
  • ELEWISE_LOGICAL_NOT
  • ELEWISE_ADD
  • ELEWISE_MUL
  • ELEWISE_QUANT_PER_CHANNEL
  • ELEWISE_DEQUANT_PER_CHANNEL
  • ELEWISE_REALDIV
  • ELEWISE_LOGICAL_AND
  • ELEWISE_LOGICAL_OR
  • ELEWISE_LESS
  • ELEWISE_GREATER
  • ELEWISE_DYNAMIC_QUANT
  • ELEWISE_TANH

ElewiseParam.QuantParam

属性

类型

默认值

描述

input_scale

float

1.0

-

asymmetric

bool

False

-

input_offset

int

0

-

ElewiseParam.MulsParam

属性

类型

默认值

描述

var_attr

float

0.0

-

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch_atb  

def elewise_add():
    input1_npu = torch.randn(2, 3, dtype=torch.float16).npu()
    input2_npu = torch.randn(2, 3, dtype=torch.float16).npu()
    print("input1: ", input1_npu)
    print("input2: ", input2_npu)
    elewise_param = torch_atb.ElewiseParam(elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD)
    elewise = torch_atb.Operation(elewise_param)

    def elewise_run():
        elewise_outputs = elewise.forward([input1_npu, input2_npu])
        return elewise_outputs

    outputs = elewise_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    elewise_add()