quantize_preprocess
Function Usage
Preprocesses the quantization of a graph based on the quantization configuration file, inserts the balanced quantization operator, generates a balanced quantization factor record file record_file, and returns a torch.nn.module model ready for calibration.
Prototype
calibration_model = quantize_preprocess(config_file, record_file, model, input_data)
Command-Line Options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
config_file |
Input |
User-defined quantization configuration file, which specifies the configuration of each layer to be quantized. |
A string |
record_file |
Input |
Directory of the balanced-factor record file, including the file name. |
A string |
model |
Input |
Source model, with weights loaded. |
A torch.nn.Module |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
calibration_model |
Return |
Result torch.nn.module model ready for calibration. |
Default: None A torch.nn.Module |
Return Value
Result torch.nn.module model ready for calibration.
Outputs
None
Examples
1 2 3 4 5 6 7 8 9 10 11 12 13 | import amct_pytorch as amct # Build a graph of the network to be quantized. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) tensor_balance_factor_record_file = os.path.join(TMP, 'tensor_balance_factor_record.txt') modified_model = os.path.join(TMP, 'modified_model.onnx') # Insert the quantization API. calibration_model = amct.quantize_preprocess(config_json_file, tensor_balance_factor_record_file, model, input_data) |