auto_channel_prune_search

Description

This API is used to calculate the sparsity sensitivity (affecting accuracy) and sparsity gain (affecting performance) of each channel based on the user model. Then the search policy searches for the optimal layer-by-layer channel sparsity ratio based on the input to balance accuracy and performance. Finally, a configuration file is generated.

Constraints

None

Prototype

auto_channel_prune_search(graph, output_nodes, config, input_data, output_cfg, sensitivity, search_alg)

Parameter Description

Option

Input/Return

Meaning

Restriction

graph

Input

tf.Graph training graph that contains automatic differentiation.

A tf.Graph.

output_nodes

Input

Name of the model output node.

A list of strings.

config

Input

Path of the auto channel pruning search configuration file.

Simplified configuration file generated based on AutoChannelPruneConfig in the basic_info.proto file. The basic_info.proto file is stored in Installation directory /amct_tensorflow/proto/basic_info.proto.

For details about the parameters in the basic_info.proto file and the example of the generated auto channel pruning search configuration file, see Simplified Configuration File for Auto Channel Pruning Search.

A string

input_data

Input

Calibration data provided by the user.

Data type: list. The content is the corresponding feed_dict data.

output_cfg

Input/Return

Path of the output channel pruning configuration file.

A string

sensitivity

Input

Sensitivity calculation method.

A string

Subclass of SensitivityBase. string indicates an existing method. Currently, the value can be TaylorLossSensitivity. The subclass is the instantiation of the subclass of SensitivityBase and can be inherited and defined by users. The default value is TaylorLossSensitivity.

search_alg

Input

Method of searching for channels to be sparsified.

A string

Subclass of SearchChannelBase. string indicates an existing method. Currently, the value can be GreedySearch. The subclass is the instantiation of the subclass of SearchChannelBase and can be inherited and defined by users. The default value is GreedySearch.

Returns

None

Outputs

Auto channel pruning configuration file.

This file needs to be transferred to the channel sparsity API for subsequent services.

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
import amct_tensorflow as amct
# Construct feed_dict data.
input_data = np.random.uniform(-10, 10, (2, 14, 14, 64)).astype(np.str_)
feed_dict = [{'input:0': input_data}]

amct.auto_channel_prune_search(
     graph=graph, 
     output_nodes=[operation_name_1, operation_name_2],
     config='./tmp/sample.cfg', 
     input_data=feed_dict,  
     output_cfg='./tmp/output.cfg', 
     sensitivity='TaylorLossSensitivity', 
     search_alg='GreedySearch')