BatchWrite

功能说明

集合通信BatchWrite的任务下发接口,返回该任务的标识handleId给用户。

BatchWrite实现了一种不同昇腾AI Server(通常是8卡或16卡的昇腾NPU设备组成的服务器形态的统称)间批量的点对点通信,即多卡环境下,不同AI Server间的卡直接传输数据的通信模式。本接口支持将本卡的多份数据同时发送到不同AI Server的多张卡上。

函数原型

1
2
template <bool commit = false>
__aicore__ inline HcclHandle BatchWrite(GM_ADDR batchWriteInfo, uint32_t itemNum)

参数说明

表1 模板参数说明

参数名

输入/输出

描述

commit

输入

bool类型。参数取值如下:

  • true:在调用Prepare接口时,Commit同步通知服务端可以执行该通信任务。
  • false:在调用Prepare接口时,不通知服务端执行该通信任务。
表2 接口参数说明

参数名

输入/输出

描述

batchWriteInfo

输入

通信任务信息的Global Memory地址。一组通信数据的相关信息必须按如下格式保存,在执行通信任务时,可以同时指定多组通信任务信息,执行通信任务时批量发送数据。

1
2
3
4
5
6
7
struct BatchWriteItem {
    uint64_t localBuf;
    uint64_t remoteBuf;
    uint64_t count;
    uint32_t dataType;
    uint32_t remoteRankId;
};
  • localBuf:本端发送数据的window地址。
  • remoteBuf:对端接收数据的window地址。
  • count:该通信任务发送的数据个数。
  • dataType:该通信任务发送的数据类型,支持的类型可参考HcclDataType参数说明
  • remoteRankId:该通信任务发送数据的目的卡卡号。

itemNum

输入

批量任务的个数。该参数取值必须与batchWriteInfo中通信任务信息的组数一致。

返回值

返回该任务的标识handleId,handleId大于等于0。调用失败时,返回 -1。

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

约束说明

调用示例

假设本卡要将不同的数据分别发送到其它AI Server的2卡、3卡的指定位置,通过调用一次Batchwrite接口,实现批量点对点通信。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
struct BatchWriteItem {
    uint64_t localBuf;     // 本端发送数据的window地址
    uint64_t remoteBuf;    // 对端接收数据的window地址
    uint64_t count;        // 发送的数据个数
    uint32_t dataType;     // 发送的数据类型
    uint32_t remoteRankId; // 发送数据的目的卡号 
}; // 按接口的约定定义格式

extern "C" __global__ __aicore__ void BatchWrite_custom(GM_ADDR inputGM, GM_ADDR workspace, GM_ADDR tilingGM) 
{
    GM_ADDR userWS = GetUserWorkspace(workspace);
    if (userWS == nullptr) {
        return;
    }
    REGISTER_TILING_DEFAULT(BatchWriteCustomTilingData); // BatchWriteCustomTilingData为对应算子头文件定义的结构体
    auto tiling = (__gm__ BatchWriteCustomTilingData *)tilingGM;
    GM_ADDR contextGM = AscendC::GetHcclContext<0>();
    __gm__ void *mc2InitTiling = (__gm__ void *)(&tiling->mc2InitTiling);
    __gm__ void *batchWriteTiling = (__gm__ void *)(&(tiling->mc2CcTiling));

    if constexpr (g_coreType == AscendC::AIV) {
        Hccl hccl;
        hccl.Init(contextGM, mc2InitTiling);
        hccl.SetCcTiling(batchWriteTiling);

        __gm__ BatchWriteItem *sendInfo = reinterpret_cast<__gm__ BatchWriteItem *>(workspace);
        
        // 需要提前将待发送的数据从inputGM搬运到localBuf所填的window地址上
        sendInfo->localBuf = hccl.GetWindowsOutAddr(hccl.GetRankId());
        // 对端的接收地址也要是window地址,接收端需要考虑是否搬运到输出或者workspace上
        sendInfo->remoteBuf = hccl.GetWindowsInAddr(2U);
        sendInfo->count = 16U;
        sendInfo->dataType = HcclDataType::HCCL_DATA_TYPE_FP16;
        sendInfo->remoteRankId = 2U;
        
        // 可以组装多个通信任务,实现批量发送
        (sendInfo + 1)->localBuf = hccl.GetWindowsOutAddr(hccl.GetRankId());
        (sendInfo + 1)->remoteBuf = hccl.GetWindowsInAddr(3U);
        (sendInfo + 1)->count = 32U;
        (sendInfo + 1)->dataType = HcclDataType::HCCL_DATA_TYPE_BFP16;
        (sendInfo + 1)->remoteRankId = 3U;
         
        // 确保cache中的数据已刷新到GM地址上
        GlobalTensor<int64_t> tempTensor;
        tempTensor.SetGlobalBuffer((__gm__ int64_t *)sendInfo);
        DataCacheCleanAndInvalid<int64_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(tempTensor);

        auto handleId = hccl.BatchWrite<true>(sendInfo, 2U);
        // wait仅表示本端发送完毕,对端是否接收到数据需要在对端判断
        hccl.Wait(handleId);    
        AscendC::SyncAll();
        hccl.Finalize();
    }
}