SoftmaxParam
属性 |
类型 |
默认值 |
描述 |
---|---|---|---|
axes |
List[int] |
list() |
- |
约束说明
axes的长度大小需小于DEFAULT_SVECTOR_SIZE,否则会抛出长度溢出异常。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import torch import torch_atb def softmax(): input_npu = torch.randn(2, 16, 256, dtype=torch.float32).npu() print("input: ", input_npu) softmax_param = torch_atb.SoftmaxParam(axes = [0]) softmax = torch_atb.Operation(softmax_param) def softmax_run(): softmax_outputs = softmax.forward([input_npu]) return softmax_outputs outputs = softmax_run() print("outputs: ", outputs) if __name__ == "__main__": softmax() |
父主题: OpParam