save_distill_model
Description
Generates a fake-quantized model for accuracy simulation and a deployable model based on the distilled model.
Prototype
save_distill_model(model, save_path, input_data, record_file=None, input_names=None, output_names=None, dynamic_axes=None)
Command-Line Options
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
model |
Input |
Quantized model after distillation. |
A torch.nn.Module |
save_path |
Input |
Path for saving the distilled model. Must include the prefix of the model name, for example, ./quantized_model/*model. |
A string |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
record_file |
Input |
Directory of the quantization factor record file, including the file name. |
Default: None A string If None, the quantization factor record file is stored in the amct_log folder. |
input_names |
Input |
Names of the model input nodes, which are displayed in the saved quantized ONNX model. |
Default: None A list of strings. |
output_names |
Input |
Names of the model output nodes, which are displayed in the saved quantized ONNX model. |
Default: None A list of strings. |
dynamic_axes |
Input |
Dynamic axes of the model inputs and outputs. For example, for inputs (NCHW) with uncertain N, H, and W and outputs (NL) with uncertain N, the format is as follows: {"inputs": [0,2,3], "outputs": [0]}, where 0, 2, and 3 indicate the indexes of N, H, and W respectively. |
Default: None A dict<string, dict<python:int, string>>, or dict<string, list(int)> |
Return Value
None
Outputs
- A fake-quantized ONNX model for accuracy simulation on ONNX Runtime with the file name containing the fake_quant keyword.
- A deployable ONNX model with the file name containing the deploy keyword. The model can be deployed on Ascend AI Processor after being converted by the ATC tool.
When distillation is performed again, the preceding files output by this API will be overwritten.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import amct_pytorch as amct # Build a graph of the network to be distilled. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) # Insert the distillation API and save the distilled model as an ONNX file. amct.save_distill_model( model, "./model/distilled" input_data, record_file="./results/records.txt" input_names=['input'], output_names=['output'], dynamic_axes={'input':{0: 'batch_size'}, 'output':{0: 'batch_size'}}) |