create_distill_config
Description
Finds all distillable layers and structures based on the graph structure, automatically generates a distillation configuration file, and writes the quantization configuration and distillation structure of the layers into a configuration file.
Prototype
create_distill_config(config_file, model, input_data, config_defination=None)
Command-Line Options
Option |
Input/Return |
Description |
Restriction |
|---|---|---|---|
config_file |
Input |
Directory of the distillation configuration file, including the file name. The existing file (if any) in the path will be overwritten upon this API call. |
A string |
model |
Input |
Floating-point source model to be distilled, with weights loaded. |
A torch.nn.module. |
input_data |
Input |
Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor). |
A tuple. |
config_defination |
Input |
Simplified configuration file. The simplified configuration file distill.cfg is generated from the distill_config_pytorch.proto file. The distill_config_pytorch.proto file is stored in /amct_pytorch/proto/distill_config_pytorch.proto under the AMCT installation directory. For details about the parameters in the distill_config_pytorch.proto file and the generated simplified quantization configuration file distill.cfg, see Simplified Distillation Configuration File. |
Default: None A string. |
Returns
None
Outputs
A distillation configuration file in JSON format. (When distillation is performed again, the configuration file output by this API will be overwritten.) The following is an example configuration file of INT8 quantization:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | { "version":1, "batch_num":1, "group_size":1, "data_dump":false, "distill_group":[ [ "conv1", "bn", "relu" ], [ "conv2", "bn2", "relu2" ] ], "conv1":{ "quant_enable":true, "distill_data_config":{ "algo":"ulq_quantize", "dst_type":"INT8" }, "distill_weight_config":{ "algo":"arq_distill", "channel_wise":true, "dst_type":"INT8" } }, "conv2":{ "quant_enable":true, "distill_data_config":{ "algo":"ulq_quantize", "dst_type":"INT8" }, "distill_weight_config":{ "algo":"arq_distill", "channel_wise":true, "dst_type":"INT8" } } ... } |
Examples
1 2 3 4 5 6 7 8 9 10 11 | import amct_pytorch as amct # Build a graph of the network to be distilled. model = build_model() model.load_state_dict(torch.load(state_dict_path)) input_data = tuple([torch.randn(input_shape)]) # Generate a distillation configuration file. amct.create_distill_config(config_file="./configs/config.json", model, input_data, config_defination="./configs/distill.cfg") |