distill
Function Usage
Distills the input graph structure based on the given distillation configuration file and returns the modified torch.nn.module model.
Prototype
distill_model = distill(model, compress_model, config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=None, optimizer=None)
Command-Line Options
Parameter |
Input/Return |
Description |
Restriction |
|---|---|---|---|
model |
Input |
Floating-point source model to be distilled, with weights loaded. |
A torch.nn.module. |
compress_model |
Input |
Modified torch.nn.module model that can be used for distillation. |
A torch.nn.Module The model passed to this API must be a compressed model. |
config_file |
Input |
User-generated distillation configuration file, which is used to specify the configuration and distillation structure of the quantization layer in the model network. |
A string The config.json file passed to this call must be the same as that passed to the create_distill_config call. |
train_loader |
Input |
Training dataset. |
Type: torch.utils.data.DataLoader It must match the model input size. |
epochs |
Input |
Maximum number of epochs. |
Default: 1 Type: int |
lr |
Input |
Learning rate. |
Default: 1e-3 A float. |
sample_instance |
Input |
Instantiated object of the method provided by the user for obtaining model input data. |
Default: None Data type: DistillSampleBase It must be inherited from the DistillSampleBase class and implement the get_model_input_data method. For details, see the /amct_pytorch/distill/distll_sample.py file under the AMCT installation directory. |
loss |
Input |
Instantiated object for computing the loss. |
Default: None Type: torch.nn.modules.loss._Loss |
optimizer |
Input |
Instantiated object of the optimizer. |
Default: None Type: torch.optim.Optimizer |
distill_model |
Return |
Modified torch.nn.module model for distillation. |
Default: None A torch.nn.Module |
Returns
A distilled model.
Outputs
None
Examples
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import amct_pytorch as amct # Build a graph of the network to be distilled. model = build_model() model.load_state_dict(torch.load(state_dict_path)) compress_model = compress(model) input_data = tuple([torch.randn(input_shape)]) train_loader = torch.utils.data.DataLoader(input_data) loss = torch.nn.MSELoss() optimizer = torch.optim.AdamW(compress_model.parameters(), lr=0.1) # Perform distillation. distill_model = amct.distill( model, compress_model config_json_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=loss, optimizer=optimizer) |