本节介绍自动混合精度搜索场景的接口调用流程和调用示例。
接口调用流程如图1所示,蓝色部分为用户实现,灰色部分为用户调用AMCT提供的API实现。
用户准备好TensorFlow的原始模型、自动混合精度配置文件和推理评估器(Evaluator),调用auto_mixed_precision_search,根据压缩率、量化位宽、量化敏感度以及计算复杂度信息,执行自动混合精度搜索,得到混合精度配置文件与可用于量化感知训练的简易配置文件。
其中Evaluator模块需要用户自定义,用来执行模型的推理,获取量化因子,dump数据(每一层的输入数据)等信息。
本示例演示了使用AMCT进行自动混合精度搜索的流程,该过程需要用户实现一个模型推理和校准的评估器。
1 2 3 |
import amct_tensorflow as amct from amct_tensorflow.common.auto_calibration import AutoCalibrationEvaluatorBase amct.set_logging_level(print_level="info", save_level="info") |
上述回调函数的入参要和基类AutoCalibrationEvaluatorBase保持一致。其中:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
class ModelEvaluator(AutoCalibrationEvaluatorBase): # The evaluator for model def __init__(self, *args, **kwargs): # 做成员变量初始化 # 设置预期精度损失,此处请替换为具体的数值 self.diff = expected_acc_loss pass def calibration(self, graph, outputs, batch_num): """ 对一张图做量化校准前向推理 graph:tensorflow.Graph类型, 对图graph做前向推理 outputs: 列表类型,图graph的输出,在推理过程需要获取的输出 batch_num: int类型,前向推理的batch数目,与量化配置的batch_num一致 """ pass def evaluate(self, graph, outputs, iterations): # pylint: disable=R0914 """ 对一张图做量化校准前向推理 graph:tensorflow.Graph类型, 对图graph做前向推理 outputs: 列表类型,图graph的输出,在推理过程需要获取的输出 iterations: int类型,前向推理的batch数目 """ pass |
1 2 3 4 5 6 |
evaluator = ModelEvaluator() 或者 evaluator = amct.GraphEvaluator( data_dir="./data/input_bin/", input_shape="input:32,3,224,224", data_types="float32") |
1 2 3 4 5 6 7 8 9 10 11 12 |
model_file = './model/user_model.pb' outputs = ['user_model_outputs0', 'user_model_outputs1'] cfg_file = './configs/auto_mixed_precision.cfg' save_dir = './results/auto_mixed_precision' amct.auto_mixed_precision_search( model_file=model_file, outputs=outputs, amc_config=cfg_file, save_dir=save_dir, evaluator=evaluator, sensitivity='MseSimilarity') |