Workflow

This section describes the layers that support QAT, and API call sequence and example.

Currently, QAT supports quantization only for float32 network models. The layers that support QAT are listed as follows. For the quantization sample, see Sample List.

Table 1 Layers that support QAT as well as their restrictions

Supported Layer Type

Restriction

Remarks

torch.nn.Linear

-

Layers sharing the weight and bias parameters do not support quantization.

torch.nn.Conv2d

  • padding_mode = zeros
  • Given hardware restrictions, do not perform QAT when the number of input channels (Cin) in the original model is less than or equal to 16, as this may hurt the quantized deployable model's inference accuracy.
  • The shape of the input data must be (N, Cin, Hin, Win).

torch.nn.ConvTranspose2d

  • padding_mode = zeros
  • Given hardware restrictions, do not perform QAT when the number of input channels (Cin) in the original model is less than or equal to 16, as this may hurt the quantized deployable model's inference accuracy.
  • The shape of the input data must be (N, Cin, Hin, Win).

API Call Sequence

Figure 1 shows the API call sequence of QAT. The training environment uses the CPU/GPU environment (for the Atlas A2 training products/Atlas A2 inference products, the NPU environment is also supported) of the PyTorch framework. Based on the inference script of the open-source framework, the AMCT API is called to compress the model. The compressed model needs to be converted into an offline model that adapts to the Ascend AI Processor using the ATC before it can be used for inference on the Ascend AI Processor.

Figure 1 API call sequence
The user implements the operations in blue, while those in gray are implemented by using AMCT APIs.
  1. Build an original PyTorch model and then generate a quantization configuration file by using the create_quant_retrain_config API.
  2. Call the create_quant_retrain_model API to modify the original model and insert activation and weight quantization into the model for calculating quantization parameters.
  3. Train the modified model. If the training is not interrupted, perform inference on the trained model. During the inference, the quantization factors are written into the record file.

    If the training is interrupted, call the restore_quant_retrain_model API again based on the saved .pth model parameters and quantization configuration file to output a modified retrained network for further QAT. Then, perform inference.

  4. Call the save_quant_retrain_model API to insert quantization operators such as AscendQuant and AscendDequant and save the quantized model.

Example

  1. Training is performed based on the PyTorch environment. Currently, only multi-device training in distributed mode (DistributedDataParallel) is supported. Multi-device training in DataParallel mode is not supported. If the DataParallel mode is used for training, an error is reported.
  2. Tweak the arguments passed to AMCT API calls as required. QAT relies on the user training result. Ensure that a PyTorch training script that yields satisfactory training accuracy is available.
  3. When the QAT feature of AMCT is used, if the training process is suspended, check whether other ONNX Runtime programs are running on the current server (by running the top command). If yes, suspend other ONNX Runtime programs, and perform QAT again.
  4. Perform quantization by referring to this section. If the model contains PyTorch custom operators, the ONNX model may fail to be exported. As a result, quantization fails. The detailed error is as follows: "Model cannot be quantized for it cannot be export to onnx!" In this case, you can perform quantization in single-operator mode by referring to QAT in Single-Operator Mode.
  5. Take the following steps to get started. Update the sample code based on your situation.
  1. Import the AMCT package and set the log level using the environment variable in "AMCT (PyTorch)" in Post-installation Actions.
    1
    import amct_pytorch as amct
    
  2. (Optional) Run inference on the original model in the PyTorch environment based on the test dataset to validate the environment setup and inference script. (Update the sample code based on your situation.)

    This step is recommended as it guarantees a properly functioning original model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    2
    3
    ori_model.load()
    # Test the model.
    user_test_model(ori_model, test_data, test_iterations)
    
  3. Run AMCT to start quantization.
    1. Generate a quantization configuration file.

      Before performing this step, restore the already trained parameters, for example, ori_model.load() in 2.

      1
      2
      3
      4
      5
      6
      config_file = './tmp/config.json'
      simple_cfg = './retrain.cfg'
      amct.create_quant_retrain_config(config_file=config_file,
                                       model=ori_model,
                                       input_data=ori_model_input_data,
                                       config_defination=simple_cfg)
      
    2. Modify the model.
      Insert operators related to activation quantization and weight quantization into the ori_model model to calculate quantization parameters, and save the model as a new trained model retrain_model.
      1
      2
      3
      4
      5
      record_file = './tmp/record.txt'
      quant_retrain_model = amct.create_quant_retrain_model(config_file=config_file,
      						      model=ori_model,
      						      record_file=record_file,
      						      input_data=ori_model_input_data)
      
    3. Implement gradient descent optimization on the modified graph, train the graph on the training dataset, and calculate quantization factors. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph. Perform this step after 3.b.
        1
        optimizer = user_create_optimizer(quant_retrain_model)
        
      2. Restore the model from existing checkpoints and train the model.

        Note: Restore the model parameters from existing checkpoints and then train the model. The parameters saved during training should include quantization factors. Quantization factors are generated after the first batch_num training. If the number of training times is less than the value of batch_num, the training fails.

        1
        2
        quant_pth = './ckpt/user_model'
        user_train_model(optimizer, quant_retrain_model, train_data)
        
      3. After the training is complete, run inference to calculate and save the quantization factors.
        1
        user_infer_graph(quant_retrain_model)
        
    4. Save the quantized model.
      Call the save_quant_retrain_model API to insert operators such as AscendQuant and AscendDequant and save the quantized models based on the quantization factors and the retrained model.
      1
      2
      3
      4
      5
      6
      quant_model_path = './result/user_model'
      amct.save_quant_retrain_model(config_file=config_file,
                                    model=quant_retrain_model,
                                    record_file=record_file,
                                    save_path=quant_model_path,
                                    input_data=ori_model_input_data)
      

      Note: If this API is called to save the model in each round of retraining, the retraining may be abnormal. You are advised to call this API to save the model after the retraining is complete.

  4. (Optional) Run inference on the fake-quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. (Update the sample code based on your situation.)

    Check the accuracy drop (from quantization) of the fake-quantized model by comparing with that of the original model in 2.

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)
    

