Tiling Data
功能说明
AI Cpu启动下发通信任务前,需获取固定的通信配置Mc2Msg。在算子实现中,由Tiling组装通信配置项,通过配置固定参数和固定参数顺序的Tiling Data,将通信配置信息在调用AI Cpu通信接口时传递给AI Cpu。
参数说明
参数名 |
描述 |
---|---|
preparePosition |
设置服务端组装任务的方式,用户需要在Tiling中显示赋值,当前支持的取值如下: 1:AI Cpu与AI Core通过通信任务机制实现消息传递和任务下发;由AI Core侧通过消息通知时设置为1,即算子中使用Hccl时设置为1。 |
sendOff |
预留参数,不可配置。 |
recvOff |
预留参数,不可配置。 |
tailSendOff |
预留参数,不可配置。 |
tailRecvOff |
预留参数,不可配置。 |
sendCnt |
预留参数,不可配置。 |
recvCnt |
预留参数,不可配置。 |
tailSendCnt |
预留参数,不可配置。 |
tailRecvCnt |
预留参数,不可配置。 |
totalCnt |
预留参数,不可配置。 |
turnNum |
预留参数,不可配置。 |
tailNum |
预留参数,不可配置。 |
stride |
预留参数,不可配置。 |
workspaceOff |
预留参数,不可配置。 |
notifyOff |
预留参数,不可配置。 |
notifyBeginCnt |
预留参数,不可配置。 |
notifyEndCnt |
预留参数,不可配置。 |
useBufferType |
设置通信算法获取输入数据的位置,参数取值如下:
|
funID |
预留参数,不可配置。 |
dataType |
预留参数,不可配置。 |
groupNum |
预留参数,不可配置。 |
reuseMode |
预留参数,不可配置。 |
commType |
预留参数,不可配置。 |
reduceOp |
预留参数,不可配置。 |
commOrder |
预留参数,不可配置。 |
waitPolicy |
预留参数,不可配置。 |
rspPolicy |
预留参数,不可配置。 |
exitPolicy |
预留参数,不可配置。 |
commAlg |
设置具体通信算法,用户需要在Tiling中显示赋值,当前支持的取值如下: 1:FullMesh算法,即NPU之间的全连接,任意两个NPU之间可以直接进行数据收发。详细的算法内容可参见集合通信算法。 |
taskType |
预留参数,不可配置。 |
debugMode |
预留参数,不可配置。 |
stepSize |
预留参数,不可配置。 |
sendArgIndex |
预留参数,不可配置。 |
recvArgIndex |
预留参数,不可配置。 |
commOutArgIndex |
预留参数,不可配置。 |
hasCommOut |
本卡的通信算法的计算结果是否输出到recvBuf(目的数据buffer地址)。仅AllGather算法与AlltoAll算法支持配置该参数。参数取值如下:
|
reserve |
保留字段。 |
reserve2 |
保留字段。 |
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
注意事项
- 算子的Tiling Data结构需要按顺序完整包含Mc2Msg参数。
- AI Cpu需获取固定数据结构的通信配置,算子注册Tiling Data时保持该结构的一致性。
调用示例
以自定义算子AllGatherMatmulCustom为例,如下为该算子的算子原型,"gather_out"为通信任务AllGather的输出。
[ { "op": "AllGatherMatmulCustom", "input_dsec": [ { "name": "x1", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "x2", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "bias", "param_type": "optional", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] } ], "output_desc":[ { "name": "y", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "gather_out", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] } ], "attr": [ { "name": "group", "dtype": "string", "default_value":"", "param_type":"required" }, { "name": "is_trans_a", "dtype": "bool", "default_value":false, "param_type":"optional" }, { "name": "is_trans_b", "dtype": "bool", "default_value":false, "param_type":"optional" }, { "name": "gather_index", "dtype": "int", "default_value":0, "param_type":"optional" }, { "name": "comm_turn", "dtype": "int", "default_value":0, "param_type":"optional" }, { "name": "rank_size", "dtype": "int", "default_value":0, "param_type":"optional" }, { "name": "is_gather_out", "dtype": "bool", "default_value":true, "param_type":"optional" } ] } ]
算子的Tiling Data结构需要按顺序完整包含Mc2Msg参数,如下为算子Tiling Data代码示例。
// 声明Mc2Msg结构 BEGIN_TILING_DATA_DEF(Mc2Msg) TILING_DATA_FIELD_DEF(uint32_t, preparePosition); TILING_DATA_FIELD_DEF(uint32_t, sendOff); TILING_DATA_FIELD_DEF(uint32_t, recvOff); TILING_DATA_FIELD_DEF(uint32_t, tailSendOff); TILING_DATA_FIELD_DEF(uint32_t, tailRecvOff); TILING_DATA_FIELD_DEF(uint64_t, sendCnt); TILING_DATA_FIELD_DEF(uint32_t, recvCnt); TILING_DATA_FIELD_DEF(uint32_t, tailSendCnt); TILING_DATA_FIELD_DEF(uint32_t, tailRecvCnt); TILING_DATA_FIELD_DEF(uint32_t, totalCnt); TILING_DATA_FIELD_DEF(uint32_t, turnNum); TILING_DATA_FIELD_DEF(uint32_t, tailNum); TILING_DATA_FIELD_DEF(uint32_t, stride); TILING_DATA_FIELD_DEF(uint32_t, workspaceOff); TILING_DATA_FIELD_DEF(uint32_t, notifyOff); TILING_DATA_FIELD_DEF(uint16_t, notifyBeginCnt); TILING_DATA_FIELD_DEF(uint16_t, notifyEndCnt); TILING_DATA_FIELD_DEF(uint8_t, useBufferType); TILING_DATA_FIELD_DEF(uint8_t, funID); TILING_DATA_FIELD_DEF(uint8_t, dataType); TILING_DATA_FIELD_DEF(uint8_t, groupNum); TILING_DATA_FIELD_DEF(uint8_t, reuseMode); TILING_DATA_FIELD_DEF(uint8_t, commType); TILING_DATA_FIELD_DEF(uint8_t, reduceOp); TILING_DATA_FIELD_DEF(uint8_t, commOrder); TILING_DATA_FIELD_DEF(uint8_t, waitPolicy); TILING_DATA_FIELD_DEF(uint8_t, rspPolicy); TILING_DATA_FIELD_DEF(uint8_t, exitPolicy); TILING_DATA_FIELD_DEF(uint8_t, commAlg); TILING_DATA_FIELD_DEF(uint8_t, taskType); TILING_DATA_FIELD_DEF(uint8_t, debugMode); TILING_DATA_FIELD_DEF(uint8_t, stepSize); TILING_DATA_FIELD_DEF(uint8_t, sendArgIndex); TILING_DATA_FIELD_DEF(uint8_t, recvArgIndex); TILING_DATA_FIELD_DEF(uint8_t, commOutArgIndex); TILING_DATA_FIELD_DEF(uint8_t, hasCommOut); TILING_DATA_FIELD_DEF(uint8_t, reserve); TILING_DATA_FIELD_DEF(uint32_t, reserve2); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(Mc2MsgOp, Mc2Msg) BEGIN_TILING_DATA_DEF(AllGatherMatmulCustomTilingData) TILING_DATA_FIELD_DEF_STRUCT(Mc2Msg, msg); END_TILING_DATA_DEF;
// 配置Mc2Msg AllGatherMatmulCustomTilingData tiling; tiling.msg.set_preparePosition(1); tiling.msg.set_commAlg(1); tiling.msg.set_useBufferType(1); tiling.msg.set_hasCommOut(1);