昇腾社区首页
中文
注册

aclnnMoeDistributeDispatchV2

支持的产品型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

功能说明

算子功能:对token数据进行量化(可选),当存在TP域通信时,先进行EP(Expert Parallelism)域的AllToAllV通信,再进行TP(Tensor Parallelism)域的AllGatherV通信;当不存在TP域通信时,进行EP(Expert Parallelism)域的AllToAllV通信。

agOut=AllGatherV(X)expandXOut=AllToAllV(agOut)agOut = AllGatherV(X)\\ expandXOut = AllToAllV(agOut)\\
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:该接口必须与aclnnMoeDistributeCombineV2配套使用。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:该接口必须与aclnnMoeDistributeCombineV2或aclnnMoeDistributeCombineAddRmsNorm配套使用。

说明:

aclnnMoeDistributeCombineV2、aclnnMoeDistributeCombineAddRmsNorm算子在后续文档中统称为CombineV2系列算子。

相较于aclnnMoeDistributeDispatch接口,该接口变更如下:

  • 输出了更详细的token信息辅助CombineV2系列算子高效地进行全卡同步,因此原接口中shape为(Bs * K, 1)的expandIdx出参替换为shape为(A * 128, 1)的assitInfoForCombineOut参数;
  • 新增commAlg入参,预留字段,暂未使用。

详细说明请参考以下参数说明。

函数原型

每个算子分为两段式接口,必须先调用 “aclnnMoeDistributeDispatchV2GetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeDistributeDispatchV2”接口执行计算。

  • aclnnStatus aclnnMoeDistributeDispatchV2GetWorkspaceSize(const aclTensor* x, const aclTensor* expertIds, const aclTensor* scalesOptional, const aclTensor* xActiveMaskOptional, const aclTensor* expertScalesOptional, const char* groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, const char* groupTp, int64_t tpWorldSize, int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t quantMode, int64_t globalBs, int64_t expertTokenNumsType, const char* commAlg, aclTensor* expandXOut, aclTensor* dynamicScalesOut, aclTensor* assitInfoForCombineOut, aclTensor* expertTokenNumsOut, aclTensor* epRecvCountsOut, aclTensor* tpRecvCountsOut, aclTensor* expertScalesOut, uint64_t* workspaceSize, aclOpExecutor** executor)
  • aclnnStatus aclnnMoeDistributeDispatchV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

