Matrix-vector multiplication

Overview

Matrix-vector multiplication (General Matrix-Vector multiplication), namely, GEMV, refers to a scenario in which a matrix multiplication operation is performed on a left matrix A whose shape is (1, K) and a right matrix B whose shape is (K, N) in Matmul computation when M=1. Matmul allows you to enable the GEMV mode by setting the data format of matrix A to VECTOR on the tiling side and kernel side. In this way, the calculation scenario where M is 1 can be efficiently processed. If the GEMV mode is not enabled when M is 1, the M-direction is processed as a non-alignment scenario during Matmul calculation. Compared with the non-alignment processing mode, the GEMV mode transfers less data and provides better performance.

A Matmul in which M=1, K=256, N=32, and a data type of the left and right matrices is half is used as a specific example to describe an internal processing process of the Matmul API in the GEMV mode.

  • GEMV mode

    When matrix A is moved from A1 to A2, the 1 x 256 vector is processed as a 16 x 16 matrix. The 16 x 16 fractal matrix is moved at a time by calling the LoadData API. The transfer and matrix multiplication calculation of matrix B are the same as those in the basic scenario, as shown in the following figure.

    Figure 1 Matrix multiplication in GEMV mode (M=1)
  • Non-GEMV mode

    When matrix A is moved from A1 to A2, the 1 x 256 vector is processed as non-aligned matrix data. The M direction is aligned to 32 bytes before being moved. Each time the LoadData API is called to transfer a 16 x 16 fractal matrix, a total of 16 times (K/16) are transferred. As a result, the amount of transferred data increases, and the performance is poorer than that in GEMV mode, as shown in the following figure.

    Figure 2 Matrix multiplication in non-GEMV mode (M=1)

Use Case

Matrix multiplication calculation is performed on a matrix A (M=1, K>1) whose shape is (1, K), that is, data input to the matrix A is vector data.

Restrictions

  • In Matmul computation, to enable the GEMV mode, the original input shape M of matrix A must be equal to 1.
  • In the GEMV scenario, left matrix A does not support transpose.
  • In GEMV scenarios, the left matrix data in the global memory must be 16-byte aligned..

Examples

For a complete operator example, see matmul_gemv operator sample.

  • Tiling Implementation
    Call the SetAType API to set the data format of matrix A to CubeFormat::VECTOR. Other tiling implementation is the same as that in the basic scenario.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
    matmul_tiling::MatmulApiTiling tiling(ascendcPlatform);
    // Call the API to set the format of matrix A to CubeFormat::VECTOR.
    tiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::VECTOR, matmul_tiling::DataType::DT_FLOAT16);
    tiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); 
    tiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
    tiling.SetBiasType(AscendC::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); 
    ... // Other implementation content
    optiling::TCubeTiling tilingData;   
    int ret = tiling.GetTiling(tilingData);
    
  • Kernel implementation
    Compared with the basic scenario, in the GEMV scenario, the data format of the template parameter A_TYPE is set to CubeFormat::VECTOR when the Matmul object is created.
    1
    2
    3
    4
    5
    6
    7
    #include "lib/matmul_intf.h"
    
    using A_TYPE = AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::VECTOR, half>; 
    using B_TYPE = AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half>;
    using C_TYPE = AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>; 
    using BIAS_TYPE = AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float>; 
    AscendC::Matmul<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE> mm;