Weight Update (Online Inference)
Background
During the inference phase, the training server continuously trains the model to update the weights. However, it would be easier to directly update the latest weights in the inference phase, instead of going through the complete procedure including saving the .pb file, compiling it into an offline model, and then executing the offline model all over. In this scenario, the online inference mode can be used to directly update the weights.
This section describes how to update weights during online inference in TensorFlow.
Basic Workflow
As shown in Figure 1, weight update and inference execution can be implemented using loops. If the value of batch_size does not change after multiple times of execution, the operations in dashed lines do not need to be performed again and again. The major steps are:
- Obtain the online inference model and weight information from the .ckpt file, for example. The actual weights to be updated come from the external key-value pairs.
- Construct a weight update graph, including the variables and overloading operators to be updated, for example, Assign.
- Execute the weight update graph to update the key-value pairs obtained in step 1 to the corresponding weights.
- Start online inference.
Samples
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import tensorflow as tf import time import numpy as np from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig from npu_bridge.estimator import npu_ops class TestPolicy(object): def __init__(self, ckpt_path): # Set the NPU configurations for model compilation and optimization. # -------------------------------------------------------------------------------- config = tf.compat.v1.ConfigProto() custom_op = config.graph_options.rewrite_options.custom_optimizers.add() custom_op.name = "NpuOptimizer" # Configuration 1: Perform inference on the Ascend NPU. custom_op.parameter_map["use_off_line"].b = True # Configuration 2: In the online inference scenario, you are advised to retain the default precision selection force_fp16 to achieve better performance. custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("force_fp16") # Configuration 3: Disable remapping. config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # Configuration 4: Set graph_run_mode to inference. custom_op.parameter_map["graph_run_mode"].i = 0 # Configuration 5: Set the AI Core parallelism degree to 4. custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("AIcoreEngine:4") # -------------------------------------------------------------------------------- # Initialization self.sess = tf.compat.v1.Session(config=config) self.ckpt_path = ckpt_path # Load the model. self.load_graph() self.graph = self.sess.graph def load_graph(self): ''' Load the model from the .ckpt file, obtain the weight information, and construct a weight update graph. ''' saver = tf.compat.v1.train.import_meta_graph(self.ckpt_path + '.meta') saver.restore(self.sess, self.ckpt_path) self.vars = tf.compat.v1.trainable_variables() self.var_placeholder_dict = {} self.var_id_to_name = {} self.update_op = [] for id, var in enumerate(self.vars): self.var_placeholder_dict[var.name] = tf.compat.v1.placeholder(var.dtype, shape=var.get_shape(), name=("PlaceHolder_" + str(id))) self.var_id_to_name[id] = var.name self.update_op.append(tf.compat.v1.assign(var, self.var_placeholder_dict[var.name])) self.update_op = tf.group(*self.update_op) # The actual key-value weight pairs come from the training server. self.key_value = self.get_dummy_weights_for_test() def unload(self): ''' Close the session to destroy allocations. ''' print("====== start to unload ======") self.sess.close() def get_dummy_weights_for_test(self): ''' Obtain the weight information from the .ckpt file and construct a weight update graph. :return: key-value weight pairs :NOTES: The actual key-value weight pairs come from the training server. ''' weights_data = self.sess.run(self.vars) weights_key_value = {} for id, var in enumerate(weights_data): weights_key_value[self.var_id_to_name[id]] = var return weights_key_value def get_weights_key_value(self): ''' Obtains the key-value weight pairs. :return: key-value weight pairs :NOTES: The actual key-value weight pairs come from the training server. ''' return self.key_value def update_weights(self): ''' Update the weights. ''' feed_dict = {} weights_key_value = self.get_weights_key_value() for key, weight in weights_key_value.items(): feed_dict[self.var_placeholder_dict[key]] = weight self.sess.run(self.update_op, feed_dict=feed_dict) def infer(self, input_image): ''' Start inference. :param: input_image. The image data is used in the example. :return: output inference result, which is labels in the example ''' image = self.graph.get_operation_by_name('Placeholder').outputs[0] label_output = self.graph.get_operation_by_name('accuracy/ArgMax').outputs[0] output = self.sess.run([label_output], feed_dict={image: input_image}) return output def prepare_input_data(batch): ''' Input data for inference :param: batch_size :return: inference data ''' image = 255 * np.random.random([batch, 784]).astype('float32') return image if __name__ == "__main__": batch_size = 16 ckpt_path = "./mnist_deep_model/mnist_deep_model" policy = TestPolicy(ckpt_path) update_count = 10 for i in range(update_count): update_start = time.time() policy.update_weights() update_consume = time.time() - update_start print("Update weight time cost: {} ms".format(update_consume * 1000)) test_count = 20 input_data = prepare_input_data(batch_size) start_time = time.time() for i in range(test_count): output = policy.infer(input_data) print("result is ", output) time_consume = (time.time() - start_time) / (test_count) print("Inference average time cost: {} ms \n".format(time_consume * 1000)) policy.unload() print("====== end of test ======") |
