create_quant_retrain_config

Function Usage

Finds all quantizable layers in a graph, creates a quantization configuration file, and writes the quantization configuration of the quantizable layers to the configuration file.

Constraints

None

Prototype

create_quant_retrain_config(config_file, graph, config_defination=None)

Command-Line Options

Option

Input/Return

Description

Restriction

config_file

Input

Path of the QAT configuration file, including the file name.

The existing file (if any) in the path will be overwritten upon this API call.

A string

graph

Input

A tf.Graph of the model to be quantized.

A tf.Graph.

config_defination

Input

Simplified QAT Configuration File

Whether to create a simplified configuration file quant.cfg from the retrain_config_tf.proto file in /amct_tensorflow/proto/retrain_config_tf.proto in the AMCT installation path.

For details about the parameters in the retrain_config_tf.proto file and the generated simplified quantization configuration file quant.cfg, see Simplified QAT Configuration File.

A string

Default: None

Return Value

None

Outputs

Outputs: A QAT configuration file in JSON format. (When QAT is performed again, this file output by the API will be overwritten.) The following is an example configuration file of INT8 quantization:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
{
    "version":1,
    "batch_num":1,
    "conv1":{
        "retrain_enable":true,
        "retrain_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "retrain_weight_config":{
            "algo":"arq_retrain",
            "channel_wise":true,
            "dst_type":"INT8"
        }
    },
    "conv2_1/expand":{
        "retrain_enable":true,
        "retrain_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "retrain_weight_config":{
            "algo":"arq_retrain",
            "channel_wise":true,
            "dst_type":"INT8"
        }
    }
}

Examples

1
2
3
4
5
PATH, _ = os.path.split(os.path.realpath(__file__))
config_path = os.path.join(PATH, 'resnet50_config.json')
simple_config = './retrain/retrain.cfg'
graph = tf.compat.v1.get_default_graph()
amct.create_quant_retrain_config(config_path, graph, simple_config)