TensorFlow用于模型加载的接口。
1 | def restore(self, sess, save_path) |
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
sess |
Session |
必选 |
需要导入模型TensorFlow的Session。 |
save_path |
str |
必选 |
说明:
当前加载文件单个大小上限为500G,并发读取可能会引发系统OOM。 |
具体使用方法可参考Rec SDK中的little demo,以下仅提供一个使用的流程示例。
1 2 3 4 5 6 7 8 9 10 11 12 13 | # 1、导入需要的库 import tensorflow as tf from mx_rec.util.initialize import init, get_rank_id # 2、构建计算图 # ... # 3、创建saver saver = tf.compat.v1.train.Saver() # 4、获取rank_id rank_id = get_rank_id() # 5、设置需要加载的模型保存时的训练步数,比如: latest_step = 200 with tf.compat.v1.Session() as sess: saver.restore(sess, f"./saved-model/model-{rank_id}-{latest_step}") |
接口调用流程及示例,参见迁移与训练。