Tiling Data
功能说明
AI Cpu启动下发通信任务前,需获取固定的通信配置Mc2Msg。在算子实现中,由Tiling组装通信配置项,通过配置固定参数和固定参数顺序的Tiling Data,将通信配置信息在调用AI Cpu通信接口时传递给AI Cpu。
参数说明
|
参数名 |
描述 |
|---|---|
|
preparePosition |
设置服务端组装任务的方式,用户需要在Tiling中显式赋值,当前支持的取值如下: 1:AI Cpu与AI Core通过通信任务机制实现消息传递和任务下发;由AI Core侧通过消息通知时设置为1,即算子中使用Hccl时设置为1。 |
|
sendOff |
预留参数,不可配置。 |
|
recvOff |
预留参数,不可配置。 |
|
tailSendOff |
预留参数,不可配置。 |
|
tailRecvOff |
预留参数,不可配置。 |
|
sendCnt |
预留参数,不可配置。 |
|
recvCnt |
预留参数,不可配置。 |
|
tailSendCnt |
预留参数,不可配置。 |
|
tailRecvCnt |
预留参数,不可配置。 |
|
totalCnt |
预留参数,不可配置。 |
|
turnNum |
预留参数,不可配置。 |
|
tailNum |
预留参数,不可配置。 |
|
stride |
预留参数,不可配置。 |
|
workspaceOff |
预留参数,不可配置。 |
|
notifyOff |
预留参数,不可配置。 |
|
notifyBeginCnt |
预留参数,不可配置。 |
|
notifyEndCnt |
预留参数,不可配置。 |
|
useBufferType |
设置通信算法获取输入数据的位置,参数取值如下:
|
|
funID |
预留参数,不可配置。 |
|
dataType |
预留参数,不可配置。 |
|
groupNum |
预留参数,不可配置。 |
|
reuseMode |
预留参数,不可配置。 |
|
commType |
预留参数,不可配置。 |
|
reduceOp |
预留参数,不可配置。 |
|
commOrder |
预留参数,不可配置。 |
|
waitPolicy |
预留参数,不可配置。 |
|
rspPolicy |
预留参数,不可配置。 |
|
exitPolicy |
预留参数,不可配置。 |
|
commAlg |
设置具体通信算法,用户需要在Tiling中显示赋值,当前支持的取值如下: 1:FullMesh算法,即NPU之间的全连接,任意两个NPU之间可以直接进行数据收发。详细的算法内容可参见集合通信算法。 |
|
taskType |
预留参数,不可配置。 |
|
debugMode |
预留参数,不可配置。 |
|
stepSize |
预留参数,不可配置。 |
|
sendArgIndex |
预留参数,不可配置。 |
|
recvArgIndex |
预留参数,不可配置。 |
|
commOutArgIndex |
预留参数,不可配置。 |
|
hasCommOut |
本卡的通信算法的计算结果是否输出到recvBuf(目的数据buffer地址)。仅AllGather算法与AlltoAll算法支持配置该参数。参数取值如下:
|
|
reserve |
保留字段。 |
|
reserve2 |
保留字段。 |
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
注意事项
- 算子的Tiling Data结构需要按顺序完整包含Mc2Msg参数。
- AI Cpu需获取固定数据结构的通信配置,算子注册Tiling Data时保持该结构的一致性。
调用示例
以自定义算子AllGatherMatmulCustom为例,如下为该算子的算子原型,"gather_out"为通信任务AllGather的输出。
[
{
"op": "AllGatherMatmulCustom",
"input_dsec": [
{
"name": "x1",
"param_type": "required",
"format": [
"ND"
],
"type": [
"float16",
"bfloat16"
]
},
{
"name": "x2",
"param_type": "required",
"format": [
"ND"
],
"type": [
"float16",
"bfloat16"
]
},
{
"name": "bias",
"param_type": "optional",
"format": [
"ND"
],
"type": [
"float16",
"bfloat16"
]
}
],
"output_desc":[
{
"name": "y",
"param_type": "required",
"format": [
"ND"
],
"type": [
"float16",
"bfloat16"
]
},
{
"name": "gather_out",
"param_type": "required",
"format": [
"ND"
],
"type": [
"float16",
"bfloat16"
]
}
],
"attr": [
{
"name": "group",
"dtype": "string",
"default_value":"",
"param_type":"required"
},
{
"name": "is_trans_a",
"dtype": "bool",
"default_value":false,
"param_type":"optional"
},
{
"name": "is_trans_b",
"dtype": "bool",
"default_value":false,
"param_type":"optional"
},
{
"name": "gather_index",
"dtype": "int",
"default_value":0,
"param_type":"optional"
},
{
"name": "comm_turn",
"dtype": "int",
"default_value":0,
"param_type":"optional"
},
{
"name": "rank_size",
"dtype": "int",
"default_value":0,
"param_type":"optional"
},
{
"name": "is_gather_out",
"dtype": "bool",
"default_value":true,
"param_type":"optional"
}
]
}
]
算子的Tiling Data结构需要按顺序完整包含Mc2Msg参数,如下为算子Tiling Data代码示例。
// 声明Mc2Msg结构
BEGIN_TILING_DATA_DEF(Mc2Msg)
TILING_DATA_FIELD_DEF(uint32_t, preparePosition);
TILING_DATA_FIELD_DEF(uint32_t, sendOff);
TILING_DATA_FIELD_DEF(uint32_t, recvOff);
TILING_DATA_FIELD_DEF(uint32_t, tailSendOff);
TILING_DATA_FIELD_DEF(uint32_t, tailRecvOff);
TILING_DATA_FIELD_DEF(uint64_t, sendCnt);
TILING_DATA_FIELD_DEF(uint32_t, recvCnt);
TILING_DATA_FIELD_DEF(uint32_t, tailSendCnt);
TILING_DATA_FIELD_DEF(uint32_t, tailRecvCnt);
TILING_DATA_FIELD_DEF(uint32_t, totalCnt);
TILING_DATA_FIELD_DEF(uint32_t, turnNum);
TILING_DATA_FIELD_DEF(uint32_t, tailNum);
TILING_DATA_FIELD_DEF(uint32_t, stride);
TILING_DATA_FIELD_DEF(uint32_t, workspaceOff);
TILING_DATA_FIELD_DEF(uint32_t, notifyOff);
TILING_DATA_FIELD_DEF(uint16_t, notifyBeginCnt);
TILING_DATA_FIELD_DEF(uint16_t, notifyEndCnt);
TILING_DATA_FIELD_DEF(uint8_t, useBufferType);
TILING_DATA_FIELD_DEF(uint8_t, funID);
TILING_DATA_FIELD_DEF(uint8_t, dataType);
TILING_DATA_FIELD_DEF(uint8_t, groupNum);
TILING_DATA_FIELD_DEF(uint8_t, reuseMode);
TILING_DATA_FIELD_DEF(uint8_t, commType);
TILING_DATA_FIELD_DEF(uint8_t, reduceOp);
TILING_DATA_FIELD_DEF(uint8_t, commOrder);
TILING_DATA_FIELD_DEF(uint8_t, waitPolicy);
TILING_DATA_FIELD_DEF(uint8_t, rspPolicy);
TILING_DATA_FIELD_DEF(uint8_t, exitPolicy);
TILING_DATA_FIELD_DEF(uint8_t, commAlg);
TILING_DATA_FIELD_DEF(uint8_t, taskType);
TILING_DATA_FIELD_DEF(uint8_t, debugMode);
TILING_DATA_FIELD_DEF(uint8_t, stepSize);
TILING_DATA_FIELD_DEF(uint8_t, sendArgIndex);
TILING_DATA_FIELD_DEF(uint8_t, recvArgIndex);
TILING_DATA_FIELD_DEF(uint8_t, commOutArgIndex);
TILING_DATA_FIELD_DEF(uint8_t, hasCommOut);
TILING_DATA_FIELD_DEF(uint8_t, reserve);
TILING_DATA_FIELD_DEF(uint32_t, reserve2);
END_TILING_DATA_DEF;
REGISTER_TILING_DATA_CLASS(Mc2MsgOp, Mc2Msg)
BEGIN_TILING_DATA_DEF(AllGatherMatmulCustomTilingData)
TILING_DATA_FIELD_DEF_STRUCT(Mc2Msg, msg);
END_TILING_DATA_DEF;// 配置Mc2Msg AllGatherMatmulCustomTilingData tiling; tiling.msg.set_preparePosition(1); tiling.msg.set_commAlg(1); tiling.msg.set_useBufferType(1); tiling.msg.set_hasCommOut(1);