昇腾社区首页
中文
注册
开发者
下载

aclnnTopKTopPSampleV2

产品支持情况

[object Object]undefined

功能说明

  • 接口功能: 根据输入词频logits、topK/topP/minP采样参数、随机采样权重分布q,进行topK-topP-minP-sample采样计算。当输入isNeedSampleResult为false时,输出每个batch的最大词频logitsSelectIdx,以及topK-topP-minP采样后的词频分布logitsTopKPSelect;当输入isNeedSampleResult为true时,输出topK-topP-minP采样后的中间计算结果logitsIdx和logitsSortMasked,其中logitsSortMasked为词频logits经过topK-topP-minP采样计算后的中间结果,logitsIdx为logitsSortMasked在logits中对应的索引。

    算子包含四个可单独使能,但上下游处理关系保持不变的采样算法(从原始输入到最终输出):TopK采样、TopP采样、MinP采样、指数采样(本文档中Sample所指)。目前支持以下计算场景。如下表所示:

    [object Object]undefined
  • 计算公式: 输入logits为大小为[batch, voc_size]的词频表,其中每个batch对应一条输入序列,而voc_size则是约定每个batch的统一长度。[object Object] logits中的每一行logits[batch][:]根据相应的topK[batch]、topP[batch]、minP[batch, :]、q[batch, :],执行不同的计算场景。[object Object] 下述公式中使用b和v来分别表示batch和voc_size方向上的索引。

    TopK采样

    1. 按分段长度v采用分段topk归并排序,用{s-1}块的topK对当前{s}块的输入进行预筛选,渐进更新单batch的topK,减少冗余数据和计算。
    2. topK[batch]对应当前batch采样的k值,有效范围为1≤topK[batch]≤min(voc_size[batch], ks_max),如果top[k]超出有效范围,则视为跳过当前batch的topK采样阶段,也同样会则跳过当前batch的排序,将输入logits[batch]直接传入下一模块。
    • 对当前batch分割为若干子段,滚动计算topKValue[b]:
    topKValue[b]=Max(topK[b])s=1Sv{topKValue[b]{s1}{logits[b][v]topKMin[b][s1]}}Card(topKValue[b])=topK[b]topKValue[b] = {Max(topK[b])}_{s=1}^{\left \lceil \frac{S}{v} \right \rceil }\left \{ topKValue[b]\left \{s-1 \right \} \cup \left \{ logits[b][v] \ge topKMin[b][s-1] \right \} \right \}\\ Card(topKValue[b])=topK[b]

    其中:

    topKMin[b][s]=Min(topKValue[b]{s})topKMin[b][s] = Min(topKValue[b]\left \{ s \right \})

    v表示预设的滚动topK时固定的分段长度:

    v=8ks_maxv = 8 * \text{ks\_max}

    ks_max有效取值范围[1,1024],默认为1024,并且需要向上对齐到8的整数倍。

    • 生成需要过滤的mask
    sortedValue[b]=sort(topKValue[b],descendant)sortedValue[b] = sort(topKValue[b], descendant) topKMask=sortedValuetopKValuetopKMask = sortedValue \geq topKValue
    • 将小于阈值的部分通过mask置为defLogit:
    sortedValue[b][v]={defLogittopKMask[b][v] = falsesortedValue[b][v]topKMask[b][v] = truesortedValue[b][v]= \begin{cases} defLogit & \text{topKMask[b][v] = false} \\ sortedValue[b][v] & \text{topKMask[b][v] = true} & \end{cases}
    • 其中defLogit取决于入参约束属性input_is_logits,该属性控制输入Logits和输出logits_top_kp_select的归一化:
    defLogit={inf,inputIsLogits=true0,inputIsLogits=false\text{defLogit} = \begin{cases} -inf, & \text{inputIsLogits} = \text{true} \\ 0, & \text{inputIsLogits} = \text{false} \end{cases}

    TopP采样

    • 根据入参约束属性inputIsLogits,如果该属性为True,则对排序后结果进行归一化:

      logit_sortProb={softmax(logits_sort),inputIsLogits=Truelogits_sort,inputIsLogits=False\text{logit\_sortProb} = \begin{cases} \text{softmax}(\text{logits\_sort}), & \text{inputIsLogits} = \text{True} \\ \text{logits\_sort}, & \text{inputIsLogits} = \text{False} \end{cases}
    • 根据输入[object Object]的数值,本模块的处理策略如下:

      [object Object]undefined
    • 如果执行常规topP采样,且如果前序topK环节已有排序输出结果,则根据topK采样输出计算累积词频,并根据top_p截断采样:

      topPMask[b]={0,topKMask[b]logits_sortProb[b][]>p[b]1,topKMask[b]logits_sortProb[b][]p[b]topPMask[b] = \begin{cases} 0, & \sum_{\text{topKMask}[b]}^{} \text{logits\_sortProb}[b][*] > p[b] \\ 1, & \sum_{\text{topKMask}[b]}^{} \text{logits\_sortProb}[b][*] \leq p[b] \end{cases}
    • 如果执行常规topP采样,但前序topK环节被跳过,则计算top-p的mask:

      topPMask[b]={topKMask[b][0:GuessK],GuessKprobValue[b][]p[b]probSum[b][v]1p[b],otherstopPMask[b] = \begin{cases} topKMask[b][0:GuessK], & \sum_{\text{GuessK}}^{} probValue[b][*] \ge p[b] \\ probSum[b][v] \le 1 - p[b], & \text{others} \end{cases}
    • 将需要过滤的位置设置为默认无效值defLogit,得到logits_sort,记为sortedValue[b][v]:

    sortedValue[b][v]={defLogittopPMask[b][v]=falselogit_sortProb[b][v]topPMask[b][v]=truesortedValue[b][v] = \begin{cases} defLogit & \quad \text{topPMask}[b][v] = \text{false} \\ logit\_sortProb[b][v] & \quad \text{topPMask}[b][v] = \text{true} \end{cases}
    • 取过滤后sortedValue[b][v]每行中前topK个元素,查找这些元素在输入中的原始索引,整合为logits_idx:
    logitsIdx[b][v]=Index(sortedValue[b][v]Logits)logitsIdx[b][v] = Index(sortedValue[b][v] \in Logits)
    • 使用截断后的sortedValue作为logitsSortMasked:
    logitsSortMasked[b,:]=sortedValue[b]logitsSortMasked[b,:] = sortedValue[b]

    minP采样

    • 如果min_ps[b]∈(0, 1),则执行min_p采样:

      logitsMax[b]=Max(logitsSortMasked[b])\text{logitsMax}[b] = \text{Max}(\text{logitsSortMasked}[b]) minPThd=logitsMax[b]minPs[b]\text{minPThd} = \text{logitsMax}[b] * \text{minPs}[b] minPMask[b]={0,logitsSortMasked[b]<minPThd1,logitsSortMasked[b]minPThd\text{minPMask}[b] = \begin{cases} 0, & \text{logitsSortMasked}[b] < \text{minPThd} \\ 1, & \text{logitsSortMasked}[b] \geq \text{minPThd} \end{cases} logitsSortMasked[b,:]={defLogit,minPMask[b]=0logitsSortMasked[b,:],minPMask[b]=1\text{logitsSortMasked}[b,:] = \begin{cases} \text{defLogit}, & \text{minPMask}[b] = 0 \\ \text{logitsSortMasked}[b,:], & \text{minPMask}[b] = 1 \end{cases}
    • 其他情况:

      logitsSortMasked[b,:]={logitsSortMasked[b,:],if minPs[b]0max(logitsSortMasked[b,:]),if minPs[b]1\text{logitsSortMasked}[b, :] = \begin{cases} \text{logitsSortMasked}[b, :], & \text{if } minPs[b] \leq 0 \\ \max(\text{logitsSortMasked}[b, :]), & \text{if } minPs[b] \geq 1 \end{cases}

      min_ps[b]≥1时,每个batch仅取1个最大token,其余位置填充defLogit。

    可选输出

    • 如果​入参属性IsNeedLogits=True,则使用topK-topP-minP联合采样后的logitsIndexMasked,进行[object Object]输出。logitsIndex[b][v]=Index(logitsSortMasked[b][v]Logits)\text{logitsIndex}[b][v] = \text{Index}(\text{logitsSortMasked}[b][v] \in \text{Logits}) logitsIndexMasked[b,:]=logitsIndex[b,:]topKMask[b]topPMask[b]minPMask[b]\text{logitsIndexMasked}[b,:] = \text{logitsIndex}[b,:] * \text{topKMask}[b] * \text{topPMask}[b] * \text{minPMask}[b] 其中,topK、topP、minP采样环节如果被跳过,则相应mask为全1。
    • 接下来使用logitsIndexMasked对输入Logits进行Select,过滤输入Logits中的高频token作为[object Object]输出:logitsTopKpSelect[b][v]={logits[b][v],if logitsIndexMasked[b,v]=TruedefLogit,if logitsIndexMasked[b,v]=False\text{logitsTopKpSelect}[b][v] = \begin{cases} \text{logits}[b][v], & \text{if } logitsIndexMasked[b,v] = \text{True} \\ \text{defLogit}, & \text{if } logitsIndexMasked[b,v] = \text{False} \end{cases}

    后继处理

    • 此阶段输入为前序对前序topK-topP-minP采样的联合结果logitsSortMasked。

    • 此处输入须要确保logitsSortMasked∈(0,1),根据输入Logits的实际情况,配置入参约束属性inputIsLogits,即:

      inputIsLogits={True,Logits[0,1]False,Logits[0,1]\text{inputIsLogits} = \begin{cases} True, & \text{Logits} \notin [0,1] \\ False, & \text{Logits} \in [0,1] \end{cases}

      使得

      probs[b]=logitsSortMasked[b,:]\text{probs}[b] = \text{logitsSortMasked}[b, :]

      接下来有三种模式:None,QSample,输出中间结果,通过入参约束属性isNeedSampleResult和是否输入q加以控制。

    • None:

    • isNeedSampleResult为false,且不输入q时为该模式。该模式下直接对每个batch通过Argmax取最大元素和索引,并通过gatherOut输出。

      logitsSelectIdx[b]=LogitsIdx[b][ArgMax(probs[b][:])]\text{logitsSelectIdx}[b] = \text{LogitsIdx}[b]\left[\text{ArgMax}(\text{probs}[b][:])\right]
    • QSample:

    • isNeedSampleResult为false,且输入q时为该模式。该模式先对probs进行指数分布采样:

      qCnt=Sum(MinPMask==1)qCnt = \text{Sum}(\text{MinPMask} == 1) probsOpt[b]=probs[b]q[b,:qCnt]+eps\text{probsOpt}[b] = \frac{\text{probs}[b]}{q[b, :qCnt] + \text{eps}}
    • 再进行Argmax-GatherOut输出结果:

      logitsSelectIdx[b]=LogitsIdx[b][ArgMax(probsOpt[b][:])]\text{logitsSelectIdx}[b] = \text{LogitsIdx}[b][\text{ArgMax}(\text{probsOpt}[b][:])]
    • 输出中间结果:

    • isNeedSampleResult为true时,为该模式。此时会输出经过采样后的logitsSortMasked及其在输入中的原始索引logitsIdx:

      logitsSortMasked[b,v]={logitsSortMasked[b,v],if minPMask[b,v]=10,if minPMask[b,v]=0\text{logitsSortMasked}[b, v] = \begin{cases} \text{logitsSortMasked}[b, v], & \text{if } \text{minPMask}[b, v] = 1 \\ 0, & \text{if } \text{minPMask}[b, v] = 0 \end{cases} logitsIdx[b][v]=Index(logitsSortMasked[b][v])logitsIdx[b][v] = Index(logitsSortMasked[b][v])

函数原型

每个算子分为,必须先调用[object Object]接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用[object Object]接口执行计算。

[object Object]
[object Object]

aclnnTopKTopPSampleV2GetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnTopKTopPSampleV2

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

约束说明

  • 确定性计算:
    • aclnnTopKTopPSampleV2默认确定性实现。
  • 对于所有采样参数,它们的尺寸必须满足,batch>0,0<vocSize<=2^20。
  • topK只接受非负值作为合法输入;传入0和负数会跳过相应batch的采样。
  • logits、q、logitsTopKPselect、logitsIdx、logitsSortMasked的尺寸和维度必须完全一致。
  • logits、topK、topP、minPs、logitsSelectIdx、logitsIdx、logitsSortMasked除最后一维以外的所有维度必须顺序和大小完全一致。目前logits只能是2维,topK、topP、logitsSelectIdx必须是1维非空Tensor。logits、topK、topP不允许空Tensor作为输入,如需跳过相应模块,需按相应规则设置输入。
  • 如果需要单独跳过topK模块,请传入[batch, 1]大小的Tensor,并使每个元素均为无效值。
  • 如果min(ksMaxAligned, 1024)<topK[batch]<vocSize[batch],则视为选择当前batch的全部有效元素并跳过topK采样。其中ksMaxAligned为ksMax向上对齐到8的整数倍,ksMax的值域为[1, 1024]。
  • 如果需要单独跳过topP模块,请传入[batch, 1]大小的Tensor,并使每个元素均≥1。
  • 如果需要单独跳过minP模块,请传入[object Object]或者传入[batch, 1]大小的Tensor,并使每个元素均≤0。
  • 如果需要单独跳过sample模块,传入[object Object]即可;如需使用sample模块,则必须传入尺寸为[batch, vocSize]的Tensor。
  • 如果需要输出中间结果,isNeedSampleResult设为true,并且传入[object Object],此时logitsSelectIdx不输出。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]