昇腾社区首页
中文
注册

GmmDeqSwigluQuantGmmDeqOperation(代码开放)

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

GmmDeqSwigluQuantGmmDeq 算子为整网中 gate up + gate down 两层的融合算子

图1 计算流程图

定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
struct GmmDeqSwigluQuantGmmDeqParam {
    enum OutputType {
        OUTPUT_FLOAT16 = 0,
        OUTPUT_BFLOAT16,
        OUTPUT_INVALID
    };
    enum GroupListType {
        GROUP_LIST_CUMSUM = 0,
        GROUP_LIST_SINGLE,
        GROUP_LIST_INVALID
    };
    enum WeightUpPermuteType {
        PERMUTE_N256 = 0,
        PERMUTE_N128,
        PERMUTE_INVALID
    };
    OutputType outputType = OUTPUT_FLOAT16;
    GroupListType groupListType = GROUP_LIST_CUMSUM;
    WeightUpPermuteType weightUpPermuteType = PERMUTE_N256;
    bool transposeWeightUp = false;
    bool transposeWeightDown = true;
    uint8_t rsv[42] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

outputType

OutputType

OUTPUT_FLOAT16

OUTPUT_FLOAT16

控制输出类型。

  • OUTPUT_FLOAT16:默认值,输出数据类型为float16,目前只支持默认值。
  • OUTPUT_BFLOAT16:输出数据类型为bf16。
  • OUTPUT_INVALID:无效值。

groupListType

GroupListType

GROUP_LIST_CUMSUM

GROUP_LIST_CUMSUM

控制groupList的类型。

  • GROUP_LIST_CUMSUM:默认值,累加和,目前只支持默认值。
  • GROUP_LIST_SINGLE:单组值模式。
  • GROUP_LIST_INVALID:无效值。

weightUpPermuteType

WeightUpPermuteType

PERMUTE_N256

PERMUTE_N256

PERMUTE_N128

控制重排的方式,目前只支持256和128。

  • PERMUTE_N256:表示在 N 轴上重排的基本块宽度。
  • PERMUTE_N128:表示在 N 轴上重排的基本块宽度。
  • PERMUTE_INVALID:无效值。

transposeWeightUp

bool

false

false

控制前一个 GroupedMatmul 的 weight 是否转置,目前只支持不转置

transposeWeightDown

bool

true

true

控制后一个 GroupedMatmul 的 weight 是否转置,目前只支持转置

rsv[42]

uint8_t

{0}

[0]

预留参数

输入

参数

维度

数据类型

格式

描述

x1

[m, 7168]

int8

ND

第一个 gmm 的左矩阵。

permuteWeight1

[G, 7168, 4096]

int8

NZ

重排处理后的第一个gmm的右矩阵。

permuteScale1

[G, 4096]

float32

ND

重排处理后的第一个gmm的反量化 per channel scale。

perTokenScale1

[m]

float32

ND

第一个gmm的反量化per token scale。

groupList

[G]

int64

ND

Grouped matmul的group list,为前缀和模式。

weight2

[G, 7168, 2048]

int8

NZ

第二个gmm的右矩阵,转置后的NZ排布。

scale2

[G, 7168]

float32

ND

第二个gmm的反量化 per channel scale。

重排处理逻辑:

1
2
3
4
def permute_weight(w: torch.Tensor, tile_n=256):
    *dims, n = w.shape
    order = list(range(len(dims))) + [-2, -3, -1]
    return w.reshape(*dims, 2, n // tile_n, tile_n // 2).permute(order).reshape(*dims, n).contiguous()

参数weightUpPermuteType的值应当与重排逻辑中的tile_n参数相对应,目前支持PERMUTE_N256对应tile_n=256,PERMUTE_N128对应tile_n=128两种模式。

输出

参数

维度

数据类型

格式

描述

output

[m, 7168]

float16

ND

第二个gmm反量化后的结果。

约束说明

  • m表示token总数,不超过128K。
  • G表示专家数,即groupList的长度,不超过 32。
  • CUMSUM模式下,groupList为一个单调不减,长度为G的列表,其中的第 i 位的值表示第 i 个 group及其之前的所有group在m轴上的长度之和,且需要确保该列表中最大的值不超过m。