昇腾社区首页
中文
注册

MmDeqSwigluQuantMmDeqOperation(代码开放)

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

MmDeqSwigluQuantMmDeq 算子为整网中 gate up + gate down 两层的融合算子

图1 计算流程图

定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
struct MmDeqSwigluQuantMmDeqParam {
    enum OutputType {
        OUTPUT_FLOAT16 = 0,
        OUTPUT_BFLOAT16,
        OUTPUT_INVALID
    };

    enum WeightUpPermuteType {
        PERMUTE_N256 = 0,
        PERMUTE_N128,
        PERMUTE_INVALID
    };
    OutputType outputType = OUTPUT_FLOAT16;
    WeightUpPermuteType weightUpPermuteType = PERMUTE_N256;
    bool transposeWeightUp = false;
    bool transposeWeightDown = true;
    uint8_t rsv[46] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

outputType

OutputType

OUTPUT_FLOAT16

OUTPUT_FLOAT16

控制输出类型。

  • OUTPUT_FLOAT16:默认值,输出数据类型为float16,目前只支持默认值。
  • OUTPUT_BFLOAT16:输出数据类型为bf16。
  • OUTPUT_INVALID:无效值。

weightUpPermuteType

WeightUpPermuteType

PERMUTE_N256

PERMUTE_N256

PERMUTE_N128

控制重排的方式,目前只支持256和128。

  • PERMUTE_N256:表示在 N 轴上重排的基本块宽度。
  • PERMUTE_N128:表示在 N 轴上重排的基本块宽度。
  • PERMUTE_INVALID:无效值。

transposeWeightUp

bool

false

false

控制前一个 GroupedMatmul 的 weight 是否转置,目前只支持不转置

transposeWeightDown

bool

true

true

控制后一个 GroupedMatmul 的 weight 是否转置,目前只支持转置

rsv[46]

uint8_t

{0}

[0]

预留参数

输入

参数

维度

数据类型

格式

描述

x1

[m, 7168]

int8

ND

第一个mm的左矩阵。

permuteWeight1

[7168, 4096]

int8

NZ

重排处理后的第一个mm的右矩阵。

permuteScale1

[4096]

float32

ND

重排处理后的第一个mm的反量化per channel scale。

perTokenScale1

[m]

float32

ND

第一个mm的反量化per token scale。

weight2

[7168, 2048]

int8

NZ

第二个mm的右矩阵,转置后的NZ排布。

scale2

[7168]

float32

ND

第二个mm的反量化per channel scale。

重排处理逻辑:

1
2
3
4
def permute_weight(w: torch.Tensor, tile_n=256):
    *dims, n = w.shape
    order = list(range(len(dims))) + [-2, -3, -1]
    return w.reshape(*dims, 2, n // tile_n, tile_n // 2).permute(order).reshape(*dims, n).contiguous()

参数weightUpPermuteType的值应当与重排逻辑中的tile_n参数相对应,目前支持PERMUTE_N256对应tile_n=256,PERMUTE_N128对应tile_n=128两种模式。

输出

参数

维度

数据类型

格式

描述

output

[m, 7168]

float16

ND

第二个mm反量化后的结果。

约束说明

m表示token总数,不超过128K。