昇腾社区首页
中文
注册
开发者
下载

BatchMatMul2TransposeBatchMatMulFusionPass

融合模式

场景一:

当Transpose(可选)、BatchMatMul、Transpose这几个节点按下图所示顺序连接时,可融合为TransposeBatchMatmul算子节点。

场景二:

当BatchMatMul/BatchMatMulV2节点的x1输入shape为[B2 B1 1 K], x2输入shape为[1 B1 K N]/[1 B1 N K]时,会插入reshape节点,将输入合轴为x1[B2 B1 K],x2 [B1 K N]/ [B1 N K],然后调用TransposeBatchMatMul节点,生成计算结果为[B2 B1 N],再插入一个reshape节点,将输出的运算结果转化为[B2 B1 1 N]。

使用约束

场景一:

  • x1的输入为[B, M, K],x2的输入为[B, K, N],对应的BatchMatmul的属性为adj_x1=false,adj_x2=false。
  • 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:BFLOAT16,FLOAT16,FLOAT32。
  • 当输入数据类型为BFLOAT16,FLOAT16时,k和n向128对齐,需要满足B*K<65536或者B*K>=65536,k<65536。
  • 当输入数据类型为FLOAT32且输入的transpose节点不为空时,无对齐限制,但仍需满足B*K<65536。
  • 输入输出支持的数据格式:ND。

场景二:

  • x1的输入为[B2, B1, 1, K],x2的输入为[1, B1, K, N] 或者 [1, B1, N, K],对应的BatchMatmul的属性为adj_x1=false,adj_x2=false/true。
  • 输入需要满足B1*K<65536。
  • 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:FLOAT32。
  • 输入输出支持的数据格式:ND。

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas A3 训练系列产品/Atlas A3 推理系列产品