RotaryMul算子使用指南
算子名称 |
RotaryMul |
---|---|
torch_npu API接口 |
torch_npu.npu_rotary_mul(x, r1, r2) |
支持的torch_npu版本 |
1.11, 2.0, 2.1 |
支持的昇腾产品 |
Atlas 训练系列产品、Atlas A2 训练系列产品 |
支持的数据类型 |
float16,bfloat16,float |
算子IR及torch_npu接口参数
- 算子IR:
REG_OP(RotaryMul) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .INPUT(r1, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .INPUT(r2, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_BFLOAT16})) .OP_END_FACTORY_REG(RotaryMul)
- torch_npu接口:
torch_npu.npu_rotary_mul(x, r1, r2) 1
- 参数说明:
- x:q, k,shape要求输入为4维,一般为[B, N, S, D]或[B, S, N, D]或[S, B, N, D]。
- r1: cos值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。
- r2: sin值,shape要求输入为4维,一般为[1, 1, S, D]或[1, S, 1, D]或[S, 1, 1, D]。
模型中替换代码及算子计算逻辑
- 模型中替换代码:
def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2: ] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def apply_fused_rotary_pos_emb(q, k, cos, sin, offset: int = 0): return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin)
将以下代码
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
替换为:
q = torch_npu.npu_rotary_mul(q, cos, sin) k = torch_npu.npu_rotary_mul(k, cos, sin)
- 算子的计算逻辑如下:
x1, x2 = torch.chunk(x, 2, -1) x_new = torch.cat((-x2, x1), dim=-1) output = r1 * x + r2 * x_new
- 计算流程图为:
图1 流程图
算子替换的模型中小算子
使用限制
目前算子仅支持r1, r2需要broadcast为x的shape的情形,且算子shape中最后一维D必须是128的倍数。
已支持模型典型Case
case 1:
- x: [1, 13, 2048, 128]
- r1: [1, 1, 2048, 128]
- r2: [1, 1, 2048, 128]
父主题: 融合算子调优