save_quant_retrain_model
Applicability
Product |
Supported |
|---|---|
|
|
|
|
|
|
|
|
|
Description
Inserts operators such as AscendQuant and AscendDequant into the retrained model and generates a fake-quantized model for accuracy simulation and a deployable model.
Prototype
1 | save_quant_retrain_model (config_file, model, record_file, save_path, input_data, input_names=None, output_names=None, dynamic_axes=None) |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
config_file |
Input |
User-generated QAT configuration file, which is used to specify the configuration of the quantization layer in the model network. A string. |
model |
Input |
Model generated after QAT. A torch.nn.Module. |
record_file |
Input |
Path (including the file name) of the quantization factor record file. A string. |
save_path |
Input |
Model save path. Must include the prefix of the model name, for example, ./quantized_model/*model. A string. |
input_data |
Input |
Input data of the model. A torch.tensor is replaced with an equivalent tuple(torch.tensor). A tuple. |
input_names |
Input |
Names of the model input, which are displayed in the saved quantized ONNX model. Default: None A list of strings. |
output_names |
Input |
Names of the model output, which are displayed in the saved quantized ONNX model. Default: None A list of strings. |
dynamic_axes |
Input |
Dynamic axes of the model inputs and outputs. For example, if the inputs have format NCHW, where N, H and W are uncertain, and the outputs have format NL, where N is uncertain, then pass: {"inputs": [0,2,3], "outputs": [0]}, where 0, 2, and 3 indicate the indexes of N, H, and W, respectively. Default: None A dict<string, dict<python:int, string>>, or dict<string, list(int)>. Restrictions: The value of the int type must be a non-negative number. |
Returns
None
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import amct_pytorch as amct # Build a graph of the network to be quantized. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) # Train the retrained model to calculate quantization factors. train_model(quant_retrain_model, input_batch) # Run inference on the retrained model to export the quantization factors. infer_model(quant_retrain_model, input_batch) # Insert the quantization API and save the quantization-aware trained model as an ONNX file. amct.save_quant_retrain_model( config_json_file, model, record_file, save_path="./results/model" input_data, input_names=['input'], output_names=['output'], dynamic_axes={'input':{0: 'batch_size'}, 'output':{0: 'batch_size'}}) |
Flush files:
- A fake-quantized ONNX model file for accuracy simulation on ONNX Runtime with the file name containing the fake_quant keyword.
- A deployable ONNX model file with the file name containing the deploy keyword. The model can be deployed on the Ascend AI Processor after being converted by ATC.
- (Optional) *.external files, including *deploy.external and *fakequant.external:
This type of file is generated only when the size of the saved fake-quantized model and deployable model file is greater than or equal to 2 GB. The *.external file is generated in the same directory as the compressed *.onnx model file and is used to save the data in the tensor. Each tensor data is saved in a separate .external file. The file name is the same as the tensor name, for example, conv1.weight_deploy.external and conv1.weight_fakequant.external.
When ATC is used to load the compressed *.onnx deployable model file for model conversion, the tensor data in the *.external file in the same directory is automatically read.
When QAT is performed again, the preceding files output by the API will be overwritten.