GroupTopkOperation

功能

GroupTopk算子超参数。将输入tensor0中维度1(输入tensor0有2个维度:维度0和维度1)数据分groupNum个组,每组取最大值,然后选出每组最大值中前k个,最后将非前k个组的数据全部置零。

定义

1
2
3
4
5
struct GroupTopkParam {
    int32_t groupNum = 1;
    int32_t k = 0;
    uint8_t rsv[16] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

描述

是否必选

groupNum

int32_t

1

[1, expert_num]

每个token分组数量。注:“expert_num”为inTensor0.dims[1]的值。

k

int32_t

0

[1, groupNum]

选择top K专家数量。

rsv[16]

uint8_t

{0}

-

预留参数。

-

输入

参数

维度

数据类型

格式

描述

token

[d0,d1]

float16/bf16

ND

输入tensor0, 二维Tensor,维度0为token数,维度1为专家总数。

idxArr

[1024]

int32

ND

输入tensor1, 一维Tensor,用于辅助计算,固定长度1024,[0,1,2,...,1023]的等差序列。

输出

参数

维度

数据类型

格式

描述

output

[d0,d1]

float16/bf16

ND

输出tensor,只有一个输出tensor,是对输入tensor0原地写的输出。数据类型与输入tensor0保持一致。

规格约束