昇腾社区首页
中文
注册

RotaryMulGrad算子使用指南

表1 RotaryMulGrad算子基础信息

算子名称

RotaryMulGrad

torch_npu API接口

通过torch_npu.npu_rotary_mul(x, r1, r2)调用反向,无直接调用接口

支持的torch_npu版本

1.11, 2.0, 2.1

支持的昇腾产品

Atlas 训练系列产品Atlas A2 训练系列产品

支持的数据类型

float16,bfloat16,float

算子IR及torch_npu接口参数

  • 算子IR:
    REG_OP(RotaryMulGrad)
        .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .INPUT(r1, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .INPUT(r2, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .ATTR(need_backward, Bool, true)
        .OUTPUT(dx, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .OUTPUT(dr1, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .OUTPUT(dr2, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16}))
        .OP_END_FACTORY_REG(RotaryMulGrad)
  • torch_npu接口:
    output= torch_npu.npu_rotary_mul(x, r1, r2)
    output.backward(torch.ones_like(output))
  • 参数说明:
    • x:q, k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。
    • r1: cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。
    • r2: sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。

模型中替换代码及算子计算逻辑

  • rotarymulgrad为rotarymul的反向,实现自动绑定,外部无法用接口调用,必须通过.backward调用,对应的正向代码为:
    模型中替换代码:
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

    替换为:

    q = torch_npu.npu_rotary_mul(q, cos, sin)
    k = torch_npu.npu_rotary_mul(k, cos, sin)
  • 算子的正向计算逻辑如下:
    x1, x2 = torch.chunk(x, 2, -1)
    x_new = torch.cat((-x2, x1), dim=-1)
    output = r1 * x + r2 * x_new
  • 反向对应的计算图为:
    图1 计算图

使用限制

目前算子仅支持r1, r2需要broadcast为x的shape的情形,广播轴的数据量不能超过1024,且算子shape中最后一维D必须是64的倍数。

已支持模型典型Case

  • case 1:
    • x: [2, 8192, 5, 128]
    • r1: [1, 8192, 1, 128]
    • r2: [1, 8192, 1, 128]
    • dy: [2, 8192, 5, 128]
  • case 2:
    • x: [8192, 2, 5, 128]
    • r1: [8192, 1, 1, 128]
    • r2: [8192, 1, 1, 128]
    • dy: [8192, 2, 5, 128]
  • case 3:
    • x: [2048, 4, 32, 64]
    • r1: [2048, 4, 1, 64]
    • r2: [2048, 4, 1, 64]
    • dy: [2048, 4, 32, 64]]