auto_channel_prune_search
Function Usage
This API is used to calculate the sparsity sensitivity (affecting accuracy) and sparsity gain (affecting performance) of each channel based on the user model. Then the search policy searches for the optimal layer-by-layer channel sparsity ratio based on the input to balance accuracy and performance. Finally, a configuration file is generated.
Constraints
None
Prototype
auto_channel_prune_search(model, config, input_data, output_cfg, sensitivity, search_alg)
Command-line options
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
model |
Input |
PyTorch model to be sparsified. |
A torch.nn.Module. |
config |
Input |
Path of the auto channel pruning configuration file. Simplified configuration file generated based on AutoChannelPruneConfig in the basic_info.proto file. The basic_info.proto file is stored in Installation directory /amct_pytorch/proto/basic_info.proto. For details about the parameters in the basic_info.proto file and the example of the generated auto channel pruning search configuration file, see Simplified Configuration File for Auto Channel Pruning Search. |
A string |
input_data |
Input |
Input data (including labels) provided by the user |
Data type: list[data,label] The data type of the list element is torch.tensor. |
output_cfg |
Input/Return |
Path of the output channel pruning configuration file. |
A string |
sensitivity |
Input |
Sensitivity calculation method. |
A string Subclass of SensitivityBase. The string is an existing method of amct. Currently, the value can be TaylorLossSensitivity. The subclass of SensitivityBase can be instantiated and defined by users. |
search_alg |
Input |
Method of searching for channels to be sparsified. |
A string Subclass of SearchChannelBase. The string is an existing method of amct. Currently, the value can be GreedySearch. The subclass of SearchChannelBase can be instantiated and defined by users. |
Return Value
None
Outputs
Auto channel pruning configuration file.
This file needs to be transferred to the channel sparsity API for subsequent services.
Examples
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import amct_pytorch as amct Constructing the Input Data input_data = torch.randn(input_shape) model.eval() output = model.forward(input_data) labels = torch.randn(output.size()) data = [input_data,labels] amct.auto_channel_prune_search( model=model, config='./tmp/sample.cfg', input_data=data, output_cfg='./tmp/output.cfg', sensitivity='TaylorLossSensitivity', search_alg='GreedySearch') |