quantize_preprocess

Function Usage

Preprocesses the quantization of a graph based on the quantization configuration file, inserts the balanced quantization operator, generates a balanced quantization factor record file record_file, and returns a torch.nn.module model ready for calibration.

Prototype

calibration_model = quantize_preprocess(config_file, record_file, model, input_data)

Command-Line Options

Option

Input/Return

Description

Restriction

config_file

Input

User-defined quantization configuration file, which specifies the configuration of each layer to be quantized.

A string

record_file

Input

Directory of the balanced-factor record file, including the file name.

A string

model

Input

Source model, with weights loaded.

A torch.nn.Module

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

calibration_model

Return

Result torch.nn.module model ready for calibration.

Default: None

A torch.nn.Module

Return Value

Result torch.nn.module model ready for calibration.

Outputs

None

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import amct_pytorch as amct
# Build a graph of the network to be quantized.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

tensor_balance_factor_record_file = os.path.join(TMP, 'tensor_balance_factor_record.txt')
modified_model = os.path.join(TMP, 'modified_model.onnx')
# Insert the quantization API.
calibration_model = amct.quantize_preprocess(config_json_file,
                                             tensor_balance_factor_record_file,
                                             model,
                                             input_data)