算子名称 |
RotaryMul & RotaryMulGrad |
---|---|
torch_npu api接口 |
torch_npu.npu_rotary_mul(x, r1, r2) |
支持的torch_npu版本 |
1.11.0, 2.1.0, 2.2.0 |
支持的芯片类型 |
Atlas 训练系列产品,Atlas A2 训练系列产品 |
支持的数据类型 |
float16, bfloat16, float |
REG_OP(RotaryMul) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .INPUT(r1, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .INPUT(r2, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .OP_END_FACTORY_REG(RotaryMul)
torch_npu.npu_rotary_mul(x, r1, r2)
模型中替换代码截图参见图1,上方红框里的内容为模型源码,下方红框里的内容为替换的新接口。
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
q = torch_npu.npu_rotary_mul(q, cos, sin) k = torch_npu.npu_rotary_mul(k, cos, sin)
x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new
算子shape中最后一维必须是128的倍数。