Search Process

This section describes the API call sequence and example of auto channel pruning search.

API Call Sequence

Figure 1 shows the API call sequence. The user implements the operations in blue, while those in gray are implemented by using AMCT APIs. For the sparsity example, see Sample List.

Prepare a TensorFlow training model and an auto channel pruning search configuration file, and call auto_channel_prune_search to perform auto channel pruning search based on the compression ratio, sparsity sensitivity of each channel, and sparsity gain. In this way, you can obtain a simplified configuration file used for channel pruning. You can customize the methods of the sensitivity and search_alg modules or use the default methods in the API.

  • The sensitivity module calculates the sparsity sensitivity of each channel.
  • The search_alg module implements a process of searching for sparse channels based on the channel sensitivity and channel sparsity gain.
Figure 1 Calling Process

Examples

This example demonstrates how to use AMCT to automatically search for sparse channels. During this process, you need to input a TensorFlow training graph and calibration data. You can customize the sensitivity and search_alg modules.

  • Take the following steps to get started. Update the sample code based on your situation.
  • Tweak the arguments passed to AMCT API calls as required.
  1. Import the AMCT package and set the log level.
    1
    2
    import amct_tensorflow as amct
    amct.set_logging_level(print_level="info", save_level="info")
    
  2. (Optional) Implement the sensitivity module to obtain the sensitivity of each channel at each layer, providing data for subsequent search algorithms. For details, see the default sensitivity module (AMCT installation directory).
    The process of the TaylorLossSensitivity method under amct_tensorflow/interface/auto_channel_prune_search.py is as follows:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    from amct.common.auto_prune.sensitivity_base import SensitivityBase
    
    class Sensitivity(SensitivityBase)
        def __init__(self)
            pass
        def setup_initialization(self, graph_tuple, input_data, test_iteration, output_nodes=None):
    # Necessary initialization
            # graph_tuple (graph, graph_info)
            pass
        def get_sensitivity(self, search_records):
    # Method of obtaining the sensitivity, which is written to the record after calculation.
            pass
    
  3. (Optional) Implement the search_alg module. This requires you to implement the channel_prune_search callback to search for sparse channels based on the channel sensitivity and channel gain. For details, see the GreedySearch method in the /amct_tensorflow/common/auto_prune/search_channel_base.py file in the installation directory of the search_alg module.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    from amct.common.auto_prune.search_channel_base import SearchChannelBase
    
    class Search(SearchChannelBase)
        def __init__(self)
        # Initialization
            pass
    
        def channel_prune_search(self, graph_info, search_records, prune_config):
            """
        # Input
    graph_info: dict, including the number of channels and bit complexity of each operator in the graph, which can be used to calculate the compression ratio.
    search_records: protobuf object, which contains the sparse layers to be searched for.
    prune_config: triplet-target compression ratio (float), Ascend affinity optimization switch (bool), and maximum single-layer sparsity ratio (float)
        # Output
    Dict. The key is the name of the sparsified layer to be searched for, and the value is a list consisting of 01, indicating whether the channel should be sparsified.
            """
            pass
    
  4. (Optional) Build a graph, read the trained parameters, and run inference on the graph in the TensorFlow environment to validate the inference script and environment setup. (Update the sample code based on your situation.)

    This step is recommended as it guarantees a properly functioning source model for inference with acceptable accuracy. You can use a subset from the test dataset to improve the efficiency.

    1
    user_test_evaluate_model(evaluate_model, test_data)
    
  5. Create a training graph and construct calibration data. (Update the sample code based on your situation.)
    1
    2
    3
    4
    train_graph = user_load_train_graph()
    input_data = []
    for _ in range(test_iteration):
        input_data.append(user_load_feed_dict())
    
  6. Perform auto channel pruning search by calling AMCT.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    output_prune_cfg = './prune.cfg'
    amct.auto_channel_prune_search(
        graph=train_graph, 
        output_nodes=user_model_outputs, 
        config=cfg_file, 
        input_data=input_data,
        output_cfg=output_prune_cfg,
        sensitivity='TaylorLossSensitivity',
        search_alg='GreedySearch')