create_quant_retrain_model
Applicability
Product |
Supported |
|---|---|
|
|
|
|
|
|
|
|
|
Description
Quantizes a graph based on the given configuration file, inserts quantization-related layers (quantization-aware layers of activations and weights and layers for searching for N), generates a quantization factor record file (record_file), and returns the resultant model of the torch.nn.Module type for QAT.
Prototype
1 | quant_retrain_model = create_quant_retrain_model (config_file, model, record_file, input_data) |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
config_file |
Input |
User-generated QAT configuration file, which is used to specify the configuration of the quantization layer in the model network. A string. |
model |
Input |
Original model for QAT, with weights loaded. A torch.nn.Module. |
record_file |
Input |
Path (including the file name) of the quantization factor record file. A string. |
input_data |
Input |
Input data of the model. A torch.tensor is replaced with an equivalent tuple(torch.tensor). A tuple. |
Returns
Returns the resultant torch.nn.Module that can be used for QAT.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 | import amct_pytorch as amct # Build a graph of the model for QAT. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt') # Insert the quantization API. quant_retrain_model = amct.create_quant_retrain_model( config_json_file, model, scale_offset_record_file, input_data) |