昇腾社区首页
EN
注册

torch_npu.npu_rotary_mul

接口原型

torch_npu.npu_rotary_mul(Tensor input, Tensor r1, Tensor r2, str rotary_mode='half') -> Tensor

在模型训练场景中,正向算子的输入Tensor input将被保留以供反向计算时使用。在r1,r2不需要计算反向梯度场景下(requires_grad=False),使用该接口相较融合前小算子使用的设备内存占用会有所增加。

功能描述

实现RotaryEmbedding旋转位置编码。支持FakeTensor模式。

  • half模式:
    x1, x2 = torch.chunk(input, 2, -1)
    x_new = torch.cat((-x2, x1), dim=-1)
    output = r1 * input + r2 * x_new
  • interleave模式:
    x1 = input[..., ::2]
    x2 = input[..., 1::2]
    x_new = rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ...(d two)", two=2)
    output = r1 * input + r2 * x_new

参数说明

  • input:必选输入,4维Tensor,数据类型float16,bfloat16,float32。
  • r1:必选输入,4维Tensor,数据类型float16,bfloat16,float32。
  • r2:必选输入,4维Tensor,数据类型float16,bfloat16,float32。
  • rotary_mode: 可选属性,数据类型string,用于选择计算模式,支持“half”、“interleave”两种模式。缺省为half。

输出说明

输出为Tensor,shape和dtype同输入Tensor input。

约束说明

  • jit_compile=False场景(适用Atlas A2 训练系列产品Atlas A3 训练系列产品):
    • half模式:

      input:layout支持:BNSD、BSND、SBND;D < 896,且为2的倍数;B, N < 1000;当需要计算cos/sin的反向梯度时,B*N <= 1024。

      r1、r2:数据范围:[-1, 1];对应input layout的支持情况:

      • x为BNSD: 11SD、B1SD、BNSD;
      • x为BSND: 1S1D、BS1D、BSND;
      • x为SBND: S11D、SB1D、SBND。

        half模式下,当输入layout是BNSD,且D为非32Bytes对齐时,建议不使用该融合算子(模型启动脚本中不开启--use-fused-rotary-pos-emb选项),否则可能出现性能下降。

    • interleave模式:

      input:layout支持:BNSD、BSND、SBND; B * N < 1000; D < 896, 且D为2的倍数;

      r1、r2:数据范围:[-1, 1];对应input layout的支持情况:

      • x为BNSD: 11SD;
      • x为BSND: 1S1D;
      • x为SBND: S11D
  • jit_compile=True场景(适用Atlas 训练系列产品Atlas A2 训练系列产品Atlas 推理系列产品):

    仅支持rotary_mode为half模式,且r1/r2 layout一般为11SD、1S1D、S11D。

    shape要求输入为4维,其中B维度和N维度数值需小于等于1000,D维度数值为128。

    广播场景下,广播轴的总数据量不能超过1024。

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品
  • Atlas A3 训练系列产品
  • Atlas 推理系列产品

调用示例

1
2
3
4
5
6
import torch
import torch_npu
x = torch.rand(2, 2, 5, 128).npu()
r1 = torch.rand(1, 2, 1, 128).npu()
r2 = torch.rand(1, 2, 1, 128).npu()
out = torch_npu.npu_rotary_mul(x, r1, r2)