GroupedMatmulInplaceAddOperation

功能

算子Groupedmatmul和add融合,实现分组矩阵乘+InplaceAdd功能。

计算图

分组实现矩阵乘计算,并在输出地址原地累加。每组矩阵乘的维度中ki各不相同,x_i/weight_i可以在k_i上拼接。算子数学表达式为:

其中g为分组个数,m_i/k_i/n_i为对应shape。

图1 计算示意图

定义

1
2
3
4
5
struct GroupedMatmulInplaceAddParam {
    bool transposeA = false;
    bool transposeB = false;
    uint8_t rsv[22] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

描述

transposeA

bool

False

false/true

代表矩阵A是否随路转置。

transposeB

bool

False

false

代表矩阵B是否随路转置,当前仅支持False。

rsv[22]

uint8_t

{0}

-

预留参数。

输入

参数

维度

数据类型

格式

是否必选

描述

x

[m, k]

float16/bf16

ND

矩阵乘的A矩阵,m取值范围(0,65535]。

weight

[k, n]

float16/bf16

ND

权重,矩阵乘的B矩阵,n取值范围(0,65535]。数据类型与x保持一致。

groupList

[numGroup]

int64

ND

1维Tensor,代表矩阵乘分为多少个组,最大长度为128。

out

[numGroup * m, n]或[m, numGroup * n]

float

ND

既是输入tensor又是输出tensor,支持原地写,numGroup最大为128。

输出

参数

维度

数据类型

格式

是否必选

描述

out

[numGroup * m, n]或[m, numGroup * n]

float

ND

输出。

规格约束