QAT in Single-Operator Mode

Description

When basic quantization is performed by referring to Performs quantization aware training., the internal processing logic of quantization needs to convert the source model into an ONNX model and modify the graph based on the ONNX model. In this case, if the model contains PyTorch custom operators, the ONNX model may fail to be exported, leading to quantization failure.

QAT in single-operator mode provides custom QAT operators converted from native PyTorch operators. The quantization factors are retrained based on the converted operators and then stored in the QAT single-operator as operator parameters. In this way, the ONNX model does not need to be exported. After the training, the torch.onnx.export mechanism is used to establish the mapping between the QAT operator and the native ONNX operator. The parameters calculated by the QAT operator in the PyTorch model are passed to the native ONNX quantization operator to export the model. The following figure shows the diagram.

Figure 1 QAT in single-operator mode

This function can be used in either of the following methods:

  • Training from scratch: Use the QAT operator to construct graphs and start training from scratch. You can call the API for construction from scratch (provided in Single-operator mode) in the model building script to construct a QAT operator and use the operator to build a model.
  • Fine-tuning: Replace the operator to be quantized based on the existing network. This method is more commonly used. If you have built a model, call the native operator construction API provided in Single-operator mode to construct a QAT operator. Then replace the operator to be quantized in the network model by referring to the operator replacement solution in the following example.

QAT operator specifications are as follows. For the call example, see Sample List.

Table 1 QAT operator specifications

Type of the Operator to Be Quantized

New Operator Type

Restrictions

Remarks

torch.nn.Conv2d

Conv2dQAT

padding_mode = zeros

Layers sharing the weight and bias parameters do not support quantization.

torch.nn.ConvTranspose2d

ConvTranspose2dQAT

padding_mode = zeros

torch.nn.Conv3d

Conv3dQAT

padding_mode = zeros;

dilation_d = 1

torch.nn.Linear

LinearQAT

Channel wise is not supported.

torch.nn.LSTM

LSTMQAT

  • num_layers=1, bidirectional=False, dropout=0, proj_size=0
  • Only 3D input is supported.
  • Input and initial_h quantization share retrain_data_config of the user input. Weight and recurrence_weight share retrain_weight_config.
  • The hx input cannot be None, and the input cannot be PackedSequence.

-

torch.nn.GRU

GRUQAT

  • num_layers=1, bidirectional=False, dropout=0
  • Only 3D input is supported.
  • Input and initial_h quantization share retrain_data_config of the user input. Weight and recurrence_weight share retrain_weight_config.
  • The hx input cannot be None, and the input cannot be PackedSequence.

-

Samples

Code sample of training from scratch:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn

import amct_pytorch as amct
from amct_onnx.convert_model import convert_qat_model
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT
from amct_pytorch.nn.module.quantization.linear import LinearQAT


# Quantization configuration items of a QAT operator
config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": None,
        "clip_max": None
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

# Use a QAT single-operator to construct LeNet.
net = torch.nn.Sequential(
    Conv2dQAT(1, 6, kernel_size=5, padding=2, config=config), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    Conv2dQAT(6, 16, kernel_size=5, config=config), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    LinearQAT(16 * 5 * 5, 120, config=config), nn.Sigmoid(),
    LinearQAT(120, 84, config=config), nn.Sigmoid(),
    LinearQAT(84, 10, config=config)
)

# Training
train(net, train_data, test_data)
# Export the intermediate model.
torch.onnx.export(model, data, 'inter_model.onnx')
# Export the fake-quantized model and deployable model.
# Generate a lenet_fake_quant_model.onnx model that can be used for accuracy simulation in the ONNX execution framework ONNX Runtime.
# Generate lenet_deploy_model.onnx, which is a model file that can be deployed on the AI processor.
convert_qat_model('inter_model.onnx', './outputs/lenet')
# Use the fake-quantized model to perform accuracy simulation.
validata_onnx('./outputs/lenet_fake_quant_model.onnx', val_data)

Code sample of fine-tuning:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torchvision.models.resnet import resnet101

import amct_pytorch as amct
from amct_onnx.convert_model import convert_qat_model
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT

model = resnet101(pretrained=True)
# Quantization configuration items of a QAT operator
config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": None,
        "clip_max": None
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": True
    }
}

def _set_module(model, submodule_key, module):
    # Replace the native operator in the model with the QAT operator.
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)

for name, module in model.named_modules():
    # Traverse each node in the source graph and replace the native Torch operator of the Conv2d type with a custom QAT single-operator.
    if isinstance(module, torch.nn.Conv2d):
        qat_module = Conv2dQAT.from_float(
            module, config=config)
        _set_module(model, name, qat_module)

# Training process
train_and_val(model, train_data, test_data)
# Export the intermediate model.
torch.onnx.export(model, data, 'inter_model.onnx')
# Export the fake-quantized model and deployable model.
# Generate a resnet101_fake_quant_model.onnx model that can be used for accuracy simulation in the ONNX execution framework ONNX Runtime.
# Generate resnet101_deploy_model.onnx, which is a model file that can be deployed on the AI processor.
convert_qat_model('inter_model.onnx', './outputs/resnet101')
# Use the fake-quantized model to perform accuracy simulation.
validata_onnx('./outputs/resnet101_fake_quant_model.onnx', val_data)