算子名称 |
RotaryMulGrad |
---|---|
torch_npu API接口 |
通过torch_npu.npu_rotary_mul(x, r1, r2)调用反向,无直接调用接口 |
支持的torch_npu版本 |
1.11, 2.0, 2.1 |
支持的昇腾产品 |
Atlas 训练系列产品、Atlas A2 训练系列产品 |
支持的数据类型 |
float16,bfloat16,float |
REG_OP(RotaryMulGrad) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .INPUT(r1, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .INPUT(r2, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .ATTR(need_backward, Bool, true) .OUTPUT(dx, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .OUTPUT(dr1, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .OUTPUT(dr2, TensorType({DT_FLOAT16, DT_FLOAT32, DT_BFLOAT16})) .OP_END_FACTORY_REG(RotaryMulGrad)
output= torch_npu.npu_rotary_mul(x, r1, r2) output.backward(torch.ones_like(output))
替换为:
q = torch_npu.npu_rotary_mul(q, cos, sin) k = torch_npu.npu_rotary_mul(k, cos, sin)
x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new
目前算子仅支持r1, r2需要broadcast为x的shape的情形,广播轴的数据量不能超过1024,且算子shape中最后一维D必须是64的倍数。