GroupTopk算子超参数。将输入tensor0中维度1(输入tensor0有2个维度:维度0和维度1)数据分groupNum个组,每组取最大值,然后选出每组最大值中前k个,最后将非前k个组的数据全部置零。
1 2 3 4 5 6 7 8 9 10 | struct GroupTopkParam { int32_t groupNum = 1; int32_t k = 0; enum GroupMultiFlag : uint16_t { UNDEFINED = 0, SUM_MULTI_MAX }; uint16_t n = 1; uint8_t rsv[12] = {0}; }; |
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
groupNum |
int32_t |
1 |
[1, expert_num] |
否 |
每个token分组数量。注:“expert_num”为输入参数token维度dim_1的值。 |
k |
int32_t |
0 |
[1, groupNum] |
是 |
选择top K专家数量。需要大于等于1。 |
groupMultiFlag |
uint16_t |
UNDEFINED |
[0,1] |
否 |
枚举值,组内取值计算类型。
|
n |
uint16_t |
1 |
[1, expert_num/groupNum] |
否 |
每组内取值的个数。
|
rsv[12] |
uint8_t |
{0} |
- |
- |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
token |
[dim_0, dim_1] |
float16/bf16 |
ND |
输入tensor0, 二维Tensor,dim_0为token数,dim_1为专家总数。 |
idxArr |
[1024] |
int32 |
ND |
输入tensor1, 一维Tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_0, dim_1] |
float16/bf16 |
ND |
输出tensor,只有一个输出tensor,是对输入tensor0原地写的输出。数据类型与输入tensor0保持一致。 |
将输入的各专家分数分组,每组内选取最大值,根据每组最大值大小选取topk组,其余组置零。
对于每个token,首先将token的数据均匀分groupNum组,
groupMultiFlag取0时,计算过程下图所示:
硬件型号 |
支持情况 |
---|---|
不支持 |
|
支持 |
|
支持 |
见参数列表,需要满足groupMultiFlag置为UNDEFINED。
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
token |
[num_tokens, expert_num] |
float16/bf16 |
ND |
二维tensor。
|
idxArr |
[1024] |
int32 |
ND |
一维tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[num_tokens, expert_num] |
float16/bf16 |
ND |
是对输入token原地写的输出。数据类型与输入token保持一致。 |
1 2 3 4 | // 参数构造 atb::infer::GroupTopkParam param; param.groupNum = 8; param.k = 1; |
将输入的各专家分数分组,每组内降序排列选取前n个值求和,根据和的大小选取topk组,其余组置零。
对于每个token,首先将token的数据均匀分groupNum组,
硬件型号 |
支持情况 |
---|---|
不支持 |
|
支持 |
|
支持 |
见参数列表,需要满足groupMultiFlag置为SUM_MULTI_MAX。
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
token |
[num_tokens, expert_num] |
float16/bf16 |
ND |
二维tensor。
|
idxArr |
[1024] |
int32 |
ND |
一维tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[num_tokens, expert_num] |
float16/bf16 |
ND |
是对输入tensor0原地写的输出。数据类型与输入tensor0保持一致。 |
1 2 3 4 5 6 | // 参数构造 atb::infer::GroupTopkParam param; param.groupNum = 8; param.k = 1; param.groupMultiFlag = 1; param.n = 2; |