aclnnTopKTopPSample
产品支持情况
功能说明
算子功能: 根据输入词频logits、topK/topP采样参数、随机采样权重分布q,进行topK-topP-sample采样计算,输出每个batch的最大词频logitsSelectIdx,以及topK-topP采样后的词频分布logitsTopKPSelect。
算子包含三个可单独使能,但上下游处理关系保持不变的采样算法(从原始输入到最终输出):TopK采样、TopP采样、指数采样(本文档中Sample所指)。它们可以构成八种计算场景。如下表所示:
[object Object]undefined
计算公式: 输入logits为大小为[batch, voc_size]的词频表,其中每个batch对应一条输入序列,而voc_size则是约定每个batch的统一长度。[object Object] logits中的每一行logits[batch][:]根据相应的topK[batch]、topP[batch]、q[batch, :],执行不同的计算场景。[object Object] 下述公式中使用b和v来分别表示batch和voc_size方向上的索引。
TopK采样
- 按分段长度v采用分段topk归并排序,用{s-1}块的topK对当前{s}块的输入进行预筛选,渐进更新单batch的topK,减少冗余数据和计算。
- topK[batch]对应当前batch采样的k值,有效范围为1≤topK[batch]≤min(voc_size[batch], 1024),如果top[k]超出有效范围,则视为跳过当前batch的topK采样阶段,也同样会则跳过当前batch的排序,将输入logits[batch]直接传入下一模块。
- 对当前batch分割为若干子段,滚动计算topKValue[b]:
其中:
v表示预设的滚动topK时固定的分段长度:
- 生成需要过滤的mask
- 将小于阈值的部分通过mask置为-inf
- 通过softmax将经过topK过滤后的logits按最后一轴转换为概率分布
- 按最后一轴计算累积概率(从最小的概率开始累加)
TopP采样
如果前序topK采样已有排序输出结果,则根据topK采样输出计算累积词频,并根据topP截断采样:
如果topK采样被跳过,则先对输入logits[b]进行softmax处理:
- 尝试使用topKGuess,对logits进行滚动排序,获取计算topP的mask:
- 如果在访问到logitsValue[b]的第1e4个元素之前,下条件得到满足,则视为topKGuess成功,
- 如果topKGuess失败,则对当前序logitsValue[b]进行全排序和cumsum,按topP[b]截断采样:
- 将需要过滤的位置设置为-inf,得到sortedValue[b][v]:
取过滤后sortedValue[b][v]每行中前topK个元素,查找这些元素在输入中的原始索引,整合为
[object Object]:
指数采样(Sample)
- 如果
[object Object],则根据[object Object],选取采样后结果输出到[object Object]:
对
[object Object]进行指数分布采样:从
[object Object]中取出每个batch的最大元素,从[object Object]中gather相应元素的输入索引,作为输出[object Object]:
其中0≤b<sortedValue.size(0),0≤v<sortedValue.size(1)。
函数原型
每个算子分为,必须先调用[object Object]接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用[object Object]接口执行计算。
aclnnTopKTopPSampleGetWorkspaceSize
aclnnTopKTopPSample
约束说明
- 对于所有参数,它们的尺寸必须满足,batch>0,0<vocSize<=2^20。
- logits、q、logitsTopKPselect的尺寸和维度必须完全一致。
- logits、topK、topP、logitsSelectIdx除最后一维以外的所有维度必须顺序和大小完全一致。目前logits只能是2维,topK、topP、logitsSelectIdx必须是1维非空Tensor。logits、topK、topP不允许空Tensor作为输入,如需跳过相应模块,需按相应规则设置输入。
- 如果需要单独跳过topK模块,请传入[batch, 1]大小的Tensor,并使每个元素均为无效值。
- 如果1024<topK[batch]<vocSize[batch],则视为选择当前batch的全部有效元素并跳过topK采样。
- 如果需要单独跳过topP模块,请传入[batch, 1]大小的Tensor,并使每个元素均≥1。
- 如果需要单独跳过sample模块,传入
[object Object]即可;如需使用sample模块,则必须传入尺寸为[batch, vocSize]的Tensor。