1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110 | #include <iostream>
#include <vector>
#include <memory>
#include <thread>
#include <chrono>
#include "hccl/hccl.h"
#include "hccl/hccl_types.h"
#include "mpi.h"
#define ACLCHECK(ret) do {\
if(ret != ACL_SUCCESS)\
{\
printf("acl interface return err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\
return ret;\
}\
} while(0)\
#define HCCLCHECK(ret) do {\
if(ret != HCCL_SUCCESS)\
{\
printf("hccl interface return errreturn err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\
return ret;\
}\
} while(0)
struct ThreadContext {
HcclComm comm;
int32_t device;
};
int Sample(void *arg)
{
ThreadContext* ctx = (ThreadContext *)arg;
// 申请通信用device、sendBuf,recvBuf内存、stream等资源
ACLCHECK(aclrtSetDevice(ctx->device));
aclrtStream stream;
ACLCHECK(aclrtCreateStream(&stream));
void* sendBuff;
void* recvBuff;
void* hostBuff;
uint64_t count = 8;
int mallocSize = count * sizeof(float);
//初始化输入内存
ACLCHECK(aclrtMallocHost((void**)&hostBuff, mallocSize));
float* tmpHostBuff = static_cast<float*>(hostBuff);
for (uint32_t i = 0; i < count; ++i) {
tmpHostBuff[i] = 2;
}
ACLCHECK(aclrtMalloc((void**)&sendBuff, mallocSize, ACL_MEM_MALLOC_HUGE_FIRST));
ACLCHECK(aclrtMemcpy((void*)sendBuff, mallocSize, (void*)hostBuff, mallocSize, ACL_MEMCPY_HOST_TO_DEVICE));
ACLCHECK(aclrtMalloc((void**)&recvBuff, mallocSize, ACL_MEM_MALLOC_HUGE_FIRST));
//执行SendRecv操作
if (ctx->device / 4 == 0) {
HCCLCHECK(HcclSend(sendBuff, count, HCCL_DATA_TYPE_FP32, ctx->device + 4, ctx->comm, stream));
} else {
HCCLCHECK(HcclRecv(recvBuff, count, HCCL_DATA_TYPE_FP32, ctx->device - 4, ctx->comm, stream));
}
ACLCHECK(aclrtSynchronizeStream(stream));
if (ctx->device / 4 == 1) {
void* resultBuff;
ACLCHECK(aclrtMallocHost((void**)&resultBuff, mallocSize));
ACLCHECK(aclrtMemcpy((void*)resultBuff, mallocSize, (void*)recvBuff, mallocSize, ACL_MEMCPY_DEVICE_TO_HOST));
float* tmpResBuff = static_cast<float*>(resultBuff);
for (uint32_t i = 0; i < count; ++i) {
std::cout << "rankId:" << ctx->device << ",i" << i << " " << tmpResBuff[i] << std::endl;
}
ACLCHECK(aclrtFreeHost(resultBuff));
}
// 释放通信用sendBuf、recvBuf内存,stream等资源
ACLCHECK(aclrtFreeHost(hostBuff));
ACLCHECK(aclrtFree(recvBuff));
ACLCHECK(aclrtFree(sendBuff));
ACLCHECK(aclrtDestroyStream(stream));
ACLCHECK(aclrtResetDevice(ctx->device));
HCCLCHECK(HcclCommDestroy(ctx->comm));
}
int main() {
MPI_Init(NULL, NULL);
int procSize = 0;
int procRank = 0;
// 获取当前进程在所属进程组的编号
MPI_Comm_size(MPI_COMM_WORLD, &procSize);
MPI_Comm_rank(MPI_COMM_WORLD, &procRank);
int devId = procRank;
int devCount = procSize;
//设备资源初始化
ACLCHECK(aclInit(NULL));
// 指定集合通信操作使用的设备
ACLCHECK(aclrtSetDevice(devId));
// 在 rootRank 获取 rootInfo
HcclRootInfo rootInfo;
int32_t rootRank = 0;
if(devId == rootRank) {
HCCLCHECK(HcclGetRootInfo(&rootInfo));
}
// 将root_info广播到通信域内的其他rank
MPI_Bcast(&rootInfo, HCCL_ROOT_INFO_BYTES, MPI_CHAR, rootRank, MPI_COMM_WORLD);
MPI_Barrier(MPI_COMM_WORLD);
// 初始化集合通信域
HcclComm hcclComm;
HCCLCHECK(HcclCommInitRootInfo(devCount, &rootInfo, devId, &hcclComm));
struct ThreadContext args;
args.comm = hcclComm;
args.device = devId;
Sample((void *)&args);
// 设备资源去初始化
ACLCHECK(aclFinalize());
MPI_Finalize();
return 0;
}
|