GatingOperation
功能
Gating处理。
约束
该算子仅支持Atlas 800I A2推理产品。
定义
struct GatingParam { int32_t topkExpertNum = 0; int32_t cumSumNum = 0; };
成员
成员名称 |
描述 |
---|---|
topkExpertNum |
每个token选中的专家数。默认值为0。
|
cumSumNum |
专家总数。默认值为0。取值范围为[0, 127]。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
input_topk |
[batch*seqlen*topk] |
int32 |
ND |
输入tensor。每个token选中的专家的index。 |
input_idx_arr |
[batch*seqlen*topk] |
int32 |
ND |
输入tensor。每个token原始的index,具体的值为[0,1,2,3,...]。 |
注:batch*seqlen代表token个数,topk代表一个token选择多少个专家。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
token_index |
[batch*seqlen*topk] |
int32 |
ND |
输出tensor。token重排以后原始的index值。 |
cum_sum |
[expertNum] |
int32 |
ND |
输出tensor。每个专家被选中的次数。
|
original_index |
[batch*seqlen*topk] |
int32 |
ND |
输出tensor。token重排以后token的index值。 |
注:batch*seqlen代表token个数,topk代表一个token选择多少个专家。 |