昇腾社区首页
中文
注册
开发者
下载

aclnnFusedLinearOnlineMaxSum

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:

    功能等价Megatron的matmul与fused_vocab_parallel_cross_entropy的实现,支持vocabulary_size维度切卡融合matmul与celoss,中间根据通信拆分为,需要依次调用实现完整功能。

  • 计算公式:

    1. inputinputweightTweight^T做矩阵乘得到:

      vocabParallelLogitsOutOptional=input@weightTvocabParallelLogitsOutOptional = input @ weight^T
    2. 计算vocabParallelLogitsOutOptionalvocabParallelLogitsOutOptional每行的最大值:

      logitsMaxLocalOut=max(vocabParallelLogitsOutOptional,dim=1)logitsMaxLocalOut = max(vocabParallelLogitsOutOptional, dim=-1)
    3. 计算vocabParallelLogitsOutOptionalvocabParallelLogitsOutOptionallogitsMaxLocalOutlogitsMaxLocalOut的差值:

      subRes[b][n]=vocabParallelLogitsOutOptional[b][n]logitsMaxLocalOut[b]subRes[b][n] = vocabParallelLogitsOutOptional[b][n] - logitsMaxLocalOut[b]
    4. 计算subRessubRes经过指数运算后每行的和

      sumExpLogitsLocalOut=sum(exp(subRes),dim=1)sumExpLogitsLocalOut = sum(exp(subRes), dim=-1)
    5. 计算targettarget小于vocabStartIndexvocabStartIndextargettarget大于vocabEndIndexvocabEndIndex的mask

      targetMask=(target<vocabStartIndex)(target>vocabEndIndex)targetMask = (target < vocabStartIndex) | (target > vocabEndIndex)
    6. 计算maskedTargetOutmaskedTargetOut

      maskedTargetOut[b]={0targetMask[b]=truetarget[b]vocabStartIndextargetMask[b]=falsemaskedTargetOut[b] = \begin{cases} 0 & \text{targetMask[b]=true}\\ target[b] - vocabStartIndex & \text{targetMask[b]=false} \end{cases}
    7. 计算predictedLogitsLocalOutpredictedLogitsLocalOut

      predictedLogitsLocalOut[b]={0targetMask[b]=truesubRes[b][maskedTargetOut[b]]targetMask[b]=falsepredictedLogitsLocalOut[b] = \begin{cases} 0 & \text{targetMask[b]=true}\\ subRes[b][maskedTargetOut[b]] & \text{targetMask[b]=false} \end{cases}
    8. 计算targetMaskOuttargetMaskOut

      alignNum=(input.size(0)+7)/88maskBit[p]={uint8(targetMask[p])p < input.size(0)1input.size(0) <= p < alignNumtargetMaskOut[k]=0b(maskBit[8k:8k+8])alignNum = (input.size(0) + 7) / 8 * 8\\ maskBit[p] = \begin{cases} uint8(targetMask[p]) & \text{p < input.size(0)}\\ 1 & \text{input.size(0) <= p < alignNum} \end{cases} \\ targetMaskOut[k] = 0b(maskBit[8*k:8*k+8])

    其中0b<input.size(0),0n<weight.size(0),0p<alignNum,0k<alignNum/80 \le b \lt input.size(0), 0 \le n \lt weight.size(0), 0 \le p \lt alignNum, 0 \le k \lt alignNum / 8

函数原型

每个算子分为,必须先调用“aclnnFusedLinearOnlineMaxSumGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnFusedLinearOnlineMaxSum”接口执行计算。

[object Object]
[object Object]

aclnnFusedLinearOnlineMaxSumGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

    第一段接口会完成入参校验,出现以下场景时报错:

    [object Object]

aclnnFusedLinearOnlineMaxSum

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

约束说明

  • 确定性说明:
    • [object Object]Atlas 训练系列产品[object Object]、[object Object]Atlas 推理系列产品[object Object]:aclnnFusedLinearOnlineMaxSum默认确定性实现。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]