昇腾社区首页
中文
注册

aclnnMoeInitRoutingV2

支持的产品型号

  • Atlas A2训练系列产品/Atlas 800I A2推理产品

接口原型

每个算子分为,必须先调用 “aclnnMoeInitRoutingV2GetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeInitRoutingV2”接口执行计算。

  • aclnnStatus aclnnMoeInitRoutingV2GetWorkspaceSize(const aclTensor *x, const aclTensor *expertIdx, int64_t activeNum, int64_t expertCapacity, int64_t expertNum, int64_t dropPadMode, int64_t expertTokensCountOrCumsumFlag, bool expertTokensBeforeCapacityFlag, aclTensor *expandedXOut, aclTensor *expandedRowIdxOut, aclTensor *expertTokensCountOrCumsum, aclTensor *expertTokensBeforeCapacity, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnMoeInitRoutingV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

功能描述

  • 算子功能:MoE的routing计算,根据的计算结果做routing处理。

    本接口针对做了如下功能变更,请根据实际情况选择合适的接口:

    • 新增drop模式,在该模式下输出内容会根据每个专家的expertCapacity处理,超过expertCapacity不做处理,不足的会补0。
    • 新增dropless模式下expertTokensCountOrCumsum可选输出,drop场景下expertTokensBeforeCapacity可选输出。
    • 删除rowIdx输入。
  • 计算公式

    1.对输入expertIdx做排序,得出排序后的结果sortedExpertIdx和对应的序号sortedRowIdx

    sortedExpertIdx,sortedRowIdx=keyValueSort(expertIdx)sortedExpertIdx, sortedRowIdx=keyValueSort(expertIdx)

    2.以sortedRowIdx做位置映射得出expandedRowIdx

    expandedRowIdx[sortedRowIdx[i]]=iexpandedRowIdx[sortedRowIdx[i]]=i

    3.对x取前numRows个sortedRowIdx的对应位置的值,得出expandedX

    expandedX[i]=x[sortedRowIdx[i]%numRows]expandedX[i]=x[sortedRowIdx[i]\%numRows]

    4.对sortedExpertIdx的每个专家统计直方图结果,再进行Cumsum,得出expertTokensCountOrCumsum

    expertTokensCountOrCumsum[i]=Cumsum(Histogram(sortedExpertIdx))expertTokensCountOrCumsum[i]=Cumsum(Histogram(sortedExpertIdx))

    5.对sortedExpertIdx的每个专家统计直方图结果,得出expertTokensBeforeCapacity

    expertTokensBeforeCapacity[i]=Histogram(sortedExpertIdx)expertTokensBeforeCapacity[i]=Histogram(sortedExpertIdx)

aclnnMoeInitRoutingV2GetWorkspaceSize

  • 参数说明

    • x(aclTensor*,计算输入):MOE的输入即token特征输入,要求为一个2D的Tensor,shape为 (NUM_ROWS, H),H代表每个Token的长度,数据类型支持FLOAT16、BFLOAT16、FLOAT32,要求为ND。
    • expertIdx (aclTensor*,计算输入):的输出每一行特征对应的K个处理专家,要求是一个2D的shape (NUM_ROWS, K)。数据类型支持int32,要求为ND。
    • activeNum(int64_t,计算输入):表示总的最大处理row数,expandedXOut只有这么多行是有效的,0表示不使用该参数,该值必须大于等于0。
    • expertCapacity(int64_t, 计算输入):表示每个专家能够处理的row数,超过capacity的row不会处理, 0表示不使用该参数,该值必须大于等于0。
    • expertNum(int64_t, 计算输入):表示专家数,0表示不使用该参数,该值必须大于等于0。
    • dropPadMode(int64_t, 计算输入):表示使用不同的场景,取值为0和1,0代表dropless场景,该场景下不校验expertCapacity;1代表drop场景,需要校验expertNum和expertCapacity,对于每个专家处理的超过和不足expertCapacity的值会做相应的处理。默认值为0。
    • expertTokensCountOrCumsumFlag(int64_t, 计算输入):取值为0、1和2,0代表不输出expertTokensCountOrCumsum,1代表输出的值为各个专家处理的token数量的累计值,2代表输出的值为各个专家处理的token数量。
    • expertTokensBeforeCapacityFlag(bool,计算输入):取值为0和1,0代表不输出expertTokensBeforeCapacity,1代表输出的值为在drop之前各个专家处理的token数量。
    • expandedXOut(aclTensor*,计算输出):根据expertIdx进行扩展过的特征,在dropless场景下要求是一个2D的Tensor,shape (min(NUM_ROWS * k, activeNum), H),在drop场景下要求是一个3D的Tensor,shape(expertNum, expertCapacity, H)。数据类型同x,支持FLOAT16、BFLOAT16、FLOAT32,要求为ND。
    • expandedRowIdxOut(aclTensor*,计算输出):expandedXOut和x的映射关系, 要求是一个1D的Tensor,Shape为(NUM_ROWS*K, ),数据类型支持int32,要求为ND。
    • expertTokensCountOrCumsum(aclTensor*,计算输出):输出每个专家处理的token数量的统计结果及累加值,通过expertTokensCountOrCumsumFlag参数控制是否输出,该值仅在dropless场景下输出,要求是一个1D的Tensor,Shape为(expertNum, ),数据类型支持int32,要求为ND。
    • expertTokensBeforeCapacity(aclTensor*,计算输出):输出drop之前每个专家处理的token数量的统计结果,通过expertTokensBeforeCapacityFlag参数控制是否输出,该值仅在drop场景下输出,要求是一个1D的Tensor,Shape为(expertNum, ),数据类型支持int32,要求为ND。
    • workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。
    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
  • 返回值

    返回aclnnStatus状态码,具体参见

    [object Object]

aclnnMoeInitRoutingV2

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMoeInitRoutingV2GetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束与限制

无。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]