GatherParam
属性 |
类型 |
默认值 |
描述 |
---|---|---|---|
axis |
int |
0 |
- |
batch_dims |
int |
0 |
- |
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import torch import torch_atb def gather(): gather_param = torch_atb.GatherParam(axis = 1) gather = torch_atb.Operation(gather_param) intensor0 = torch.randn([3,5],dtype=torch.float16) intensor0_npu = intensor0.npu() print("intensor0: ", intensor0) intensor1 = torch.randint(0, 5, [3,4],dtype=torch.int64) intensor1_npu = intensor1.npu() print("intensor1: ", intensor1) def gather_run(): gather_outputs = gather.forward([intensor0_npu, intensor1_npu]) return gather_outputs outputs = gather_run() print("outputs: ", outputs) if __name__ == "__main__": gather() |
父主题: OpParam