昇腾社区首页
中文
注册

GatherParam

属性

类型

默认值

描述

axis

int

0

-

batch_dims

int

0

-

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch_atb  

def gather():
    gather_param = torch_atb.GatherParam(axis = 1)
    gather = torch_atb.Operation(gather_param)
    intensor0 = torch.randn([3,5],dtype=torch.float16)
    intensor0_npu = intensor0.npu()
    print("intensor0: ", intensor0)
    intensor1 = torch.randint(0, 5, [3,4],dtype=torch.int64)
    intensor1_npu = intensor1.npu()
    print("intensor1: ", intensor1)

    def gather_run():
        gather_outputs = gather.forward([intensor0_npu, intensor1_npu])
        return gather_outputs

    outputs = gather_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    gather()