create_distill_model
Function Usage
Quantizes the input graph structure based on the specified distillation configuration file, and inserts quantization-related operators (distillation layer and searching N layer of activations and weights) into the input graph structure, and returns the modified torch.nn.module model that can be used for distillation.
Prototype
compress_model = create_distill_model(config_file, model, input_data)
Command-Line Options
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
config_file |
Input |
User-generated distillation configuration file, which is used to specify the configuration and distillation structure of the quantization layer in the model network. |
A string The config.json file passed to this call must be the same as that passed to the create_distill_config call. |
model |
Input |
Floating-point source model to be distilled, with weights loaded. |
A torch.nn.module. |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
compress_model |
Returns |
Modified torch.nn.module model that can be used for distillation. |
Default: None A torch.nn.Module |
Return Value
A quantized model.
Outputs
None
Example
1 2 3 4 5 6 7 8 9 10 11 | import amct_pytorch as amct # Build a graph of the network to be distilled. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) # Generate a compressed model. compress_model = amct.create_distill_model( config_json_file, model, input_data) |