auto_channel_prune_search
Description
This API is used to calculate the sparsity sensitivity (affecting accuracy) and sparsity gain (affecting performance) of each channel based on the user model. Then the search policy searches for the optimal layer-by-layer channel sparsity ratio based on the input to balance accuracy and performance. Finally, a configuration file is generated.
Constraints
None
Prototype
auto_channel_prune_search(graph, output_nodes, config, input_data, output_cfg, sensitivity, search_alg)
Parameter Description
Option |
Input/Return |
Meaning |
Restriction |
|---|---|---|---|
graph |
Input |
tf.Graph training graph that contains automatic differentiation. |
A tf.Graph. |
output_nodes |
Input |
Name of the model output node. |
A list of strings. |
config |
Input |
Path of the auto channel pruning search configuration file. Simplified configuration file generated based on AutoChannelPruneConfig in the basic_info.proto file. The basic_info.proto file is stored in Installation directory /amct_tensorflow/proto/basic_info.proto. For details about the parameters in the basic_info.proto file and the example of the generated auto channel pruning search configuration file, see Simplified Configuration File for Auto Channel Pruning Search. |
A string |
input_data |
Input |
Calibration data provided by the user. |
Data type: list. The content is the corresponding feed_dict data. |
output_cfg |
Input/Return |
Path of the output channel pruning configuration file. |
A string |
sensitivity |
Input |
Sensitivity calculation method. |
A string Subclass of SensitivityBase. string indicates an existing method. Currently, the value can be TaylorLossSensitivity. The subclass is the instantiation of the subclass of SensitivityBase and can be inherited and defined by users. The default value is TaylorLossSensitivity. |
search_alg |
Input |
Method of searching for channels to be sparsified. |
A string Subclass of SearchChannelBase. string indicates an existing method. Currently, the value can be GreedySearch. The subclass is the instantiation of the subclass of SearchChannelBase and can be inherited and defined by users. The default value is GreedySearch. |
Returns
None
Outputs
Auto channel pruning configuration file.
This file needs to be transferred to the channel sparsity API for subsequent services.
Examples
1 2 3 4 5 6 7 8 9 10 11 12 13 | import amct_tensorflow as amct # Construct feed_dict data. input_data = np.random.uniform(-10, 10, (2, 14, 14, 64)).astype(np.str_) feed_dict = [{'input:0': input_data}] amct.auto_channel_prune_search( graph=graph, output_nodes=[operation_name_1, operation_name_2], config='./tmp/sample.cfg', input_data=feed_dict, output_cfg='./tmp/output.cfg', sensitivity='TaylorLossSensitivity', search_alg='GreedySearch') |