aclnnMoeDistributeDispatchV2GetWorkspaceSize

  • 参数说明

    • x(aclTensor*,计算输入):表示本卡发送的token数据,Device侧的aclTensor。要求为一个2D的Tensor,shape为 (Bs, H),其中Bs为batch size,H为hidden size,即隐藏层大小,数据类型支持FLOAT16、BFLOAT16,数据格式要求为ND,支持非连续的Tensor
    • expertIds(aclTensor*,计算输入):每个token的topK个专家索引,Device侧的aclTensor,要求为一个2D的Tensor,shape为 (Bs, K)。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor
    • scalesOptional(aclTensor*,计算输入):每个专家的量化平滑参数,Device侧的aclTensor,要求是一个2D的Tensor,shape (sharedExpertNum + moeExpertNum, H)。非量化场景传空指针,动态量化可选择传入有效数据或传入空指针。数据类型支持FLOAT32,数据格式要求为ND,支持非连续的Tensor
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当HCCL_INTRA_PCIE_ENABLE为1且HCCL_INTRA_ROCE_ENABLE为0时,要求传nullptr。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:无特殊要求。
    • xActiveMaskOptional(aclTensor*,计算输入):表示token是否参与通信,Device侧的aclTensor。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传空指针即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求是一个1D的Tensor,shape为 (Bs, ),数据类型支持BOOL;可选择传入有效数据或传入空指针,传入空指针时是表示所有token都会参与通信。数据格式要求为ND,支持非连续的Tensor
    • expertScalesOptional(aclTensor*,计算输入):每个token的topK个专家权重,Device侧的aclTensor。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:要求是一个2D的shape (Bs, K)。数据类型支持FLOAT32,数据格式要求为ND,支持非连续的Tensor
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前版本不支持,传空指针即可。
    • groupEp(char*,计算输入):EP通信域名称,专家并行的通信域,string数据类型。字符串长度范围为[1, 128),不能和groupTp相同。
    • epWorldSize(int64_t,计算输入):EP通信域size,数据类型支持INT64。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:取值支持16、32、64。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值区间[2, 384]。
    • epRankId(int64_t,计算输入): EP域本卡Id,数据类型支持INT64,取值范围[0, epWorldSize)。同一个EP通信域中各卡的epRankId不重复。
    • moeExpertNum(int64_t,计算输入): MoE专家数量,数据类型支持INT64,取值范围(0, 512],并且满足moeExpertNum % (epWorldSize-sharedExpertRankNum)=0。
    • groupTp(char*,计算输入): TP通信域名称,数据并行的通信域,string数据类型。若有TP域通信需要传参,若无TP域通信,传空字符即可。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传空字符即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,字符串长度范围为[1, 128),不能和groupEp相同。
    • tpWorldSize(int64_t,计算输入):TP通信域size,int数据类型。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传0即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, 2],0和1表示无TP域通信,有TP域通信时仅支持2。
    • tpRankId(int64_t,计算输入):TP域本卡Id,数据类型支持INT64。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传0即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[0, 1],同一个TP通信域中各卡的tpRankId不重复。无TP域通信时,传0即可。
    • expertShardType(int64_t,计算输入):表示共享专家卡分布类型,数据类型支持INT64。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传0即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
    • sharedExpertNum(int64_t,计算输入):表示共享专家数量,一个共享专家可以复制部署到多个卡上,数据类型支持INT64。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传0即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前取值范围[0, 4],0表示无共享专家。
    • sharedExpertRankNum(int64_t,计算输入):表示共享专家卡数量,数据类型支持INT64。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前版本不支持,传0即可。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前取值范围[0, epWorldSize - 1),不为0时需满足epWorldSize % (sharedExpertRankNum / sharedExpertNum) = 0。
    • quantMode(int64_t,计算输入):表示量化模式,支持0:非量化,2:动态量化。
    • globalBs(int64_t,计算输入):EP域全局的batch size大小,数据类型支持INT64。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当每个rank的Bs不同时,传入256 * epWorldSize;当每个rank的Bs相同时,支持取值0或Bs * epWorldSize。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当每个rank的Bs数一致场景下,globalBs = Bs * epWorldSize 或 globalBs = 0;当每个rank的Bs数不一致场景下,globalBs = maxBs * epWorldSize,其中maxBs表示单卡Bs最大值。
    • expertTokenNumsType(int64_t,计算输入):输出expertTokenNums中值的语义类型。支持0:expertTokenNums中的输出为每个专家处理的token数的前缀和,1:expertTokenNums中的输出为每个专家处理的token数量。
    • commAlg(char*,计算输入):表示通信亲和内存布局算法,string数据类型。预留字段,当前版本不支持,传空指针即可。
    • expandXOut(aclTensor*,计算输出):根据expertIds进行扩展过的token特征,Device侧的aclTensor,要求为一个2D的Tensor,shape为 (max(tpWorldSize, 1) * A, H),数据类型支持FLOAT16、BFLOAT16、INT8,数据格式要求为ND,支持非连续的Tensor
    • dynamicScalesOut(aclTensor*,计算输出):数据类型FLOAT32,要求为一个1D的Tensor,shape为 (A, ),数据格式要求为ND,支持非连续的Tensor。当quantMode为2时,才有该输出。
    • assitInfoForCombineOut(aclTensor*,计算输出):表示给同一专家发送的token个数,对应CombineV2系列算子中的assitInfoForCombine,Device侧的aclTensor,要求是一个1D的Tensor。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:shape要求为 (Bs*K, )
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:shape要求为 (A*128, )
    • expertTokenNumsOut(aclTensor*,计算输出):表示每个专家收到的token个数,Device侧的aclTensor,数据类型INT64,要求为一个1D的Tensor,shape为 (localExpertNum, ),数据格式要求为ND,支持非连续的Tensor
    • epRecvCountsOut(aclTensor*,计算输出):从EP通信域各卡接收的token数,对应CombineV2系列算子中的epSendCounts,Device侧的aclTensor,数据类型INT32,要求为一个1D的Tensor,数据格式要求为ND,支持非连续的Tensor
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:要求shape为 (moeExpertNum + 2 * globalBs * K * serverNum, ),前moeExpertNum个数表示从EP通信域各卡接收的token数,2 * globalBs * K * serverNum存储了机间机内做通信前combine可以提前做reduce的token个数和token在通信区中的偏移。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求shape为 (epWorldSize * max(tpWorldSize, 1) * localExpertNum, )。
    • tpRecvCountsOut(aclTensor*,计算输出):从TP通信域各卡接收的token数,对应CombineV2系列算子中的tpSendCounts,Device侧的aclTensor。若有TP域通信则有该输出,若无TP域通信则无该输出。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当前不支持TP域通信。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当有TP域通信时,要求是一个1D的Tensor,shape为 (tpWorldSize, )。数据类型支持INT32,数据格式要求为ND,支持非连续的Tensor
    • expertScalesOut(aclTensor*,计算输出):表示本卡输出token的权重,对应CombineV2系列算子中的expertScalesOptional,Device侧的aclTensor。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:要求是一个1D的Tensor,shape为 (A, ),数据类型支持FLOAT32,数据格式要求为ND,支持非连续的Tensor
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当前版本不支持该输出。
    • workspaceSize(uint64_t*,出参):返回需要在Device侧申请的workspace大小。
    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
  • 返回值

    返回aclnnStatus状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:
    161001(ACLNN_ERR_PARAM_NULLPTR): 1. 输入和输出的必选参数Tensor是空指针。
    161002(ACLNN_ERR_PARAM_INVALID): 1. 输入和输出的数据类型不在支持的范围内。
    561002(ACLNN_ERR_INNER_TILING_ERROR): 1. 输入和输出的shape不在支持的范围内。
                                          2. 参数的取值不在支持的范围。 

