tf.compat.v1.train.Saver.restore

Function

Loads a model.

Prototype

1
def restore(self, sess, save_path)

Parameters

Parameter

Type

Mandatory/Optional

Description

sess

Session

Mandatory

Session of the TensorFlow model to be imported.

save_path

str

Mandatory

  • Path for storing the model checkpoint file.
  • The local file system and HDFS file system are supported. The length range is [1, 1024].
  • When loading a model for multi-device training, you can set save_path to the same loading path (which stores the multi-device training result). Each device automatically loads the parameters of the local device.
NOTE:

The maximum size of a single file to be loaded is 500 GB. Concurrent reads may cause system OOM.

Return Value

  • Success: None
  • Failure: An exception is thrown.

Example

The following just briefly describes the usage. For details, see the little demo in the Rec SDK code repository.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
#1. Import required libraries.
import tensorflow as tf
from mx_rec.util.initialize import init, get_rank_id
# 2. Build a computational graph.
# ...
# 3. Create a saver.
saver = tf.compat.v1.train.Saver() 
# 4. Obtain the rank ID.
rank_id = get_rank_id()
# 5. Set the number of training steps when the model to be loaded is saved.
latest_step = 200
with tf.compat.v1.Session() as sess:
    saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}")

See Also

For details about the API call sequence and example, see Porting and Training.