restore_quant_retrain_model
Function Usage
Description: Quantizes a graph based on the given configuration file, inserts quantization-related layers (quantization-aware layers of activations and weights and layers for searching for N), generates a quantization factor record file (record_file), loads the checkpoint weight parameters saved during training, and returns the result model of the torch.nn.Module type.
Prototype
quant_retrain_model = restore_quant_retrain_model (config_file, model, record_file, input_data, pth_file, state_dict_name=None)
Parameters
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
config_file |
Input |
User-defined QAT configuration file, which specifies the configuration of each layer to be quantized. |
A string The config.json file passed to this call must be the same as that passed to the create_quant_retrain_model call. |
model |
Input |
Source model, with weights unloaded. |
A torch.nn.Module |
record_file |
Input |
Directory of the quantization factor record file, including the file name. |
A string |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
pth_file |
Input |
Weight file saved during training. |
A string |
state_dict_name |
Input |
Key value corresponding to the weight in the weight file. |
Default: None A string |
quant_retrain_model |
Return |
Result model of the torch.nn.Module type for QAT. |
Default: None A torch.nn.Module |
Returns
Returns the result model of the torch.nn.module type for quantization aware training.
Outputs
None
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 | import amct_pytorch as amct # Build a graph of the network to be quantized. model = build_model() input_data = tuple([torch.randn(input_shape)]) scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt') # Insert the quantization API. quant_retrain_model = amct.restore_quant_retrain_model( config_json_file, model, scale_offset_record_file, input_data, pth_file) |