aclnnMoeDistributeDispatchV2

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMoeDistributeDispatchV2GetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  • aclnnMoeDistributeDispatchV2接口与CombineV2系列算子接口必须配套使用,具体参考调用示例

  • 调用接口过程中使用的groupEp、epWorldSize、moeExpertNum、groupTp、tpWorldSize、expertShardType、sharedExpertNum、sharedExpertRankNum、globalBs参数取值所有卡需保持一致,网络中不同层中也需保持一致,且和CombineV2系列算子对应参数也保持一致。

  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。

  • 参数说明里shape格式说明:

    • A:表示本卡可能接收的最大token数量,取值范围如下:
      • 对于共享专家,要满足A = Bs * epWorldSize * sharedExpertNum / sharedExpertRankNum。
      • 对于MoE专家,当globalBs为0时,要满足A >= Bs * epWorldSize * min(localExpertNum, K);当globalBs非0时,要满足A >= globalBs * min(localExpertNum, K)。
    • H:表示hidden size隐藏层大小。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:取值范围(0, 7168],且保证是32的整数倍。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围[1024, 7168]。
    • Bs:表示batch sequence size,即本卡最终输出的token数量。
      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:取值范围为0 < Bs ≤ 256。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0 < Bs ≤ 512。
    • K:表示选取topK个专家,取值范围为0 < K ≤ 16同时满足0 < K ≤ moeExpertNum。
    • serverNum:表示服务器的节点数,取值只支持2、4、8。
    • localExpertNum:表示本卡专家数量。
      • 对于共享专家卡,localExpertNum = 1
      • 对于MoE专家卡,localExpertNum = moeExpertNum / (epWorldSize - sharedExpertRankNum),localExpertNum > 1时,不支持TP域通信。
  • HCCL_BUFFSIZE: 调用本接口前需检查HCCL_BUFFSIZE环境变量取值是否合理,该环境变量表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。

    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:要求 >= 2 * (Bs * epWorldSize * min(localExpertNum, K) * H * sizeof(uint16) + 2MB)。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:要求 >= 2且满足>= 2 * (localExpertNum * maxBs * epWorldSize * Align512(Align32(2 * H) + 64) + (K + sharedExpertNum) * maxBs * Align512(2 * H)),localExpertNum需使用MoE专家卡的本卡专家数, 其中,Align512(x) = ((x + 512 - 1) / 512) * 512, Align32(x) = ((x + 32 - 1) / 32) * 32。
  • HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE:

    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:当K = 8 且 Bs ≤ 128时,设置环境变量HCCL_INTRA_PCIE_ENABLE = 1和HCCL_INTRA_ROCE_ENABLE = 0可以减少跨机通信数据量,可能提升算子性能。 此时,HCCL_BUFFSIZE要求 >= moeExpertNum * Bs * (H * sizeof(dtypeX) + 4 * K * sizeof(uint32)) + 4MB + 100MB。
  • 通信域使用约束:

    • 一个模型中的CombineV2系列算子和aclnnMoeDistributeDispatchV2仅支持相同EP通信域,且该通信域中不允许有其他算子。
    • 一个模型中的CombineV2系列算子和aclnnMoeDistributeDispatchV2仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。

