SetQuantVector

Applicability

Product

Supported

Atlas A3 training products / Atlas A3 inference products

Atlas A2 training products / Atlas A2 inference products

Atlas 200I/500 A2 inference products

Atlas inference product 's AI Core

Atlas inference product 's Vector Core

x

Atlas training products

x

Function

Quantizes or dequantizes the output matrix using a vector. The input parameter is a vector with shape [1, N], where N corresponds to the N dimension in the Matmul computation (i.e., the number of columns). For each column of the output matrix, the coefficient from the corresponding element of the vector is applied to perform quantization or dequantization. For details about quantization and dequantization, see Quantization Scenarios.

Matmul dequantization scenario: During Matmul computation, the input of the left and right matrices is of the int8_t or int4b_t type, and the output is of the half type. Alternatively, both the input and output of the left and right matrices are of the int8_t type. In this scenario, when the data of matrix C is moved from CO1 to the global memory, dequantization is performed to dequantize the final result to the half or int8_t type.

Matmul quantization scenario: During Matmul computation, the input of the left and right matrices is of the half or bfloat16_t type, and the output is of the int8_t type. In this scenario, when the data of matrix C is moved from CO1 to the global memory, quantization is performed to quantize the final result to the int8_t type.

Prototype

1
__aicore__ inline void SetQuantVector(const GlobalTensor<uint64_t>& quantTensor)

Parameters

Parameter

Input/Output

Description

quantTensor

Input

The parameter vector used during quantization or dequantization operations.

Returns

None

Restrictions

The value must be the same as that of SetDequantType.

This API must be called before Iterate or IterateAll.

Example

1
2
3
4
5
6
7
8
GlobalTensor gmQuant;
...
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling);
mm.SetQuantVector(gmQuant);
mm.SetTensorA(gm_a);
mm.SetTensorB(gm_b);
mm.SetBias(gm_bias);
mm.IterateAll(gm_c);