API功能: 根据输入词频
[object Object]、[object Object]/[object Object]采样参数、随机采样权重分布[object Object],进行topK-topP-Sample采样计算,输出每个batch的最大词频[object Object],以及topK-topP采样后的词频分布[object Object]。算子包含三个可单独使能,但上下游处理关系保持不变的采样算法(从原始输入到最终输出):topK采样、topP采样、指数采样(Sample)。它们可以构成八种计算场景。如下表所示:
[object Object]undefined
计算公式: 输入
[object Object]为大小是[batch, voc_size]的词频表,其中每个batch对应一条输入序列,而voc_size则是约定每个batch的统一长度。[object Object][object Object]中的每一行logits[batch][:]根据相应的top_k[batch]、top_p[batch]、q[batch, :],执行不同的计算场景。[object Object] 下述公式中使用b和v来分别表示batch和voc_size方向上的索引。topK采样
- 按分段长度v采用分段topK归并排序,用{s-1}块的topK对当前{s}块的输入进行预筛选,渐进更新单batch的topK,减少冗余数据和计算。
- top_k[batch]对应当前batch采样的k值,有效范围为1≤top_k[batch]≤min(voc_size[batch], 1024),如果top_k[batch]超出有效范围,则视为跳过当前batch的topK采样阶段,也同样会则跳过当前batch的排序,将输入logits[batch]直接传入下一模块。[object Object]
- 具体计算流程如下所示:
- 对当前batch分割为若干子段,滚动计算top_k_value[b]: 其中: v表示预设的滚动topK时的固定分段长度:
- 生成需要过滤的mask:
- 将小于阈值的部分通过mask置为-inf:
- 通过softmax将经过topK过滤后的
[object Object]按最后一轴转换为概率分布: - 按最后一轴计算累积概率(从最小的概率开始累加): topP采样
- 如果前序topK采样已有排序输出结果,则根据topK采样输出计算累积词频,并根据topP截断采样:
- 如果topK采样被跳过,则先对输入logits[b]进行softmax处理:
- 尝试使用
[object Object],对[object Object]进行滚动排序,获取计算topP采样的mask: - 如果在访问到logits_value[b]的第1e4个元素之前,如下条件得到满足,则视为
[object Object]成功: - 如果
[object Object]失败,则对当前序logits_value[b]进行全排序和cumsum,按top_p[b]截断采样: - 将需要过滤的位置设置为-inf,得到sorted_value[b][v]:
- 取过滤后sorted_value[b][v]每行中前topK个元素,查找这些元素在输入中的原始索引,整合为logits_idx: 指数采样(Sample)
- 如果
[object Object]设置为True,则根据logits_idx,选取采样后结果输出到[object Object]: - 对sorted_value进行指数分布采样:
- 从probs_opt中取出每个batch的最大元素,从logits_idx中gather相应元素的输入索引,作为输出
[object Object]: 其中0≤b<,0≤v<。
[object Object]
- logits(
[object Object]):必选参数,表示待采样的输入词频,目前支持2维,词频索引固定为最后一维。数据类型支持[object Object]和[object Object],数据格式支持,支持非连续Tensor。 - top_k(
[object Object]):必选参数,表示每个batch采样的k值,有效范围为1≤top_k[batch]≤min(voc_size[batch], 1024),目前支持1维。数据类型支持[object Object],数据格式支持,支持非连续Tensor。 - top_p(
[object Object]):必选参数,表示每个batch采样的p值,有效范围为0<,目前支持1维。数据类型和数据格式与[object Object]保持一致,支持非连续Tensor。 - q(
[object Object]):可选参数,topK-topP采样输出的随机采样权重分布矩阵,数据类型支持[object Object],数据格式支持,支持非连续Tensor,默认值为None。 - eps(
[object Object]):可选参数,在softmax和权重采样中防止除零,默认值为1e-8。 - is_need_logits(
[object Object]):可选参数,控制[object Object]的输出条件,默认值为False。 - top_k_guess(
[object Object]):可选参数,表示每个batch在尝试topP部分遍历采样时的候选[object Object]大小,必须为正整数,默认值为32。
- logits_select_idx(
[object Object]):表示经过topK-topP-sample计算流程后,每个batch中词频最大元素max(probs_opt[batch, :])在输入[object Object]中的位置索引。数据类型支持[object Object],数据格式支持。 - logits_top_kp_select(
[object Object]):表示经过topK-topP计算流程后,输入[object Object]中剩余未被过滤的[object Object]。数据类型支持[object Object],数据格式支持。
- 该接口支持推理场景下使用。
- 该接口目前不支持图模式。
[object Object]、[object Object]、[object Object]的尺寸和维度必须完全一致。[object Object]、[object Object]、[object Object]、[object Object]除最后一维以外的所有维度必须顺序和大小完全一致。目前[object Object]只能是2维,[object Object]、[object Object]、[object Object]必须是1维非空Tensor。[object Object]、[object Object]、[object Object]不允许空Tensor作为输入,如需跳过相应模块,需按相应规则设置输入。- 如果需要单独跳过topK模块,请传入[batch, 1]大小的Tensor,并使每个元素均为无效值。
- 如果1024<,则视为选择当前batch的全部有效元素并跳过topK环节。
- 如果需要单独跳过topP模块,请传入[batch, 1]大小的Tensor,并使每个元素均≥1。
- 如果需要单独跳过Sample模块,使用其默认值或设置
[object Object]为None;如需使用Sample模块,则必须传入尺寸为[batch, voc_size]的Tensor。
[object Object]