tf.compat.v1.train.Saver.restore
Function
Loads a model.
Prototype
1 | def restore(self, sess, save_path) |
Parameters
Parameter |
Type |
Mandatory/Optional |
Description |
|---|---|---|---|
sess |
Session |
Mandatory |
Session of the TensorFlow model to be imported. |
save_path |
str |
Mandatory |
NOTE:
The maximum size of a single file to be loaded is 500 GB. Concurrent reads may cause system OOM. |
Return Value
- Success: None
- Failure: An exception is thrown.
Example
The following just briefly describes the usage. For details, see the little demo in the Rec SDK code repository.
1 2 3 4 5 6 7 8 9 10 11 12 13 | #1. Import required libraries. import tensorflow as tf from mx_rec.util.initialize import init, get_rank_id # 2. Build a computational graph. # ... # 3. Create a saver. saver = tf.compat.v1.train.Saver() # 4. Obtain the rank ID. rank_id = get_rank_id() # 5. Set the number of training steps when the model to be loaded is saved. latest_step = 200 with tf.compat.v1.Session() as sess: saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}") |
See Also
For details about the API call sequence and example, see Porting and Training.
Parent topic: TensorFlow APIs