GroupedMatmulInplaceAddOperation(部分代码开放)
产品支持情况
硬件型号  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能
算子Groupedmatmul和add融合,实现分组矩阵乘+InplaceAdd功能。
计算图
分组实现矩阵乘计算,并在输出地址原地累加。每组矩阵乘的维度中ki各不相同,x_i/weight_i可以在k_i上拼接。算子数学表达式为:
其中g为分组个数,m_i/k_i/n_i为对应shape。

定义
1 2 3 4 5  | struct GroupedMatmulInplaceAddParam { bool transposeA = false; bool transposeB = false; uint8_t rsv[22] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
描述  | 
|---|---|---|---|---|
transposeA  | 
bool  | 
False  | 
false/true  | 
代表矩阵A是否随路转置。  | 
transposeB  | 
bool  | 
False  | 
false  | 
代表矩阵B是否随路转置,当前仅支持False。  | 
rsv[22]  | 
uint8_t  | 
{0}  | 
-  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
x  | 
[m, k]  | 
float16/bf16  | 
ND  | 
是  | 
矩阵乘的A矩阵,m取值范围(0,65535]。  | 
weight  | 
[k, n]  | 
float16/bf16  | 
ND  | 
是  | 
权重,矩阵乘的B矩阵,n取值范围(0,65535]。数据类型与x保持一致。  | 
groupList  | 
[numGroup]  | 
int64  | 
ND  | 
是  | 
1维Tensor,代表矩阵乘分为多少个组,最大长度为128。  | 
out  | 
[numGroup * m, n]或[m, numGroup * n]  | 
float  | 
ND  | 
是  | 
既是输入tensor又是输出tensor,支持原地写,numGroup最大为128。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
out  | 
[numGroup * m, n]或[m, numGroup * n]  | 
float  | 
ND  | 
是  | 
输出。  | 
规格约束
- transposeB当前仅支持False,transposeA的取值以及矩阵A/B的维度需要满足矩阵乘的维度约束。
 - 输出tensor与最后一个输入tensor为同一个tensor。
 - groupList 为长度小于等于128的一维张量,代表A/B在k轴上分组的索引,groupList 中的后一个元素应大于等于前一个元素,最后一个元素大小等于k。
例如:矩阵A维度为[5, 10],矩阵B维度为[10, 4],groupList维度为1维张量,numGroup为3,值分别为[3, 6, 10],表示把k分为3组,分别为[0, 3)、[3, 6)、[6, 10),分组进行矩阵乘。