TopkToppSamplingOperation
功能
依据给定的词表概率以及top-p,设置随机种子及top-k保留词数,选择最合适的词及对应概率作为输出。
支持batch级别随机种子、top-k取样,支持exponential取样。
约束
probs必须是是两维张量。
定义
struct TopkToppSamplingParam { enum TopkToppSamplingType { SAMPLING_UNDEFINED = -1, SINGLE_TOPK_SAMPLING, BATCH_TOPK_MULTINOMIAL_SAMPLING, BATCH_TOPK_EXPONENTIAL_SAMPLING, SAMPLING_MAX, }; TopkToppSamplingType topkToppSamplingType = SINGLE_TOPK_SAMPLING; uint32_t randSeed = 0; uint32_t topk = 100; std::vector<uint32_t> randSeeds; };
成员
成员名称 |
描述 |
---|---|
TopkToppSamplingType |
取样处理类型。
|
topkToppSamplingType |
采样类型,默认为非batch级别随机种子、Topk的取样。 |
randSeed |
top-p阶段随机抽样使用的随机数种子。 当topktoppSamplingtype = SINGLE_TOPK_SAMPLING时使用。 |
topk |
top-k阶段保留的词的个数,需要小于词表的词数。 top-k必须大于0且小于或等于输入probs最后一维的大小。 当topktoppSamplingtype = SINGLE_TOPK_SAMPLING时使用。 |
randSeeds |
每个batch下top-p阶段随机抽样使用的随机数种子。 维度与batch大小一致。 当topktoppSamplingtype = BATCH_TOPK_MULTINOMIAL_SAMPLING时使用。 |
输入输出(非batch级随机种子,topk取样)
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
probs |
[batch, voc_size] |
float16 |
ND |
输入。 |
topp |
[batch, 1] |
float16 |
ND |
输入topp,topp截取的概率,batch的值需与probs的一致。 |
sampled_indices |
[batch, 1] |
int32 |
ND |
输出,取样的idx。 |
sampled_probs |
[batch, 1] |
float16 |
ND |
输出,取样的值。 |
输入输出(batch级随机种子,topk的multinomial取样)
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
probs |
[batch, voc_size] |
float16 |
ND |
输入。 |
topk |
[batch, 1] |
int32 |
ND |
输入,topk截取的位置,batch的值需与probs的一致。 |
topp |
[batch, 1] |
float16 |
ND |
输入,topp截取的概率,batch的值需与probs的一致。 |
sampled_indices |
[batch, 1] |
int32 |
ND |
输出,取样的idx。 |
sampled_probs |
[batch, 1] |
float16 |
ND |
输出,取样的值。 |
输入输出(batch级随机种子,topk的Exponential取样)
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
probs |
[batch, voc_size] |
float16 |
ND |
输入。 |
topk |
[batch, 1] |
int32 |
ND |
输入,topk截取的位置,batch的值需与probs的一致。 |
topp |
[batch, 1] |
float16 |
ND |
输入,topp截取的概率,batch的值需与probs的一致。 |
exp |
[batch, voc_size] |
float16 |
ND |
输入,所除的指数分布,维度需与probs的一致。 |
sampled_indices |
[batch, 1] |
int32 |
ND |
输出,取样的idx。 |
sampled_probs |
[batch, 1] |
float16 |
ND |
输出,取样的值。 |