TensorFlow用于模型保存的接口。
1 | def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False, is_incremental_checkpoint=False, save_delta=False) |
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
sess |
Session |
必选 |
需要保存模型的Session。 |
save_path |
str |
必选 |
模型保存路径。支持本地文件系统和HDFS文件系统,长度范围为[1,150]。
|
global_step |
int, np.int64 |
可选 |
在checkpoint文件名补充训练步数,默认值为“None”,取值范围为[0, 2147483647]。 |
latest_filename |
str |
可选 |
protocol buffer文件的可选名称,该文件将包含最新checkpoint列表,默认为“None”。长度范围[1, 50]。 |
meta_graph_suffix |
str |
可选 |
MetaGraphDef文件的后缀,默认为“meta”,长度范围[1, 50]。 |
write_meta_graph |
bool |
可选 |
是否写入MetaGraph文件,默认为“True”。 取值范围:
|
write_state |
bool |
可选 |
是否写入CheckpointStateProto文件,默认为“True”。 取值范围:
|
strip_default_attrs |
bool |
可选 |
保存模型文件时,是否删除NodeDefs中的默认值属性,默认为“False”。
|
save_debug_info |
bool |
可选 |
是否保存Debug信息,默认为“False”。
|
is_incremental_checkpoint |
bool |
可选 |
是否开启模型增量保存与加载,默认为False。
|
save_delta |
bool |
可选 |
是否保存增量模型:
|
具体使用方法可参考Rec SDK代码仓中的little demo,以下仅提供一个使用的流程示例。
1 2 3 4 5 6 7 8 9 10 11 12 13 | # 1、导入需要的库 import tensorflow as tf from mx_rec.util.initialize import init, get_rank_id # 2、构建计算图 # ... # 3、创建saver saver = tf.compat.v1.train.Saver() # 4、获取rank_id rank_id = get_rank_id() # 5、设置需要保存模型时的训练步数 global_step = 200 with tf.compat.v1.Session() as sess: saver.save(sess, f"./saved-model/model-{rank_id}", global_step=global_step) |
接口调用流程及示例,参见模型迁移与训练。