1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108 | #include <iostream>
#include <vector>
#include <memory>
#include <thread>
#include <chrono>
#include "hccl/hccl.h"
#include "hccl/hccl_types.h"
#include "mpi.h"
#define ACLCHECK(ret) do { \
if(ret != ACL_SUCCESS)\
{\
printf("acl interface return err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret);\
return ret;\
}\
} while(0)
#define HCCLCHECK(ret) do { \
if(ret != HCCL_SUCCESS) \
{ \
printf("hccl interface return errreturn err %s:%d, retcode: %d \n", __FILE__, __LINE__, ret); \
return ret;\
} \
} while(0)
struct ThreadContext {
HcclComm comm;
int32_t device;
};
int Sample(void *arg)
{
ThreadContext* ctx = (ThreadContext *)arg;
// 申请通信用device、sendBuf,recvBuf内存、stream等资源
ACLCHECK(aclrtSetDevice(ctx->device));
aclrtStream stream;
ACLCHECK(aclrtCreateStream(&stream));
void* sendBuff;
void* recvBuff;
void* hostBuff;
uint64_t count = 8;
int mallocSize = count * sizeof(float);
//初始化输入内存
ACLCHECK(aclrtMallocHost((void**)&hostBuff, mallocSize));
float* tmpHostBuff = static_cast<float*>(hostBuff);
for (uint32_t i = 0; i < count; ++i) {
tmpHostBuff[i] = 2;
}
ACLCHECK(aclrtMalloc((void**)&sendBuff, mallocSize, ACL_MEM_MALLOC_HUGE_FIRST));
ACLCHECK(aclrtMemcpy((void*)sendBuff, mallocSize, (void*)hostBuff, mallocSize, ACL_MEMCPY_HOST_TO_DEVICE));
ACLCHECK(aclrtMalloc((void**)&recvBuff, mallocSize, ACL_MEM_MALLOC_HUGE_FIRST));
// 执行SendRecv操作
if (ctx->device / 4 == 0) {
HCCLCHECK(HcclSend(sendBuff, count, HCCL_DATA_TYPE_FP32, ctx->device + 4, ctx->comm, stream));
} else {
HCCLCHECK(HcclRecv(recvBuff, count, HCCL_DATA_TYPE_FP32, ctx->device - 4, ctx->comm, stream));
}
ACLCHECK(aclrtSynchronizeStream(stream));
if (ctx->device / 4 == 1) {
void* resultBuff;
ACLCHECK(aclrtMallocHost((void**)&resultBuff, mallocSize));
ACLCHECK(aclrtMemcpy((void*)resultBuff, mallocSize, (void*)recvBuff, mallocSize, ACL_MEMCPY_DEVICE_TO_HOST));
float* tmpResBuff = static_cast<float*>(resultBuff);
for (uint32_t i = 0; i < count; ++i) {
std::cout << "rankId:" << ctx->device << ",i" << i << " " << tmpResBuff[i] << std::endl;
}
ACLCHECK(aclrtFreeHost(resultBuff));
}
// 释放通信用sendBuf、recvBuf内存,stream等资源
ACLCHECK(aclrtFreeHost(hostBuff));
ACLCHECK(aclrtFree(recvBuff));
ACLCHECK(aclrtFree(sendBuff));
ACLCHECK(aclrtDestroyStream(stream));
ACLCHECK(aclrtResetDevice(ctx->device));
HCCLCHECK(HcclCommDestroy(ctx->comm));
return 0;
}
int main() {
// 设备资源初始化
ACLCHECK(aclInit(NULL));
uint32_t ndev = 8;
int32_t devices[8] = {0, 1, 2, 3, 4, 5, 6, 7};
HcclComm comms[ndev];
for (int32_t i = 0; i < ndev; i++) {
ACLCHECK(aclrtSetDevice(devices[i]));
}
// 初始化通信域
HCCLCHECK(HcclCommInitAll(ndev, devices, comms));
// 启动线程执行集合通信操作
std::vector<std::unique_ptr<std::thread> > threads(ndev);
struct ThreadContext args[ndev];
for (uint32_t i = 0; i < ndev; i++) {
args[i].device = i;
args[i].comm = comms[i];
threads[i].reset(new (std::nothrow) std::thread(&Sample, (void *)&args[i]));
}
for (uint32_t i = 0; i < ndev; i++) {
threads[i]->join();
}
// 设备资源去初始化
ACLCHECK(aclFinalize());
return 0;
}
|