tf.compat.v1.train.Saver.save
功能描述
TensorFlow用于模型保存的接口。
函数原型
def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False)
参数说明
参数名  | 
类型  | 
可选/必选  | 
说明  | 
|---|---|---|---|
sess  | 
Session  | 
必选  | 
需要保存模型的Session。  | 
save_path  | 
str  | 
必选  | 
模型保存路径。支持本地文件系统和HDFS文件系统,长度范围为[1,150]。  | 
global_step  | 
int, np.int64  | 
可选  | 
在checkpoint文件名补充训练步数,默认值为“None”,取值范围为[0, 2147483647]。  | 
latest_filename  | 
str  | 
可选  | 
protocol buffer文件的可选名称,该文件将包含最新checkpoint列表,默认为“None”。长度范围[1, 50]。  | 
meta_graph_suffix  | 
str  | 
可选  | 
MetaGraphDef文件的后缀,默认为“meta”,长度范围[1, 50]。  | 
write_meta_graph  | 
bool  | 
可选  | 
是否写入MetaGraph文件,默认为“True”。 取值范围: 
  | 
write_state  | 
bool  | 
可选  | 
是否写入CheckpointStateProto文件,默认为“True”。 取值范围: 
  | 
strip_default_attrs  | 
bool  | 
可选  | 
保存模型文件时,是否删除NodeDefs中的默认值属性,默认为“False”。 
  | 
save_debug_info  | 
bool  | 
可选  | 
是否保存Debug信息,默认为“False”。 
  | 
返回值说明
- 成功:返回“model_checkpoint_path”,即模型保存路径。
 - 失败:抛出异常。
 
使用示例
具体使用方法可参考mxRec代码仓中的little demo,以下仅提供一个使用的流程示例。
# 1、导入需要的库
import tensorflow as tf
from mx_rec.util.initialize import init, get_rank_id
# 2、构建计算图
# ...
# 3、创建saver
saver = tf.compat.v1.train.Saver() 
# 4、获取rank_id
rank_id = get_rank_id()
# 5、设置需要保存模型时的训练步数
global_step = 200
with tf.compat.v1.Session() as sess:
    saver.save(sess, f"./saved-model/model-{rank_id}", global_step=global_step)
参考资源
接口调用流程及示例,参见模型迁移与训练。
父主题: TensorFlow相关接口