torch_npu.npu_rotary_mul

接口原型

torch_npu.npu_rotary_mul(Tensor input, Tensor r1, Tensor r2, str rotary_mode='half') -> Tensor

在模型训练场景中,正向算子的输入Tensor input将被保留以供反向计算时使用。在r1,r2不需要计算反向梯度场景下(requires_grad=False),使用该接口相较融合前小算子使用的设备内存占用会有所增加。

功能描述

实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。

参数说明

输出说明

输出为Tensor,shape和dtype同输入Tensor input。

约束说明

支持的型号

调用示例

1
2
3
4
5
6
import torch
import torch_npu
x = torch.rand(2, 2, 5, 128).npu()
r1 = torch.rand(1, 2, 1, 128).npu()
r2 = torch.rand(1, 2, 1, 128).npu()
out = torch_npu.npu_rotary_mul(x, r1, r2)