RotaryMulGrad算子使用指南

表1 RotaryMulGrad算子基础信息

算子名称

RotaryMulGrad

torch_npu API接口

通过torch_npu.npu_rotary_mul(x, r1, r2)调用反向,无直接调用接口

支持的torch_npu版本

1.11, 2.0, 2.1

支持的昇腾产品

Atlas 训练系列产品Atlas A2 训练系列产品

支持的数据类型

float16,bfloat16,float

算子IR及torch_npu接口参数

模型中替换代码及算子计算逻辑

使用限制

目前算子仅支持r1, r2需要broadcast为x的shape的情形,广播轴的数据量不能超过1024,且算子shape中最后一维D必须是64的倍数。

已支持模型典型Case