create_distill_model

Function Usage

Quantizes the input graph structure based on the specified distillation configuration file, and inserts quantization-related operators (distillation layer and searching N layer of activations and weights) into the input graph structure, and returns the modified torch.nn.module model that can be used for distillation.

Prototype

compress_model = create_distill_model(config_file, model, input_data)

Command-Line Options

Option

Input/Return

Meaning

Restriction

config_file

Input

User-generated distillation configuration file, which is used to specify the configuration and distillation structure of the quantization layer in the model network.

A string

The config.json file passed to this call must be the same as that passed to the create_distill_config call.

model

Input

Floating-point source model to be distilled, with weights loaded.

A torch.nn.module.

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

compress_model

Returns

Modified torch.nn.module model that can be used for distillation.

Default: None

A torch.nn.Module

Return Value

A quantized model.

Outputs

None

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import amct_pytorch as amct
# Build a graph of the network to be distilled.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

# Generate a compressed model.
compress_model = amct.create_distill_model(
                 config_json_file,
                 model,
                 input_data)