Quantization Process

This section describes the supported quantization layers of QAT, and API call sequence and example.

Currently, QAT supports quantization only for FP32 network models. For details about the quantization example, see. The layers that support QAT are listed as follows.

Table 1 Layers that support QAT and restrictions

Supported Layer Type

Restriction

Remarks

torch.nn.Linear

-

Layers sharing the weight and bias parameters do not support quantization.

torch.nn.Conv2d

  • padding_mode = zeros
  • Given hardware restrictions, do not perform QAT when the number of input channels (Cin) in the source model is less than or equal to 16, as this may hurt quantized model's inference accuracy.
  • The shape of the input data must be (N, Cin, Hin, Win).

torch.nn.ConvTranspose2d

  • padding_mode = zeros
  • Given hardware restrictions, do not perform QAT when the number of input channels (Cin) in the source model is less than or equal to 16, as this may hurt quantized model's inference accuracy.
  • The shape of the input data must be (N, Cin, Hin, Win).

API Call Sequence

Figure 1 shows the API call sequence for QAT.

Figure 1 API Call Sequence
The user implements the operations in blue, while those in gray are implemented by using AMCT APIs.
  1. Build a source PyTorch model and then generate a quantization configuration file by using the create_quant_retrain_config API.
  2. Call the create_quant_retrain_model API to modify the source model and insert activation and weight quantization into the model for calculating quantization parameters.
  3. Train the modified model. If the training is not interrupted, perform inference on the trained model. During the inference, the quantization factors are written into the record file.

    If the training process is interrupted, call the restore_quant_retrain_model API again based on the saved .pth model parameters and quantization configuration file to output a modified retrained network for further QAT. Then, perform inference.

  4. Call save_quant_retrain_model to insert quantization operators such as AscendQuant and AscendDequant and save the quantized model.

Examples

  1. Training is performed based on the PyTorch environment. Currently, only multi-device training in distribution mode (DistributedDataParallel) is supported. Multi-device training in DataParallel mode is not supported. If the DataParallel mode is used for training, an error is reported.
  2. Tweak the arguments passed to AMCT API calls as required. QAT relies on the user training result. Ensure that a PyTorch training script that yields satisfactory training accuracy is available.
  3. When the QAT feature of the AMCT is used, if the training process is suspended, check whether other ONNX Runtime programs are running on the current server (by running the top command). If yes, suspend other ONNX Runtime programs, perform QAT again.
  4. Perform quantization by referring to this section. If the model contains PyTorch custom operators, the ONNX model may fail to be exported. As a result, quantization fails. The detailed error is as follows: "Model cannot be quantized for it cannot be export to onnx!" In this case, you can perform quantization in single-operator mode by referring to QAT in Single-Operator Mode.
  5. Take the following steps to get started. Update the sample code based on your situation.
  1. Import the AMCT package and set the log level (see Post-installation Actions for details).
    1
    import amct_pytorch as amct
    
  2. (Optional) Run inference on the source model in the PyTorch environment based on the test dataset to validate the inference script and environment setup. (Update the sample code based on your situation.)

    This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    2
    3
    ori_model.load()
    # Test the model.
    user_test_model(ori_model, test_data, test_iterations)
    
  3. Run AMCT to start quantization.
    1. Generate a quantization configuration file.

      Before performing this step, restore the already trained parameters, for example, ori_model.load() in 2.

      1
      2
      3
      4
      5
      6
      config_file = './tmp/config.json'
      simple_cfg = './retrain.cfg'
      amct.create_quant_retrain_config(config_file=config_file,
                                       model=ori_model,
                                       input_data=ori_model_input_data,
                                       config_defination=simple_cfg)
      
    2. Modify the model.
      Insert operators related to activation quantization and weight quantization into the ori_model model to calculate quantization parameters, and save the model as a new training model retrain_model.
      1
      2
      3
      4
      5
      record_file = './tmp/record.txt'
      quant_retrain_model = amct.create_quant_retrain_model(config_file=config_file,
      						      model=ori_model,
      						      record_file=record_file,
      						      input_data=ori_model_input_data)
      
    3. Implement gradient descent optimization on the modified graph, train the graph on the training dataset, and calculate quantization factors. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph. Perform this step after 3.b.
        1
        optimizer = user_create_optimizer(quant_retrain_model)
        
      2. Restore the model from existing checkpoints and train the model.

        Note: Restore the model parameters from existing checkpoints and then train the model. The parameters saved during training should include quantization factors. Quantization factors are generated after the first batch_num training. If the number of training times is less than batch_num, the training fails.

        1
        2
        quant_pth = './ckpt/user_model'
        user_train_model(optimizer, quant_retrain_model, train_data)
        
      3. After the training is complete, run inference to calculate and save the quantization factors.
        1
        user_infer_graph(quant_retrain_model)
        
    4. Save the quantized model.
      Call the save_quant_retrain_model API to insert operators such as AscendQuant and AscendDequant and save the resultant model based on the quantization factors and the retrained model.
      1
      2
      3
      4
      5
      6
      quant_model_path = './result/user_model'
      amct.save_quant_retrain_model(config_file=config_file,
                                    model=quant_retrain_model,
                                    record_file=record_file,
                                    save_path=quant_model_path,
                                    input_data=ori_model_input_data)
      
  4. (Optional) Run inference on the fake-quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to analyze the accuracy. (Update the sample code based on your situation.)

    Check the accuracy loss of the fake-quantized model by comparing with that of the source model (see 2).

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)
    

If the training process is interrupted, restore data from the checkpoints to resume the training.

  1. Import the AMCT package and set the log level (see Post-installation Actions for details).
    1
    import amct_pytorch as amct
    
  1. Prepare a source model.
    1
    ori_model= user_create_model()
    
  2. Call AMCT to resume the QAT process.
    1. Modify the model, insert quantization operators into the ori_model model, and save the new model as retrain_model.
      1
      2
      3
      4
      5
      6
      7
      8
      9
      config_file = './tmp/config.json'
      simple_cfg = './retrain.cfg'
      record_file = './tmp/record.txt'
      quant_pth_file = './ckpt/user_model_newest.ckpt'
      quant_retrain_model = amct.restore_quant_retrain_model(config_file=config_file,
      						       model=ori_model,
      						       record_file=record_file,
      	                                               input_data=ori_model_input_data,
      	                                               pth_file=quant_pth_file)
      
    2. Implement gradient descent optimization on the modified graph, train the graph on the training dataset, and calculate quantization factors. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph. Perform this step after 3.a.
        1
        optimizer = user_create_optimizer(retrain_model)
        
      2. Restore the model from existing checkpoints and train the model.

        The quantization factors are saved to the checkpoints.

        1
        user_train_model(optimizer, retrain_model, train_data)
        
      3. After the training is complete, run inference to calculate and save the quantization factors.
        1
        user_infer_graph(train_graph, retrain_ops[-1].output_tensor)
        
    3. Save the model.
      1
      2
      3
      4
      5
      6
      quant_model_path = './result/user_model'
      amct.save_quant_retrain_model(config_file=config_file,
                                    model=ori_model,
                                    record_file=record_file,
                                    save_path=quant_model_path,
                                    input_data=ori_model_input_data)
      
  3. (Optional) Run inference on the quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. Update the sample code based on your situation.

    Check the accuracy loss of the fake-quantized model by comparing with that of the source model (see 2).

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)