GroupTopkOperation

功能

GroupTopk算子超参数。将输入tensor0中维度1(输入tensor0有2个维度:维度0和维度1)数据分groupNum个组,每组取最大值,然后选出每组最大值中前k个,最后将非前k个组的数据全部置零。

定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
struct GroupTopkParam {
    int32_t groupNum = 1;
    int32_t k = 0;
    enum GroupMultiFlag : uint16_t {
        UNDEFINED = 0, 
        SUM_MULTI_MAX  
    };
    uint16_t n = 1; 
    uint8_t rsv[12] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

groupNum

int32_t

1

[1, expert_num]

每个token分组数量。注:“expert_num”为输入参数token维度dim_1的值。

k

int32_t

0

[1, groupNum]

选择top K专家数量。需要大于等于1。

groupMultiFlag

uint16_t

UNDEFINED

[0,1]

枚举值,组内取值计算类型。

  • UNDEFINED:默认类型,每组内取最大值。
  • SUM_MULTI_MAX:每组内取n个最大值求和,需要设置参数n。

n

uint16_t

1

[1, expert_num/groupNum]

每组内取值的个数。

  • groupMultiFlag为1时,n需要大于0。
  • groupMultiFlag为0时不生效。

rsv[12]

uint8_t

{0}

-

-

预留参数。

输入

参数

维度

数据类型

格式

描述

token

[dim_0, dim_1]

float16/bf16

ND

输入tensor0, 二维Tensor,dim_0为token数,dim_1为专家总数。

idxArr

[1024]

int32

ND

输入tensor1, 一维Tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。

输出

参数

维度

数据类型

格式

描述

output

[dim_0, dim_1]

float16/bf16

ND

输出tensor,只有一个输出tensor,是对输入tensor0原地写的输出。数据类型与输入tensor0保持一致。

规格约束

基础功能

组内再排序和选取前n个值功能