开发者
资源

InplaceAddRmsNormFusionPass

融合模式

场景一:将AddRmsNorm的y输出地址复用x1输入地址,x输出地址复用x2输入地址,转换为InplaceAddRmsNorm算子。

场景二:当识别到AddRmsNorm的y输出仅接了2个输出且第一个输出接Cast算子(将输出y从float16/bfloat16 Cast到float32),将AddRmsNorm和Cast融合为AddRmsNormCast算子。

使用约束

  • 通用约束:仅限推理场景,不支持训练场景。
  • 融合成AddRmsNormCast约束:
    • 仅支持输入x1为float16/bfloat16。
    • 仅支持cast从float16/bfloat16转float32。
    • 融合前后的输出中都不能包含rstd。
  • 融合成InplaceAddRmsNorm约束:非Atlas 350 加速卡场景下,AddRmsNorm的第二个输出须无后继节点

支持的型号

Atlas A2 训练系列产品/Atlas A2 推理系列产品

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas 推理系列产品

Atlas 350 加速卡