Operator Implementation
In the implementation of MC² operators, HCCL high-level APIs are used for communication, and MatMul high-level APIs are used for matrix multiplication. For more information about collective communication and related concepts, see HCCL User Guide. The development process of MC² operators is the same as that of common operators. However, the current MC² operators do not support Kernel Launch or Graph (GE) Development, and only support Single-Operator API Calling.
The following uses the AllGatherMatmulCustom operator (AllGatherMatmul for short) as an example to describe the design and implementation process of MC² operators from the aspects of operator analysis, data flow analysis, operator project creation, prototype definition, tiling implementation, kernel implementation, and build and run. For details about the complete code of the operator in the example, see AllGatherMatmul sample. This example can run only on the
Operator Analysis
Operator analysis is to specify the mathematical expression, input, output, and kernel function name of an operator.
- Specify the mathematical expression and communication and compute logic of an operator.The AllGatherMatmul operator implements the fusion of AllGather communication and MatMul matrix multiplication. The operator logic is as follows: Perform AllGather communication on the input communication matrix a to obtain the left matrix of MatMul compute, that is, the communication result gather_out. Perform MatMul compute on gather_out and the right matrix b to obtain the output c. The corresponding mathematical expression is as follows:
gather_out = AllGather(a) c = gather_out ∗ b
- Specify the input, output, and attributes.
- a and b are the source operands. a is the input matrix for communication, with shape [M, K]. b is the right matrix of MatMul, with shape [K, N]. In this example, M, K, and N are fixed at 512, 5120, and 640, respectively.
- gather_out is the destination operand that stores the AllGather communication result. Its shape is [M * rankDim, K], where rankDim is the number of cards in the communicator and is fixed at 8 in this example.
- c is the destination operand, which stores the MatMul compute result. Its shape is [M * rankDim, N].
- Both the operator input and output are of type float16, and the format is ND.
- The attribute group is the communicator name of an operator, which specifies the communicator where the operator runs.
- Define the kernel function name and parameters.
- In this example, the kernel function is named all_gather_matmul_custom.
- Based on the analysis of the operator input and output, the kernel function has these parameters: aGM, bGM, cGM, and gatherOutGM. aGM and bGM indicate the memory address of the input in the global memory, and cGM and gatherOutGM indicates the memory address of the output in the global memory. Note that the parameter names of the kernel function are different from those of the input and output of the single-operator API call. This is because the parameters of the kernel function are the memory addresses of the input and output in the global memory, while the input and output types of the single-operator API call are aclTensor.
- Specify the APIs required for operator implementation.
- The operator involves AllGather communication. According to the communication-related APIs in the Ascend C API reference, HCCL high-level APIs need to be called to implement AllGather communication.
- The operator involves data movement between the local memory and global memory of the left and right MatMul matrices. For details about the data movement API, see the Ascend C API reference. Call DataCopy to move data.
- The compute process involves Cube compute. For details about the Cube compute API, see the Ascend C API reference. Call MatMul high-level APIs to implement matrix multiplication.
|
OpType |
AllGatherMatmulCustom |
|||
|---|---|---|---|---|
|
Operator input and output |
Name |
Shape |
Data Type |
Format |
|
a |
[512, 5120] |
float16 |
ND |
|
|
b |
[5120, 640] |
float16 |
ND |
|
|
c |
[4096, 640] |
float16 |
ND |
|
|
gather_out |
[4096, 5120] |
float16 |
ND |
|
|
Operator attribute |
group (char*): string on the host, indicating the communicator name. |
|||
|
Kernel function name |
all_gather_matmul_custom |
|||
Data Flow Analysis
The AllGatherMatmul operator performs AllGather communication between cards and Matmul compute within a card. The communication and compute are performed multiple times based on the tiles and tail blocks after data tiling, with pipelines mutually overlapped. During the analysis, assume that the communication matrix is tiled along the M axis. After tiling, the number of tiles (tileCnt) is 2, and that of tail blocks (tailCnt) is 1. The following figure shows communication-compute overlapping.
The function of AllGather is as follows: Re-sort the inputs of all cards in a communicator based on the card ID, combine the inputs, and send the result to all cards. Therefore, the AllGather result contains the data of the local card, that is, the communication matrix a input by the local card. The operator does not need to wait for the communication of this part of data to complete or tile the data. Instead, it can directly perform MatMul compute based on the complete communication matrix a. The AllGatherMatmul operator first performs MatMul compute on the local card data. This allows the communication of tile 1 to overlap with MatMul compute. Meanwhile, the compute of tile 1, tile 2, and tail block 1 no longer involve MatMul compute on the local card data, which reduces the compute workload of subsequent tiles and tail blocks, improves the communication-compute overlapping ratio, and boosts overall performance. Note that not all MC² operator are suitable for performing MatMul compute on the local card data first. Since communication precedes compute in the AllGatherMatmul operator, performing Matmul compute on local card data first enables overlapping between local card compute and the first communication. For operators where compute precedes communication, such as MatmulAllReduce, it is recommended that the compute on local card data be performed at last, overlapping with the last communication, as shown in the following figure.
AllGatherMatmul operator logic analysis:
- The AI Core writes the communication information to be executed to the message area in the global memory to deliver the task. The message area is the global memory with a specific address. The AI Core and AI CPU transfer messages between them by writing data to and reading data from the message area in polling mode. These operations are encapsulated in HCCL high-level APIs.Figure 3 Communication process of MC² operators
- The AI CPU reads all communication task information from the message area and starts to execute the first round of AllGather collective communication tasks based on links such as Huawei Cache Coherency System (HCCS) (used for high-speed interconnection between CPUs and NPUs) or RDMA over Converged Ethernet (RoCE) (an RDMA technology carried on the converged Ethernet, that is, RDMA communication across the Ethernet). At the same time, the AI Core starts to perform MatMul compute on the local card data.
The following figure shows the first round of communication and local card compute when the number of communication cards is 4. Tile 1 indicates the processing flow of the first round of communication and the matrix multiplication that overlaps with it. In the figure, the number in the form of X-Y in the small matrix after tiling indicates that the data block where the small matrix is located corresponds to the Yth data block on the Xth card.
Figure 4 First round of AllGatherMatmul communication and matrix multiplication of the local card data on rank0
- After completing the first round of communication, the AI CPU writes a message indicating that the first round of communication is complete to the message area and starts to execute the second round of communication. At the same time, after completing the MatMul compute on local card data, the AI Core waits for the message indicating that the first round of communication tasks are completed by polling the message area, and starts to execute the MatMul compute on the first round of communication result, that is, tile 1.The following figure shows the second round of communication and rank0 compute when the number of communication cards is 4. Tile 2 indicates the processing flow of the second round of communication and the matrix multiplication that overlaps with it.Figure 5 Second round of AllGatherMatmul communication and matrix multiplication of tile 1 on rank0
- Similar to step 3, complete the communication and compute of all remaining data blocks step by step.
Operator Project Creation
The procedure for creating an operator project for an MC² operator is the same as that for a common operator. For details, see Operator Project Creation. This example uses the custom operator project generation tool msOpGen to create an operator project for the AllGatherMatmul operator based on the following prototype definition JSON file.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
[ { "op": "AllGatherMatmulCustom", "input_desc": [ { "name": "a", "param_type": "required", "format": [ "ND" ], "type": [ "float16" ] }, { "name": "b", "param_type": "required", "format": [ "ND" ], "type": [ "float16" ] } ], "output_desc":[ { "name": "c", "param_type": "required", "format": [ "ND" ], "type": [ "float16" ] }, { "name": "gather_out", "param_type": "required", "format": [ "ND" ], "type": [ "float16" ] } ], "attr": [ { "name": "group", "type": "string", "default_value":"", "param_type":"required" } ] } ] |
Operator Prototype Definition
Compared with common operators, MC² operators have the following restrictions when implementing operator prototype definition:
- Define an attribute indicating the operator communicator name. The communicator is a context for implementing collective communication. It manages corresponding communication entities (for example, NPUs) and resources required for communication.
- Register the operator as an MC² operator through MC² in the prototype registration, and set the communicator name of the operator through HcclGroup.
The AllGatherMatmul operator uses the "group" attribute to indicate its communicator name. The attribute is defined in the operator prototype as follows:
1 2 3 |
this->Attr("group").AttrType(REQUIRED).String(); // "group" is an attribute of the MC² operator, indicating the communicator name. The String type in the prototype definition corresponds to the char* type in the single-operator API. ... this->MC²().HcclGroup("group"); // Set the "group" attribute to the communicator name of the operator. |
The complete prototype definition of the AllGatherMatmul operator is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
namespace ops { class AllGatherMatmulCustom : public OpDef { public: explicit AllGatherMatmulCustom(const char *name) : OpDef(name) { this->Input("a") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}); this->Input("b") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}) .IgnoreContiguous(); this->Output("c") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}); this->Output("gather_out") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}); this->Attr("group").AttrType(REQUIRED).String(); this->AICore().SetTiling(AllGatherMatmulCustomTilingFunc); // Register AllGatherMatmulCustomTilingFunc as the Tiling entrypoint function. this->AICore().AddConfig("ascendxxx"); // Replace ascendxxx with the actual Ascend AI Processor. this->MC2().HcclGroup("group"); } }; OP_ADD(AllGatherMatmulCustom); } |
Tiling implementation
The tiling strategy design for MC² operators mainly includes the communication tiling strategy, and MatMul multi-core tiling and intra-core tiling strategy.
- Communication tiling strategy: The size of the data block in each round of communication has a significant impact on the performance of MC² operators. In the example, the M-axis of the communication matrix A is tiled based on the length of the tile M-axis, which is 448. For details about how to determine the tiling strategy in specific scenarios, see MC² Operator Performance Tuning Cases.
- Matmul multi-core tiling and intra-core tiling:
- Multi-core tiling: Based on the current number of cores, tile the input shape (M, K, N) into multiple cores, resulting in single-core shapes singleCoreM, singleCoreK, and singleCoreN.
- Intra-core tiling: Based on the restrictions on the local memory size, further tile the single-core shapes to obtain the shape sizes baseM, baseN, and baseK of the matrices A, B, and C involved in a matrix multiplication instruction.
As described above, the communication matrix is tiled into the tile and tail block. MatMul compute needs to be separately performed on communication results of the tile and the tail block, and local card data. As shown in the following figure, the lengths of the tile, tail block, and local card data in the M-axis are tileM, tailM, and rankM, respectively. That is, the left matrix has three distinct shapes during MatMul compute. Therefore, it is necessary to respectively take the sizes of the tile, tail block, and local card data block of the communication matrix as the original input shapes for matrix multiplication, and call the Tiling APIs provided by MatMul high-level APIs to obtain the corresponding multi-core tiling and intra-core tiling strategies for these three shapes. For details about the concepts and principles of singleCoreM and baseM, see Basics.Figure 6 Matrix multiplication of the AllGatherMatmul operator on rank0
The procedure for tiling implementation are as follows:
- Define the tiling structure of the AllGatherMatmul operator.
The tiling structure of a MC² operator obtained through communication and MatMul fusion generally includes the following three parts:
- Tiling structures of HCCL high-level APIs Define the Mc2InitTiling and Mc2CcTiling parameters. The Mc2InitTiling parameter is used to initialize the communication task configuration and must be defined as the first parameter of the operator tiling structure. Mc2CcTiling indicates the parameter configuration of each communication task. Since the AllGatherMatmul operator contains only one communication task (AllGather), only one Mc2CcTiling parameter needs to be defined.
- TCubeTiling structures of MatMul high-level APIs. Generally, the shapes of the tile, tail block, and local card data block are different. Since TCubeTiling can only store the result of tiling compute for one input shape, it is necessary to define separate tiling structures for the tile, tail block, and local card data block to store their corresponding multi-core tiling and intra-core tiling strategies.
- Custom structure AllGatherMatmulTiling required by the AllGatherMatmul operator.
The complete tiling structure of the AllGatherMatmul operator is defined as follows:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
struct AllGatherMatmulTiling { uint32_t rankM; // Length of the M-axis of matrix A uint32_t rankN; // Length of the N-axis of matrix B uint32_t rankK; // Length of the K-axis of matrices A and B uint32_t tileNum; // Number of tiles uint32_t tailM; // Length of the M-axis of the tail block uint32_t tailNum; // Number of tail blocks (0 or 1) }; class AllGatherMatmulCustomTilingData { public: Mc2InitTiling mc2InitTiling; Mc2CcTiling mc2CcTiling; TCubeTiling localTiling; TCubeTiling tileTiling; TCubeTiling tailTiling; AllGatherMatmulTiling cfg; };
- Obtain the object pointer to the tiling structure of the AllGatherMatmul operator.
1AllGatherMatmulCustomTilingData *tiling = context->GetTilingData<AllGatherMatmulCustomTilingData>();
context is the object pointer to TilingContext. This pointer is automatically passed from the registered tiling entrypoint function AllGatherMatmulCustomTilingFunc by the framework and is used to store the context of the operator tiling compute. In tiling implementation of the AllGatherMatmul operator, you can use the context to obtain parameters such as the input and output shapes and input attributes required for tiling compute, and save the tiling result (for example, TilingKey and TilingData) to the context for subsequent operator execution.
- Set the custom tiling structure parameters of the operator.
1 2 3 4 5 6
tiling->cfg.tileNum = rankM / TILE_M; // TILE_M is a constant 448 in the example, indicating the length of the tile on the M-axis after the communication data block is tiled. tiling->cfg.tailM = rankM % TILE_M; tiling->cfg.tailNum = (rankM % TILE_M == 0) ? 0 : 1; tiling->cfg.rankM = rankM; tiling->cfg.rankN = rankN; tiling->cfg.rankK = rankK;
- Set the tiling structures of high-level Matmul APIs.Use matmul_tiling::MultiCoreMatmulTiling to obtain the TCubeTiling structure. First, create a multi-core tiling object mmTiling. Second, set the type information of parameters A, B, and C, as well as the M, N, and K shape information. Finally, call the GetTiling API to obtain the tiling information. For details, see MatMul Tiling Class. In the AllGatherMatmul operator, the preceding logic is encapsulated into the matmulTilingFunc function. Then, the matmulTilingFunc function is called based on the shape and size of the tile, tail block, and local card data to obtain the corresponding TCubeTiling parameters.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
// Encapsulate the function for setting the TCubeTiling structure into matmulTilingFunc. auto matmulTilingFunc = [&](int64_t m, int64_t n, int64_t k, TCubeTiling &cubeTiling) -> bool { matmul_tiling::MultiCoreMatmulTiling mmTiling; mmTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); mmTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); mmTiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); mmTiling.SetBias(false); mmTiling.SetDim(aicCoreNum); mmTiling.SetShape(m, n, k); mmTiling.SetOrgShape(m, n, k); mmTiling.SetBufferSpace(L1_BUFFER_SIZE); if (mmTiling.GetTiling(cubeTiling) != 0) { return false; } return true; }; // Set the Matmul TCubeTiling structure of the local card data. if (!matmulTilingFunc(rankM, rankN, rankK, tiling->localTiling)) { ERROR_LOG("Get local matmul tiling failed"); return ge::GRAPH_FAILED; } // Set the Matmul TCubeTiling structure of the tile. if (!matmulTilingFunc(TILE_M, rankN, rankK, tiling->tileTiling)) { ERROR_LOG("Get tile matmul tiling failed"); return ge::GRAPH_FAILED; } // Set the Matmul TCubeTiling structure of the tail block. if (!matmulTilingFunc(rankM % TILE_M, rankN, rankK, tiling->tailTiling)) { ERROR_LOG("Get tail matmul tiling failed"); return ge::GRAPH_FAILED; }
- Set the tiling structures of HCCL high-level APIs.Based on the communication task type and algorithm configuration, create an Mc2CcTilingConfig object. You can pass the references of the mc2InitTiling and mc2CcTiling members in the operator tiling structure to the GetTiling method to obtain the Mc2InitTiling and Mc2CcTiling parameters that need to be transferred to the kernel. For details about how to use the tiling structures of HCCL high-level APIs, see HCCL Tiling Usage Description.
1 2 3 4 5 6 7
uint32_t opType = HCCL_CMD_ALLGATHER; // Set the communication task type. std::string algConfig = "AllGather=level0:doublering"; // Set the communication algorithm. This parameter is reserved and does not take effect after being configured. uint32_t reduceType = HCCL_REDUCE_SUM; // Set the reduction operation type. This parameter is valid only for communication tasks that have reduction operations. For AllGather communication, you can directly use the default value HCCL_REDUCE_SUM. AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig, reduceType); mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); mc2CcTilingConfig.SetSkipLocalRankCopy(0); // The output gatherOut must contain matrix A of the local card. Therefore, set this parameter to 0. mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling);
Kernel Implementation
In the kernel implementation of the AllGatherMatmul operator, MatMul compute needs to be performed on the left matrices of three shapes: local card data, communication tile, and tail communication block. To avoid repeated code, it is necessary to abstract a general MatMul compute function that is applicable to different input shapes. Before designing the MatMul compute function, consider the following basic information required for MatMul compute:
- Addresses of input matrices A and B and output matrix C.
- TCubeTiling structure: contains information such as the shapes and data types of matrices A, B, and C, and the inter-core and intra-core tiling strategies for performing MatMul compute on matrices A and B.
In addition to the information required for MatMul compute, to quickly implement Matmul matrix multiplication, you can use the MatMul object in the MatMul high-level API to perform the compute. If the MatMul object is defined in the MatMul compute function, the MatMul object is instantiated and resources are released each time the function is called. This will cause high runtime overhead. Therefore, the object is also used as a parameter of the MatMul compute function to implement object reuse.
To sum up, the MatMul compute functions defined in the Kernel implementation for different input shapes are as follows. The MatMul compute function is named MatmulKernel. The input parameters aGM, bGM, and cGM indicate the addresses of the original input and output matrices to be computed. The input parameter tiling indicates the TCubeTiling structure, and the input parameter mm corresponds to the implementation class of MatMul high-level APIs. MATMUL_TYPE is a type alias that specializes the MatmulType template.
1 2 3 4 |
using MATMUL_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half>; __aicore__ inline void MatmulKernel(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, TCubeTiling &tiling, Matmul<MATMUL_TYPE, MATMUL_TYPE, MATMUL_TYPE> &mm) |
The MatmulKernel function is implemented as follows:
- The TCubeTiling structure stores the number of cores required for MatMul compute, and directly returns on cores that do not require compute, ending the compute.
1 2 3
if (GetBlockIdx() >= tiling.usedCoreNum) { return; }
- MatMul high-level APIs require that GlobalTensor be used as the input and output matrices. Therefore, three GlobalTensors, aGlobal, bGlobal, and cGlobal, are defined based on the addresses of the input matrices A, B, and C in the global memory.
1 2 3 4
GlobalTensor<half> aGlobal, bGlobal, cGlobal; aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(aGM), tiling.M * tiling.Ka); bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(bGM), tiling.Ka * tiling.N); cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cGM), tiling.M * tiling.N);
- To implement multi-core parallelism for higher efficiency, matrix data needs to be tiled and allocated to different cores for processing. The following figure shows the strategy of tiling only the M and N axes but not the K axis: In this scenario, each core needs to calculate the offset of the matrix data to be processed relative to the original matrix and use the offset matrix as the input parameter when passing matrices A, B, and C. In addition, to support the processing of the tail block data after core division, each core needs to compute the sizes of singleCoreM and singleCoreN to be processed and set them by calling the MatMul high-level API in the next step.Figure 7 MatMul compute core division
1 2 3 4 5 6 7 8 9 10
int mSingleBlocks = (tiling.M + tiling.singleCoreM - 1) / tiling.singleCoreM; int mCoreIndex = GetBlockIdx() % mSingleBlocks; int nCoreIndex = GetBlockIdx() / mSingleBlocks; // Compute the offset of the matrix data to be processed by the current core relative to the original matrix. int offsetA = mCoreIndex * tiling.Ka * tiling.singleCoreM; int offsetB = nCoreIndex * tiling.singleCoreN; int offsetC = mCoreIndex * tiling.N * tiling.singleCoreM + nCoreIndex * tiling.singleCoreN; // Compute singleCoreM/singleCoreN of the current core, which will be used as the input parameters of the subsequent SetTail API. int tailM = Std::min(tiling.M - mCoreIndex * tiling.singleCoreM, tiling.singleCoreM); int tailN = Std::min(tiling.N - nCoreIndex * tiling.singleCoreN, tiling.singleCoreN);
- Call Matmul high-level APIs to set the original complete shape for MatMul compute, the addresses of the input and output matrices processed by the current core, and the sizes of the computed singleCoreM and singleCoreN, and perform matrix multiplication.
1 2 3 4 5
mm.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); mm.SetTensorA(aGlobal[offsetA]); mm.SetTensorB(bGlobal[offsetB]); mm.SetTail(tailM, tailN); mm.IterateAll(cGlobal[offsetC]);
The kernel function of the AllGatherMatmul operator is defined as follows. The meanings of the aGM, bGM, cGM, and gatherOutGM parameters are described in Operator Analysis. workspaceGM and tilingGM indicate the addresses of the workspace and tiling data in the global memory, respectively.
1
|
extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, GM_ADDR gatherOutGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) |
The following describes the implementation procedure of the AllGatherMatmul operator.
- Matmul compute depends on AICs. Therefore, the control operator logic runs only on AICs. Use the ASCEND_IS_AIV macro to check whether the current core is an AIV. If so, end the running of the current core.
1 2 3
if ASCEND_IS_AIV { return; }
- Register an operator tiling structure, obtain data from the tiling structure, and initialize TPipe.
1 2 3
REGISTER_TILING_DEFAULT(AllGatherMatmulCustomTilingData); GET_TILING_DATA(tilingData, tilingGM); TPipe pipe;
- Define and assign values to variables required for subsequent compute.
1 2 3 4 5 6 7 8 9 10 11 12
auto &&localTiling = tilingData.localTiling; auto &&tileTiling = tilingData.tileTiling; auto &&tailTiling = tilingData.tailTiling; const auto tileNum = tilingData.cfg.tileNum; // Number of tiles const auto tailNum = tilingData.cfg.tailNum; // Number of tail blocks const auto aTileEleCnt = tileTiling.M * tileTiling.Ka; // Number of elements in the tile of the communication matrix const auto aTileSize = tileTiling.M * tileTiling.Ka * sizeof(half); // Number of bytes in the tile of the communication matrix const auto cTileSize = tileTiling.M * tileTiling.N * sizeof(half); // Number of bytes in the output matrix corresponding to the tile of the communication matrix const auto aTailEleCnt = tailTiling.M * tailTiling.Ka; // Number of elements in the tail block of the communication matrix const auto aRankEleCnt = localTiling.M * localTiling.Ka; // Number of elements in the communication matrix const auto aRankSize = localTiling.M * localTiling.Ka * sizeof(half); // Number of bytes in the communication matrix const auto cRankSize = localTiling.M * localTiling.N * sizeof(half); // Number of bytes in the output matrix corresponding to the communication matrix
- Initialize the HCCL object and deliver the AllGather communication task.
1 2 3 4 5 6 7
Hccl hccl; GM_ADDR contextGM = GetHcclContext<HCCL_GROUP_ID_0>(); hccl.InitV2(contextGM, &tilingData); hccl.SetCcTilingV2(offsetof(AllGatherMatmulCustomTilingData, mc2CcTiling)); auto handleId = hccl.AllGather<true>(aGM, gatherOutGM, aTileEleCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankEleCnt, tileNum); auto tailHandleId = hccl.AllGather<true>(aGM + tileNum * aTileSize, gatherOutGM + tileNum * aTileSize, aTailEleCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankEleCnt, tailNum);
- Initialize the MatMul object and perform MatMul compute on the local card data.
1 2 3 4
Matmul<MATMUL_TYPE, MATMUL_TYPE, MATMUL_TYPE> mm; REGIST_MATMUL_OBJ(GetTPipePtr(), GetSysWorkSpacePtr(), mm); mm.Init(&localTiling); MatmulKernel(aGM, bGM, cGM + hccl.GetRankId() * cRankSize, localTiling, mm);
- Wait for the tile communication to complete round by round and perform MatMul compute on it.
1 2 3 4 5 6 7 8 9 10 11 12 13
auto aAddr = gatherOutGM; auto cAddr = cGM; mm.Init(&tileTiling); for (uint32_t i = 0; i < tileNum; i++) { hccl.Wait(handleId); for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { if (rankId == hccl.GetRankId()) continue; MatmulKernel(aAddr + rankId * aRankSize, bGM, cAddr + rankId * cRankSize, tileTiling, mm); } aAddr += aTileSize; cAddr += cTileSize; }
- Wait for the communication of the tail block to complete and perform MatMul compute on it.
1 2 3 4 5 6 7 8 9 10 11
aAddr = gatherOutGM + tileNum * aTileSize; cAddr = cGM + tileNum * cTileSize; if (tailNum > 0) { mm.Init(&tailTiling); hccl.Wait(tailHandleId); for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { if (rankId == hccl.GetRankId()) continue; MatmulKernel(aAddr + rankId * aRankSize, bGM, cAddr + rankId * cRankSize, tailTiling, mm); } }
- Releases resources.
1 2
mm.End(); hccl.Finalize();
The preceding code is integrated. The complete kernel code is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
#define ASCENDC_CUBE_ONLY #include "kernel_operator.h" #include "lib/matmul_intf.h" #include "all_gather_matmul_custom_tiling.h" using namespace AscendC; using MATMUL_TYPE = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half>; __aicore__ inline void MatmulKernel(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, TCubeTiling &tiling, Matmul<MATMUL_TYPE, MATMUL_TYPE, MATMUL_TYPE> &mm) { if (GetBlockIdx() >= tiling.usedCoreNum) { return; } GlobalTensor<half> aGlobal, bGlobal, cGlobal; aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(aGM), tiling.M * tiling.Ka); bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(bGM), tiling.Ka * tiling.N); cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cGM), tiling.M * tiling.N); int mSingleBlocks = (tiling.M + tiling.singleCoreM - 1) / tiling.singleCoreM; int mCoreIndx = GetBlockIdx() % mSingleBlocks; int nCoreIndx = GetBlockIdx() / mSingleBlocks; int offsetA = mCoreIndx * tiling.Ka * tiling.singleCoreM; int offsetB = nCoreIndx * tiling.singleCoreN; int offsetC = mCoreIndx * tiling.N * tiling.singleCoreM + nCoreIndx * tiling.singleCoreN; int tailM = Std::min(tiling.M - mCoreIndx * tiling.singleCoreM, tiling.singleCoreM); int tailN = Std::min(tiling.N - nCoreIndx * tiling.singleCoreN, tiling.singleCoreN); mm.SetOrgShape(tiling.M, tiling.N, tiling.Ka, tiling.Kb); mm.SetTensorA(aGlobal[offsetA]); mm.SetTensorB(bGlobal[offsetB]); mm.SetTail(tailM, tailN); mm.IterateAll(cGlobal[offsetC]); } extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, GM_ADDR gatherOutGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) { if ASCEND_IS_AIV { return; } REGISTER_TILING_DEFAULT(AllGatherMatmulCustomTilingData); GET_TILING_DATA(tilingData, tilingGM); TPipe pipe; auto &&localTiling = tilingData.localTiling; auto &&tileTiling = tilingData.tileTiling; auto &&tailTiling = tilingData.tailTiling; const auto tileNum = tilingData.cfg.tileNum; // Number of tiles const auto tailNum = tilingData.cfg.tailNum; // Number of tail blocks const auto aTileEleCnt = tileTiling.M * tileTiling.Ka; // Number of elements in the tile of the communication matrix const auto aTileSize = tileTiling.M * tileTiling.Ka * sizeof(half); // Number of bytes in the tile of the communication matrix const auto cTileSize = tileTiling.M * tileTiling.N * sizeof(half); // Number of bytes in the tile of the output matrix const auto aTailEleCnt = tailTiling.M * tailTiling.Ka; // Number of elements in the tail block of the communication matrix const auto aRankEleCnt = localTiling.M * localTiling.Ka; // Number of elements in the communication matrix of a single card const auto aRankSize = localTiling.M * localTiling.Ka * sizeof(half); // Number of bytes in the communication matrix of a single card const auto cRankSize = localTiling.M * localTiling.N * sizeof(half); // Number of bytes in the output matrix of a single card Hccl hccl; GM_ADDR contextGM = GetHcclContext<HCCL_GROUP_ID_0>(); hccl.InitV2(contextGM, &tilingData); hccl.SetCcTilingV2(offsetof(AllGatherMatmulCustomTilingData, mc2CcTiling)); auto handleId = hccl.AllGather<true>(aGM, gatherOutGM, aTileEleCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankEleCnt, tileNum); auto tailHandleId = hccl.AllGather<true>(aGM + tileNum * aTileSize, gatherOutGM + tileNum * aTileSize, aTailEleCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankEleCnt, tailNum); Matmul<MATMUL_TYPE, MATMUL_TYPE, MATMUL_TYPE> mm; REGIST_MATMUL_OBJ(GetTPipePtr(), GetSysWorkSpacePtr(), mm); mm.Init(&localTiling); MatmulKernel(aGM, bGM, cGM + hccl.GetRankId() * cRankSize, localTiling, mm); auto aAddr = gatherOutGM; auto cAddr = cGM; mm.Init(&tileTiling); for (uint32_t i = 0; i < tileNum; i++) { hccl.Wait(handleId); for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { if (rankId == hccl.GetRankId()) continue; MatmulKernel(aAddr + rankId * aRankSize, bGM, cAddr + rankId * cRankSize, tileTiling, mm); } aAddr += aTileSize; cAddr += cTileSize; } aAddr = gatherOutGM + tileNum * aTileSize; cAddr = cGM + tileNum * cTileSize; if (tailNum > 0) { mm.Init(&tailTiling); hccl.Wait(tailHandleId); for (uint32_t rankId = 0; rankId < hccl.GetRankDim(); rankId++) { if (rankId == hccl.GetRankId()) continue; MatmulKernel(aAddr + rankId * aRankSize, bGM, cAddr + rankId * cRankSize, tailTiling, mm); } } mm.End(); hccl.Finalize(); } |
Build and run
The following briefly describes the AllGatherMatmul example from three steps: build, installation, and run.
- Build
Run the install.sh script to complete the build. For details, see the commands for generating a custom operator project and building the operator in the AllGatherMatmul sample.
The example directory structure is as follows. The AllGatherMatmulCustom directory is necessary for operator implementation. The install.sh script uses msOpGen to create a CustomOp directory in the 21_all_gather_matmul_custom directory, copies the operator implementation file to the corresponding directory, and calls build.sh, the build entry script generated by msOpGen, to build the operator.
├── 21_all_gather_matmul_custom │ ├── AclNNInvocation // Call the AllGatherMatmulCustom operator using ACLNN. │ ├── AllGatherMatmulCustom // AllGatherMatmulCustom operator project │ ├── all_gather_matmul_custom.json // Prototype definition JSON file of the AllGatherMatmulCustom operator │ ├── all_gather_matmul_demo_def.h // AllGatherMatmulCustom operator parameter configuration │ └── install.sh // Script for calling msOpGen to generate a custom operator project and build it
The following shows the directory structure of the CustomOp generated by msOpGen.
├── CustomOp // AllGatherMatmul custom operator project generated by msOpGen │ ├── cmake │ ├── op_host // Implementation file on the host │ ├── op_kernel // Implementation file on the kernel │ ├── scripts // Directory of scripts used for custom operator project packing │ ├── build.sh // Build entry script │ ├── CMakeLists.txt // Build script of the operator project │ └── CMakePresets.json // Build configuration options
- Installation
Before deploying the custom operator package, ensure that the environment variable ASCEND_OPP_PATH that specifies the default deployment path of the custom operator package exists in the environment.
1 2 3 4 5 6
# View the output of the environment variable. echo $ASCEND_OPP_PATH # If no output is displayed, set the environment variable. ASCEND_INSTALL_PATH indicates the installation path of the CANN software package. source [ASCEND_INSTALL_PATH]/set_env.bash # For example, source /usr/local/Ascend/cann/set_env.sh.
Run the following command to switch to the directory where the built custom operator installation package is located and install the custom operator package:
1 2
cd CustomOp/build_out ./custom_opp_<target os>_<target architecture>.run
After the command is executed successfully, related files in the custom operator package are deployed to the vendors/customize directory specified by the environment variable ASCEND_OPP_PATH.
- RunSwitch to the AclNNInvocation directory and run the run.sh script to run the single-operator sample.
1 2
cd ../../AclNNInvocation bash run.sh
The AclNNInvocation directory in the example provides the complete sample code for calling single-operator APIs. After the custom operator is built and deployed in the first two steps, a single-operator API is automatically generated and can be directly called in an application. Generally, the operator API is defined as a two-segment API. See the following example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Obtain the size of the workspace used by the operator. aclnnStatus aclnnAllGatherMatmulCustomGetWorkspaceSize( const aclTensor *a, const aclTensor *b, char *group, const aclTensor *cOut, const aclTensor *gatherOutOut, uint64_t *workspaceSize, aclOpExecutor **executor); // Execute the operator. aclnnStatus aclnnAllGatherMatmulCustom( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, const aclrtStream stream);
The first-phase API aclnnAllGatherMatmulCustomGetWorkspaceSize is used to compute the workspace size required during the API call. Allocate memory on the device based on workspaceSize, and call the second-phase API aclnnAllGatherMatmulCustom to perform the compute. For details, see Single-Operator API Calling.
In the MC² scenario, the application that calls the single-operator API needs to call the HCCL API Reference to create a communicator and execute the AllGatherMatmul operator on multiple threads. The following is the code example of the key steps in the main function and thread calling function, which is for reference only.1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
int main(int argc, char **argv) { // 1. Initialize AscendCL. if (aclInit(NULL) != ACL_SUCCESS) { ERROR_LOG("aclInit failed"); return FAILED; } // 2. Create a communicator. HcclComm comms[RANK_DIM]; // RANK_DIM indicates the number of cards, which is 8 in this example. int32_t devices[RANK_DIM]; for (int32_t i = 0; i < RANK_DIM; i++) { devices[i] = i; } if (HcclCommInitAll(RANK_DIM, devices, comms) != HCCL_SUCCESS) { ERROR_LOG("Hccl comm init failed."); (void)aclFinalize(); return FAILED; } // 3. Create multiple threads to call the AllGatherMatmul operator on all cards in the communicator. std::vector<std::unique_ptr<std::thread>> threads(RANK_DIM); for (uint32_t rankId = 0; rankId < RANK_DIM; rankId++) { threads[rankId].reset(new(std::nothrow) std::thread(&RunOp, rankId, std::ref(comms[rankId]))); } for (uint32_t rankId = 0; rankId < RANK_DIM; rankId++) { threads[rankId]->join(); } // 4. Deinitialize AscendCL. (void)aclFinalize(); return SUCCESS; }
In the main function, you can use HcclCommInitAll API to create a communicator for RANK_DIM cards in the current process. Each card corresponds to a thread to be created later. Each thread calls the RunOp function, which is responsible for allocating runtime resources on the card and calling the two-phase APIs of single-operator APIs. The following is a code example of the RunOp function.1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
bool RunOp(uint32_t rankId, HcclComm &comm) { // 1. Allocate resources such as the context and stream of the current thread. aclrtContext context; aclrtCreateContext(&context, rankId); aclrtStream stream; aclrtCreateStream(&stream); aclrtSetCurrentContext(context); // 2. Obtain the communicator name of the card corresponding to the current thread. char group[128] = {0}; HcclGetCommName(comm, group); // 3. Allocate device memory to store the input and output of the operator. // ...... // 4. Compute the workspace size and allocate memory. size_t workspaceSize = 0; aclOpExecutor *handle = nullptr; auto ret = aclnnAllGatherMatmulCustomGetWorkspaceSize(a, b, group, c, gatherOut, &workspaceSize, &handle); void *workspace = nullptr; if (workspaceSize != 0) { aclrtMalloc(&workspace, workspaceSize); } // 5. Run the operator. ret = aclnnAllGatherMatmulCustom(workspace, workspaceSize, handle, stream); // 6. Synchronize and wait. ret = aclrtSynchronizeStreamWithTimeout(stream, 10000); // 10,000 ms stream synchronization timeout // 7. Free the device memory for operator inputs, outputs, and workspace. // ...... // 8. Destroy allocations such as the communicator, context, and stream. (void)HcclCommDestroy(comm); (void)aclrtDestroyStream(stream); (void)aclrtDestroyContext(context); (void)aclrtResetDevice(rankId); return true; }