torch_npu.npu_rotary_mul(Tensor x, Tensor r1, Tensor r2): -> Tensor
x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new
输出为Tensor,shape和dtype同输入Tensor x。
shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。
广播场景下,广播轴的总数据量不能超过1024。
import torch import torch_npu x = torch.rand(2, 2, 5, 128).npu() r1 = torch.rand(1, 2, 1, 128).npu() r2 = torch.rand(1, 2, 1, 128).npu() out = torch_npu.npu_rotary_mul(x, r1, r2)