create_quant_retrain_config
Description
Finds all quantizable layers in a graph, creates a quantization configuration file, and writes the quantization configuration of the quantizable layers to the configuration file.
Prototype
create_quant_retrain_config(config_file, model, input_data, config_defination=None)
Command-line options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
config_file |
Input |
Path of the QAT configuration file, including the file name. The existing file (if any) in the path will be overwritten upon this API call. |
A string |
model |
Input |
Source model, with weights loaded. |
A torch.nn.Module |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
config_defination |
Input |
Simplified configuration file. The simplified configuration file quant.cfg is generated from the retrain_config_pytorch.proto file. The retrain_config_pytorch.proto template is stored in /amct_pytorch/proto/retrain_config_pytorch.proto in the AMCT installation path. For details about the parameters in the retrain_config_pytorch.proto file and the generated simplified quantization configuration file quant.cfg, see Simplified QAT Configuration File. |
Default: None A string. |
Returns
None
Outputs
Outputs: A QAT configuration file in JSON format. (When QAT is performed again, this file output by the API will be overwritten.) The following is an example configuration file of INT8 quantization:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | { "version":1, "batch_num":1, "conv1":{ "retrain_enable":true, "retrain_data_config":{ "algo":"ulq_quantize", "dst_type":"INT8" }, "retrain_weight_config":{ "algo":"arq_retrain", "channel_wise":true, "dst_type":"INT8" } }, "layer1.0.conv1":{ "retrain_enable":true, "retrain_data_config":{ "algo":"ulq_quantize", "dst_type":"INT8" }, "retrain_weight_config":{ "algo":"arq_retrain", "channel_wise":true, "dst_type":"INT8" } }, "fc":{ "retrain_enable":true, "retrain_data_config":{ "algo":"ulq_quantize", "dst_type":"INT8" }, "retrain_weight_config":{ "algo":"arq_retrain", "channel_wise":false, "dst_type":"INT8" } } ... } |
Example
1 2 3 4 5 6 7 8 9 10 | import amct_pytorch as amct # Build a graph of the network to be quantized. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) # Create a quantization configuration file. amct.create_quant_retrain_config(config_file="./configs/config.json", model=model, input_data=input_data) |