开发者
资源

BatchMatMul2TransposeBatchMatMulFusionPass

融合模式

模式一:

当Transpose(可选)、BatchMatMul、Transpose这几个节点按下图所示顺序连接时,可融合为TransposeBatchMatmul算子节点。

如下融合模式支持的型号为:
  • Atlas A2 训练系列产品/Atlas A2 推理系列产品
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas 350 加速卡场景下,融合模式如下。

模式二:

当BatchMatMul/BatchMatMulV2节点的x1输入shape为[B2,B1,1,K], x2输入shape为[1,B1,K,N]/[1,B1,N,K]时,会插入reshape节点,将输入合轴为x1[B2,B1,K],x2 [B1,K,N]/ [B1,N,K],然后调用TransposeBatchMatMul节点,生成计算结果为[B2,B1,N],再插入一个reshape节点,将输出的运算结果转化为[B2,B1,1,N]。

模式三:

当Transpose(可选)、Transpose(可选)、BatchMatMul、Transpose、Reshape、Reshape、Transpose这几个节点按下图所示顺序连接时,可融合为TransposeBatchMatmul算子节点。

使用约束

模式一:

  • 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:BFLOAT16,FLOAT16,FLOAT32。
  • 输入输出支持的数据格式:ND。
  • x1输入支持[B,M,K]或者[M,B,K],x2输入只支持[B,K,N],对应的TransposeBatchMatMul的属性为perm_x1=[0,1,2]/[1,0,2],perm_x2=[0,1,2]。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • x1的输入为[B,M,K]或者[M,B,K],x2的输入为[B,K,N]或者[B,N,K],对应的TransposeBatchMatMul的属性为perm_x1=[0,1,2]/[1,0,2],perm_x2=[0,1,2]/[0,2,1]。该约束仅适用于如下型号:
    • Atlas 350 加速卡

模式二:

  • x1的输入为[B2,B1,1,K],x2的输入为[1,B1,K,N] 或者 [1,B1,N,K],对应的BatchMatmul的属性为adj_x1=false,adj_x2=false/true。
  • 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:FLOAT32。
  • 输入输出支持的数据格式:ND。
  • 该融合规则不支持使能HFLOAT32。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • 输入需要满足B1*K<65536。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • 当输入数据类型为BFLOAT16,FLOAT16时,k和n向128对齐,需要满足B*K<65536或者B*K>=65536,k<65536。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • 当输入数据类型为FLOAT32且输入的transpose节点不为空时,无对齐限制,但仍需满足B*K<65536。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品

模式三:

  • out的输出为[B1,M,B/B1*N],B%B1=0。
  • 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:BFLOAT16,FLOAT16,FLOAT32。
  • 输入输出支持的数据格式:ND。
  • x1输入支持[B, M, K]或者[M,B,K],x2输入只支持[B,K,N],对应的TransposeBatchMatMul的属性为perm_x1=[0,1,2]/[1,0,2],perm_x2=[0,1,2]。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • x1的输入为[B,M,K]或者[M,B,K],x2的输入为[B,K,N]或者[B,N,K],对应的TransposeBatchMatMul的属性为perm_x1=[0,1,2]/[1,0,2],perm_x2=[0,1,2]/[0,2,1]。该约束仅适用于如下型号:

    Atlas 350 加速卡

  • 当输入数据类型为BFLOAT16,FLOAT16时,k和n向128对齐,需要满足B*K<65536或者B*K>=65536,k<65536。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • 当输入数据类型为FLOAT32且输入的transpose节点不为空时,无对齐限制,但仍需满足B*K<65536。该约束仅适用于如下型号:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品

支持的型号

Atlas A2 训练系列产品/Atlas A2 推理系列产品

Atlas A3 训练系列产品/Atlas A3 推理系列产品

Atlas 350 加速卡