QAT in Single-Operator Mode

Description

When basic quantization is performed by referring to Quantization Aware Training, the internal processing logic of quantization needs to convert the original model into an ONNX model and modify the graph based on the ONNX model. In this case, if the model contains PyTorch custom operators, the ONNX model may fail to be exported, leading to quantization failure.

QAT in single-operator mode provides custom QAT operators converted from native PyTorch operators. The quantization factors are retrained based on the converted operators and then stored in the QAT single-operator as operator parameters. In this way, the ONNX model does not need to be exported. After the training, the torch.onnx.export mechanism is used to establish the mapping between the QAT operator and the native ONNX operator. The parameters calculated by the QAT operator in the PyTorch model are passed to the native ONNX quantization operator to export the model. Figure 1 shows the diagram.

Figure 1 QAT in single-operator mode

The following figure shows the implementation principle (torch.nn.Conv2d is used as an example):

ARQ weight quantization and ULQ activation quantization operators are inserted into the torch.nn.Conv2d operator to retrain quantization factors. The quantization factors are saved as operator parameters in the Conv2dQAT single-operator. The torch.onnx.export API is called to export the QAT custom operator as the native ONNX quantization operators QuantizeLinear and DequantizeLinear based on the mapping. The model with the two quantization operators is called the Quantize and DeQuantize (QDQ) ONNX model. For details about the QDQ model, click here.

The QDQ ONNX model cannot be directly converted into an offline model adapted to the Ascend AI Processor using ATC. You need to use the QAT Model Adaptation to CANN Format feature in AMCT (ONNX) to adapt the model to a CANN model, and then use ATC to perform the next conversion.

This function supports two methods: training from scratch and fine-tuning.

  • Training from scratch: Use the QAT operator to construct graphs and start training from scratch. You can call the Single-Operator Mode API for construction from scratch in the model building script to construct a QAT operator and use the operator to build a model.
  • Fine-tuning: Replace the operator to be quantized based on the existing network. This method is more commonly used. If you have built a model, call the native operator construction Single-Operator Mode API to construct a QAT operator. Then replace the operator to be quantized in the network model by referring to the operator replacement solution in the following example.

QAT operator specifications are as follows. For the call example, see Sample List.

Table 1 QAT operator specifications

Type of the Operator to Be Quantized

New Operator Type

Restriction

Remarks

torch.nn.Conv2d

Conv2dQAT

  • padding_mode = zeros
  • Only 4D inputs are supported.

Layers sharing the weight and bias parameters do not support quantization.

torch.nn.ConvTranspose2d

ConvTranspose2dQAT

  • padding_mode = zeros
  • Only 4D inputs are supported.

torch.nn.Conv3d

Conv3dQAT

  • padding_mode = zeros
  • dilation_d = 1
  • Only 5D inputs are supported.

torch.nn.Linear

LinearQAT

Channel-wise quantization is not supported.

Samples

Code sample of training from scratch:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn

import amct_pytorch as amct
from amct_onnx.convert_model import convert_qat_model
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT
from amct_pytorch.nn.module.quantization.linear import LinearQAT


# Quantization configuration items of a QAT operator
config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": None,
        "clip_max": None
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": False
    }
}

# Use a QAT single-operator to construct LeNet.
net = torch.nn.Sequential(
    Conv2dQAT(1, 6, kernel_size=5, padding=2, config=config), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    Conv2dQAT(6, 16, kernel_size=5, config=config), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    LinearQAT(16 * 5 * 5, 120, config=config), nn.Sigmoid(),
    LinearQAT(120, 84, config=config), nn.Sigmoid(),
    LinearQAT(84, 10, config=config)
)

# Training
train(net, train_data, test_data)
# Export the intermediate model.
torch.onnx.export(model, data, 'inter_model.onnx')

# Export the fake-quantized model and deployable model.
# Generate a lenet_fake_quant_model.onnx model that can be used for accuracy simulation in the ONNX execution framework ONNX Runtime.
# Generate lenet_deploy_model.onnx, which is a model file that can be deployed on the AI processor.
convert_qat_model('inter_model.onnx', './outputs/lenet')

# Use the fake-quantized model to perform accuracy simulation.
validate_onnx('./outputs/lenet_fake_quant_model.onnx', val_data)

Code sample of fine-tuning:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torchvision.models.resnet import resnet101

import amct_pytorch as amct
from amct_onnx.convert_model import convert_qat_model
from amct_pytorch.nn.module.quantization.conv2d import Conv2dQAT

model = resnet101(pretrained=True)
# Quantization configuration items of a QAT operator
config = {
    "retrain_data_config": {
        "dst_type": "INT8",
        "batch_num": 10,
        "fixed_min": False,
        "clip_min": None,
        "clip_max": None
    },
    "retrain_weight_config": {
        "dst_type": "INT8",
        "weights_retrain_algo": "arq_retrain",
        "channel_wise": True
    }
}

def _set_module(model, submodule_key, module):
    # Replace the native operator in the model with the QAT operator.
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)
    setattr(cur_mod, tokens[-1], module)

for name, module in model.named_modules():
    # Traverse each node in the original graph and replace the native Torch operator of the Conv2d type with a custom QAT single-operator.
    if isinstance(module, torch.nn.Conv2d):
        qat_module = Conv2dQAT.from_float(
            module, config=config)
        _set_module(model, name, qat_module)

# Training process
train_and_val(model, train_data, test_data)
# Export the intermediate model.
torch.onnx.export(model, data, 'inter_model.onnx')

# Export the fake-quantized model and deployable model.
# Generate a resnet101_fake_quant_model.onnx model that can be used for accuracy simulation in the ONNX execution framework ONNX Runtime.
# Generate resnet101_deploy_model.onnx, which is a model file that can be deployed on the AI processor.
convert_qat_model('inter_model.onnx', './outputs/resnet101')

# Use the fake-quantized model to perform accuracy simulation.
validate_onnx('./outputs/resnet101_fake_quant_model.onnx', val_data)