Quantization/Dequantization of Matrix Multiplication Outputs

Overview

For specific input and output data types, Matmul can perform data quantization or dequantization on the elements of the output matrix C when moving the computation result from CO1 to the global memory.

  • Matmul quantization scenario: During Matmul computation, the left matrix A and right matrix B are of the half or bfloat16_t data type, and the output matrix C is of the int8_t data type. In this scenario, when the data of matrix C is moved from CO1 to the global memory, the quantization operation is performed to quantize the final result to the int8_t type, as shown in the following figure.
    Figure 1 Matmul quantization scenario
  • Matmul dequantization scenario: During Matmul computation, the left matrix A and right matrix B are of the int8_t or int4b_t data type, and the output matrix C is of the half data type. Alternatively, the left matrix A and right matrix B are of the int8_t data type, and the output matrix C is of the int8_t data type. In this scenario, when the data of matrix C is moved from CO1 to the global memory, the dequantization operation is performed to dequantize the final result to the corresponding half or int8_t type, as shown in the following figure.
    Figure 2 Matmul dequantization scenario
There are two Matmul quantization/dequantization modes: quantization/dequantization of the same coefficient and vector quantization/dequantization. You can call the SetDequantType API on the operator tiling side to set the quantization/dequantization mode. The differences between the two modes are as follows:
  • Quantization/dequantization mode of the same coefficient (PER_TENSOR mode): The entire matrix C corresponds to one quantization parameter, and the shape of the quantization parameter is [1]. Call the SetQuantScalar API in the operator kernel to set quantization parameters.
  • Vector quantization/dequantization mode (PER_CHANNEL mode): The shape of matrix C is [m, n]. Each channel dimension, that is, each column of matrix C, corresponds to a quantization parameter. The shape of the quantization parameter is [n]. Call the SetQuantVector API in the operator kernel to set quantization parameters.
Table 1 API Configuration in Quantization/Dequantization Mode

Mode

Tiling APIs

Kernel APIs

Quantization/Dequantization of the same coefficient

SetDequantType(DequantType::SCALAR)

SetQuantScalar(gmScalar)

Vector quantization/dequantization

SetDequantType(DequantType::TENSOR)

SetQuantVector(gmTensor)

Use Case

Quantization/Dequantization is required for the matrix computation result. In this case, the following table lists the data types supported by the Matmul input and output matrices.

Table 2 Data types supported by Matmul quantization/dequantization

Matrix A

Matrix B

Matrix C

Supporting Platform

half

half

int8_t

  • Atlas A3 training products / Atlas A3 inference products
  • Atlas A2 training products / Atlas A2 inference products

bfloat16_t

bfloat16_t

int8_t

  • Atlas A3 training products / Atlas A3 inference products
  • Atlas A2 training products / Atlas A2 inference products

int8_t

int8_t

half

  • Atlas A3 training products / Atlas A3 inference products
  • Atlas A2 training products / Atlas A2 inference products

int4b_t

int4b_t

half

  • Atlas A3 training products / Atlas A3 inference products
  • Atlas A2 training products / Atlas A2 inference products

int8_t

int8_t

int8_t

  • Atlas A3 training products / Atlas A3 inference products
  • Atlas A2 training products / Atlas A2 inference products

Restrictions

  • The quantization/dequantization mode set on the kernel side must be the same as that set on the tiling side.
    • The SetQuantScalar API is called on the kernel side to set the quantization/dequantization mode of the same coefficient, and the SetDequantType API is called on the tiling side to set the mode to DequantType::SCALAR.
    • The SetQuantVector API is called on the kernel side to set the vector quantization/dequantization mode, and the SetDequantType API is called on the tiling side to set the mode to DequantType::TENSOR.
  • If matrices A and B are of type int8_t or int4b_t and matrix C is of type half, the output of this feature does not support the INF_NAN mode. If the result needs to be output in INF_NAN mode, you are advised to output the result to TPosition::VECIN when calling the Matmul API, set the output data type to int32_t, and then use the high-level API AscendDequant based on the AIV core to dequantize the result to the half type.

Examples

For a complete operator example, see matmul_quant operator sample.

  • Tiling Implementation
    Call the SetDequantType API to set the quantization or dequantization mode. Other implementation details are the same as those in the basic scenario.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
    matmul_tiling::MatmulApiTiling tiling(ascendcPlatform); 
    tiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT8);
    tiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT8);   
    tiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT32);   
    tiling.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_INT32);   
    tiling.SetShape(M, N, K);   
    tiling.SetOrgShape(M, N, K);  
    tiling.EnableBias(true);
    tiling.SetDequantType(DequantType::SCALAR); // Setting the quantization/dequantization mode for the same coefficient
    // tiling.SetDequantType(DequantType::TENSOR); // Setting the quantization/dequantization mode for vectors
    ... // Perform other configurations.
    
  • Kernel Implementation
    Call the SetQuantScalar or SetQuantVector API to set quantization parameters based on the specific quantization mode. Other implementation details are the same as those in the basic scenario.
    • Quantization/Dequantization mode of the same coefficient
      1
      2
      3
      4
      5
      6
      7
      8
      REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling);
      float tmp = 0.1;  // Multiplied by 0.1 during GM output
      uint64_t ans = static_cast<uint64_t>(*reinterpret_cast<int32_t*>(&tmp)); // Quantization coefficient of the floating-point value converted to the uint64_t type for setting
      mm.SetQuantScalar(ans);
      mm.SetTensorA(gm_a);
      mm.SetTensorB(gm_b);
      mm.SetBias(gm_bias);
      mm.IterateAll(gm_c);
      
    • Quantization/Dequantization mode of vectors.
      1
      2
      3
      4
      5
      6
      7
      8
      GlobalTensor gmQuant;
      ...
      REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling);
      mm.SetQuantVector(gmQuant);
      mm.SetTensorA(gm_a);
      mm.SetTensorB(gm_b);
      mm.SetBias(gm_bias);
      mm.IterateAll(gm_c);