Search Process
This section describes the API call sequence and example of auto channel pruning search.
API Call Sequence
Figure 1 shows the API call sequence. The user implements the operations in blue, while those in gray are implemented by using AMCT APIs. For the sparsity example, see Sample List.
Prepare a PyTorch model and an auto channel pruning search configuration file, and call auto_channel_prune_search to perform auto channel pruning search based on the compression ratio, sparsity sensitivity of each channel, and sparsity gain. In this way, you can obtain a simplified configuration file used for channel pruning. You can customize the methods of the sensitivity and search_alg modules or use the default methods in the API.
- The sensitivity module calculates the sparsity sensitivity of each channel.
- The search_alg module implements a process of searching for sparse channels based on the channel sensitivity and channel sparsity gain.
Calling Example
This example demonstrates how to use AMCT to automatically search for sparse channels. During this process, you need to input a PyTorch model and calibration data. You can customize the sensitivity and search_alg modules.
- Take the following steps to get started. Update the sample code based on your situation.
- Tweak the arguments passed to AMCT API calls as required.
- Import the AMCT package and set the log level (see Post-installation Actions for details).
1import amct_pytorch as amct
- (Optional) Implement the sensitivity module to obtain the sensitivity of each channel at each layer, providing data for subsequent search algorithms. For details, see the default sensitivity module (AMCT installation directory).The process of the TaylorLossSensitivity method under amct_pytorch/auto_channel_prune_search.py is as follows:
1 2 3 4 5 6 7 8 9 10 11 12
from amct.common.auto_channel_prune.sensitivity_base import SensitivityBase class Sensitivity(SensitivityBase) def __init__(self) pass def setup_initialization(self, graph_tuple, input_data, test_iteration): # Necessary initialization # graph_tuple (graph, graph_info) pass def get_sensitivity(self, search_records): # Method of obtaining the sensitivity, which is written to the record after calculation. pass
- (Optional) Implement the search_alg module. This requires you to implement the channel_prune_search callback to search for sparse channels based on the channel sensitivity and channel gain. For details, see the GreedySearch method in the /amct_pytorch/common/auto_prune/search_channel_base.py file in the installation directory of the search_alg module.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
from amct.common.auto_prune.search_channel_base import SearchChannelBase class Search(SearchChannelBase) def __init__(self) # Initialization pass def channel_prune_search(self, graph_info, search_records, prune_config): """ # Input graph_info: dict, including the number of channels and bit complexity of each operator in the graph, which can be used to calculate the compression ratio. search_records: protobuf object, which contains the sparse layers to be searched for. prune_config: triplet-target compression ratio (float), Ascend affinity optimization switch (bool), and maximum single-layer sparsity ratio (float) # Output Dict. The key is the name of the sparsified layer to be searched for, and the value is a list consisting of 01, indicating whether the channel should be sparsified. """ pass
- (Optional) Obtain the PyTorch model, perform inference in the environment, and check whether the environment and inference script are normal. (Update the sample code based on your situation.) When performing this step, you can use some test sets to reduce the running time.
1 2 3
ori_model.load() # Test the model. user_test_model(ori_model, test_data, test_iterations)
- Construct calibration data. (Update the information based on your situation.)
1 2 3
input_data = [] for _ in range(test_iteration): input_data.append(user_load_feed_dict())
- Perform auto channel pruning search by calling AMCT.
1 2 3 4 5 6 7 8
output_prune_cfg = './prune.cfg' amct.auto_channel_prune_search( model=model, config=cfg_file, input_data=input_data, output_cfg=output_prune_cfg, sensitivity='TaylorLossSensitivity', search_alg='GreedySearch')
