create_quant_retrain_config

Description

Finds all quantizable layers in a graph, creates a quantization configuration file, and writes the quantization configuration of the quantizable layers to the configuration file.

Prototype

create_quant_retrain_config(config_file, model, input_data, config_defination=None)

Command-line options

Option

Input/Return

Description

Restriction

config_file

Input

Path of the QAT configuration file, including the file name.

The existing file (if any) in the path will be overwritten upon this API call.

A string

model

Input

Source model, with weights loaded.

A torch.nn.Module

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

config_defination

Input

Simplified configuration file.

The simplified configuration file quant.cfg is generated from the retrain_config_pytorch.proto file.

The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path.

For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file quant.cfg, see Simplified QAT Configuration File.

Default: None

A string.

Returns

None

Outputs

Outputs: A QAT configuration file in JSON format. (When QAT is performed again, this file output by the API will be overwritten.) The following is an example configuration file of INT8 quantization:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
{
    "version":1,
    "batch_num":1,
    "conv1":{
        "retrain_enable":true,
        "retrain_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "retrain_weight_config":{
            "algo":"arq_retrain",
            "channel_wise":true,
            "dst_type":"INT8"
        }
    },
    "layer1.0.conv1":{
        "retrain_enable":true,
        "retrain_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "retrain_weight_config":{
            "algo":"arq_retrain",
            "channel_wise":true,
            "dst_type":"INT8"
        }
    },
    "fc":{
        "retrain_enable":true,
        "retrain_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "retrain_weight_config":{
            "algo":"arq_retrain",
            "channel_wise":false,
            "dst_type":"INT8"
        }
    }
...
}

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import amct_pytorch as amct
# Build a graph of the network to be quantized.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

# Create a quantization configuration file.
amct.create_quant_retrain_config(config_file="./configs/config.json",
                            model=model,
                            input_data=input_data)