昇腾社区首页
中文
注册

RopeParam

属性

类型

默认值

描述

rotary_coeff

int

4

-

cos_format

int

0

-

调用示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch_atb

def rope():
    rope_param = torch_atb.RopeParam(rotary_coeff = 4)
    rope = torch_atb.Operation(rope_param)

    def gen_inputs():
        ntoken = 4
        seqlen = 4
        hidden_size = 4096
        head_size = 128
        intensor0 = torch.rand(ntoken, hidden_size).half()
        intensor1 = torch.rand(ntoken, hidden_size).half()
        intensor2 = torch.rand(ntoken, head_size).half()
        intensor3 = torch.rand(ntoken, head_size).half()
        intensor4 = torch.tensor([seqlen], dtype=torch.int32)
        return [intensor0, intensor1, intensor2, intensor3, intensor4]

    in_tensors = gen_inputs()
    in_tensors_npu = [tensor.npu() for tensor in in_tensors]
    print("in_tensors: ", in_tensors_npu)

    def rope_run():
        rope_outputs = rope.forward(in_tensors_npu)
        return rope_outputs

    outputs = rope_run()
    print("outputs: ", outputs)

if __name__ == "__main__":
    rope()