调用示例

Atlas A3 训练系列产品/Atlas A3 推理系列产品为例,调起MoeDistributeCombineV2和MoeDistributeDispatchV2算子。

  • 文件准备:
    1.新建combineDemo目录,按照下方指导在combineDemo下新建aclnnCombineDemo.cpp,buildCombine.sh文件并参考如下代码修改。

    2.安装cann包,并根据下方指导编译运行combineDemo。

  • 编译脚本

    #!/bin/bash
    cann_path="/path/to/cann_env" # 更改cann包环境的路径
    g++ "aclnnCombineDemo.cpp" -o combineDemo -I"$cann_path/latest/include/" -I"$cann_path/latest/include/aclnnop/" \
                        -L="$cann_path/latest/lib64/" -lascendcl -lnnopbase -lopapi -lop_common -lpthread -lhccl
  • 编译与运行:

    # source cann环境
    source /path/to/cann_env/latest/bin/setenv.bash
    
    # 编译aclnnCombineDemo.cpp
    bash buildCombine.sh
    
    ./combineDemo
  • 示例代码如下,仅供参考

    #include <thread>
    #include <iostream>
    #include <string>
    #include <vector>
    #include "acl/acl.h"
    #include "hccl/hccl.h"
    #include "aclnnop/aclnn_moe_distribute_dispatch_v2.h"
    #include "aclnnop/aclnn_moe_distribute_combine_v2.h"
    
    #define CHECK_RET(cond, return_expr) \
        do {                             \
            if (!(cond)) {               \
                return_expr;             \
            }                            \
        } while (0)
    
    #define LOG_PRINT(message, ...)         \
        do {                                \
            printf(message, ##__VA_ARGS__); \
        } while(0)
    
    struct Args {
        uint32_t rankId;
        uint32_t epRankId;
        uint32_t tpRankId;
        HcclComm hcclEpComm;
        HcclComm hcclTpComm;
        aclrtStream dispatchStream;
        aclrtStream combineStream;
        aclrtContext context;
    };
    
    constexpr uint32_t EP_WORLD_SIZE = 8;
    constexpr uint32_t TP_WORLD_SIZE = 2;
    constexpr uint32_t DEV_NUM = EP_WORLD_SIZE * TP_WORLD_SIZE;
    
    int64_t GetShapeSize(const std::vector<int64_t> &shape)
    {
        int64_t shape_size = 1;
        for (auto i : shape) {
            shape_size *= i;
        }
        return shape_size;
    }
    
    template<typename T>
    int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
        aclDataType dataType, aclTensor **tensor)
    {
        auto size = GetShapeSize(shape) * sizeof(T);
        auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret);
        ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret);
        std::vector<int64_t> strides(shape.size(), 1);
        for (int64_t i = shape.size() - 2; i >= 0; i--) {
            strides[i] = shape[i +1] * strides[i + 1];
        }
        *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
            shape.data(), shape.size(), *deviceAddr);
        return 0;
    }
    
    int LaunchOneProcessDispatchAndCombine(Args &args)
    {
        int ret = aclrtSetCurrentContext(args.context);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret);
    
        char hcomEpName[128] = {0};
        ret = HcclGetCommName(args.hcclEpComm, hcomEpName);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1);
        char hcomTpName[128] = {0};
        ret = HcclGetCommName(args.hcclTpComm, hcomTpName);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetTpCommName failed, ret %d\n", ret); return -1);
        LOG_PRINT("[INFO] rank = %d, hcomEpName = %s, hcomTpName = %s, dispatchStream = %p, combineStream = %p, \
                  context = %p\n", args.rankId, hcomEpName, hcomTpName, args.dispatchStream, args.combineStream,                 \
                  args.context);
    
        int64_t Bs = 8;
        int64_t H = 7168;
        int64_t K = 3;
        int64_t expertShardType = 0;
        int64_t sharedExpertNum = 1;
        int64_t sharedExpertRankNum = 1;
        int64_t moeExpertNum = 7;
        int64_t quantMode = 0;
        int64_t globalBs = Bs * EP_WORLD_SIZE;
        int64_t expertTokenNumsType = 1;
        int64_t outDtype = 0;
        int64_t commQuantMode = 0;
        int64_t groupList_type = 1;
        int64_t localExpertNum;
        int64_t A;
        if (args.epRankId < sharedExpertRankNum) {
            localExpertNum = 1;
            A = globalBs / sharedExpertRankNum;
        } else {
            localExpertNum = moeExpertNum / (EP_WORLD_SIZE - sharedExpertRankNum);
            A = globalBs * (localExpertNum < K ? localExpertNum : K);
        }
    
        void *xDeviceAddr = nullptr;
        void *expertIdsDeviceAddr = nullptr;
        void *scalesDeviceAddr = nullptr;
        void *expertScalesDeviceAddr = nullptr;
        void *expandXDeviceAddr = nullptr;
        void *dynamicScalesDeviceAddr = nullptr;
        void *expandIdxDeviceAddr = nullptr;
        void *expertTokenNumsDeviceAddr = nullptr;
        void *epRecvCountsDeviceAddr = nullptr;
        void *tpRecvCountsDeviceAddr = nullptr;
        void *expandScalesDeviceAddr = nullptr;
    
        aclTensor *x = nullptr;
        aclTensor *expertIds = nullptr;
        aclTensor *scales = nullptr;
        aclTensor *expertScales = nullptr;
        aclTensor *expandX = nullptr;
        aclTensor *dynamicScales = nullptr;
        aclTensor *expandIdx = nullptr;
        aclTensor *expertTokenNums = nullptr;
        aclTensor *epRecvCounts = nullptr;
        aclTensor *tpRecvCounts = nullptr;
        aclTensor *expandScales = nullptr;
    
        std::vector<int64_t> xShape{Bs, H};
        std::vector<int64_t> expertIdsShape{Bs, K};
        std::vector<int64_t> scalesShape{moeExpertNum + 1, H};
        std::vector<int64_t> expertScalesShape{Bs, K};
        std::vector<int64_t> expandXShape{TP_WORLD_SIZE * A, H};
        std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE * A};
        std::vector<int64_t> expandIdxShape{A * 128};
        std::vector<int64_t> expertTokenNumsShape{localExpertNum};
        std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE * localExpertNum * EP_WORLD_SIZE};
        std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE * localExpertNum};
        std::vector<int64_t> expandScalesShape{A};
    
        int64_t xShapeSize = GetShapeSize(xShape);
        int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape);
        int64_t scalesShapeSize = GetShapeSize(scalesShape);
        int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape);
        int64_t expandXShapeSize = GetShapeSize(expandXShape);
        int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape);
        int64_t expandIdxShapeSize = GetShapeSize(expandIdxShape);
        int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape);
        int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape);
        int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape);
        int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape);
    
        std::vector<int16_t> xHostData(xShapeSize, 1);
        std::vector<int32_t> expertIdsHostData;
        for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) {
            for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) {
                expertIdsHostData.push_back(k_id);
            }
        }
    
        std::vector<float> scalesHostData(scalesShapeSize, 0.1);
        std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1);
        std::vector<int16_t> expandXHostData(expandXShapeSize, 0);
        std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0);
        std::vector<int32_t> expandIdxHostData(expandIdxShapeSize, 0);
        std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0);
        std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0);
        std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0);
        std::vector<float> expandScalesHostData(expandScalesShapeSize, 0);
    
        ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
         ret = CreateAclTensor(expandIdxHostData, expandIdxShape, &expandIdxDeviceAddr, aclDataType::ACL_INT32, &expandIdx);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        
        uint64_t dispatchWorkspaceSize = 0;
        aclOpExecutor *dispatchExecutor = nullptr;
        void *dispatchWorkspaceAddr = nullptr;
    
        uint64_t combineWorkspaceSize = 0;
        aclOpExecutor *combineExecutor = nullptr;
        void *combineWorkspaceAddr = nullptr;
    
        /**************************************** 调用dispatch ********************************************/
    
        ret = aclnnMoeDistributeDispatchV2GetWorkspaceSize(x, expertIds, (quantMode > 0 ? scales : nullptr), nullptr, 
                expertScales, hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE,
                args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBs,
                expertTokenNumsType, nullptr, expandX, dynamicScales, expandIdx, expertTokenNums, epRecvCounts,
                tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor);
        
        CHECK_RET(ret == ACL_SUCCESS,
            LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV2GetWorkspaceSize failed. ret = %d \n", ret); return ret);
    
        if (dispatchWorkspaceSize > 0) {
            ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
        }
        // 调用第二阶段接口
        ret = aclnnMoeDistributeDispatchV2(dispatchWorkspaceAddr, dispatchWorkspaceSize,
                                           dispatchExecutor, args.dispatchStream);
        ret = aclrtSynchronizeStreamWithTimeout(args.dispatchStream, 10000);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV2 failed. ret = %d \n", ret);  \
            return ret);
    
        /**************************************** 调用combine ********************************************/
        // 调用第一阶段接口
        ret = aclnnMoeDistributeCombineV2GetWorkspaceSize(expandX, expertIds,
                                                         expandIdx, epRecvCounts,
                                                         expertScales, tpRecvCounts,
                                                         nullptr, nullptr, nullptr,
                                                         nullptr, nullptr, nullptr, 
                                                         hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum,
                                                         hcomTpName, TP_WORLD_SIZE, args.tpRankId, expertShardType,
                                                         sharedExpertNum, sharedExpertRankNum, globalBs, outDtype,
                                                         commQuantMode, groupList_type, nullptr, x,
                                                         &combineWorkspaceSize, &combineExecutor);
        CHECK_RET(ret == ACL_SUCCESS,
            LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV2GetWorkspaceSize failed. ret = %d \n", ret); return ret);
        // 根据第一阶段接口计算出的workspaceSize申请device内存
        if (combineWorkspaceSize > 0) {
            ret = aclrtMalloc(&combineWorkspaceAddr, combineWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
        }
    
        // 调用第二阶段接口
        ret = aclnnMoeDistributeCombineV2(combineWorkspaceAddr, combineWorkspaceSize, combineExecutor, args.combineStream);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV2 failed. ret = %d \n", ret);
            return ret);
        // (固定写法)同步等待任务执行结束
        ret = aclrtSynchronizeStreamWithTimeout(args.combineStream, 10000);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);
            return ret);
        LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV2 and aclnnMoeDistributeCombineV2                      \
                   execute successfully.\n", args.rankId);
    
        // 释放device资源
        if (dispatchWorkspaceSize > 0) {
            aclrtFree(dispatchWorkspaceAddr);
        }
        if (combineWorkspaceSize > 0) {
            aclrtFree(combineWorkspaceAddr);
        }
        if (x != nullptr) {
            aclDestroyTensor(x);
        }
        if (expertIds != nullptr) {
            aclDestroyTensor(expertIds);
        }
        if (scales != nullptr) {
            aclDestroyTensor(scales);
        }
        if (expertScales != nullptr) {
            aclDestroyTensor(expertScales);
        }
        if (expandX != nullptr) {
            aclDestroyTensor(expandX);
        }
        if (dynamicScales != nullptr) {
            aclDestroyTensor(dynamicScales);
        }
        if (expandIdx != nullptr) {
            aclDestroyTensor(expandIdx);
        }
        if (expertTokenNums != nullptr) {
            aclDestroyTensor(expertTokenNums);
        }
        if (epRecvCounts != nullptr) {
            aclDestroyTensor(epRecvCounts);
        }
        if (tpRecvCounts != nullptr) {
            aclDestroyTensor(tpRecvCounts);
        }
        if (expandScales != nullptr) {
            aclDestroyTensor(expandScales);
        }
        if (xDeviceAddr != nullptr) {
            aclrtFree(xDeviceAddr);
        }
        if (expertIdsDeviceAddr != nullptr) {
            aclrtFree(expertIdsDeviceAddr);
        }
        if (scalesDeviceAddr != nullptr) {
            aclrtFree(scalesDeviceAddr);
        }
        if (expertScalesDeviceAddr != nullptr) {
            aclrtFree(expertScalesDeviceAddr);
        }
        if (expandXDeviceAddr != nullptr) {
            aclrtFree(expandXDeviceAddr);
        }
        if (dynamicScalesDeviceAddr != nullptr) {
            aclrtFree(dynamicScalesDeviceAddr);
        }
        if (expandIdxDeviceAddr != nullptr) {
            aclrtFree(expandIdxDeviceAddr);
        }
        if (expertTokenNumsDeviceAddr != nullptr) {
            aclrtFree(expertTokenNumsDeviceAddr);
        }
        if (epRecvCountsDeviceAddr != nullptr) {
            aclrtFree(epRecvCountsDeviceAddr);
        }
        if (expandScalesDeviceAddr != nullptr) {
            aclrtFree(expandScalesDeviceAddr);
        }
        if (tpRecvCountsDeviceAddr != nullptr) {
            aclrtFree(tpRecvCountsDeviceAddr);
        }
        
        HcclCommDestroy(args.hcclEpComm);
        HcclCommDestroy(args.hcclTpComm);
        aclrtDestroyStream(args.dispatchStream);
        aclrtDestroyStream(args.combineStream);
        aclrtDestroyContext(args.context);
        aclrtResetDevice(args.rankId);
    
        return 0;
    }
    
    int main(int argc, char *argv[])
    {
        int ret = aclInit(nullptr);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtInit failed, ret = %d\n", ret); return ret);
    
        aclrtStream dispatchStream[DEV_NUM];
        aclrtStream combineStream[DEV_NUM];
        aclrtContext context[DEV_NUM];
        for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
            ret = aclrtSetDevice(rankId);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret);
            ret = aclrtCreateContext(&context[rankId], rankId);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret);
            ret = aclrtCreateStream(&dispatchStream[rankId]);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
            ret = aclrtCreateStream(&combineStream[rankId]);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
        }
    
        int32_t devicesEp[TP_WORLD_SIZE][EP_WORLD_SIZE];
        for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
            for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
                devicesEp[tpId][epId] = epId * TP_WORLD_SIZE + tpId;
            }
        }
    
        HcclComm commsEp[TP_WORLD_SIZE][EP_WORLD_SIZE];
        for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
            ret = HcclCommInitAll(EP_WORLD_SIZE, devicesEp[tpId], commsEp[tpId]);
            CHECK_RET(ret == ACL_SUCCESS,
                      LOG_PRINT("[ERROR] HcclCommInitAll ep %d failed, ret %d\n", tpId, ret); return ret);
        }
    
        int32_t devicesTp[EP_WORLD_SIZE][TP_WORLD_SIZE];
        for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
            for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
                devicesTp[epId][tpId] = epId * TP_WORLD_SIZE + tpId;
            }
        }
    
        HcclComm commsTp[EP_WORLD_SIZE][TP_WORLD_SIZE];
        for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
            ret = HcclCommInitAll(TP_WORLD_SIZE, devicesTp[epId], commsTp[epId]);
            CHECK_RET(ret == ACL_SUCCESS,
                      LOG_PRINT("[ERROR] HcclCommInitAll tp %d failed, ret %d\n", epId, ret); return ret);
        }
    
        Args args[DEV_NUM];
        std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM);
        for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
            uint32_t epRankId = rankId / TP_WORLD_SIZE;
            uint32_t tpRankId = rankId % TP_WORLD_SIZE;
    
            args[rankId].rankId = rankId;
            args[rankId].epRankId = epRankId;
            args[rankId].tpRankId = tpRankId;
            args[rankId].hcclEpComm = commsEp[tpRankId][epRankId];
            args[rankId].hcclTpComm = commsTp[epRankId][tpRankId];
            args[rankId].dispatchStream = dispatchStream[rankId];
            args[rankId].combineStream = combineStream[rankId];
            args[rankId].context = context[rankId];
            threads[rankId].reset(new(std::nothrow) std::thread(&LaunchOneProcessDispatchAndCombine, std::ref(args[rankId])));
        }
    
        for(uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
            threads[rankId]->join();
        }
    
        aclFinalize();
        LOG_PRINT("[INFO] aclFinalize success\n");
    
        return 0;
    }