GmmDeqSwigluQuantGmmDeqOperation(代码开放)
产品支持情况
硬件型号  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能
GmmDeqSwigluQuantGmmDeq 算子为整网中 gate up + gate down 两层的融合算子。

定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23  | struct GmmDeqSwigluQuantGmmDeqParam { enum OutputType { OUTPUT_FLOAT16 = 0, OUTPUT_BFLOAT16, OUTPUT_INVALID }; enum GroupListType { GROUP_LIST_CUMSUM = 0, GROUP_LIST_SINGLE, GROUP_LIST_INVALID }; enum WeightUpPermuteType { PERMUTE_N256 = 0, PERMUTE_N128, PERMUTE_INVALID }; OutputType outputType = OUTPUT_FLOAT16; GroupListType groupListType = GROUP_LIST_CUMSUM; WeightUpPermuteType weightUpPermuteType = PERMUTE_N256; bool transposeWeightUp = false; bool transposeWeightDown = true; uint8_t rsv[42] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
outputType  | 
OutputType  | 
OUTPUT_FLOAT16  | 
OUTPUT_FLOAT16  | 
是  | 
控制输出类型。 
  | 
groupListType  | 
GroupListType  | 
GROUP_LIST_CUMSUM  | 
GROUP_LIST_CUMSUM  | 
是  | 
控制groupList的类型。 
  | 
weightUpPermuteType  | 
WeightUpPermuteType  | 
PERMUTE_N256  | 
PERMUTE_N256 PERMUTE_N128  | 
是  | 
控制重排的方式,目前只支持256和128。 
  | 
transposeWeightUp  | 
bool  | 
false  | 
false  | 
是  | 
控制前一个 GroupedMatmul 的 weight 是否转置,目前只支持不转置  | 
transposeWeightDown  | 
bool  | 
true  | 
true  | 
是  | 
控制后一个 GroupedMatmul 的 weight 是否转置,目前只支持转置  | 
rsv[42]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
x1  | 
[m, 7168]  | 
int8  | 
ND  | 
第一个 gmm 的左矩阵。  | 
permuteWeight1  | 
[G, 7168, 4096]  | 
int8  | 
NZ  | 
重排处理后的第一个gmm的右矩阵。  | 
permuteScale1  | 
[G, 4096]  | 
float32  | 
ND  | 
重排处理后的第一个gmm的反量化 per channel scale。  | 
perTokenScale1  | 
[m]  | 
float32  | 
ND  | 
第一个gmm的反量化per token scale。  | 
groupList  | 
[G]  | 
int64  | 
ND  | 
Grouped matmul的group list,为前缀和模式。  | 
weight2  | 
[G, 7168, 2048]  | 
int8  | 
NZ  | 
第二个gmm的右矩阵,转置后的NZ排布。  | 
scale2  | 
[G, 7168]  | 
float32  | 
ND  | 
第二个gmm的反量化 per channel scale。  | 
重排处理逻辑:
1 2 3 4  | def permute_weight(w: torch.Tensor, tile_n=256): *dims, n = w.shape order = list(range(len(dims))) + [-2, -3, -1] return w.reshape(*dims, 2, n // tile_n, tile_n // 2).permute(order).reshape(*dims, n).contiguous()  | 
参数weightUpPermuteType的值应当与重排逻辑中的tile_n参数相对应,目前支持PERMUTE_N256对应tile_n=256,PERMUTE_N128对应tile_n=128两种模式。
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
output  | 
[m, 7168]  | 
float16  | 
ND  | 
第二个gmm反量化后的结果。  | 
规格约束
- m表示token总数,不超过128K。
 - G表示专家数,即groupList的长度,不超过 32。
 - CUMSUM模式下,groupList为一个单调不减,长度为G的列表,其中的第 i 位的值表示第 i 个 group及其之前的所有group在m轴上的长度之和,且需要确保该列表中最大的值不超过m。