QAT in Single-Operator Mode
Description
When basic quantization is performed by referring to Performs quantization aware training., the internal processing logic of quantization needs to convert the source model into an ONNX model and modify the graph based on the ONNX model. In this case, if the model contains PyTorch custom operators, the ONNX model may fail to be exported, leading to quantization failure.
QAT in single-operator mode provides custom QAT operators converted from native PyTorch operators. The quantization factors are retrained based on the converted operators and then stored in the QAT single-operator as operator parameters. In this way, the ONNX model does not need to be exported. After the training, the torch.onnx.export mechanism is used to establish the mapping between the QAT operator and the native ONNX operator. The parameters calculated by the QAT operator in the PyTorch model are passed to the native ONNX quantization operator to export the model. The following figure shows the diagram.

This function can be used in either of the following methods:
- Training from scratch: Use the QAT operator to construct graphs and start training from scratch. You can call the API for construction from scratch (provided in Single-operator mode) in the model building script to construct a QAT operator and use the operator to build a model.
- Fine-tuning: Replace the operator to be quantized based on the existing network. This method is more commonly used. If you have built a model, call the native operator construction API provided in Single-operator mode to construct a QAT operator. Then replace the operator to be quantized in the network model by referring to the operator replacement solution in the following example.
QAT operator specifications are as follows. For the call example, see Sample List.
Type of the Operator to Be Quantized |
New Operator Type |
Restrictions |
Remarks |
|---|---|---|---|
torch.nn.Conv2d |
Conv2dQAT |
padding_mode = zeros |
Layers sharing the weight and bias parameters do not support quantization. |
torch.nn.ConvTranspose2d |
ConvTranspose2dQAT |
padding_mode = zeros |
|
torch.nn.Conv3d |
Conv3dQAT |
padding_mode = zeros; dilation_d = 1 |
|
torch.nn.Linear |
LinearQAT |
Channel wise is not supported. |
|
torch.nn.LSTM |
LSTMQAT |
|
- |
torch.nn.GRU |
GRUQAT |
|
- |
Samples
Code sample of training from scratch:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | import torch import torch.nn as nn import amct_pytorch as amct from amct_onnx.convert_model import convert_qat_model from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT from amct_pytorch.nn.module.quantization.linear import LinearQAT # Quantization configuration items of a QAT operator config = { "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": None, "clip_max": None }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": False } } # Use a QAT single-operator to construct LeNet. net = torch.nn.Sequential( Conv2dQAT(1, 6, kernel_size=5, padding=2, config=config), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), Conv2dQAT(6, 16, kernel_size=5, config=config), nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(), LinearQAT(16 * 5 * 5, 120, config=config), nn.Sigmoid(), LinearQAT(120, 84, config=config), nn.Sigmoid(), LinearQAT(84, 10, config=config) ) # Training train(net, train_data, test_data) # Export the intermediate model. torch.onnx.export(model, data, 'inter_model.onnx') # Export the fake-quantized model and deployable model. # Generate a lenet_fake_quant_model.onnx model that can be used for accuracy simulation in the ONNX execution framework ONNX Runtime. # Generate lenet_deploy_model.onnx, which is a model file that can be deployed on the AI processor. convert_qat_model('inter_model.onnx', './outputs/lenet') # Use the fake-quantized model to perform accuracy simulation. validata_onnx('./outputs/lenet_fake_quant_model.onnx', val_data) |
Code sample of fine-tuning:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import torch from torchvision.models.resnet import resnet101 import amct_pytorch as amct from amct_onnx.convert_model import convert_qat_model from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT model = resnet101(pretrained=True) # Quantization configuration items of a QAT operator config = { "retrain_data_config": { "dst_type": "INT8", "batch_num": 10, "fixed_min": False, "clip_min": None, "clip_max": None }, "retrain_weight_config": { "dst_type": "INT8", "weights_retrain_algo": "arq_retrain", "channel_wise": True } } def _set_module(model, submodule_key, module): # Replace the native operator in the model with the QAT operator. tokens = submodule_key.split('.') sub_tokens = tokens[:-1] cur_mod = model for s in sub_tokens: cur_mod = getattr(cur_mod, s) setattr(cur_mod, tokens[-1], module) for name, module in model.named_modules(): # Traverse each node in the source graph and replace the native Torch operator of the Conv2d type with a custom QAT single-operator. if isinstance(module, torch.nn.Conv2d): qat_module = Conv2dQAT.from_float( module, config=config) _set_module(model, name, qat_module) # Training process train_and_val(model, train_data, test_data) # Export the intermediate model. torch.onnx.export(model, data, 'inter_model.onnx') # Export the fake-quantized model and deployable model. # Generate a resnet101_fake_quant_model.onnx model that can be used for accuracy simulation in the ONNX execution framework ONNX Runtime. # Generate resnet101_deploy_model.onnx, which is a model file that can be deployed on the AI processor. convert_qat_model('inter_model.onnx', './outputs/resnet101') # Use the fake-quantized model to perform accuracy simulation. validata_onnx('./outputs/resnet101_fake_quant_model.onnx', val_data) |