BatchMatMul2TransposeBatchMatMulFusionPass
融合模式
场景一:
当Transpose(可选)、BatchMatMul、Transpose这几个节点按下图所示顺序连接时,可融合为TransposeBatchMatmul算子节点。
场景二:
当BatchMatMul/BatchMatMulV2节点的x1输入shape为[B2 B1 1 K], x2输入shape为[1 B1 K N]/[1 B1 N K]时,会插入reshape节点,将输入合轴为x1[B2 B1 K],x2 [B1 K N]/ [B1 N K],然后调用TransposeBatchMatMul节点,生成计算结果为[B2 B1 N],再插入一个reshape节点,将输出的运算结果转化为[B2 B1 1 N]。
使用约束
场景一:
- x1的输入为[B, M, K],x2的输入为[B, K, N],对应的BatchMatmul的属性为adj_x1=false,adj_x2=false。
- 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:BFLOAT16,FLOAT16,FLOAT32。
- 当输入数据类型为BFLOAT16,FLOAT16时,k和n向128对齐,需要满足B*K<65536或者B*K>=65536,k<65536。
- 当输入数据类型为FLOAT32且输入的transpose节点不为空时,无对齐限制,但仍需满足B*K<65536。
- 输入输出支持的数据格式:ND。
场景二:
- x1的输入为[B2, B1, 1, K],x2的输入为[1, B1, K, N] 或者 [1, B1, N, K],对应的BatchMatmul的属性为adj_x1=false,adj_x2=false/true。
- 输入需要满足B1*K<65536。
- 输入x1、x2与输出out(输入与输出要对应)支持的数据类型:FLOAT32。
- 输入输出支持的数据格式:ND。
支持的型号
父主题: 图融合规则说明