quantize_model

Function Usage

Quantizes a graph based on the quantization configuration file, inserts the weight and activation quantization operators, generates a quantization factor record file record_file, and returns a torch.nn.module model ready for calibration.

Prototype

calibration_model = quantize_model(config_file, modfied_onnx_file, record_file, model, input_data, input_names=None, output_names=None, dynamic_axes=None)

Command-Line Options

Option

Input/Return

Meaning

Restriction

config_file

Input

User-defined quantization configuration file, which specifies the configuration of each layer to be quantized.

A string

modfied_onnx_file

Input

File name of the result ONNX model.

A string

record_file

Input

Directory of the quantization factor record file, including the file name.

A string

model

Input

Source model, with weights loaded.

A torch.nn.Module

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

input_names

Input

Input names of the model, which are used in modfied_onnx_file.

Default: None

A list of strings.

output_names

Input

Output names of the model, which are used in modfied_onnx_file.

Default: None

A list of strings.

dynamic_axes

Input

Dynamic dimensions of the model's inputs and outputs. For example, if an input has format NCHW, where N, H and W are dynamic, and an output has format NL, where N is dynamic, then: {"inputs": [0,2,3], "outputs": [0]}.

Default: None

A dict<string, dict<python:int, string>>, or dict<string, list(int)>

calibration_model

Returns

Result torch.nn.module model ready for calibration.

Default: None

A torch.nn.Module

Returns

Result torch.nn.module model ready for calibration.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import amct_pytorch as amct
# Build a graph of the network to be quantized.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt')
modfied_model = os.path.join(TMP, 'modfied_model.onnx')
# Insert the quantization API.
calibration_model = amct.quantize_model(config_json_file,
                                        modfied_model,
                                        scale_offset_record_file,
                                        model,
                                        input_data,
                                        input_names=['input'],
                                        output_names=['output'],
                                        dynamic_axes={'input':{0: 'batch_size'},
                                                      'output':{0: 'batch_size'}})