TensorFlow用于模型加载的接口。
def restore(self, sess, save_path)
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
sess |
Session |
必选 |
需要导入模型TensorFlow的Session。 |
save_path |
str |
必选 |
模型checkpoint文件的保存路径。 支持本地文件系统和HDFS文件系统,长度范围为[1,150]。 说明:
当前加载文件单个大小上限为500G,并发读取可能会引发系统OOM。 |
具体使用方法可参考mxRec代码仓中的little demo,以下仅提供一个使用的流程示例。
# 1、导入需要的库 import tensorflow as tf from mx_rec.util.initialize import init, get_rank_id # 2、构建计算图 # ... # 3、创建saver saver = tf.compat.v1.train.Saver() # 4、获取rank_id rank_id = get_rank_id() # 5、设置需要加载的模型保存时的训练步数,比如: latest_step = 200 with tf.compat.v1.Session() as sess: saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}")
接口调用流程及示例,参见模型迁移与训练。