GroupTopkOperation(代码开放)
产品支持情况
硬件型号 |
是否支持 |
---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
功能说明
GroupTopk算子超参数。将输入tensor0中维度1(输入tensor0有2个维度:维度0和维度1)数据分groupNum个组,每组取最大值,然后选出每组最大值中前k个,最后将非前k个组的数据全部置零。
定义
1 2 3 4 5 6 7 8 9 10 | struct GroupTopkParam { int32_t groupNum = 1; int32_t k = 0; enum GroupMultiFlag : uint16_t { UNDEFINED = 0, SUM_MULTI_MAX }; uint16_t n = 1; uint8_t rsv[12] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
groupNum |
int32_t |
1 |
[1, expert_num] |
否 |
每个token分组数量。注:“expert_num”为输入参数token维度dim_1的值。 |
k |
int32_t |
0 |
[1, groupNum] |
是 |
选择top K专家数量。需要大于等于1。 |
groupMultiFlag |
uint16_t |
UNDEFINED |
[0,1] |
否 |
枚举值,组内取值计算类型。
|
n |
uint16_t |
1 |
[1, expert_num/groupNum] |
否 |
每组内取值的个数。
|
rsv[12] |
uint8_t |
{0} |
- |
- |
预留参数。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
token |
[dim_0, dim_1] |
float16/bf16 |
ND |
输入tensor0, 二维Tensor,dim_0为token数,dim_1为专家总数。 |
idxArr |
[1024] |
int32 |
ND |
输入tensor1, 一维Tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[dim_0, dim_1] |
float16/bf16 |
ND |
输出tensor,只有一个输出tensor,是对输入tensor0原地写的输出。数据类型与输入tensor0保持一致。 |
约束说明
- 1≤expert_num≤1024,expert_num≥groupNum≥k≥1,expert_num能被groupNum整除。
基础功能
- 功能概述
将输入的各专家分数分组,每组内选取最大值,根据每组最大值大小选取topk组,其余组置零。
- 计算公式
对于每个token,首先将token的数据均匀分groupNum组,
- 每组取最大值;
- 对每组最大值降序排序;
- 获取前k个组的index,然后将其余的index对应的组的数据置零。
- 计算图
groupMultiFlag取0时,计算过程下图所示:
图1 功能示例 - 参数列表
见参数列表,需要满足groupMultiFlag置为UNDEFINED。
- 输入
参数
维度
数据类型
格式
描述
token
[num_tokens, expert_num]
float16/bf16
ND
二维tensor。
- 维度0为token数。
- 维度1为专家总数。
idxArr
[1024]
int32
ND
一维tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。
- 输出
参数
维度
数据类型
格式
描述
output
[num_tokens, expert_num]
float16/bf16
ND
是对输入token原地写的输出。数据类型与输入token保持一致。
- 使用示例
1 2 3 4
// 参数构造 atb::infer::GroupTopkParam param; param.groupNum = 8; param.k = 1;
组内再排序和选取前n个值功能
- 功能概述
将输入的各专家分数分组,每组内降序排列选取前n个值求和,根据和的大小选取topk组,其余组置零。
- 计算公式
对于每个token,首先将token的数据均匀分groupNum组,
- 每组降序排列,取前n个值;
- 对这n个值求和;
- 各个组根据求得的和大小降序排序;
- 获取前k个组的index,然后将其余的index对应的组的数据置零。
- 参数列表
见参数列表,需要满足groupMultiFlag置为SUM_MULTI_MAX。
- 输入
参数
维度
数据类型
格式
描述
token
[num_tokens, expert_num]
float16/bf16
ND
二维tensor。
- 维度0为token数。
- 维度1为专家总数。
idxArr
[1024]
int32
ND
一维tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。
- 输出
参数
维度
数据类型
格式
描述
output
[num_tokens, expert_num]
float16/bf16
ND
是对输入tensor0原地写的输出。数据类型与输入tensor0保持一致。
- 使用示例
1 2 3 4 5 6
// 参数构造 atb::infer::GroupTopkParam param; param.groupNum = 8; param.k = 1; param.groupMultiFlag = 1; param.n = 2;