```c++ #ifndef __HCCL_TEST_COMMON_H_ #define __HCCL_TEST_COMMON_H_ #include <stdio.h> #include <string.h> #include <getopt.h> #include <stdlib.h> #include <unistd.h> #include "mpi.h" #include <chrono> #include "hccl/hccl.h" #include "acl/acl.h" #define ACLCHECK(ret) do { \ if(ret != ACL_SUCCESS)\ {\ printf("acl interface return err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\ exit(EXIT_FAILURE);\ }\ } while(0) #define HCCLCHECK(ret) do { \ if(ret != HCCL_SUCCESS) \ { \ printf("hccl interface return errreturn err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret); \ exit(EXIT_FAILURE);\ } \ } while(0) #endif bool g_isDevice = false; bool setup(int device_id) { ACLCHECK(aclInit(nullptr)); ACLCHECK(aclrtSetDevice(device_id)); aclrtRunMode run_mode; ACLCHECK(aclrtGetRunMode(&run_mode)); bool g_is_device_return = (run_mode == ACL_DEVICE); return g_is_device_return; } // 请插入调用的通信算子执行函数 // HcclAllReduce 算子sample,通过main函数调用 int hccl_allreduce_sample(HcclComm &hcom) { // 指定通信算子调用所需的资源 int count = 256*256*10; aclrtStream stream; ACLCHECK(aclrtCreateStream(&stream)); aclFloat16 *bufer_host; ACLCHECK(aclrtMalloc((void **)&bufer_host, count*sizeof(aclFloat16), ACL_MEM_MALLOC_HUGE_FIRST)); aclFloat16 *bufer_dev; ACLCHECK(aclrtMalloc((void **)&bufer_dev, count*sizeof(aclFloat16), ACL_MEM_MALLOC_HUGE_FIRST)); // 获取参与集合通信的rank数量 unsigned int rankSize = 0; HCCLCHECK(HcclGetRankSize(hcom, &rankSize)); printf("Get rank size is %u \n", rankSize); // 获取参与集合通信的rank序号 unsigned int rank_id = 0; HCCLCHECK(HcclGetRankId(hcom, &rank_id)); printf("Get rank id is %u \n", rank_id); // 调用HcclAllReduce接口 HCCLCHECK(HcclAllReduce(bufer_host, bufer_dev, count, HCCL_DATA_TYPE_FP16, HCCL_REDUCE_SUM, hcom, stream)); aclrtSynchronizeStream(stream); // 资源析构 ACLCHECK(aclrtFree(bufer_dev)); ACLCHECK(aclrtDestroyStream(stream)); return 0; } // sample主函数 int main(int argc, char **argv) { int rank, size, device_id; MPI_Init(NULL, NULL); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &size); device_id = rank; g_isDevice = setup(device_id); HcclComm hcom; HcclRootInfo root_info; if (rank == 0) { // 获取root节点,root节点用户可指定,并非只可以设置为0节点 HCCLCHECK(HcclGetRootInfo(&root_info)); } MPI_Bcast(&root_info, HCCL_ROOT_INFO_BYTES, MPI_CHAR, 0, MPI_COMM_WORLD); // 调用Hccl初始化接口 HCCLCHECK(HcclCommInitRootInfo(size, &root_info, rank, &hcom)); int ret = 0; // 通信算子执行函数调用,用户可根据需要切换 ret = hccl_allreduce_sample(hcom); if (ret != 0) { printf("return error"); exit(EXIT_FAILURE); } printf("allreduce_sample executored suc \n"); HCCLCHECK(HcclCommDestroy(hcom)); ACLCHECK(aclrtResetDevice(device_id)); ACLCHECK(aclFinalize()); MPI_Finalize(); return 0; } ```