save_quant_retrain_model

Applicability

Product

Supported

Atlas A3 training series products/Atlas A3 inference series products

  • INT8 quantization: √

Atlas A2 training products/Atlas A2 inference products

  • INT8 quantization: √

Atlas 200I/500 A2 inference product

  • INT8 quantization: √

Atlas inference series products

  • INT8 quantization: √

Atlas training products

  • INT8 quantization: √

Description

Inserts operators such as AscendQuant and AscendDequant into the retrained model and generates a fake-quantized model for accuracy simulation and a deployable model.

Prototype

1
save_quant_retrain_model (config_file, model, record_file, save_path,  input_data, input_names=None, output_names=None, dynamic_axes=None)

Parameters

Parameter

Input/Output

Description

config_file

Input

User-generated QAT configuration file, which is used to specify the configuration of the quantization layer in the model network.

A string.

model

Input

Model generated after QAT.

A torch.nn.Module.

record_file

Input

Path (including the file name) of the quantization factor record file.

A string.

save_path

Input

Model save path. Must include the prefix of the model name, for example, ./quantized_model/*model.

A string.

input_data

Input

Input data of the model. A torch.tensor is replaced with an equivalent tuple(torch.tensor).

A tuple.

input_names

Input

Names of the model input, which are displayed in the saved quantized ONNX model.

Default: None

A list of strings.

output_names

Input

Names of the model output, which are displayed in the saved quantized ONNX model.

Default: None

A list of strings.

dynamic_axes

Input

Dynamic axes of the model inputs and outputs. For example, if the inputs have format NCHW, where N, H and W are uncertain, and the outputs have format NL, where N is uncertain, then pass:

{"inputs": [0,2,3], "outputs": [0]}, where 0, 2, and 3 indicate the indexes of N, H, and W, respectively.

Default: None

A dict<string, dict<python:int, string>>, or dict<string, list(int)>.

Restrictions: The value of the int type must be a non-negative number.

Returns

None

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import amct_pytorch as amct
# Build a graph of the network to be quantized.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])
# Train the retrained model to calculate quantization factors.
train_model(quant_retrain_model, input_batch)
# Run inference on the retrained model to export the quantization factors.
infer_model(quant_retrain_model, input_batch)
 
# Insert the quantization API and save the quantization-aware trained model as an ONNX file.
amct.save_quant_retrain_model(
               config_json_file,
               model, 
               record_file,
               save_path="./results/model"
               input_data,
               input_names=['input'],
               output_names=['output'],
               dynamic_axes={'input':{0: 'batch_size'},
                             'output':{0: 'batch_size'}})

Flush files:

  • A fake-quantized ONNX model file for accuracy simulation on ONNX Runtime with the file name containing the fake_quant keyword.
  • A deployable ONNX model file with the file name containing the deploy keyword. The model can be deployed on the Ascend AI Processor after being converted by ATC.
  • (Optional) *.external files, including *deploy.external and *fakequant.external:

    This type of file is generated only when the size of the saved fake-quantized model and deployable model file is greater than or equal to 2 GB. The *.external file is generated in the same directory as the compressed *.onnx model file and is used to save the data in the tensor. Each tensor data is saved in a separate .external file. The file name is the same as the tensor name, for example, conv1.weight_deploy.external and conv1.weight_fakequant.external.

    When ATC is used to load the compressed *.onnx deployable model file for model conversion, the tensor data in the *.external file in the same directory is automatically read.

When QAT is performed again, the preceding files output by the API will be overwritten.