昇腾社区首页
中文
注册
开发者
下载

aclnnMoeTokenUnpermuteWithRoutingMap

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:对经过aclnnMoeTokenPermuteWithRoutingMap处理的permutedTokens,累加回原unpermutedTokens。根据sortedIndices存储的下标,获取permutedTokens中存储的输入数据;如果存在probs数据,permutedTokens会与probs相乘,最后进行累加求和,并输出计算结果。

  • 计算公式:

    topK_num=permutedTokens.size(0)//routingMapOptional.size(0)topK\_num= permutedTokens.size(0) // routingMapOptional.size(0) numExperts=probs.size(1)numExperts = probs.size(1) numTokens=probs.size(0)numTokens = probs.size(0) capacity=sortedIndices.size(0)//numExpertscapacity = sortedIndices.size(0) // numExperts

    (1)probs不为None,paddedMode为true时:

    permuteProbs[i//capacity,sortedIndices[i]]=probs[i]permuteProbs [i//capacity,sortedIndices[i]]=probs[i] permutedTokens=permutedTokenspermuteProbspermutedTokens = permutedTokens * permuteProbs unpermutedTokens=zeros(restoreShape,dtype=permutedTokens.dtype,device=permutedTokens.device)unpermutedTokens= zeros(restoreShape, dtype=permutedTokens.dtype, device=permutedTokens.device) permuteTokenId,outIndex=sortedIndices.sort(dim=1)permuteTokenId, outIndex= sortedIndices.sort(dim=-1) unpermutedTokens[permuteTokenId[i]]+=permutedTokens[outIndex[i]]unpermutedTokens[permuteTokenId[i]] += permutedTokens[outIndex[i]]

    (2)probs不为None,paddedMode为false时(T为转置操作):

    permuteProbs=probs.T.maskedSelect(routingMap.T)permuteProbs = probs.T.maskedSelect(routingMap.T) permutedTokens=permutedTokenspermuteProbspermutedTokens = permutedTokens * permuteProbs unpermutedTokens=zeros(restoreShape,dtype=permutedTokens.dtype,device=permutedTokens.device)unpermutedTokens= zeros(restoreShape, dtype=permutedTokens.dtype, device=permutedTokens.device) unpermutedTokens[i//topK_num]+=permutedTokens[sortedIndices[i]]unpermutedTokens[i//topK\_num] += permutedTokens[sortedIndices[i]]

    (3)probs为None,paddedMode为true时:

    permuteTokenId,outIndex=sortedIndices.sort(dim=1)permuteTokenId, outIndex= sortedIndices.sort(dim=-1) unpermutedTokens[permuteTokenId[i]]+=permutedTokens[outIndex[i]]unpermutedTokens[permuteTokenId[i]] += permutedTokens[outIndex[i]]

    (4)probs为None,paddedMode为false时:

    unpermutedTokens[i//topK_num]+=permutedTokens[sortedIndices[i]]unpermutedTokens[i//topK\_num] += permutedTokens[sortedIndices[i]]

函数原型

每个算子分为,必须先调用“aclnnMoeTokenUnpermuteWithRoutingMapGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeTokenUnpermuteWithRoutingMap”接口执行计算。

[object Object]
[object Object]

aclnnMoeTokenUnpermuteWithRoutingMapGetWorkspaceSize

  • 参数说明

    [object Object]
  • 返回值

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnMoeTokenUnpermuteWithRoutingMap

  • 参数说明

    [object Object]
  • 返回值

    aclnnStatus:返回状态码,具体参见

约束说明

  • 确定性计算:

    • aclnnMoeTokenUnpermuteWithRoutingMap默认确定性实现。
  • topkNum <= 512, pad模式为false时routingMap中每行为1或true的个数固定且小于[object Object]

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]