distill

Function Usage

Distills the input graph structure based on the given distillation configuration file and returns the modified torch.nn.module model.

Prototype

distill_model = distill(model, compress_model, config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=None, optimizer=None)

Command-Line Options

Parameter

Input/Return

Description

Restriction

model

Input

Floating-point source model to be distilled, with weights loaded.

A torch.nn.module.

compress_model

Input

Modified torch.nn.module model that can be used for distillation.

A torch.nn.Module

The model passed to this API must be a compressed model.

config_file

Input

User-generated distillation configuration file, which is used to specify the configuration and distillation structure of the quantization layer in the model network.

A string

The config.json file passed to this call must be the same as that passed to the create_distill_config call.

train_loader

Input

Training dataset.

Type:

torch.utils.data.DataLoader

It must match the model input size.

epochs

Input

Maximum number of epochs.

Default: 1

Type: int

lr

Input

Learning rate.

Default: 1e-3

A float.

sample_instance

Input

Instantiated object of the method provided by the user for obtaining model input data.

Default: None

Data type: DistillSampleBase

It must be inherited from the DistillSampleBase class and implement the get_model_input_data method. For details, see the /amct_pytorch/distill/distll_sample.py file under the AMCT installation directory.

loss

Input

Instantiated object for computing the loss.

Default: None

Type:

torch.nn.modules.loss._Loss

optimizer

Input

Instantiated object of the optimizer.

Default: None

Type:

torch.optim.Optimizer

distill_model

Return

Modified torch.nn.module model for distillation.

Default: None

A torch.nn.Module

Returns

A distilled model.

Outputs

None

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import amct_pytorch as amct
# Build a graph of the network to be distilled.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
compress_model = compress(model)
input_data = tuple([torch.randn(input_shape)])
train_loader = torch.utils.data.DataLoader(input_data)
loss = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(compress_model.parameters(), lr=0.1)

# Perform distillation.
distill_model = amct.distill(
                model,
                compress_model
                config_json_file,
                train_loader,
                epochs=1,
                lr=1e-3,
                sample_instance=None, 
                loss=loss,
                optimizer=optimizer)