restore_quant_retrain_model

Function Usage

Description: Quantizes a graph based on the given configuration file, inserts quantization-related layers (quantization-aware layers of activations and weights and layers for searching for N), generates a quantization factor record file (record_file), loads the checkpoint weight parameters saved during training, and returns the result model of the torch.nn.Module type.

Prototype

quant_retrain_model = restore_quant_retrain_model (config_file, model, record_file, input_data, pth_file, state_dict_name=None)

Parameters

Option

Input/Return

Meaning

Restriction

config_file

Input

User-defined QAT configuration file, which specifies the configuration of each layer to be quantized.

A string

The config.json file passed to this call must be the same as that passed to the create_quant_retrain_model call.

model

Input

Source model, with weights unloaded.

A torch.nn.Module

record_file

Input

Directory of the quantization factor record file, including the file name.

A string

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

pth_file

Input

Weight file saved during training.

A string

state_dict_name

Input

Key value corresponding to the weight in the weight file.

Default: None

A string

quant_retrain_model

Return

Result model of the torch.nn.Module type for QAT.

Default: None

A torch.nn.Module

Returns

Returns the result model of the torch.nn.module type for quantization aware training.

Outputs

None

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import amct_pytorch as amct
# Build a graph of the network to be quantized.
model = build_model()
input_data = tuple([torch.randn(input_shape)])

scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt')
# Insert the quantization API.
quant_retrain_model = amct.restore_quant_retrain_model(
               config_json_file,
               model,
               scale_offset_record_file,
               input_data,
               pth_file)