If the training is interrupted, restore data from the checkpoints to resume the training.

  1. Import the AMCT package and set the log level using the environment variable in "AMCT (PyTorch)" in Post-installation Actions.
    1
    import amct_pytorch as amct
    
  1. Prepare an original model.
    1
    ori_model= user_create_model()
    
  2. Run AMCT to resume the QAT process.
    1. Modify the model, insert quantization operators into the ori_model model, and save the new model as retrain_model.
      1
      2
      3
      4
      5
      6
      7
      8
      9
      config_file = './tmp/config.json'
      simple_cfg = './retrain.cfg'
      record_file = './tmp/record.txt'
      quant_pth_file = './ckpt/user_model_newest.ckpt'
      quant_retrain_model = amct.restore_quant_retrain_model(config_file=config_file,
      						       model=ori_model,
      						       record_file=record_file,
      	                                               input_data=ori_model_input_data,
      	                                               pth_file=quant_pth_file)
      
    2. Implement gradient descent optimization on the modified graph, train the graph on the training dataset, and calculate quantization factors. (Update the sample code based on your situation.)
      1. Implement gradient descent optimization on the modified graph. Perform this step after 3.a.
        1
        optimizer = user_create_optimizer(retrain_model)
        
      2. Restore the model from existing checkpoints and train the model.

        Note: Restore the model parameters from existing checkpoints and then train the model. The parameters saved during training should include quantization factors. Quantization factors are generated after the first batch_num training. If the number of training times is less than the value of batch_num, the training fails.

        1
        user_train_model(optimizer, retrain_model, train_data)
        
      3. After the training is complete, run inference to calculate and save the quantization factors.
        1
        user_infer_graph(train_graph, retrain_ops[-1].output_tensor)
        
    3. Save the model.
      1
      2
      3
      4
      5
      6
      quant_model_path = './result/user_model'
      amct.save_quant_retrain_model(config_file=config_file,
                                    model=ori_model,
                                    record_file=record_file,
                                    save_path=quant_model_path,
                                    input_data=ori_model_input_data)
      
  3. (Optional) Run inference on the fake-quantized model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to test the accuracy. (Update the sample code based on your situation.)

    Check the accuracy drop (from quantization) of the fake-quantized model by comparing with that of the original model in 2.

    1
    2
    quant_model = './results/user_model_fake_quant_model.onnx'
    user_do_inference_onnx(quant_model, test_data, test_iterations)