create_quant_retrain_model

Function Usage

Description: Quantizes a graph based on the given configuration file, inserts quantization-related layers (quantization-aware layers of activations and weights and layers for searching for N), generates a quantization factor record file (record_file), and returns the result model of the torch.nn.module type for QAT.

Prototype

quant_retrain_model = create_quant_retrain_model (config_file, model, record_file, input_data)

Command-Line Options

Option

Input/Return

Meaning

Restriction

config_file

Input

User-defined QAT configuration file, which specifies the configuration of each layer to be quantized.

A string

model

Input

Source model, with weights loaded.

A torch.nn.Module

record_file

Input

Directory of the quantization factor record file, including the file name.

A string

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

quant_retrain_model

Returns

Result model of the torch.nn.module type for QAT.

Default: None

A torch.nn.Module

Return Value

Returns the result model of the torch.nn.module type for quantization aware training.

Outputs

None

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import amct_pytorch as amct
# Build a graph of the model for QAT.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])
 
scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt')
# Insert the quantization API.
quant_retrain_model = amct.create_quant_retrain_model(
               config_json_file,
               model,
               scale_offset_record_file,
               input_data)