TensorFlow网络模型由于AMCT导致输出节点改变,如何通过修改量化脚本进行后续的量化动作

问题描述

使用AMCT调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于插入了searchN层导致尾层输出节点发生改变。该场景下,需要用户根据提示信息,将推理时的输出节点替换为新的输出节点的名称;AMCT量化过程中的日志信息给出了网络输出节点变化前后的节点名称,需要用户根据提示信息, 自行修改量化脚本。

进行图修改时,导致尾层输出节点发生改变的场景有如下几种情况:

脚本修改

如果调用quantize_model接口对用户的原始TensorFlow模型进行图修改时,由于在网络最后插入了searchN层导致尾层输出节点发生改变,需要用户根据日志信息,修改量化脚本,将网络推理过程中的输出节点替换为新的节点名称,修改方法如下:

修改前的量化脚本(如下脚本只是样例,请以实际量化的模型为准):

import tensorflow as tf
import amct_tensorflow as amct

def load_pb(model_name):
    with tf.gfile.GFile(model_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

def main():
    # 网络pb文件的名字
    model_name = './pb_model/case_1_1.pb'
    # 网络量化推理输出节点的名字
    infer_output_name = 'Add:0'
    # 网络保存量化模型的输出节点的名字
    save_output_name = 'Add:0'


    # 载入网络的pb文件
    load_pb(model_name)
    # 获取网络的图结构
    graph = tf.get_default_graph()


    # 生成量化配置文件
    amct.create_quant_config(
        config_file='./configs/config.json',
        graph=graph)
    # 插入量化相关算子
    amct.quantize_model(
        graph=graph,
        config_file='./configs/config.json',
        record_file='./configs/record_scale_offset.txt')


    # 执行网络的推理过程
    with tf.Session() as sess:
        output_tensor = graph.get_tensor_by_name(infer_output_name)
        sess.run(tf.global_variables_initializer())
        sess.run(output_tensor)


    # 保存量化后的pb模型文件
    amct.save_model(
        pb_model=model_name,
        outputs=[save_output_name[:-2]],
        record_file='./configs/record_scale_offset.txt',
        save_path='./pb_model/case_1_1')


if __name__ == '__main__':
    main()

修改后的量化脚本:

import tensorflow as tf
import amct_tensorflow as amct

def load_pb(model_name):
    with tf.gfile.GFile(model_name, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

def main():
    # 网络pb文件的名字
    model_name = './pb_model/case_1_1.pb'
    # 网络量化推理输出节点的名字,需要替换为日志打印的网络输出节点变化后的节点名称
    infer_output_name = 'search_n_quant/search_n_quant_SEARCHN/Identity:0' 
    # 网络保存量化模型的输出节点的名字
    save_output_name = 'Add:0'


    # 载入网络的pb文件
    load_pb(model_name)
    # 获取网络的图结构
    graph = tf.get_default_graph()


    # 生成量化配置文件
    amct.create_quant_config(
        config_file='./configs/config.json',
        graph=graph)
    # 插入量化相关算子
    amct.quantize_model(
        graph=graph,
        config_file='./configs/config.json',
        record_file='./configs/record_scale_offset.txt')


    # 执行网络的推理过程
    with tf.Session() as sess:
        output_tensor = graph.get_tensor_by_name(infer_output_name)
        sess.run(tf.global_variables_initializer())
        sess.run(output_tensor)


    # 保存量化后的pb模型文件
    amct.save_model(
        pb_model=model_name,
        outputs=[save_output_name[:-2]],
        record_file='./configs/record_scale_offset.txt',
        save_path='./pb_model/case_1_1')


if __name__ == '__main__':
    main()