算子Groupedmatmul和add融合,实现分组矩阵乘+InplaceAdd功能。
分组实现矩阵乘计算,并在输出地址原地累加。每组矩阵乘的维度中ki各不相同,x_i/weight_i可以在k_i上拼接。算子数学表达式为:
其中g为分组个数,m_i/k_i/n_i为对应shape。
1 2 3 4 5 | struct GroupedMatmulInplaceAddParam { bool transposeA = false; bool transposeB = false; uint8_t rsv[22] = {0}; }; |
成员名称 |
类型 |
默认值 |
取值范围 |
描述 |
---|---|---|---|---|
transposeA |
bool |
False |
false/true |
代表矩阵A是否随路转置。 |
transposeB |
bool |
False |
false |
代表矩阵B是否随路转置,当前仅支持False。 |
rsv[22] |
uint8_t |
{0} |
- |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
是否必选 |
描述 |
---|---|---|---|---|---|
x |
[m, k] |
float16/bf16 |
ND |
是 |
矩阵乘的A矩阵,m取值范围(0,65535]。 |
weight |
[k, n] |
float16/bf16 |
ND |
是 |
权重,矩阵乘的B矩阵,n取值范围(0,65535]。数据类型与x保持一致。 |
groupList |
[numGroup] |
int64 |
ND |
是 |
1维Tensor,代表矩阵乘分为多少个组,最大长度为128。 |
out |
[numGroup * m, n]或[m, numGroup * n] |
float |
ND |
是 |
既是输入tensor又是输出tensor,支持原地写,numGroup最大为128。 |
参数 |
维度 |
数据类型 |
格式 |
是否必选 |
描述 |
---|---|---|---|---|---|
out |
[numGroup * m, n]或[m, numGroup * n] |
float |
ND |
是 |
输出。 |
例如:矩阵A维度为[5, 10],矩阵B维度为[10, 4],groupList维度为1维张量,numGroup为3,值分别为[3, 6, 10],表示把k分为3组,分别为[0, 3)、[3, 6)、[6, 10),分组进行矩阵乘。