aclnnMoeDistributeCombineAddRmsNorm
支持的产品型号
Atlas A3 训练系列产品/Atlas A3 推理系列产品
功能说明
算子功能:当存在TP域通信时,先进行ReduceScatterV通信,再进行AlltoAllV通信,最后将接收的数据整合(乘权重再相加);当不存在TP域通信时,进行AlltoAllV通信,最后将接收的数据整合(乘权重再相加),之后完成Add + RmsNorm融合。
函数原型
每个算子分为两段式接口,必须先调用 “aclnnMoeDistributeCombineAddRmsNormGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeDistributeCombineAddRmsNorm”接口执行计算。
aclnnStatus aclnnMoeDistributeCombineAddRmsNormGetWorkspaceSize(const aclTensor* expandX, const aclTensor* expertIds, const aclTensor* assitInfoForCombine, const aclTensor* epSendCounts, const aclTensor* expertScales, const aclTensor* tpSendCountsOptional, const aclTensor* xActiveMaskOptional, const aclTensor* activationScaleOptional, const aclTensor* weightScaleOptional, const aclTensor* groupListOptional, const aclTensor* expandScalesOptional, const aclTensor* sharedExpertXOptional, const char* groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, const char* groupTp, int64_t tpWorldSize, int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t globalBs, int64_t outDtype, int64_t commQuantMode, int64_t groupListType, const char* commAlg, aclTensor* xOut, uint64_t* workspaceSize, aclOpExecutor** executor)
aclnnStatus aclnnMoeDistributeCombineAddRmsNorm(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
计算公式:
注意该接口必须与aclnnMoeDistributeDispatchV2配套使用,相当于按MoeDistributeDispatchV2算子收集数据的路径原路返还。
aclnnMoeDistributeCombineAddRmsNormGetWorkspaceSize
参数说明:
- expandX(aclTensor*,计算输入):根据expertIds进行扩展过的token特征,Device侧的aclTensor,要求为一个2D的Tensor,shape为 (max(tpWorldSize, 1) * A , H),数据类型支持BFLOAT16,数据格式要求为ND,支持非连续的Tensor。
- expertIds(aclTensor*,计算输入):每个token的topK个专家索引,Device侧的aclTensor,要求为一个2D的Tensor,shape为 (Bs, K)。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor。
- assitInfoForCombine(aclTensor*,计算输入):对应aclnnMoeDistributeDispatchV2中的assitInfoForCombineOut输出,Device侧的aclTensor,要求是一个1D的Tensor,数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor。shape为 (A * 128, )。
- epSendCounts(aclTensor*,计算输入):对应aclnnMoeDistributeDispatchV2中的epRecvCounts输出,Device侧的aclTensor,要求是一个1D的Tensor。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor。shape为 (epWorldSize * max(tpWorldSize, 1) * localExpertNum, )。
- expertScales(aclTensor*,计算输入):每个token的topK个专家的权重,Device侧的aclTensor,要求是一个2D的Tensor,shape为 (Bs, K)。类型支持FLOAT32,数据格式要求为ND,支持非连续的Tensor。
- residualX(aclTensor*,计算输入):AddRmsNorm中Add的右矩阵,Device侧的aclTensor,要求是一个3D的Tensor,shape为(Bs,1,H)。类型为BFLOAT16,数据格式要求为ND,支持非连续的Tensor。
- gamma(aclTensor*,计算输入):RmsNorm中的gamma输入,Device侧的aclTensor,要求是一个1D的Tensor,shape为(H, )。类型为BFLOAT16,数据格式要求为ND,支持非连续的Tensor。
- tpSendCountsOptional(aclTensor*,计算输入):对应aclnnMoeDistributeDispatchV2中的tpRecvCounts输出,Device侧的aclTensor,若有TP域通信需要传参,若无TP域通信,传空指针即可。当有TP域通信时,要求是一个1D的Tensor,shape为 (tpWorldSize, )。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor。
- xActiveMaskOptional(aclTensor*,计算输入):表示token是否参与通信,Device侧的aclTensor,是一个1D的Tensor,数据类型为BOOL,shape为 (Bs, ),可传\可不传,不传时表示所有token都参与通信。数据格式要求为ND,支持非连续的Tensor。
- activationScaleOptional(aclTensor*,计算输入):Device侧的aclTensor,预留参数,当前版本不支持,传空指针即可。
- weightScaleOptional(aclTensor*,计算输入):Device侧的aclTensor,预留参数,当前版本不支持,传空指针即可。
- groupListOptional(aclTensor*,计算输入):Device侧的aclTensor,预留参数,当前版本不支持,传空指针即可。
- expandScalesOptional(aclTensor*,计算输入):对应aclnnMoeDistributeDispatchV2中的expandScales输出,Device侧的aclTensor。预留参数,当前版本不支持,传空指针即可。
- sharedExpertXOptional(aclTensor*, 计算输入):表示共享专家计算后的Token,Device侧的aclTensor。要求是一个2D或3D的Tensor,当Tensor为2D时,shape为 (Bs, H);当Tensor为3D时,shape为 (Bs, 1, H)。数据类型需跟expandX保持一致。可传/可不传。数据格式要求为ND,支持非连续的Tensor。
- groupEp(char*,计算输入):EP通信域名称,专家并行的通信域,string数据类型。字符串长度范围为[1, 128),不能和groupTp相同。
- epWorldSize(int64_t,计算输入):EP通信域size,数据类型支持INT64。取值支持(1, 384]。
- epRankId(int64_t,计算输入):EP域本卡Id,数据类型支持INT64,取值范围[0, epWorldSize)。同一个EP通信域中各卡的epRankId不重复。
- moeExpertNum(int64_t,计算输入): MoE专家数量,数据类型支持INT64,取值范围(0, 512],并且满足moeExpertNum % (epWorldSize - sharedExpertRankNum) = 0。
- groupTp(char*,计算输入):TP通信域名称,数据并行的通信域,string数据类型。若有TP域通信需要传参,若无TP域通信,传空字符即可。当有TP域通信时,字符串长度范围为[1, 128),不能和groupEp相同。
- tpWorldSize(int64_t,计算输入):TP通信域size,int数据类型。取值范围[0, 2],0和1表示无tp域通信,有tp域通信时仅支持2。
- tpRankId(int64_t,计算输入):TP域本卡Id,数据类型支持INT64。取值范围[0, 1],同一个TP通信域中各卡的tpRankId不重复。无TP域通信时,传0即可。
- expertShardType(int64_t,计算输入):表示共享专家卡分布类型,数据类型支持INT64。当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
- sharedExpertNum(int64_t,计算输入):表示共享专家数量,当前版本不支持,传0即可。
- sharedExpertRankNum(int64_t,计算输入):表示共享专家卡数量,当前版本不支持,仅支持传入0。
- globalBs(int64_t,计算输入):EP域全局的batch size大小,数据类型支持INT64。当每个rank的Bs数一致场景下,globalBs = Bs * epWorldSize 或 globalBs = 0;当每个rank的Bs数不一致场景下,globalBs = maxBs * epWorldSize,其中maxBs表示单卡Bs最大值。
- outDtype(int64_t,计算输入):用于指定输出x的数据类型,预留参数,当前版本不支持,传0即可。
- commQuantMode(int64_t,计算输入):通信量化类型。当前版本不支持,传0即可。
- groupListType(int64_t,计算输入):group List格式,预留参数,当前版本不支持,传0即可。
- commAlg(char*,计算输入):表示通信亲和内存布局算法,string数据类型,预留字段,当前版本不支持,传入空指针即可。
- normEps(float*,计算输入):用于防止AddRmsNorm除0错误,数据类型为float,可以为1e-6。
- yOut(aclTensor*,计算输出):RmsNorm后的输出结果,Device侧的aclTensor,要求是一个3D的Tensor,shape为(Bs,1,H),数据类型、数据格式与residualX保持一致。
- rstdOut(aclTensor*,计算输出):RmsNorm后的输出结果,Device侧的aclTensor,要求是一个3D的Tensor,shape为(Bs,1,1),数据类型支持FLOAT32,数据格式支持ND。
- xOut(aclTensor*,计算输出):Add后的输出结果,Device侧的aclTensor,要求是一个3D的Tensor,shape为(Bs, 1,H),数据类型、数据格式与residualX保持一致。
- workspaceSize(uint64_t*,出参):返回需要在Device侧申请的workspace大小。
- executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
返回值
返回aclnnStatus状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错: 161001(ACLNN_ERR_PARAM_NULLPTR): 1. 输入和输出的必选参数Tensor是空指针。 161002(ACLNN_ERR_PARAM_INVALID): 1. 输入和输出的数据类型不在支持的范围内。 561002(ACLNN_ERR_INNER_TILING_ERROR): 1. 输入和输出的shape不在支持的范围内。 2. 参数的取值不在支持的范围。
aclnnMoeDistributeCombineAddRmsNorm
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMoeDistributeCombineAddRmsNormGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
约束说明
aclnnMoeDistributeDispatchV2接口与aclnnMoeDistributeCombineAddRmsNorm接口必须配套使用,具体参考调用示例。
调用接口过程中使用的groupEp、epWorldSize、moeExpertNum、groupTp、tpWorldSize、expertShardType、sharedExpertNum、sharedExpertRankNum、globalBs参数取值所有卡需保持一致,网络中不同层中也需保持一致,且和aclnnMoeDistributeDispatchV2对应参数也保持一致。
Atlas A3 训练系列产品/Atlas A3 推理系列产品 :该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。参数说明里shape格式说明:
- A:表示本卡需要分发的最大token数量,取值范围如下:
- 对于共享专家,要满足A = Bs * epWorldSize * sharedExpertNum / sharedExpertRankNum。
- 对于MoE专家,当globalBs为0时,要满足A >= Bs * epWorldSize * min(localExpertNum, K);当globalBs非0时,要满足A >= globalBs * min(localExpertNum, K)。
- H:表示hidden size隐藏层大小,取值范围为[1024, 7168],且需要保证32的整数倍。
- Bs:表示batch sequence size,即本卡最终输出的token数量,取值范围为0 < Bs ≤ 512。
- K:表示选取topK个专家,取值范围为0 < K ≤ 16同时满足0 < K ≤ moeExpertNum。
- localExpertNum:表示本卡专家数量。
- 对于共享专家卡,localExpertNum = 1
- 对于MoE专家卡,localExpertNum = moeExpertNum / (epWorldSize - sharedExpertRankNum),localExpertNum > 1时,不支持TP域通信。
- A:表示本卡需要分发的最大token数量,取值范围如下:
HCCL_BUFFSIZE: 调用本接口前需检查HCCL_BUFFSIZE环境变量取值是否合理,该环境变量表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。要求 >= 2且满足1024 ^ 2 * (HCCL_BUFFSIZE - 2) / 2 >= BS * 2 * (H + 128) * (epWorldSize * localExpertNum + K + 1),localExpertNum需使用MoE专家卡的本卡专家数。
通信域使用约束:
- 一个模型中的aclnnMoeDistributeCombineAddRmsNorm和aclnnMoeDistributeDispatchV2仅支持相同EP通信域,且该通信域中不允许有其他算子。
- 一个模型中的aclnnMoeDistributeCombineAddRmsNorm和aclnnMoeDistributeDispatchV2仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。
调用示例
以
文件准备:
1.新建combineAddRmsNormDemo目录,按照下方指导在combineAddRmsNormDemo下新建aclnnCombineAddRmsNormDemo.cpp,buildCombineAddRmsNorm.sh,文件并修改。 2.将combineAddRmsNormDemo项目拷贝到服务器中。 3.安装cann包,并根据下方指导编译运行combineAddRmsNormDemo。编译脚本
#!/bin/bash cann_path="/path/to/cann_env" # 更改cann包环境的路径 g++ "aclnnCombineAddRmsNormDemo.cpp" -o combineAddRmsNormDemo -I"$cann_path/latest/include/" -I"$cann_path/latest/include/aclnnop/" \ -L="$cann_path/latest/lib64/" -lascendcl -lnnopbase -lopapi -lop_common -lpthread -lhccl
编译与运行:
# source cann环境 source /path/to/cann_env/latest/bin/setenv.bash # 编译aclnnCombineAddRmsNormDemo.cpp bash buildCombineAddRmsNorm.sh ./combineAddRmsNormDemo
示例代码如下,仅供参考
#include <thread> #include <iostream> #include <string> #include <vector> #include "acl/acl.h" #include "hccl/hccl.h" #include "aclnnop/aclnn_moe_distribute_dispatch_v2.h" #include "aclnnop/aclnn_moe_distribute_combine_add_rms_norm.h" #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while(0) struct Args { uint32_t rankId; uint32_t epRankId; uint32_t tpRankId; HcclComm hcclEpComm; HcclComm hcclTpComm; aclrtStream dispatchStream; aclrtStream combineStream; aclrtContext context; }; constexpr uint32_t EP_WORLD_SIZE = 8; constexpr uint32_t TP_WORLD_SIZE = 2; constexpr uint32_t DEV_NUM = EP_WORLD_SIZE * TP_WORLD_SIZE; int64_t GetShapeSize(const std::vector<int64_t> &shape) { int64_t shape_size = 1; for (auto i : shape) { shape_size *= i; } return shape_size; } template<typename T> int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr, aclDataType dataType, aclTensor **tensor) { auto size = GetShapeSize(shape) * sizeof(T); auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret); ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret); std::vector<int64_t> strides(shape.size(), 1); for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = shape[i + 1] * strides[i + 1]; } *tensor = aclCreateTensor( shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } int LaunchOneProcessDispatchAndCombine(Args &args) { int ret = aclrtSetCurrentContext(args.context); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret); char hcomEpName[128] = {0}; ret = HcclGetCommName(args.hcclEpComm, hcomEpName); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1); char hcomTpName[128] = {0}; ret = HcclGetCommName(args.hcclTpComm, hcomTpName); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetTpCommName failed, ret %d\n", ret); return -1); LOG_PRINT( "[INFO] rank = %d, hcomEpName = %s, hcomTpName = %s, dispatchStream = %p, combineStream = %p, context = %p\n", args.rankId, hcomEpName, hcomTpName, args.dispatchStream, args.combineStream, args.context ); int64_t BS = 8; int64_t H = 7168; int64_t K = 3; int64_t expertShardType = 0; int64_t sharedExpertNum = 0; int64_t sharedExpertRankNum = 0; int64_t moeExpertNum = 8; int64_t quantMode = 0; int64_t globalBS = BS * EP_WORLD_SIZE; int64_t expertTokenNumsType = 1; int64_t outDtype = 0; int64_t commQuantMode = 0; int64_t groupList_type = 1; int64_t localExpertNum; int64_t A; if (args.epRankId < sharedExpertRankNum) { // 共享专家卡 localExpertNum = 1; A = globalBS / sharedExpertRankNum; } else { // Moe专家卡 localExpertNum = moeExpertNum / (EP_WORLD_SIZE - sharedExpertRankNum); A = globalBS * (localExpertNum < K ? localExpertNum : K); } /* 根据当前场景,构造device侧输入输出变量 */ // 声明device侧输入输出变量 void *xDeviceAddr = nullptr; void *expertIdsDeviceAddr = nullptr; void *scalesDeviceAddr = nullptr; void *expertScalesDeviceAddr = nullptr; void *expandXDeviceAddr = nullptr; void *dynamicScalesDeviceAddr = nullptr; void *expandIdxDeviceAddr = nullptr; void *expertTokenNumsDeviceAddr = nullptr; void *epRecvCountsDeviceAddr = nullptr; void *tpRecvCountsDeviceAddr = nullptr; void *expandScalesDeviceAddr = nullptr; void *residualXDeviceAddr = nullptr; void *sharedExpertXDeviceAddr = nullptr; void *gammaDeviceAddr = nullptr; void *yOutDeviceAddr = nullptr; void *rstdOutDeviceAddr = nullptr; void *xOutDeviceAddr = nullptr; aclTensor *x = nullptr; aclTensor *expertIds = nullptr; aclTensor *scales = nullptr; aclTensor *expertScales = nullptr; aclTensor *expandX = nullptr; aclTensor *dynamicScales = nullptr; aclTensor *expandIdx = nullptr; aclTensor *expertTokenNums = nullptr; aclTensor *epRecvCounts = nullptr; aclTensor *tpRecvCounts = nullptr; aclTensor *expandScales = nullptr; aclTensor *residualX = nullptr; aclTensor *sharedExpertX = nullptr; aclTensor *gamma = nullptr; aclTensor *yOut = nullptr; aclTensor *rstdOut = nullptr; aclTensor *xOut = nullptr; // 定义当前场景下各变量维度 std::vector<int64_t> xShape{BS, H}; std::vector<int64_t> expertIdsShape{BS, K}; std::vector<int64_t> scalesShape{(sharedExpertRankNum > 0) ? moeExpertNum + 1 : moeExpertNum, H}; std::vector<int64_t> expertScalesShape{BS, K}; std::vector<int64_t> expandXShape{TP_WORLD_SIZE * A, H}; std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE * A}; std::vector<int64_t> expandIdxShape{A * 128}; std::vector<int64_t> expertTokenNumsShape{localExpertNum}; std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE * localExpertNum * EP_WORLD_SIZE}; std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE * localExpertNum}; std::vector<int64_t> expandScalesShape{A}; std::vector<int64_t> residualXShape{BS, 1, H}; std::vector<int64_t> sharedExpertXShape{BS, 1, H}; std::vector<int64_t> gammaShape{BS, 1, H}; std::vector<int64_t> yOutShape{BS, 1, H}; std::vector<int64_t> rstdOutShape{BS, 1, 1}; std::vector<int64_t> xOutShape{BS, 1, H}; int64_t xShapeSize = GetShapeSize(xShape); int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape); int64_t scalesShapeSize = GetShapeSize(scalesShape); int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape); int64_t expandXShapeSize = GetShapeSize(expandXShape); int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape); int64_t expandIdxShapeSize = GetShapeSize(expandIdxShape); int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape); int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape); int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape); int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape); int64_t residualXShapeSize = GetShapeSize(residualXShape); int64_t sharedExpertXShapeSize = GetShapeSize(sharedExpertXShape); int64_t gammaShapeSize = GetShapeSize(gammaShape); int64_t yOutShapeSize = GetShapeSize(yOutShape); int64_t rstdOutShapeSize = GetShapeSize(rstdOutShape); int64_t xOutShapeSize = GetShapeSize(xOutShape); // 构造host侧变量 std::vector<int16_t> xHostData(xShapeSize, 1); std::vector<int32_t> expertIdsHostData; for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) { // 每个token发给moe专家{0, 1, ... k - 1} for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) { expertIdsHostData.push_back(k_id); } } std::vector<float> scalesHostData(scalesShapeSize, 0.1); std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1); std::vector<int16_t> expandXHostData(expandXShapeSize, 0); std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0); std::vector<int32_t> expandIdxHostData(expandIdxShapeSize, 0); std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0); std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0); std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0); std::vector<float> expandScalesHostData(expandScalesShapeSize, 0); std::vector<int16_t> residualXHostData(residualXShapeSize, 1); std::vector<int16_t> sharedExpertXHostData(sharedExpertXShapeSize, 1); std::vector<int16_t> gammaHostData(gammaShapeSize, 1); std::vector<int16_t> yOutHostData(yOutShapeSize, 0); std::vector<float> rstdOutHostData(rstdOutShapeSize, 0); std::vector<int16_t> xOutHostData(xOutShapeSize, 0); // 构造device侧变量 ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandIdxHostData, expandIdxShape, &expandIdxDeviceAddr, aclDataType::ACL_INT32, &expandIdx); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(residualXHostData, residualXShape, &residualXDeviceAddr, aclDataType::ACL_BF16, &residualX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(sharedExpertXHostData, sharedExpertXShape, &sharedExpertXDeviceAddr, aclDataType::ACL_BF16, &sharedExpertX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(gammaHostData, gammaShape, &gammaDeviceAddr, aclDataType::ACL_BF16, &gamma); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(yOutHostData, yOutShape, &yOutDeviceAddr, aclDataType::ACL_BF16, &yOut); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(rstdOutHostData, rstdOutShape, &rstdOutDeviceAddr, aclDataType::ACL_FLOAT, &rstdOut); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_BF16, &xOut); CHECK_RET(ret == ACL_SUCCESS, return ret); /* 声明算子执行必需变量 */ uint64_t dispatchWorkspaceSize = 0; aclOpExecutor *dispatchExecutor = nullptr; void *dispatchWorkspaceAddr = nullptr; uint64_t combineAddRmsNormWorkspaceSize = 0; aclOpExecutor *combineAddRmsNormExecutor = nullptr; void *combineWorkspaceAddr = nullptr; /**************************************** 调用dispatch ********************************************/ ret = aclnnMoeDistributeDispatchV2GetWorkspaceSize(x, expertIds, (quantMode > 0 ? scales : nullptr), nullptr, expertScales, hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE, args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBS, expertTokenNumsType, nullptr, expandX, dynamicScales, expandIdx, expertTokenNums, epRecvCounts, tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV2GetWorkspaceSize failed. ret = %d \n", ret); return ret); if (dispatchWorkspaceSize > 0) { ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret); } // 调用第二阶段接口 ret = aclnnMoeDistributeDispatchV2(dispatchWorkspaceAddr, dispatchWorkspaceSize, dispatchExecutor, args.dispatchStream); ret = aclrtSynchronizeStreamWithTimeout(args.dispatchStream, 10000); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV2 failed. ret = %d \n", ret); \ return ret); /**************************************** 调用combineAddRmsNorm ********************************************/ // 调用第一阶段接口 ret = aclnnMoeDistributeCombineAddRmsNormGetWorkspaceSize( expandX, expertIds, expandIdx, epRecvCounts, expertScales, residualX, gamma, tpRecvCounts, nullptr, nullptr, nullptr, nullptr, nullptr, sharedExpertX, hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE, args.tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, globalBS, outDtype, commQuantMode, groupList_type, nullptr, 1e-6, yOut, rstdOut, xOut, &combineAddRmsNormWorkspaceSize, &combineAddRmsNormExecutor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineAddRmsNormGetWorkspaceSize failed. ret = %d \n", ret); return ret); // 根据第一阶段接口计算出的workspaceSize申请device内存 if (combineAddRmsNormWorkspaceSize > 0) { ret = aclrtMalloc(&combineWorkspaceAddr, combineAddRmsNormWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret); } // 调用第二阶段接口 ret = aclnnMoeDistributeCombineAddRmsNorm(combineWorkspaceAddr, combineAddRmsNormWorkspaceSize, combineAddRmsNormExecutor, args.combineStream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineAddRmsNorm failed. ret = %d \n", ret); return ret); // (固定写法)同步等待任务执行结束 ret = aclrtSynchronizeStreamWithTimeout(args.combineStream, 10000); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); return ret); LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV2 and aclnnMoeDistributeCombineAddRmsNorm \ execute successfully.\n", args.rankId); // 释放device资源 if (dispatchWorkspaceSize > 0) { aclrtFree(dispatchWorkspaceAddr); } if (combineAddRmsNormWorkspaceSize > 0) { aclrtFree(combineWorkspaceAddr); } if (x != nullptr) { aclDestroyTensor(x); } if (expertIds != nullptr) { aclDestroyTensor(expertIds); } if (scales != nullptr) { aclDestroyTensor(scales); } if (expertScales != nullptr) { aclDestroyTensor(expertScales); } if (expandX != nullptr) { aclDestroyTensor(expandX); } if (dynamicScales != nullptr) { aclDestroyTensor(dynamicScales); } if (expandIdx != nullptr) { aclDestroyTensor(expandIdx); } if (expertTokenNums != nullptr) { aclDestroyTensor(expertTokenNums); } if (epRecvCounts != nullptr) { aclDestroyTensor(epRecvCounts); } if (tpRecvCounts != nullptr) { aclDestroyTensor(tpRecvCounts); } if (expandScales != nullptr) { aclDestroyTensor(expandScales); } if (residualX != nullptr) { aclDestroyTensor(residualX); } if (sharedExpertX != nullptr) { aclDestroyTensor(sharedExpertX); } if (gamma != nullptr) { aclDestroyTensor(gamma); } if (yOut != nullptr) { aclDestroyTensor(yOut); } if (rstdOut != nullptr) { aclDestroyTensor(rstdOut); } if (xOut != nullptr) { aclDestroyTensor(xOut); } if (xDeviceAddr != nullptr) { aclrtFree(xDeviceAddr); } if (expertIdsDeviceAddr != nullptr) { aclrtFree(expertIdsDeviceAddr); } if (scalesDeviceAddr != nullptr) { aclrtFree(scalesDeviceAddr); } if (expertScalesDeviceAddr != nullptr) { aclrtFree(expertScalesDeviceAddr); } if (expandXDeviceAddr != nullptr) { aclrtFree(expandXDeviceAddr); } if (dynamicScalesDeviceAddr != nullptr) { aclrtFree(dynamicScalesDeviceAddr); } if (expandIdxDeviceAddr != nullptr) { aclrtFree(expandIdxDeviceAddr); } if (expertTokenNumsDeviceAddr != nullptr) { aclrtFree(expertTokenNumsDeviceAddr); } if (epRecvCountsDeviceAddr != nullptr) { aclrtFree(epRecvCountsDeviceAddr); } if (expandScalesDeviceAddr != nullptr) { aclrtFree(expandScalesDeviceAddr); } if (tpRecvCountsDeviceAddr != nullptr) { aclrtFree(tpRecvCountsDeviceAddr); } if (residualXDeviceAddr != nullptr) { aclrtFree(residualXDeviceAddr); } if (sharedExpertXDeviceAddr != nullptr) { aclrtFree(sharedExpertXDeviceAddr); } if (gammaDeviceAddr != nullptr) { aclrtFree(gammaDeviceAddr); } if (yOutDeviceAddr != nullptr) { aclrtFree(yOutDeviceAddr); } if (rstdOutDeviceAddr != nullptr) { aclrtFree(rstdOutDeviceAddr); } if (xOutDeviceAddr != nullptr) { aclrtFree(xOutDeviceAddr); } HcclCommDestroy(args.hcclEpComm); HcclCommDestroy(args.hcclTpComm); aclrtDestroyStream(args.dispatchStream); aclrtDestroyStream(args.combineStream); aclrtDestroyContext(args.context); aclrtResetDevice(args.rankId); return 0; } int main(int argc, char *argv[]) { int ret = aclInit(nullptr); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed, ret = %d\n", ret); return ret); aclrtStream dispatchStream[DEV_NUM]; aclrtStream combineStream[DEV_NUM]; aclrtContext context[DEV_NUM]; for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) { ret = aclrtSetDevice(rankId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret); ret = aclrtCreateContext(&context[rankId], rankId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret); ret = aclrtCreateStream(&dispatchStream[rankId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret); ret = aclrtCreateStream(&combineStream[rankId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret); } int32_t devicesEp[TP_WORLD_SIZE][EP_WORLD_SIZE]; for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) { for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) { devicesEp[tpId][epId] = epId * TP_WORLD_SIZE + tpId; } } HcclComm commsEp[TP_WORLD_SIZE][EP_WORLD_SIZE]; for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) { ret = HcclCommInitAll(EP_WORLD_SIZE, devicesEp[tpId], commsEp[tpId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll ep %d failed, ret %d\n", tpId, ret); return ret); } int32_t devicesTp[EP_WORLD_SIZE][TP_WORLD_SIZE]; for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) { for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) { devicesTp[epId][tpId] = epId * TP_WORLD_SIZE + tpId; } } HcclComm commsTp[EP_WORLD_SIZE][TP_WORLD_SIZE]; for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) { ret = HcclCommInitAll(TP_WORLD_SIZE, devicesTp[epId], commsTp[epId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll tp %d failed, ret %d\n", epId, ret); return ret); } Args args[DEV_NUM]; // 各线程调用各卡执行算子 std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM); for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) { uint32_t epRankId = rankId / TP_WORLD_SIZE; uint32_t tpRankId = rankId % TP_WORLD_SIZE; args[rankId].rankId = rankId; args[rankId].epRankId = epRankId; args[rankId].tpRankId = tpRankId; args[rankId].hcclEpComm = commsEp[tpRankId][epRankId]; args[rankId].hcclTpComm = commsTp[epRankId][tpRankId]; args[rankId].dispatchStream = dispatchStream[rankId]; args[rankId].combineStream = combineStream[rankId]; args[rankId].context = context[rankId]; threads[rankId].reset(new(std::nothrow) std::thread(&LaunchOneProcessDispatchAndCombine, std::ref(args[rankId]))); } for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) { threads[rankId]->join(); } aclFinalize(); LOG_PRINT("[INFO] aclFinalize success\n"); return 0; }