昇腾社区首页
中文
注册

AddLayerNormFusionPass

融合模式

  • 场景一:基础场景,将Add + Cast(可选)+ LayerNorm融合为AddLayerNorm算子。

  • 场景二:支持将Add + Add + Cast(可选)+ LayerNorm融合为AddLayerNorm算子。

  • 场景三:支持将Cast + Add + LayerNorm融合为AddLayerNorm算子。

使用约束

  • 数据类型约束:
    • 场景一只支持如下输入数据类型组合:

      x1

      x2

      gamma

      beta

      FLOAT16

      FLOAT16

      FLOAT16

      FLOAT16

      BFLOAT16

      BFLOAT16

      BFLOAT16

      BFLOAT16

      FLOAT16

      FLOAT16

      FLOAT32

      FLOAT32

      BFLOAT16

      BFLOAT16

      FLOAT32

      FLOAT32

    • 场景二只支持如下输入数据类型组合:

      x1

      x2

      gamma

      beta

      bias(可选)

      FLOAT16

      FLOAT16

      FLOAT16

      FLOAT16

      FLOAT16

      BFLOAT16

      BFLOAT16

      BFLOAT16

      BFLOAT16

      BFLOAT16

      FLOAT16

      FLOAT16

      FLOAT32

      FLOAT32

      FLOAT16

      BFLOAT16

      BFLOAT16

      FLOAT32

      FLOAT32

      BFLOAT16

    • 场景三只支持如下输入数据类型组合:

      x1

      x2

      gamma

      beta

      FLOAT16

      FLOAT32

      FLOAT32

      FLOAT32

      FLOAT32

      FLOAT16

      FLOAT32

      FLOAT32

      BFLOAT16

      FLOAT32

      FLOAT32

      FLOAT32

      FLOAT32

      BFLOAT16

      FLOAT32

      FLOAT32

    • 其他数据类型约束:

      训练场景下,不支持输入x1或x2的数据类型为FLOAT16。

  • shape约束:
    • 输入x1和x2必须有相同shape。
    • 输入gamma和beta的shape必须为1维,并且shape取值和x1,x2输入shape的尾轴相同,即:gamma.shape = beta.shape = [x1.shape[-1]]
  • Cast算子和Add算子的约束:

    如果融合模式为场景三,Cast算子的输出类型必须是FLOAT32,且Add算子后必须有x输出。

  • 对于精度的影响:

    当融合模式为场景一或者场景二,且LayerNorm和Add之间存在Cast场景。

    如果Cast输入类型为FLOAT16/BFLOAT16,输出类型为FLOAT32时,使用此融合规则AddLayerNorm输出y的数据类型为FLOAT16/BFLOAT16,相比融合前存在精度降低的情况。

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件