quantize_model

Applicability

Product

Supported

Atlas A3 training series products/Atlas A3 inference series products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference product

Atlas inference series products

Atlas training products

Description

Quantizes a graph based on the quantization configuration file, inserts the weight and activation quantization operators, generates a quantization factor record file record_file, and returns a torch.nn.Module model ready for calibration.

Prototype

1
calibration_model = quantize_model(config_file, modfied_onnx_file, record_file, model, input_data, input_names=None, output_names=None, dynamic_axes=None)

Parameters

Parameter

Input/Output

Description

config_file

Input

User-generated distillation configuration file, which is used to specify the configuration of the quantization layer in the model network.

A string.

modfied_onnx_file

Input

Name of the resultant ONNX model file.

A string.

record_file

Input

Path (including the file name) of the quantization factor record file.

A string.

model

Input

Original model for quantization, with weights loaded.

A torch.nn.Module.

input_data

Input

Input data of the model. A torch.tensor is replaced with an equivalent tuple(torch.tensor).

A tuple.

input_names

Input

Input names of the model, which are displayed in modified_onnx_file.

Default: None

A list of strings.

output_names

Input

Output names of the model, which are displayed in modified_onnx_file.

Default: None

A list of strings.

dynamic_axes

Input

Dynamic axes of the model inputs and outputs. For example, if an input has format NCHW, where N, H and W are dynamic, and an output has format NL, where N is dynamic, then: dynamic_axes={"inputs": [0,2,3], "outputs": [0]}.

Default: None

A dict<string, dict<python:int, string>>, or dict<string, list(int)>.

Returns

Returns the resultant torch.nn.Module model for calibration.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import amct_pytorch as amct
# Build a graph of the network to be quantized.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt')
modfied_model = os.path.join(TMP, 'modfied_model.onnx')
# Insert the quantization API.
calibration_model = amct.quantize_model(config_json_file,
                                        modfied_model,
                                        scale_offset_record_file,
                                        model,
                                        input_data,
                                        input_names=['input'],
                                        output_names=['output'],
                                        dynamic_axes={'input':{0: 'batch_size'},
                                                      'output':{0: 'batch_size'}})