quantize_model
Applicability
Product |
Supported |
|---|---|
√ |
|
√ |
|
√ |
|
√ |
|
√ |
Description
Quantizes a graph based on the quantization configuration file, inserts the weight and activation quantization operators, generates a quantization factor record file record_file, and returns a torch.nn.Module model ready for calibration.
Prototype
1 | calibration_model = quantize_model(config_file, modfied_onnx_file, record_file, model, input_data, input_names=None, output_names=None, dynamic_axes=None) |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
config_file |
Input |
User-generated distillation configuration file, which is used to specify the configuration of the quantization layer in the model network. A string. |
modfied_onnx_file |
Input |
Name of the resultant ONNX model file. A string. |
record_file |
Input |
Path (including the file name) of the quantization factor record file. A string. |
model |
Input |
Original model for quantization, with weights loaded. A torch.nn.Module. |
input_data |
Input |
Input data of the model. A torch.tensor is replaced with an equivalent tuple(torch.tensor). A tuple. |
input_names |
Input |
Input names of the model, which are displayed in modified_onnx_file. Default: None A list of strings. |
output_names |
Input |
Output names of the model, which are displayed in modified_onnx_file. Default: None A list of strings. |
dynamic_axes |
Input |
Dynamic axes of the model inputs and outputs. For example, if an input has format NCHW, where N, H and W are dynamic, and an output has format NL, where N is dynamic, then: dynamic_axes={"inputs": [0,2,3], "outputs": [0]}. Default: None A dict<string, dict<python:int, string>>, or dict<string, list(int)>. |
Returns
Returns the resultant torch.nn.Module model for calibration.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import amct_pytorch as amct # Build a graph of the network to be quantized. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt') modfied_model = os.path.join(TMP, 'modfied_model.onnx') # Insert the quantization API. calibration_model = amct.quantize_model(config_json_file, modfied_model, scale_offset_record_file, model, input_data, input_names=['input'], output_names=['output'], dynamic_axes={'input':{0: 'batch_size'}, 'output':{0: 'batch_size'}}) |