下载
中文
注册
我要评分
文档获取效率
文档正确性
内容完整性
文档易理解
在线提单
论坛求助
昇腾小AI

权重更新(在线推理)

背景

推理过程的同时,训练服务器不停训练得到新的权重,而推理部分希望直接更新最新的权重,不希望再走一遍保存pb、编译成离线模型、然后执行离线模型的流程,这种场景可以采用在线推理模式直接更新权重。

本节介绍TensorFlow在线推理场景下,如何更新权重。

整体流程

图1 权重更新流程示意图

图1所示,支持循环地更新权重与执行推理。如果多次执行过程batch_size等不发生变化,理论上虚线部分只执行一次。主要流程:

  1. 获取在线推理模型和权重信息,例如从ckpt文件中加载,实际更新用的权重则来自于外部的key-value;
  2. 构造权重更新图,包含需要更新的变量和赋值算子,如Assigin;
  3. 执行权重更新图, 将第1步获取到的key-value更新到对应权重;
  4. 执行在线推理流程。

样例参考

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import tensorflow as tf
import time
import numpy as np
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig
from npu_bridge.estimator import npu_ops

class TestPolicy(object):
    def __init__(self, ckpt_path):
        # NPU模型编译和优化配置
        # --------------------------------------------------------------------------------
        config = tf.compat.v1.ConfigProto()
        custom_op = config.graph_options.rewrite_options.custom_optimizers.add()
        custom_op.name = "NpuOptimizer"
        # 配置1: 选择在Ascend NPU上执行推理
        custom_op.parameter_map["use_off_line"].b = True

        # 配置2:在线推理场景下建议保持默认值force_fp16,使用float16精度推理,以获得较优的性能
        custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("force_fp16")

        # 配置3:关闭remapping
        config.graph_options.rewrite_options.remapping = RewriterConfig.OFF

        # 配置4:设置graph_run_mode为推理
        custom_op.parameter_map["graph_run_mode"].i = 0

        # 配置5:指定AICORE引擎的并行度为4
        custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("AIcoreEngine:4")
        # --------------------------------------------------------------------------------

        # 初始化操作
        self.sess = tf.compat.v1.Session(config=config)
        self.ckpt_path = ckpt_path
        # 加载推理模型
        self.load_graph()
        self.graph = self.sess.graph

    def load_graph(self):
        '''
        从ckpt加载推理模型,获取权重的信息,并构造权重更新图
        '''
        saver = tf.compat.v1.train.import_meta_graph(self.ckpt_path + '.meta')
        saver.restore(self.sess, self.ckpt_path)
        self.vars = tf.compat.v1.trainable_variables()
        self.var_placeholder_dict = {}
        self.var_id_to_name = {}
        self.update_op = []
        for id, var in enumerate(self.vars):
            self.var_placeholder_dict[var.name] = tf.compat.v1.placeholder(var.dtype, shape=var.get_shape(), name=("PlaceHolder_" + str(id)))
            self.var_id_to_name[id] = var.name
            self.update_op.append(tf.compat.v1.assign(var, self.var_placeholder_dict[var.name]))
        self.update_op = tf.group(*self.update_op)

        # 实际场景的权重key-value来自于训练服务器,这里保存一份ckpt中的权重仅用于示例
        self.key_value = self.get_dummy_weights_for_test()

    def unload(self):
        '''
        关闭session,释放资源
        '''
        print("====== start to unload ======")
        self.sess.close()

    def get_dummy_weights_for_test(self):
        '''
        从ckpt中获取权重的信息,并构造权重更新图
        :return: 权重 key-value
        :NOTES: 实际场景的权重key-value来自于训练服务器,这里返回的ckpt中的权重仅用于示例
        '''
        weights_data = self.sess.run(self.vars)
        weights_key_value = {}
        for id, var in enumerate(weights_data):
            weights_key_value[self.var_id_to_name[id]] = var
        return weights_key_value

    def get_weights_key_value(self):
        '''
        获取权重key-value
        :return: 权重 key-value
        :NOTES: 实际场景的权重key-value来自于训练服务器,这里返回保存的权重仅用于示例
        '''
        return self.key_value

    def update_weights(self):
        '''
        更新权重
        '''
        feed_dict = {}
        weights_key_value = self.get_weights_key_value()
        for key, weight in weights_key_value.items():
            feed_dict[self.var_placeholder_dict[key]] = weight
        self.sess.run(self.update_op, feed_dict=feed_dict)

    def infer(self, input_image):
        '''
        执行推理流程
        :param: input_image 输入的数据,示例中为图像数据
        :return: output 推理结果, 示例中为labels
        '''
        image = self.graph.get_operation_by_name('Placeholder').outputs[0]
        label_output = self.graph.get_operation_by_name('accuracy/ArgMax').outputs[0]
        output = self.sess.run([label_output], feed_dict={image: input_image})
        return output


def prepare_input_data(batch):
    '''
    推理的输入数据
    :param: batch 数据batch_size
    :return: 推理数据
    '''
    image = 255 * np.random.random([batch, 784]).astype('float32')
    return image

if __name__ == "__main__":
    batch_size = 16
    ckpt_path = "./mnist_deep_model/mnist_deep_model"
    policy = TestPolicy(ckpt_path)
    update_count = 10
    for i in range(update_count):
        update_start = time.time()
        policy.update_weights()
        update_consume = time.time() - update_start
        print("Update weight time cost: {} ms".format(update_consume * 1000))
        test_count = 20
        input_data = prepare_input_data(batch_size)
        start_time = time.time()
        for i in range(test_count):
            output = policy.infer(input_data)
            print("result is ", output)
        time_consume = (time.time() - start_time) / (test_count)
        print("Inference average time cost: {} ms \n".format(time_consume * 1000))

    policy.unload()
    print("====== end of test ======")
搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词