昇腾社区首页
中文
注册

GroupedMatmulInplaceAddOperation(部分代码开放)

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

算子Groupedmatmul和add融合,实现分组矩阵乘+InplaceAdd功能。

计算图

分组实现矩阵乘计算,并在输出地址原地累加。每组矩阵乘的维度中ki各不相同,x_i/weight_i可以在k_i上拼接。算子数学表达式为:

其中g为分组个数,m_i/k_i/n_i为对应shape。

图1 计算示意图

定义

1
2
3
4
5
struct GroupedMatmulInplaceAddParam {
    bool transposeA = false;
    bool transposeB = false;
    uint8_t rsv[22] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

描述

transposeA

bool

False

false/true

代表矩阵A是否随路转置。

transposeB

bool

False

false

代表矩阵B是否随路转置,当前仅支持False。

rsv[22]

uint8_t

{0}

-

预留参数。

输入

参数

维度

数据类型

格式

是否必选

描述

x

[m, k]

float16/bf16

ND

矩阵乘的A矩阵,m取值范围(0,65535]。

weight

[k, n]

float16/bf16

ND

权重,矩阵乘的B矩阵,n取值范围(0,65535]。数据类型与x保持一致。

groupList

[numGroup]

int64

ND

1维Tensor,代表矩阵乘分为多少个组,最大长度为128。

out

[numGroup * m, n]或[m, numGroup * n]

float

ND

既是输入tensor又是输出tensor,支持原地写,numGroup最大为128。

输出

参数

维度

数据类型

格式

是否必选

描述

out

[numGroup * m, n]或[m, numGroup * n]

float

ND

输出。

约束说明

  • transposeB当前仅支持False,transposeA的取值以及矩阵A/B的维度需要满足矩阵乘的维度约束。
  • 输出tensor与最后一个输入tensor为同一个tensor。
  • groupList 为长度小于等于128的一维张量,代表A/B在k轴上分组的索引,groupList 中的后一个元素应大于等于前一个元素,最后一个元素大小等于k。

    例如:矩阵A维度为[5, 10],矩阵B维度为[10, 4],groupList维度为1维张量,numGroup为3,值分别为[3, 6, 10],表示把k分为3组,分别为[0, 3)、[3, 6)、[6, 10),分组进行矩阵乘。