昇腾社区首页
中文
注册

TopkToppSamplingOperation

功能

依据给定的词表概率以及top-p,设置随机种子及top-k保留词数,选择最合适的词及对应概率作为输出。

支持batch级别随机种子、top-k取样,支持exponential取样。

约束

probs必须是是两维张量。

定义

struct TopkToppSamplingParam {
    enum TopkToppSamplingType {
        SAMPLING_UNDEFINED = -1,
        SINGLE_TOPK_SAMPLING, 
        BATCH_TOPK_MULTINOMIAL_SAMPLING,
        BATCH_TOPK_EXPONENTIAL_SAMPLING, 
        SAMPLING_MAX, 
    };
    TopkToppSamplingType topkToppSamplingType = SINGLE_TOPK_SAMPLING;
    uint32_t randSeed = 0;
    uint32_t topk = 100;
    std::vector<uint32_t> randSeeds;
};

成员

成员名称

描述

TopkToppSamplingType

取样处理类型。

  • SAMPLING_UNDEFINED:未定义。
  • SINGLE_TOPK_SAMPLING:非batch级别随机种子,Topk的取样。
  • BATCH_TOPK_MULTINOMIAL_SAMPLING:batch级别随机种子,Topk的multinomial取样。
  • BATCH_TOPK_EXPONENTIAL_SAMPLING:batch级别随机种子,Topk的exponential取样。
  • SAMPLING_MAX:枚举最大值。

topkToppSamplingType

采样类型,默认为非batch级别随机种子、Topk的取样。

randSeed

top-p阶段随机抽样使用的随机数种子。

topktoppSamplingtype = SINGLE_TOPK_SAMPLING时使用。

topk

top-k阶段保留的词的个数,需要小于词表的词数。

top-k必须大于0且小于或等于输入probs最后一维的大小。

topktoppSamplingtype = SINGLE_TOPK_SAMPLING时使用。

randSeeds

每个batch下top-p阶段随机抽样使用的随机数种子。

维度与batch大小一致。

topktoppSamplingtype = BATCH_TOPK_MULTINOMIAL_SAMPLING时使用。

输入输出(非batch级随机种子,topk取样)

参数

维度

数据类型

格式

描述

probs

[batch, voc_size]

float16

ND

输入。

topp

[batch, 1]

float16

ND

输入topp,topp截取的概率,batch的值需与probs的一致。

sampled_indices

[batch, 1]

int32

ND

输出,取样的idx。

sampled_probs

[batch, 1]

float16

ND

输出,取样的值。

输入输出(batch级随机种子,topk的multinomial取样)

参数

维度

数据类型

格式

描述

probs

[batch, voc_size]

float16

ND

输入。

topk

[batch, 1]

int32

ND

输入,topk截取的位置,batch的值需与probs的一致。

topp

[batch, 1]

float16

ND

输入,topp截取的概率,batch的值需与probs的一致。

sampled_indices

[batch, 1]

int32

ND

输出,取样的idx。

sampled_probs

[batch, 1]

float16

ND

输出,取样的值。

输入输出(batch级随机种子,topk的Exponential取样)

参数

维度

数据类型

格式

描述

probs

[batch, voc_size]

float16

ND

输入。

topk

[batch, 1]

int32

ND

输入,topk截取的位置,batch的值需与probs的一致。

topp

[batch, 1]

float16

ND

输入,topp截取的概率,batch的值需与probs的一致。

exp

[batch, voc_size]

float16

ND

输入,所除的指数分布,维度需与probs的一致。

sampled_indices

[batch, 1]

int32

ND

输出,取样的idx。

sampled_probs

[batch, 1]

float16

ND

输出,取样的值。