auto_channel_prune_search

Function Usage

This API is used to calculate the sparsity sensitivity (affecting accuracy) and sparsity gain (affecting performance) of each channel based on the user model. Then the search policy searches for the optimal layer-by-layer channel sparsity ratio based on the input to balance accuracy and performance. Finally, a configuration file is generated.

Constraints

None

Prototype

auto_channel_prune_search(model, config, input_data, output_cfg, sensitivity, search_alg)

Command-line options

Option

Input/Return

Meaning

Restriction

model

Input

PyTorch model to be sparsified.

A torch.nn.Module.

config

Input

Path of the auto channel pruning configuration file.

Simplified configuration file generated based on AutoChannelPruneConfig in the basic_info.proto file. The basic_info.proto file is stored in Installation directory /amct_pytorch/proto/basic_info.proto.

For details about the parameters in the basic_info.proto file and the example of the generated auto channel pruning search configuration file, see Simplified Configuration File for Auto Channel Pruning Search.

A string

input_data

Input

Input data (including labels) provided by the user

Data type: list[data,label]

The data type of the list element is torch.tensor.

output_cfg

Input/Return

Path of the output channel pruning configuration file.

A string

sensitivity

Input

Sensitivity calculation method.

A string

Subclass of SensitivityBase. The string is an existing method of amct. Currently, the value can be TaylorLossSensitivity. The subclass of SensitivityBase can be instantiated and defined by users.

search_alg

Input

Method of searching for channels to be sparsified.

A string

Subclass of SearchChannelBase. The string is an existing method of amct. Currently, the value can be GreedySearch. The subclass of SearchChannelBase can be instantiated and defined by users.

Return Value

None

Outputs

Auto channel pruning configuration file.

This file needs to be transferred to the channel sparsity API for subsequent services.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import amct_pytorch as amct
Constructing the Input Data
input_data = torch.randn(input_shape)         
model.eval()        
output = model.forward(input_data)        
labels = torch.randn(output.size())        
data = [input_data,labels]

amct.auto_channel_prune_search(
     model=model,
     config='./tmp/sample.cfg',
     input_data=data,  
     output_cfg='./tmp/output.cfg', 
     sensitivity='TaylorLossSensitivity', 
     search_alg='GreedySearch')