create_quant_retrain_model

Applicability

Product

Supported

Atlas A3 training series products/Atlas A3 inference series products

  • INT8 quantization: √

Atlas A2 training products/Atlas A2 inference products

  • INT8 quantization: √

Atlas 200I/500 A2 inference product

  • INT8 quantization: √

Atlas inference series products

  • INT8 quantization: √

Atlas training products

  • INT8 quantization: √

Description

Quantizes a graph based on the given configuration file, inserts quantization-related layers (quantization-aware layers of activations and weights and layers for searching for N), generates a quantization factor record file (record_file), and returns the resultant model of the torch.nn.Module type for QAT.

Prototype

1
quant_retrain_model = create_quant_retrain_model (config_file, model, record_file, input_data)

Parameters

Parameter

Input/Output

Description

config_file

Input

User-generated QAT configuration file, which is used to specify the configuration of the quantization layer in the model network.

A string.

model

Input

Original model for QAT, with weights loaded.

A torch.nn.Module.

record_file

Input

Path (including the file name) of the quantization factor record file.

A string.

input_data

Input

Input data of the model. A torch.tensor is replaced with an equivalent tuple(torch.tensor).

A tuple.

Returns

Returns the resultant torch.nn.Module that can be used for QAT.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import amct_pytorch as amct
# Build a graph of the model for QAT.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])
 
scale_offset_record_file = os.path.join(TMP, 'scale_offset_record.txt')
# Insert the quantization API.
quant_retrain_model = amct.create_quant_retrain_model(
               config_json_file,
               model,
               scale_offset_record_file,
               input_data)