Tiling Data
功能说明
AI Cpu启动下发通信任务前,需获取固定的通信配置。在算子实现中,由Tiling组装通信配置项,通过配置固定参数和固定参数顺序的Tiling Data,将通信配置信息在调用AI Cpu通信接口时传递给AI Cpu。
参数说明
参数名 |
描述 |
---|---|
preparePosition |
设置服务端组装任务的方式,参数取值如下:
|
sendOff |
预留参数,无需配置。Host下发通信任务使用。 |
recvOff |
预留参数,无需配置。Host下发通信任务使用。 |
tailSendOff |
预留参数,无需配置。Host下发通信任务使用。 |
tailRecvOff |
预留参数,无需配置。Host下发通信任务使用。 |
sendCnt |
预留参数,无需配置。Host下发通信任务使用。 |
recvCnt |
预留参数,无需配置。Host下发通信任务使用。 |
tailSendCnt |
预留参数,无需配置。Host下发通信任务使用。 |
tailRecvCnt |
预留参数,无需配置。Host下发通信任务使用。 |
totalCnt |
预留参数,无需配置。Host下发通信任务使用。 |
turnNum |
预留参数,无需配置。Host下发通信任务使用。 |
tailNum |
预留参数,无需配置。Host下发通信任务使用。 |
stride |
预留参数,无需配置。Host下发通信任务使用。 |
workspaceOff |
预留参数,无需配置。Host下发通信任务使用。 |
notifyOff |
预留参数,无需配置。Host下发通信任务使用。 |
notifyBeginCnt |
预留参数,无需配置。Host下发通信任务使用。 |
notifyEndCnt |
预留参数,无需配置。Host下发通信任务使用。 |
useBufferType |
预留参数,无需配置。Host下发通信任务使用。 |
funID |
预留参数,无需配置。Host下发通信任务使用。 |
dataType |
预留参数,无需配置。Host下发通信任务使用。 |
groupNum |
预留参数,无需配置。Host下发通信任务使用。 |
reuseMode |
预留参数,无需配置。Host下发通信任务使用。 |
commType |
预留参数,无需配置。Host下发通信任务使用。 |
reduceOp |
预留参数,无需配置。Host下发通信任务使用。 |
commOrder |
预留参数,无需配置。Host下发通信任务使用。 |
waitPolicy |
预留参数,无需配置。Host下发通信任务使用。 |
rspPolicy |
预留参数,无需配置。Host下发通信任务使用。 |
exitPolicy |
预留参数,无需配置。Host下发通信任务使用。 |
commAlg |
设置具体通信算法,参数取值如下: |
taskType |
预留参数,无需配置。Host下发通信任务使用。 |
debugMode |
预留参数,无需配置。Host下发通信任务使用。 |
stepSize |
预留参数,无需配置。Host下发通信任务使用。 |
sendArgIndex |
算子中第一个输入的参数索引。 |
recvArgIndex |
算子中第一个输出的参数索引。 |
commOutArgIndex |
通信任务的输出在算子原型中的参数索引。 |
hasCommOut |
预留参数,无需配置。Host下发通信任务使用。 |
reserve |
保留字段。 |
reserve2 |
保留字段。 |
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
注意事项
AI Cpu需获取固定数据结构的通信配置,算子注册Tiling Data时保持该结构的一致性。
调用示例
以自定义算子AllGatherMatmulCustom为例,如下为该算子的算子原型,"gather_out"为通信任务AllGather的输出。
[ { "op": "AllGatherMatmulCustom", "input_dsec": [ { "name": "x1", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "x2", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "bias", "param_type": "optional", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] } ], "output_desc":[ { "name": "y", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] }, { "name": "gather_out", "param_type": "required", "format": [ "ND" ], "type": [ "float16", "bfloat16" ] } ], "attr": [ { "name": "group", "dtype": "string", "default_value":"", "param_type":"required" }, { "name": "is_trans_a", "dtype": "bool", "default_value":false, "param_type":"optional" }, { "name": "is_trans_b", "dtype": "bool", "default_value":false, "param_type":"optional" }, { "name": "gather_index", "dtype": "int", "default_value":0, "param_type":"optional" }, { "name": "comm_turn", "dtype": "int", "default_value":0, "param_type":"optional" }, { "name": "rank_size", "dtype": "int", "default_value":0, "param_type":"optional" }, { "name": "is_gather_out", "dtype": "bool", "default_value":true, "param_type":"optional" } ] } ]
// 声明Mc2Msg结构 BEGIN_TILING_DATA_DEF(Mc2Msg) TILING_DATA_FIELD_DEF(uint32_t, preparePosition); TILING_DATA_FIELD_DEF(uint32_t, sendOff); TILING_DATA_FIELD_DEF(uint32_t, recvOff); TILING_DATA_FIELD_DEF(uint32_t, tailSendOff); TILING_DATA_FIELD_DEF(uint32_t, tailRecvOff); TILING_DATA_FIELD_DEF(uint64_t, sendCnt); TILING_DATA_FIELD_DEF(uint32_t, recvCnt); TILING_DATA_FIELD_DEF(uint32_t, tailSendCnt); TILING_DATA_FIELD_DEF(uint32_t, tailRecvCnt); TILING_DATA_FIELD_DEF(uint32_t, totalCnt); TILING_DATA_FIELD_DEF(uint32_t, turnNum); TILING_DATA_FIELD_DEF(uint32_t, tailNum); TILING_DATA_FIELD_DEF(uint32_t, stride); TILING_DATA_FIELD_DEF(uint32_t, workspaceOff); TILING_DATA_FIELD_DEF(uint32_t, notifyOff); TILING_DATA_FIELD_DEF(uint16_t, notifyBeginCnt); TILING_DATA_FIELD_DEF(uint16_t, notifyEndCnt); TILING_DATA_FIELD_DEF(uint8_t, useBufferType); TILING_DATA_FIELD_DEF(uint8_t, funID); TILING_DATA_FIELD_DEF(uint8_t, dataType); TILING_DATA_FIELD_DEF(uint8_t, groupNum); TILING_DATA_FIELD_DEF(uint8_t, reuseMode); TILING_DATA_FIELD_DEF(uint8_t, commType); TILING_DATA_FIELD_DEF(uint8_t, reduceOp); TILING_DATA_FIELD_DEF(uint8_t, commOrder); TILING_DATA_FIELD_DEF(uint8_t, waitPolicy); TILING_DATA_FIELD_DEF(uint8_t, rspPolicy); TILING_DATA_FIELD_DEF(uint8_t, exitPolicy); TILING_DATA_FIELD_DEF(uint8_t, commAlg); TILING_DATA_FIELD_DEF(uint8_t, taskType); TILING_DATA_FIELD_DEF(uint8_t, debugMode); TILING_DATA_FIELD_DEF(uint8_t, stepSize); TILING_DATA_FIELD_DEF(uint8_t, sendArgIndex); TILING_DATA_FIELD_DEF(uint8_t, recvArgIndex); TILING_DATA_FIELD_DEF(uint8_t, commOutArgIndex); TILING_DATA_FIELD_DEF(uint8_t, hasCommOut); TILING_DATA_FIELD_DEF(uint8_t, reserve); TILING_DATA_FIELD_DEF(uint32_t, reserve2); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(Mc2MsgOp, Mc2Msg) BEGIN_TILING_DATA_DEF(AllGatherMatmulCustomTilingData) TILING_DATA_FIELD_DEF_STRUCT(Mc2Msg, msg); END_TILING_DATA_DEF; // 配置Mc2Msg AllGatherMatmulCustomTilingData tiling; tiling.msg.set_preparePosition(1); tiling.msg.set_commAlg(1); tiling.msg.set_sendArgIndex(0); // 设置算子原型中第一个输入数据的参数索引 tiling.msg.set_recvArgIndex (3); // 设置算子原型中第一个输出数据的参数索引 tiling.msg.set_commOutArgIndex(4); // 设置算子原型中通信输出数据的参数索引