Layer-wise Distillation
This section describes the layers that support layer-wise distillation, and API call sequence and example.
- Operators that support distillation and quantization:
- torch.nn.Linear: Layers sharing the weight and bias parameters do not support quantization.
- torch.nn.Conv2d: Quantization is supported only when padding_mode is set to zeros. Layers sharing the weight and bias parameters do not support quantization.
- Activation operators that support distillation:
- torch.nn.Relu
- torch.nn.LeakyRelu
- torch.nn.Sigmoid
- torch.nn.Tanh
- torch.nn.Softmax
- Normalization operator that supports distillation:
torch.nn.BatchNorm2d: Layers sharing the weight and bias parameters do not support distillation.
For the layer-wise distillation example, see Sample List.
API Call Sequence
Figure 1 shows the API call sequence.
- Creating the distillation configuration: You need to construct the source PyTorch model, and then use the create_distill_config API to combine the user-defined distillation configuration with the distillation configuration defined by the AMCT algorithm to output the distillation configuration of each network layer.
- Creating a distillation model: Call the create_distill_model API to modify the source model, and generate a distillation model based on the distillation configuration. Quantization before distillation
- Distilling the quantization model: Call the distill API to distill the network by block based on the user-configured model inference and optimization methods and the distillation configuration. After distillation, the quantized model can achieve better performance.
- Saving the distilled model: Call the save_distill_model API to save the distilled quantization model, including the fake-quantized model file in the ONNX Runtime environment or deployable model file on Ascend AI Processor.
Examples
- Take the following steps to get started. Update the sample code based on your situation.
- Tweak the arguments passed to AMCT API calls as required.
- Import the AMCT package and set the log level (see Post-installation Actions for details).
1import amct_pytorch as amct
- (Optional) Validate the inference script and environment setup in the source PyTorch environment. Update the sample code based on your situation.
You are advised to use the source model to be quantized and related test dataset for running inference in the PyTorch environment.
This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.
1user_do_inference_torch(ori_model, test_data, test_iterations)
- Call AMCT to distill the model.
- Call the create_distill_config API to create the distillation configuration. The configuration contains user-defined and automatically-searched distillation structures
1 2 3 4 5 6
config_file = './tmp/config.json' simple_cfg = './distill.cfg' amct.create_distill_config(config_file=config_file, model=ori_model, input_data=ori_model_input_data, config_defination=simple_cfg)
- Call the create_distill_model API to create a distillation model.Quantize the float-pointing model to be distilled and replace the operator to be compressed in it with a CANN quantization operator.
1 2 3 4
compress_model = amct.create_distill_model( config_file=config_file, model=ori_model, input_data=ori_model_input_data)
- Distill the model layer by layer.
Call the API to perform layer-by-layer distillation. Distillation is performed on the distillation structure in the configuration.
1 2 3 4 5 6 7 8 9 10
distill_model = amct.distill( model=ori_model, compress_model, config_file=config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=loss, optimizer=optimizer)
- Save the distilled model.Call the save_distill_model API to insert operators such as AscendQuant and AscendDequant in the model and save it as the distilled model.
1 2 3 4 5 6 7 8 9
amct.save_distill_model( model, "./model/distilled", input_data, record_file="./results/records.txt", input_names=['input'], output_names=['output'], dynamic_axes={'input':{0: 'batch_size'}, 'output':{0: 'batch_size'}})
- Call the create_distill_config API to create the distillation configuration. The configuration contains user-defined and automatically-searched distillation structures
- (Optional) Run inference on the distilled model (quant_model) in the ONNX Runtime environment based on the test dataset (test_data) to analyze the accuracy. Update the sample code based on your situation.
Compare the accuracy of the fake-quantized model with that of the source model (see 2).
1 2
quant_model = './results/user_model_fake_quant_model.onnx' user_do_inference_onnx(quant_model, test_data, test_iterations)
