torch_npu.npu_rotary_mul
接口原型
torch_npu.npu_rotary_mul(Tensor input, Tensor r1, Tensor r2, str rotary_mode='half') -> Tensor

在模型训练场景中,正向算子的输入Tensor input将被保留以供反向计算时使用。在r1,r2不需要计算反向梯度场景下(requires_grad=False),使用该接口相较融合前小算子使用的设备内存占用会有所增加。
功能描述
实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。
- half模式:
x1, x2 = torch.chunk(input, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * input + r2 * x_new
- interleave模式:
x1 = input[..., ::2] x2 = input[..., 1::2] x_new = rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ...(d two)", two=2) output = r1 * input + r2 * x_new
参数说明
- input:必选输入,4维Tensor,数据类型float16,bfloat16,float32。
- r1:必选输入,4维Tensor,数据类型float16,bfloat16,float32。
- r2:必选输入,4维Tensor,数据类型float16,bfloat16,float32。
- rotary_mode: 可选属性,数据类型string,用于选择计算模式,支持“half”、“interleave”两种模式。缺省为half。
输出说明
输出为Tensor,shape和dtype同输入Tensor input。
约束说明
- jit_compile=False场景(适用
Atlas A2 训练系列产品 ,Atlas A3 训练系列产品 ):- half模式:
input:layout支持:BNSD、BSND、SBND;D < 896,且为2的倍数;B, N < 1000;当需要计算cos/sin的反向梯度时,B*N <= 1024。
r1、r2:数据范围:[-1, 1];对应input layout的支持情况:
- x为BNSD: 11SD、B1SD、BNSD;
- x为BSND: 1S1D、BS1D、BSND;
- x为SBND: S11D、SB1D、SBND。
half模式下,当输入layout是BNSD,且D为非32Bytes对齐时,建议不使用该融合算子(模型启动脚本中不开启--use-fused-rotary-pos-emb选项),否则可能出现性能下降。
- half模式:
- jit_compile=True场景(适用
Atlas 训练系列产品 ,Atlas A2 训练系列产品 ,Atlas 推理系列产品 ):仅支持rotary_mode为half模式,且r1/r2 layout一般为11SD、1S1D、S11D。
shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。
广播场景下,广播轴的总数据量不能超过1024。
支持的型号
Atlas 训练系列产品 Atlas A2 训练系列产品 Atlas A3 训练系列产品 Atlas 推理系列产品
调用示例
1 2 3 4 5 6 | import torch import torch_npu x = torch.rand(2, 2, 5, 128).npu() r1 = torch.rand(1, 2, 1, 128).npu() r2 = torch.rand(1, 2, 1, 128).npu() out = torch_npu.npu_rotary_mul(x, r1, r2) |
父主题: torch_npu