aclnnAlltoAllvGroupedMatMul
产品支持情况
产品 | 是否支持 |
---|---|
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object] | √ |
[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object] | × |
[object Object]Atlas 200I/500 A2 推理产品[object Object] | × |
[object Object]Atlas 推理系列产品 [object Object] | × |
[object Object]Atlas 训练系列产品[object Object] | × |
功能说明
算子功能:完成路由专家AlltoAllv、Permute、GroupedMatMul融合并实现与共享专家MatMul并行融合,先通信后计算。
计算公式:
路由专家:
共享专家:
函数原型
每个算子分为undefined,必须先调用“aclnnAlltoAllvGroupedMatMulGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnAlltoAllvGroupedMatMul”接口执行计算。
aclnnStatus aclnnAlltoAllvGroupedMatMulGetWorkspaceSize(const aclTensor* gmmX, const aclTensor* gmmWeight, const aclTensor* sendCountsTensorOptional, const aclTensor* recvCountsTensorOptional, const aclTensor* mmXOptional, const aclTensor* mmWeightOptional, const char* group, int64_t epWorldSize, const aclIntArray* sendCounts, const aclIntArray* recvCounts, bool transGmmWeight, bool transMmWeight, bool permuteOutFlag, aclTensor* gmmY, aclTensor* mmYOptional, aclTensor* permuteOutOptional, uint64_t* workspaceSize, aclOpExecutor** executor)
aclnnStatus aclnnAlltoAllvGroupedMatMul(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)
aclnnAlltoAllvGroupedMatMulGetWorkspaceSize
参数说明:
- gmmX(aclTensor*,计算输入):该输入进行AlltoAllv通信与Permute操作后结果作为GroupedMatMul计算的左矩阵。数据类型支持FLOAT16、BFLOAT16,支持2维,shape为(BSK, H1),数据格式支持ND。
- gmmWeight(aclTensor*,计算输入):GroupedMatMul计算的右矩阵。数据类型与gmmX保持一致,支持3维,shape为(e, H1, N1),数据格式支持ND。
- sendCountsTensorOptional(aclTensor*,计算输入):可选输入,数据类型支持INT32、INT64,shape为(e * epWorldSize,),数据格式支持ND。当前版本暂不支持,传nullptr。
- recvCountsTensorOptional(aclTensor*,计算输入):可选输入,数据类型支持INT32、INT64,shape为(e * epWorldSize,),数据格式支持ND。当前版本暂不支持,传nullptr。
- mmXOptional(aclTensor*,计算输入):可选输入,共享专家MatMul计算中的左矩阵。当需要融合共享专家矩阵计算时,该参数必选,支持2维,shape为(BS, H2)。
- mmWeightOptional(aclTensor*,计算输入):可选输入,共享专家MatMul计算中的右矩阵。当需要融合共享专家矩阵计算时,该参数必选,支持2维,shape为(H2, N2)。
- group(char*,计算输入):专家并行的通信域名,字符串长度要求(0, 128)。
- epWorldSize(int64_t,计算输入):
- [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]:ep通信域size,取值支持8、16、32、64。
- sendCounts(aclIntArray*,计算输入):表示发送给其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。
- recvCounts(aclIntArray*,计算输入):表示接收其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。
- transGmmWeight(bool, 计算输入):GroupedMatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
- transMmWeight(bool, 计算输入):共享专家MatMul的右矩阵是否需要转置,true表示需要转置,false表示不转置。
- permuteOutFlag(bool, 计算输入):permuteOutOptional是否需要输出,true表明需要输出,false表明不需要输出。
- gmmY(aclTensor*, 计算输出):表示最终的计算结果,数据类型与输入gmmX保持一致,支持2维,shape为(A, N1)。
- mmYOptional(aclTensor*, 计算输出):共享专家MatMul的输出,数据类型与mmXOptional保持一致,支持2维,shape为(BS, N2)。仅当传入mmXOptional与mmWeightOptional才输出。
- permuteOutOptional(aclTensor*, 计算输出):permute之后的输出,数据类型与gmmX保持一致。
- workspaceSize(uint64_t*, 出参):返回需要在Device侧申请的workspace大小。
- executor(aclOpExecutor**, 出参):返回op执行器,包含了算子计算流程。
返回值:
aclnnStatus:返回状态码,具体参见undefined。
[object Object]
aclnnAlltoAllvGroupedMatMul
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnAlltoAllvGroupedMatMulGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的Stream。
返回值:
返回aclnnStatus状态码,具体参见undefined。
约束说明
- 参数说明里shape使用的变量:
- BSK:本卡发送的token数,是sendCounts参数累加之和,取值范围(0, 52428800)。
- H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
- e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
- N1:表示路由专家的head_num,取值范围(0, 65536)。
- N2:表示共享专家的head_num,取值范围(0, 65536)。
- BS:batch sequence size。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:本卡收到的token数,是recvCounts参数累加之和。
- ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
- 单卡通信量要求在2MB到100MB范围内。
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考undefined。