SoftmaxParam
属性  | 
类型  | 
默认值  | 
描述  | 
|---|---|---|---|
axes  | 
List[int]  | 
list()  | 
-  | 
约束说明
axes的长度大小需小于DEFAULT_SVECTOR_SIZE,否则会抛出长度溢出异常。
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18  | import torch import torch_atb def softmax(): input_npu = torch.randn(2, 16, 256, dtype=torch.float32).npu() print("input: ", input_npu) softmax_param = torch_atb.SoftmaxParam(axes = [0]) softmax = torch_atb.Operation(softmax_param) def softmax_run(): softmax_outputs = softmax.forward([input_npu]) return softmax_outputs outputs = softmax_run() print("outputs: ", outputs) if __name__ == "__main__": softmax()  | 
父主题: OpParam