Batch Matmul basic functions
Overview
Batch Matmul refers to the scenario where Matmul computing is processed in batches. In this scenario, the IterateBatch invoking interface is provided externally. IterateBatch can be invoked once to calculate multiple C matrices with the size of singleCoreM x singleCoreN.
Data needs to be moved in and out during a single Matmul computation. If multiple Matmul computations are performed and the input shape in a single Matmul computation is small, the movement overhead accounts for a large proportion of the total time. The IterateBatch interface is used to process Matmul in batches, which effectively improves the bandwidth utilization.
Currently, Batch Matmul supports four layout types: BSNGD, SBNGD, BNGS1S2, and NORMAL (BMNK data layout format). For details about the data layout format, see IterateBatch.
The following figure shows the Batch Matmul calculation of the NORMAL data layout format. The Matmul computation involves four matrix multiplication operations: mat_a1*mat_b1, mat_a2*mat_b2, mat_a3*mat_b3 and mat_a4 x mat_b4. Four singleCoreM x singleCoreN operations need to be computed on a single core. In this scenario, if the shape is small, you can consider it as a batch Matmul scenario for batch processing to improve performance. IterateBatch can calculate mat_c1 = mat_a1 * mat_b1, mat_c2 = mat_a2 * mat_b2, mat_c3 = mat_a3 * mat_b3 and mat_c4 = mat_a4 * mat_b4 at the same time.
Use Case
During Matmul computation, multiple C matrices with the size of singleCoreM x singleCoreN need to be computed, and the shape processed by a single Matmul computation is small.
Restrictions
- Only Norm Template is supported.
- For the BSNGD, SBNGD, and BNGS1S2 Layout formats, the total size of the multi-batch data of the input A and B matrices after alignment by fractal must be less than the size of the L1 Buffer. For the NORMAL layout format, there is no such restriction, but the batchMode parameter in MatmulConfig needs to be configured, that is, the relationship between the size of the multi-batch data of the input A and B matrices and the size of the L1 Buffer.
- For the BSNGD, SBNGD, and BNGS1S2 layouts, if the G axis of the left matrix and right matrix is ALayoutInfoG and BLayoutInfoG, respectively, the following equation applies: ALayoutInfoG/batchA = BLayoutInfoG/batchB. For the NORMAL layout, batchA and batchB must meet the multiple relationship. The batch in the shape (batch, n) of the bias must be the same as that of matrix C.
- If the output is transferred to the Unified Buffer, the size of matrix C (BaseM x BaseN) must be less than the allocated Unified Buffer memory size.
- For the BSNGD and SBNGD layouts, the input and output data must be in ND format. For the BNGS1S2 and NORMAL layouts, the input data can be in ND or NZ format.
- Batch Matmul does not support quantization/dequantization, that is, the SetQuantScalar and SetQuantVector APIs are not supported.
- In the BSNGD scenario, multiple rows of SDs cannot be computed at a time. Cyclic computation is required in the operator program.
- In asynchronous mode, data cannot be transferred to the Unified Buffer using the IterateBatch function.
- If the template parameter enableMixDualMaster (default value: false) is set to true, Batch Matmul is not supported in the MixDualMaster (dual-master mode) scenario.
- In the batch scenario, matrices A and B support the half/float/bfloat16_t/int8_t data type, but do not support the int4b_t data type.
Examples
The following is an example of calling Batch Matmul with the NORMAL data layout. For details about the complete example of Batch Matmul with the BSNDG data layout, see BatchMatmul sample.
- Tiling Implementation
Use SetBatchInfoForNormal to set the M/N/K axis information of matrices A, B, and C and the BatchNum of matrices A and B.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); matmul_tiling::MultiCoreMatmulTiling tiling(ascendcPlatform); int32_t M = 32; int32_t N = 256; int32_t K = 64; tiling->SetDim(1); tiling->SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); tiling->SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); tiling->SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); tiling->SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); tiling->SetShape(M, N, K); tiling->SetOrgShape(M, N, K); tiling->EnableBias(true); tiling->SetBufferSpace(-1, -1, -1); constexpr int32_t BATCH_NUM = 3; tiling->SetBatchInfoForNormal(BATCH_NUM, BATCH_NUM, M, N, K); // Set the matrix layout. tiling->SetBufferSpace(-1, -1, -1); optiling::TCubeTiling tilingData; int ret = tiling.GetTiling(tilingData);
- Kernel Implementation
- Create a Matmul object.
Set the input and output layout to NORMAL using MatmulType.
1 2 3 4 5 6 7 8
#include "lib/matmul_intf.h" typedef AscendC::MatmulType <AscendC::TPosition::GM, CubeFormat::ND, half, false, LayoutMode::NORMAL> aType; typedef AscendC::MatmulType <AscendC::TPosition::GM, CubeFormat::ND, half, true, LayoutMode::NORMAL> bType; typedef AscendC::MatmulType <AscendC::TPosition::GM, CubeFormat::ND, float, false, LayoutMode::NORMAL> cType; typedef AscendC::MatmulType <AscendC::TPosition::GM, CubeFormat::ND, float> biasType; constexpr MatmulConfig MM_CFG = GetNormalConfig(false, false, false, BatchMode::BATCH_LESS_THAN_L1); AscendC::Matmul<aType, bType, cType, biasType, MM_CFG> mm;
- Perform the initialization operation.
1REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling); // Initialize the matmul object.
- Set the left matrix A, right matrix B, and bias.
1 2 3
mm.SetTensorA(gm_a); // Set the left matrix A. mm.SetTensorB(gm_b); // Set the right matrix B. mm.SetBias(gm_bias); // Set the bias.
- Execute the matrix multiplication operation. The left matrix computes batchA pieces of MK data each time, and the right matrix computes batchB pieces of KN data each time.
1mm.IterateBatch(gm_c, batchA, batchB, false);
- End the matrix multiplication operation.
mm.End();
- Create a Matmul object.