tf.compat.v1.train.Saver.save

功能描述

TensorFlow用于模型保存的接口。

函数原型

1
def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False, is_incremental_checkpoint=False, save_delta=False)

参数说明

参数名

类型

可选/必选

说明

sess

Session

必选

需要保存模型的Session。

save_path

str

必选

模型保存路径。支持本地文件系统和HDFS文件系统,长度范围为[1,150]。

  • 使用HDFS路径保存数据时,日志中会打印“hdfsWrite: FSDataOutputStream#write error”的日志,该日志不影响数据保存,可忽略。
  • 多节点使用HDFS进行训练时,要使用相同的HDFS路径作为保存路径。
  • ​在使用多卡训练时,save_path不需要传入带卡ID的参数,多卡训练的结果将自动合并保存在save_path中。在后续多卡进行模型加载时,也会自动加载属于本卡的参数。

global_step

int, np.int64

可选

在checkpoint文件名补充训练步数,默认值为“None”,取值范围为[0, 2147483647]。

latest_filename

str

可选

protocol buffer文件的可选名称,该文件将包含最新checkpoint列表,默认为“None”。长度范围[1, 50]。

meta_graph_suffix

str

可选

MetaGraphDef文件的后缀,默认为“meta”,长度范围[1, 50]。

write_meta_graph

bool

可选

是否写入MetaGraph文件,默认为“True”

取值范围:

  • True:写入MetaGraph文件
  • False:不写入MetaGraph文件

write_state

bool

可选

是否写入CheckpointStateProto文件,默认为“True”

取值范围:

  • True:写入CheckpointStateProto文件
  • False:不写入CheckpointStateProto文件

strip_default_attrs

bool

可选

保存模型文件时,是否删除NodeDefs中的默认值属性,默认为“False”

  • 参数值为“True”,则默认值属性将在接口调用时从NodeDefs中删除。
  • 参数值为“False”,则不进行删除操作。

save_debug_info

bool

可选

是否保存Debug信息,默认为“False”

  • 参数值为“True”,则将图形调试信息保存到一个单独的文件中,该文件位于“save_path”对应的目录中,并在生成文件的扩展名之前添加“_debug”。仅当“write_meta_graph”为`“True”`时,此功能才会生效。
  • 参数值为“False”,则不保存Debug信息。

is_incremental_checkpoint

bool

可选

是否开启模型增量保存与加载,默认为False。

  • True:开启模型增量保存与加载。
  • False:关闭模型增量保存与加载。

save_delta

bool

可选

是否保存增量模型:

  • True:保存增量模型。
  • False:不保存增量模型,保存全量模型。

返回值说明

使用示例

具体使用方法可参考Rec SDK代码仓中的little demo,以下仅提供一个使用的流程示例。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# 1、导入需要的库
import tensorflow as tf
from mx_rec.util.initialize import init, get_rank_id
# 2、构建计算图
# ...
# 3、创建saver
saver = tf.compat.v1.train.Saver() 
# 4、获取rank_id
rank_id = get_rank_id()
# 5、设置需要保存模型时的训练步数
global_step = 200
with tf.compat.v1.Session() as sess:
    saver.save(sess, f"./saved-model/model-{rank_id}", global_step=global_step)

参考资源

接口调用流程及示例,参见模型迁移与训练