代码样例
```c++
#ifndef __HCCL_TEST_COMMON_H_
#define __HCCL_TEST_COMMON_H_
#include <stdio.h>
#include <string.h>
#include <getopt.h>
#include <stdlib.h>
#include <unistd.h>
#include "mpi.h"
#include <chrono>
#include "hccl/hccl.h"
#include "acl/acl.h"
#define ACLCHECK(ret) do { \
if(ret != ACL_SUCCESS)\
{\
printf("acl interface return err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\
exit(EXIT_FAILURE);\
}\
} while(0)
#define HCCLCHECK(ret) do { \
if(ret != HCCL_SUCCESS) \
{ \
printf("hccl interface return errreturn err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret); \
exit(EXIT_FAILURE);\
} \
} while(0)
#endif
bool g_isDevice = false;
bool setup(int device_id) {
ACLCHECK(aclInit(nullptr));
ACLCHECK(aclrtSetDevice(device_id));
aclrtRunMode run_mode;
ACLCHECK(aclrtGetRunMode(&run_mode));
bool g_is_device_return = (run_mode == ACL_DEVICE);
return g_is_device_return;
}
// 请插入调用的通信算子执行函数
// HcclAllReduce 算子sample,通过main函数调用
int hccl_allreduce_sample(HcclComm &hcom) {
// 指定通信算子调用所需的资源
int count = 256*256*10;
aclrtStream stream;
ACLCHECK(aclrtCreateStream(&stream));
aclFloat16 *bufer_host;
ACLCHECK(aclrtMalloc((void **)&bufer_host, count*sizeof(aclFloat16), ACL_MEM_MALLOC_HUGE_FIRST));
aclFloat16 *bufer_dev;
ACLCHECK(aclrtMalloc((void **)&bufer_dev, count*sizeof(aclFloat16), ACL_MEM_MALLOC_HUGE_FIRST));
// 获取参与集合通信的rank数量
unsigned int rankSize = 0;
HCCLCHECK(HcclGetRankSize(hcom, &rankSize));
printf("Get rank size is %u \n", rankSize);
// 获取参与集合通信的rank序号
unsigned int rank_id = 0;
HCCLCHECK(HcclGetRankId(hcom, &rank_id));
printf("Get rank id is %u \n", rank_id);
// 调用HcclAllReduce接口
HCCLCHECK(HcclAllReduce(bufer_host, bufer_dev, count, HCCL_DATA_TYPE_FP16, HCCL_REDUCE_SUM, hcom, stream));
aclrtSynchronizeStream(stream);
// 资源析构
ACLCHECK(aclrtFree(bufer_dev));
ACLCHECK(aclrtDestroyStream(stream));
return 0;
}
// sample主函数
int main(int argc, char **argv) {
int rank, size, device_id;
MPI_Init(NULL, NULL);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
device_id = rank;
g_isDevice = setup(device_id);
HcclComm hcom;
HcclRootInfo root_info;
if (rank == 0) {
// 获取root节点,root节点用户可指定,并非只可以设置为0节点
HCCLCHECK(HcclGetRootInfo(&root_info));
}
MPI_Bcast(&root_info, HCCL_ROOT_INFO_BYTES, MPI_CHAR, 0, MPI_COMM_WORLD);
// 调用Hccl初始化接口
HCCLCHECK(HcclCommInitRootInfo(size, &root_info, rank, &hcom));
int ret = 0;
// 通信算子执行函数调用,用户可根据需要切换
ret = hccl_allreduce_sample(hcom);
if (ret != 0) {
printf("return error");
exit(EXIT_FAILURE);
}
printf("allreduce_sample executored suc \n");
HCCLCHECK(HcclCommDestroy(hcom));
ACLCHECK(aclrtResetDevice(device_id));
ACLCHECK(aclFinalize());
MPI_Finalize();
return 0;
}
```
父主题: 集合通信opbase接口参考