昇腾社区首页
中文
注册

GatingOperation

功能

Gating处理。

约束

该算子仅支持Atlas 800I A2推理产品

定义

struct GatingParam {
    int32_t topkExpertNum = 0;
    int32_t cumSumNum = 0;
};

成员

成员名称

描述

topkExpertNum

每个token选中的专家数。默认值为0。

  • “cumSumNum”为0时,取值为1。
  • “cumSumNum”不为0时,取值范围为(0, cumSumNum]。

cumSumNum

专家总数。默认值为0。取值范围为[0, 127]。

输入

参数

维度

数据类型

格式

描述

input_topk

[batch*seqlen*topk]

int32

ND

输入tensor。每个token选中的专家的index。

input_idx_arr

[batch*seqlen*topk]

int32

ND

输入tensor。每个token原始的index,具体的值为[0,1,2,3,...]。

注:batch*seqlen代表token个数,topk代表一个token选择多少个专家。

输出

参数

维度

数据类型

格式

描述

token_index

[batch*seqlen*topk]

int32

ND

输出tensor。token重排以后原始的index值。

cum_sum

[expertNum]

int32

ND

输出tensor。每个专家被选中的次数。

  • “cumSumNum”为0时,expertNum值为1。
  • “cumSumNum”不为0时,expertNum值为cumSumNum。

original_index

[batch*seqlen*topk]

int32

ND

输出tensor。token重排以后token的index值。

注:batch*seqlen代表token个数,topk代表一个token选择多少个专家。