TopkToppSamplingOperation

功能

依据输入词表的概率,根据topk和topp算法结合,选择一个词并输出。

约束

probs必须是是两维张量。

定义

struct TopkToppSamplingParam {
    uint32_t randSeed = 0;
    uint32_t topk = 100;
};

成员

成员名称

描述

randSeed

topp阶段随机抽样使用的随机数种子。

topk

topk阶段保留的词的个数。

topk必须大于0且小于或等于输入probs最后一维的大小。

输入

参数

维度

数据类型

格式

probs

[batch, voc_size]

float16

ND

topp

[batch, 1]

float16

ND

输出

参数

维度

数据类型

格式

sampled_indices

[batch, 1]

int32

ND

sampled_probs

[batch, 1]

float16

ND