昇腾社区首页
中文
注册

aclnnAlltoAllvGroupedMatMul

产品支持情况

产品 是否支持
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]
[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object] ×
[object Object]Atlas 200I/500 A2 推理产品[object Object] ×
[object Object]Atlas 推理系列产品 [object Object] ×
[object Object]Atlas 训练系列产品[object Object] ×

功能说明

  • 算子功能:完成路由专家AlltoAllv、Permute、GroupedMatMul融合并实现与共享专家MatMul并行融合,先通信后计算

  • 计算公式

    • 路由专家:

      ataOut=AlltoAllv(gmmX)permuteOut=Permute(ataOut)gmmY=permuteOut×gmmWeightataOut = AlltoAllv(gmmX) \\ permuteOut = Permute(ataOut) \\ gmmY = permuteOut \times gmmWeight
    • 共享专家:

      mmY=mmX×mmWeightmmY = mmX \times mmWeight

函数原型

每个算子分为undefined,必须先调用“aclnnAlltoAllvGroupedMatMulGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnAlltoAllvGroupedMatMul”接口执行计算。

  • aclnnStatus aclnnAlltoAllvGroupedMatMulGetWorkspaceSize(const aclTensor* gmmX, const aclTensor* gmmWeight, const aclTensor* sendCountsTensorOptional, const aclTensor* recvCountsTensorOptional, const aclTensor* mmXOptional, const aclTensor* mmWeightOptional, const char* group, int64_t epWorldSize, const aclIntArray* sendCounts, const aclIntArray* recvCounts, bool transGmmWeight, bool transMmWeight, bool permuteOutFlag, aclTensor* gmmY, aclTensor* mmYOptional, aclTensor* permuteOutOptional, uint64_t* workspaceSize, aclOpExecutor** executor)
  • aclnnStatus aclnnAlltoAllvGroupedMatMul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)

aclnnAlltoAllvGroupedMatMulGetWorkspaceSize

  • 参数说明:

    • gmmX(aclTensor*,计算输入):该输入进行AlltoAllv通信与Permute操作后结果作为GroupedMatMul计算的左矩阵。数据类型支持FLOAT16、BFLOAT16,支持2维,shape为(BSK, H1),数据格式支持ND。
    • gmmWeight(aclTensor*,计算输入):GroupedMatMul计算的右矩阵。数据类型与gmmX保持一致,支持3维,shape为(e, H1, N1),数据格式支持ND。
    • sendCountsTensorOptional(aclTensor*,计算输入):可选输入,数据类型支持INT32、INT64,shape为(e * epWorldSize,),数据格式支持ND。当前版本暂不支持,传nullptr
    • recvCountsTensorOptional(aclTensor*,计算输入):可选输入,数据类型支持INT32、INT64,shape为(e * epWorldSize,),数据格式支持ND。当前版本暂不支持,传nullptr
    • mmXOptional(aclTensor*,计算输入):可选输入,共享专家MatMul计算中的左矩阵。当需要融合共享专家矩阵计算时,该参数必选,支持2维,shape为(BS, H2)。
    • mmWeightOptional(aclTensor*,计算输入):可选输入,共享专家MatMul计算中的右矩阵。当需要融合共享专家矩阵计算时,该参数必选,支持2维,shape为(H2, N2)。
    • group(char*,计算输入):专家并行的通信域名,字符串长度要求(0, 128)。
    • epWorldSize(int64_t,计算输入):
      • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:ep通信域size,取值支持8、16、32、64。
    • sendCounts(aclIntArray*,计算输入):表示发送给其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。
    • recvCounts(aclIntArray*,计算输入):表示接收其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。
    • transGmmWeight(bool, 计算输入):GroupedMatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
    • transMmWeight(bool, 计算输入):共享专家MatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
    • permuteOutFlag(bool, 计算输入):permuteOutOptional是否需要输出,true表明需要输出,false表明不需要输出。
    • gmmY(aclTensor*, 计算输出):表示最终的计算结果,数据类型与输入gmmX保持一致,支持2维,shape为(A, N1)。
    • mmYOptional(aclTensor*, 计算输出):共享专家MatMul的输出,数据类型与mmXOptional保持一致,支持2维,shape为(BS, N2)。仅当传入mmXOptional与mmWeightOptional才输出。
    • permuteOutOptional(aclTensor*, 计算输出):permute之后的输出,数据类型与gmmX保持一致。
    • workspaceSize(uint64_t*, 出参):返回需要在Device侧申请的workspace大小。
    • executor(aclOpExecutor**, 出参):返回op执行器,包含了算子计算流程。
  • 返回值:

    aclnnStatus:返回状态码,具体参见undefined

    [object Object]

aclnnAlltoAllvGroupedMatMul

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnAlltoAllvGroupedMatMulGetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的Stream。
  • 返回值:

    返回aclnnStatus状态码,具体参见undefined

约束说明

  • 参数说明里shape使用的变量:
    • BSK:本卡发送的token数,是sendCounts参数累加之和,取值范围(0, 52428800)。
    • H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
    • H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
    • e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
    • N1:表示路由专家的head_num,取值范围(0, 65536)。
    • N2:表示共享专家的head_num,取值范围(0, 65536)。
    • BS:batch sequence size。
    • K:表示选取TopK个专家,K的范围[2, 8]。
    • A:本卡收到的token数,是recvCounts参数累加之和。
    • ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
  • 单卡通信量要求在2MB到100MB范围内。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考undefined