Model Freezing
Model Conversion and Storage with sess.run()
For TensorFlow training with sess.run, saver = tf.train.Saver() and saver.save() are used to save models. The following files are generated after each saver.save() call:
- checkpoint: lists the latest checkpoint file and other checkpoint files.
- model.ckpt.data-00000-of-00001: saves the current parameters, that is, the weight data.
- model.ckpt.index: saves the current parameter names.
- model.ckpt.meta: saves the current graph structure.
In this mode, the model weight data and model graph are saved separately. In the inference scenario, the TensorFlow freeze_graph function is used to combine the weight data and model graph into a .pb file, as shown in the dotted box in the following workflow diagram.

Briefly, the workflow of generating a .pb file by using TensorFlow freeze_graph is as follows:
- Specify the model and checkpoint file path.
- Define the input node. For example the input node for training is IteratorV2, but the input node required for inference is a placeholder.
- Define the output node. The output node for training is the loss value, while the output node required for inference is a node previous to the loss value, such as ArgMax or BiasAdd.
- The same operator could be processed differently in the training graph and inference graph (for example, BatchNorm and dropout operators). Therefore, you must use the model to generate an inference graph.
- BatchNorm: During training, the mean and variance values are calculated based on the samples. However, during inference, the mean and variance values are calculated based on the moving average of the samples. Therefore, the mean calculation method of BatchNorm varies with the training or inference scenario.
- Dropout: During inference, you must mask dropout by setting rate to 1.
1 2 3 4
if is_training: x = npu_ops.dropout(x, 0.65) else: x = npu_ops.dropout(x, 1.0)
Find the entry point function of the inference test logic in the training script and set is_training to False to generate an inference graph.
1 2
# Call the network to generate an inference graph. alexnet.inference is the entry point function of the inference test logic in the training script. logits = alexnet.inference(inputs, version="he_uniform", num_classes=1000, is_training=False)
- Call tf.train.writegraph to write the preceding inference graph to a .pb file, which will be fed into the freeze_graph function call.
- Call freeze_graph to merge the .pb graph file generated by tf.train.writegraph and the checkpoint file to generate a .pb graph file for inference.
A code sample is provided as follows.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import tensorflow as tf from tensorflow.python.tools import freeze_graph from npu_bridge.npu_init import * # Import the model file. import alexnet # Specify the checkpoint path. ckpt_path = "/opt/npu/model_ckpt/alexnet/model_8p/model.ckpt-0" def main(): tf.reset_default_graph() # Define the input node of the network. inputs = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name="input") # Call the network to generate an inference graph. logits = alexnet.inference(inputs, version="he_uniform", num_classes=1000, is_training=False) # Define the output node of the network. predict_class = tf.argmax(logits, axis=1, output_type=tf.int32, name="output") with tf.Session() as sess: # Save the graph to the model.pb file in the ./pb_model directory. # The model.pb file will be provided as input_graph to the following freeze_graph call. tf.train.write_graph(sess.graph_def, './pb_model', 'model.pb') # Generate a model file. freeze_graph.freeze_graph( input_graph='./pb_model/model.pb', # Pass the model file generated using write_graph. input_saver='', input_binary=False, input_checkpoint=ckpt_path, # Pass the checkpoint file generated in training. output_node_names='output', # Consistent with the output node of the inference network. restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', output_graph='./pb_model/alexnet.pb', # Set to the name of the inference network to be generated. clear_devices=False, initializer_nodes='') print("done") if __name__ == '__main__': main() |
- input_graph: model file generated by write_graph.
- input_binary: used in conjunction with input_graph. If it is set to True, input_graph is binary. If it is set to False, input_graph is a file. Defaults to False.
- input_checkpoint: checkpoint path.
- output_node_names: name of the output node. Use commas (,) to separate multiple names.
- output_graph: path of the converted .pb file.
Upon execution completion, the alexnet.pb file is generated in the ./pb_model/ directory. This file is the converted .pb graph file ready for inference.
Model Conversion and Storage with Estimator
Estimator can save models in ckpt and saved_model formats. The ckpt method is similar to the sess.run method. You are advised to save the model in saved_model format to save the memory and avoid some possible errors. Generally, the saved_model model is saved through estimator.export_savedmodel, which consists of the following parts:
1 2 3 4 | |--- save_model.pb # Network structure to be saved. |--- variables # Parameter weight, including variables (tf.variable objects) of all models. |---|--- variables.data-00000-of-00001 |---|--- variables.index |
To convert saved_model into a .pb model for inference, perform the following steps:
- Define the input node.The input received by Estimator during training is in Iterator format, which facilitates iteration between epochs. Before saving a model for inference, use the placeholder to define a specific input.
1 2 3 4 5 6
def serving_input_fn(): input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids') input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'input_ids': input_ids, }) return input_fn
- Save the saved_model model.
Estimator can directly call the export_savemodel function to save the model and automatically switch the mode and freeze the graph.
1 2 3
if FLAGS.do_export: estimator.evaluate() estimator.export_savedmodel(FLAGS.output_dir, serving_input_fn)
- Freeze the .pb model.
Use the freeze_graph function of TensorFlow to freeze a graph into a .pb model. Note that if your model has an NPU custom operator, import the NPU operator module to the source code of freeze_graph.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
import tensorflow as tf from tensorflow.python.tools import freeze_graph from npu_bridge.npu_init import * freeze_graph.freeze_graph( input_saved_model_dir='savedModel', output_node_names='output', # Consistent with the output node of the inference network. output_graph='test.pb', # Name of the inference network to be generated initializer_nodes='', input_graph= None, input_saver= False, input_binary=False, input_checkpoint=None, restore_op_name=None, filename_tensor_name=None, clear_devices=False, input_meta_graph=False)
For details about the complete code example, see ModelZoo-TensorFlow.
