昇腾社区首页
中文
注册

SoftmaxParam

属性

类型

默认值

描述

axes

List[int]

list()

-

约束说明

axes的长度大小需小于DEFAULT_SVECTOR_SIZE,否则会抛出长度溢出异常。

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import torch
import torch_atb  

def softmax():
    input_npu = torch.randn(2, 16, 256, dtype=torch.float32).npu()
    print("input: ", input_npu)
    softmax_param = torch_atb.SoftmaxParam(axes = [0])
    softmax = torch_atb.Operation(softmax_param)

    def softmax_run():
        softmax_outputs = softmax.forward([input_npu])
        return softmax_outputs

    outputs = softmax_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    softmax()