接口调用流程如图1所示,蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现。稀疏示例请参见获取更多样例>resnet_v1_50。
用户准备好TensorFlow的训练模型、自动通道稀疏搜索配置文件和校准数据,调用auto_channel_prune_search,根据压缩率、各通道的稀疏敏感度以及稀疏收益,执行自动通道稀疏搜索,得到可用作通道稀疏的简易配置文件。其中,sensitivity模块与search_alg模块用户可以自定义或者使用接口内部默认方法。
本示例演示了使用AMCT进行自动通道稀疏搜索的流程,该过程需要用户传入tensorflow训练模式的图与校准数据,用户可选择自定义实现sensitivity模块与search_alg模块。
import amct_tensorflow as amct amct.set_logging_level(print_level="info", save_level="info")
from amct.common.auto_prune.sensitivity_base import SensitivityBase class Sensitivity(SensitivityBase) def __init__(self) pass def setup_initialization(self, graph_tuple, input_data, test_iteration, output_nodes=None): # 必要的初始化 # graph_tuple (graph, graph_info) pass def get_sensitivity(self, search_records): # 获取敏感度方法,计算后写到record中 pass
from amct.common.auto_prune.search_channel_base import SearchChannelBase class Search(SearchChannelBase) def __init__(self) # 初始化 pass def channel_prune_search(self, graph_info, search_records, prune_config): """ 输入: graph_info: dict,包含图中各算子的通道数量与比特复杂度信息,可用于计算压缩率 search_records: protobuf对象,包含待搜索的可稀疏层 prune_config: 三元组-目标压缩率(float)、昇腾亲和优化开关(bool)、单层最大稀疏率(float) 输出: dict,key为待搜索的可稀疏层层名,value为01组成的list,对应该通道是否应稀疏 """ pass
推荐执行该步骤,以确保原始模型可以完成推理且精度正常;执行该步骤时,可以使用部分测试集,减少运行时间。
user_test_evaluate_model(evaluate_model, test_data)
train_graph = user_load_train_graph() input_data = [] for _ in range(test_iteration): input_data.append(user_load_feed_dict())
output_prune_cfg = './prune.cfg' amct.auto_channel_prune_search( graph=train_graph, output_nodes=user_model_outputs, config=cfg_file, input_data=input_data, output_cfg=output_prune_cfg, sensitivity='TaylorLossSensitivity', search_alg='GreedySearch')