create_distill_config

Description

Finds all distillable layers and structures based on the graph structure, automatically generates a distillation configuration file, and writes the quantization configuration and distillation structure of the layers into a configuration file.

Prototype

create_distill_config(config_file, model, input_data, config_defination=None)

Command-Line Options

Option

Input/Return

Description

Restriction

config_file

Input

Directory of the distillation configuration file, including the file name.

The existing file (if any) in the path will be overwritten upon this API call.

A string

model

Input

Floating-point source model to be distilled, with weights loaded.

A torch.nn.module.

input_data

Input

Input data of the model. A torch.tensor, equivalent to a tuple(torch.tensor).

A tuple.

config_defination

Input

Simplified configuration file.

The simplified configuration file distill.cfg is generated from the distill_config_pytorch.proto file.

The distill_config_pytorch.proto file is stored in /amct_pytorch/proto/distill_config_pytorch.proto under the AMCT installation directory.

For details about the parameters in the distill_config_pytorch.proto file and the generated simplified quantization configuration file distill.cfg, see Simplified Distillation Configuration File.

Default: None

A string.

Returns

None

Outputs

A distillation configuration file in JSON format. (When distillation is performed again, the configuration file output by this API will be overwritten.) The following is an example configuration file of INT8 quantization:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
{
    "version":1,
    "batch_num":1,
    "group_size":1,
    "data_dump":false,
    "distill_group":[
        [
            "conv1",
            "bn",
            "relu"
        ],
        [
            "conv2",
            "bn2",
            "relu2"
        ]
    ],
    "conv1":{
        "quant_enable":true,
        "distill_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "distill_weight_config":{
            "algo":"arq_distill",
            "channel_wise":true,
            "dst_type":"INT8"
        }
    },
    "conv2":{
        "quant_enable":true,
        "distill_data_config":{
            "algo":"ulq_quantize",
            "dst_type":"INT8"
        },
        "distill_weight_config":{
            "algo":"arq_distill",
            "channel_wise":true,
            "dst_type":"INT8"
        }
    }
...
}

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
import amct_pytorch as amct
# Build a graph of the network to be distilled.
model = build_model()
model.load_state_dict(torch.load(state_dict_path))
input_data = tuple([torch.randn(input_shape)])

# Generate a distillation configuration file.
amct.create_distill_config(config_file="./configs/config.json",
                           model,
                           input_data,
                           config_defination="./configs/distill.cfg")