Activation Quantization Balance Preprocessing

In scenarios where activations are unevenly distributed, the result of per-tensor quantization on activations has a large error due to outliers, while per-channel quantization result has a small error. The current hardware does not support per-channel quantization for activations and only supports per-channel quantization for weights. To reduce the quantization error, this section introduces a special method based on AMCT.

Use the quantize_preprocess API to calculate the balance factor, perform a mathematical equivalent conversion between the model activations and weights to balance their distribution, and then migrate some of the quantization difficulties from activations to weights. The layers that support this feature are listed as follows.

Table 1 Supported layers and their restrictions

Supported Layer Type

Restriction

Remarks

torch.nn.Linear

-

Layers sharing the weight and bias parameters do not support quantization.

torch.nn.Conv2d

padding_mode = zeros

torch.nn.Conv3d

dilation_d = 1, dilation_h/dilation_w ≥ 1

padding_mode = zeros

torch.nn.ConvTranspose2d

padding_mode = zeros

API Call Sequence

Figure 1 shows the API call sequence for balance preprocessing.

Figure 1 API call sequence for balance preprocessing
The user implements the operations in blue, while those in gray are implemented by using AMCT APIs.
  1. Construct the original PyTorch model, set the DMQ parameters in the simplified configuration file dmp_quant.cfg (for details about the parameters in the simplified configuration file, see), and import the configuration file.
  2. Optimize the source PyTorch model using the quantize_preprocess API based on the quantization configuration file. The optimized model contains quantization algorithms.
  3. Calibrate the model by running forward passes once on the calibration dataset in the PyTorch environment to obtain the balanced quantization factor and save it into a file.
  4. Optimize the source PyTorch model again using the quantize_model API based on the PyTorch model, quantization configuration file, and record file. The optimized model contains quantization algorithms.
  5. Calibrate the model by running forward passes on the calibration dataset in the PyTorch environment to obtain the quantization factor and save it into a file.
  6. Call the save_model API to save the quantized model, including the fake-quantized model file in the ONNX Runtime environment or deployable model file on Ascend AI Processor.

Examples

  • Take the following steps to get started. Update the sample code based on your situation.
  • Tweak the arguments passed to AMCT API calls as required.
  1. Import the AMCT package and set the log level (see Post-installation Actions for details).
    1
    import amct_pytorch as amct
    
  2. (Optional) Validate the inference script and environment setup in the source PyTorch environment. Update the sample code based on your situation.

    You are advised to use the source model to be quantized and related test dataset for running inference in the PyTorch environment.

    This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    user_do_inference_torch(ori_model, test_data, test_iterations)
    
  3. Run AMCT to quantize the model.
    1. Generate a quantization configuration file.
      You can set the DMQ parameters in the simplified configuration file dmp_quant.cfg and import the configuration file through the config_defination parameter.
       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      config_defination = os.path.join(PATH, 'dmp_quant.cfg')
      config_file = './tmp/config.json'
      skip_layers = []
      batch_num = 1
      amct.create_quant_config(config_file=config_file,
      			        model=ori_model,
                                      input_data=ori_model_input_data,
      			        skip_layers=skip_layers,
      			        batch_num=batch_num,
                                      config_defination=config_defination)
      
    2. Modify the graph. Insert the quantize_preprocess operator into the graph to calculate the balanced quantization factor.
      1
      2
      3
      4
      5
      6
      record_file = './tmp/record.txt'
      modified_onnx_model = './tmp/modified_model.onnx'
      calibration_model = amct.quantize_preprocess(config_file=config_file,
      					            record_file=record_file,
                                                          model=ori_model,
                                                          input_data=ori_model_input_data)
      
    3. Run inference on the modified model (calibration_model) in the PyTorch environment based on the calibration dataset (calibration_data) to determine the balanced quantization factor. Update the sample code based on your situation.

      Pay attention to the following points:

      1. Ensure that the calibration dataset and the preprocessed data match the model to preserve the accuracy.
      2. Ensure that the number of forward passes is 1. If the number exceeds 1, the subsequent process will fail as the balanced quantization factor is recorded each time inference is executed.
      1
      user_do_inference_torch(calibration_model, calibration_data, test_iterations=1)
      
    4. Modify the graph to insert activation and weight quantization operators for quantization parameter calculation.
      1
      2
      3
      4
      5
      6
      modified_onnx_model = './tmp/modified_model.onnx'
      calibration_model = amct.quantize_model(config_file=config_file,
      					       modified_onnx_model=modified_onnx_model,
      					       record_file=record_file,
                                                     model=ori_model,
                                                     input_data=ori_model_input_data)
      
    5. Run inference on the modified model (calibration_model) in the PyTorch environment based on the calibration dataset (calibration_data) to determine the quantization factors. Update the sample code based on your situation.

      Pay attention to the following points:

      1. Ensure that the calibration dataset and the preprocessed data match the model to preserve the accuracy.
      2. Ensure that the number of forward passes (specified by batch_num) is large enough.
      1
      user_do_inference_torch(calibration_model, calibration_data, batch_num)
      
    6. Save the model.
      Call save_model to insert operators such as AscendQuant and AscendDequant and save the quantized models based on the quantization factors.
      1
      2
      3
      4
      quant_model_path = './results/user_model'
      amct.save_model(modified_onnx_file=modified_onnx_file,
                             record_file=record_file,
                             save_path=quant_model_path)
      
  4. (Optional) Run inference on the quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to analyze the accuracy. Update the sample code based on your situation.
    Compare the accuracy of the fake-quantized model with that of the source model (see 2).
    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)