tf.compat.v1.train.Saver.save

功能描述

TensorFlow用于模型保存的接口。

函数原型

def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False)

参数说明

参数名

类型

可选/必选

说明

sess

Session

必选

需要保存模型的Session。

save_path

str

必选

模型保存路径。支持本地文件系统和HDFS文件系统,长度范围为[1,150]。

global_step

int, np.int64

可选

在checkpoint文件名补充训练步数,默认值为“None”,取值范围为[0, 2147483647]。

latest_filename

str

可选

protocol buffer文件的可选名称,该文件将包含最新checkpoint列表,默认为“None”。长度范围[1, 50]。

meta_graph_suffix

str

可选

MetaGraphDef文件的后缀,默认为“meta”,长度范围[1, 50]。

write_meta_graph

bool

可选

是否写入MetaGraph文件,默认为“True”

取值范围:

  • True:写入MetaGraph文件
  • False:不写入MetaGraph文件

write_state

bool

可选

是否写入CheckpointStateProto文件,默认为“True”

取值范围:

  • True:写入CheckpointStateProto文件
  • False:不写入CheckpointStateProto文件

strip_default_attrs

bool

可选

保存模型文件时,是否删除NodeDefs中的默认值属性,默认为“False”

  • 参数值为“True”,则默认值属性将在接口调用时从NodeDefs中删除。
  • 参数值为“False”,则不进行删除操作。

save_debug_info

bool

可选

是否保存Debug信息,默认为“False”

  • 参数值为“True”,则将图形调试信息保存到一个单独的文件中,该文件位于“save_path”对应的目录中,并在生成文件的扩展名之前添加“_debug”。仅当“write_meta_graph”为`“True”`时,此功能才会生效。
  • 参数值为“False”,则不保存Debug信息。

返回值说明

使用示例

具体使用方法可参考mxRec代码仓中的little demo,以下仅提供一个使用的流程示例。

# 1、导入需要的库
import tensorflow as tf
from mx_rec.util.initialize import init, get_rank_id
# 2、构建计算图
# ...
# 3、创建saver
saver = tf.compat.v1.train.Saver() 
# 4、获取rank_id
rank_id = get_rank_id()
# 5、设置需要保存模型时的训练步数
global_step = 200
with tf.compat.v1.Session() as sess:
    saver.save(sess, f"./saved-model/model-{rank_id}", global_step=global_step)

参考资源

接口调用流程及示例,参见模型迁移与训练