save_distill_model

Description

Generates a fake-quantized model for accuracy simulation and a deployable model based on the distilled model.

Prototype

save_distill_model(model, save_path, input_data, record_file=None, input_names=None, output_names=None, dynamic_axes=None)

Command-Line Options

Option

Input/Return

Meaning

Restriction

model

Input

Quantized model after distillation.

A torch.nn.Module

save_path

Input

Path for saving the distilled model.

Must include the prefix of the model name, for example, ./quantized_model/*model.

A string

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

record_file

Input

Directory of the quantization factor record file, including the file name.

Default: None

A string

If None, the quantization factor record file is stored in the amct_log folder.

input_names

Input

Names of the model input nodes, which are displayed in the saved quantized ONNX model.

Default: None

A list of strings.

output_names

Input

Names of the model output nodes, which are displayed in the saved quantized ONNX model.

Default: None

A list of strings.

dynamic_axes

Input

Dynamic axes of the model inputs and outputs. For example, for inputs (NCHW) with uncertain N, H, and W and outputs (NL) with uncertain N, the format is as follows: {"inputs": [0,2,3], "outputs": [0]}, where 0, 2, and 3 indicate the indexes of N, H, and W respectively.

Default: None

A dict<string, dict<python:int, string>>, or dict<string, list(int)>

Return Value

None

Outputs

  • A fake-quantized ONNX model for accuracy simulation on ONNX Runtime with the file name containing the fake_quant keyword.
  • A deployable ONNX model with the file name containing the deploy keyword. The model can be deployed on Ascend AI Processor after being converted by the ATC tool.

When distillation is performed again, the preceding files output by this API will be overwritten.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import amct_pytorch as amct
# Build a graph of the network to be distilled.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])
 
# Insert the distillation API and save the distilled model as an ONNX file.
amct.save_distill_model(
               model, 
               "./model/distilled"
               input_data,
               record_file="./results/records.txt"
               input_names=['input'],
               output_names=['output'],
               dynamic_axes={'input':{0: 'batch_size'},
                             'output':{0: 'batch_size'}})