tf.compat.v1.train.Saver.restore

功能描述

TensorFlow用于模型加载的接口。

函数原型

1
def restore(self, sess, save_path)

参数说明

参数名

类型

可选/必选

说明

sess

Session

必选

需要导入模型TensorFlow的Session。

save_path

str

必选

  • 模型checkpoint文件的保存路径。
  • 支持本地文件系统和HDFS文件系统,长度范围为[1,1024]。
  • 在使用多卡训练加载模型时,多卡save_path可以输入同一加载路径(该路径下保存了多卡训练的结果),各卡会自动加载属于本卡的参数。
说明:

当前加载文件单个大小上限为500G,并发读取可能会引发系统OOM。

返回值说明

使用示例

具体使用方法可参考Rec SDK中的little demo,以下仅提供一个使用的流程示例。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# 1、导入需要的库
import tensorflow as tf
from mx_rec.util.initialize import init, get_rank_id
# 2、构建计算图
# ...
# 3、创建saver
saver = tf.compat.v1.train.Saver() 
# 4、获取rank_id
rank_id = get_rank_id()
# 5、设置需要加载的模型保存时的训练步数,比如:
latest_step = 200
with tf.compat.v1.Session() as sess:
    saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}")

参考资源

接口调用流程及示例,参见迁移与训练