RopeParam
属性 |
类型 |
默认值 |
描述 |
---|---|---|---|
rotary_coeff |
int |
4 |
- |
cos_format |
int |
0 |
- |
调用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch import torch_atb def rope(): rope_param = torch_atb.RopeParam(rotary_coeff = 4) rope = torch_atb.Operation(rope_param) def gen_inputs(): ntoken = 4 seqlen = 4 hidden_size = 4096 head_size = 128 intensor0 = torch.rand(ntoken, hidden_size).half() intensor1 = torch.rand(ntoken, hidden_size).half() intensor2 = torch.rand(ntoken, head_size).half() intensor3 = torch.rand(ntoken, head_size).half() intensor4 = torch.tensor([seqlen], dtype=torch.int32) return [intensor0, intensor1, intensor2, intensor3, intensor4] in_tensors = gen_inputs() in_tensors_npu = [tensor.npu() for tensor in in_tensors] print("in_tensors: ", in_tensors_npu) def rope_run(): rope_outputs = rope.forward(in_tensors_npu) return rope_outputs outputs = rope_run() print("outputs: ", outputs) if __name__ == "__main__": rope() |
父主题: OpParam