MmDeqSwigluQuantMmDeqOperation(代码开放)
产品支持情况
硬件型号 |
是否支持 |
---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
功能说明
MmDeqSwigluQuantMmDeq 算子为整网中 gate up + gate down 两层的融合算子。

定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | struct MmDeqSwigluQuantMmDeqParam { enum OutputType { OUTPUT_FLOAT16 = 0, OUTPUT_BFLOAT16, OUTPUT_INVALID }; enum WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; OutputType outputType = OUTPUT_FLOAT16; WeightUpPermuteType weightUpPermuteType = PERMUTE_N256; bool transposeWeightUp = false; bool transposeWeightDown = true; uint8_t rsv[46] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
outputType |
OutputType |
OUTPUT_FLOAT16 |
OUTPUT_FLOAT16 |
是 |
控制输出类型。
|
weightUpPermuteType |
WeightUpPermuteType |
PERMUTE_N256 |
PERMUTE_N256 PERMUTE_N128 |
是 |
控制重排的方式,目前只支持256和128。
|
transposeWeightUp |
bool |
false |
false |
是 |
控制前一个 GroupedMatmul 的 weight 是否转置,目前只支持不转置 |
transposeWeightDown |
bool |
true |
true |
是 |
控制后一个 GroupedMatmul 的 weight 是否转置,目前只支持转置 |
rsv[46] |
uint8_t |
{0} |
[0] |
否 |
预留参数 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x1 |
[m, 7168] |
int8 |
ND |
第一个mm的左矩阵。 |
permuteWeight1 |
[7168, 4096] |
int8 |
NZ |
重排处理后的第一个mm的右矩阵。 |
permuteScale1 |
[4096] |
float32 |
ND |
重排处理后的第一个mm的反量化per channel scale。 |
perTokenScale1 |
[m] |
float32 |
ND |
第一个mm的反量化per token scale。 |
weight2 |
[7168, 2048] |
int8 |
NZ |
第二个mm的右矩阵,转置后的NZ排布。 |
scale2 |
[7168] |
float32 |
ND |
第二个mm的反量化per channel scale。 |
重排处理逻辑:
1 2 3 4 | def permute_weight(w: torch.Tensor, tile_n=256): *dims, n = w.shape order = list(range(len(dims))) + [-2, -3, -1] return w.reshape(*dims, 2, n // tile_n, tile_n // 2).permute(order).reshape(*dims, n).contiguous() |
参数weightUpPermuteType的值应当与重排逻辑中的tile_n参数相对应,目前支持PERMUTE_N256对应tile_n=256,PERMUTE_N128对应tile_n=128两种模式。
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[m, 7168] |
float16 |
ND |
第二个mm反量化后的结果。 |
约束说明
m表示token总数,不超过128K。