tf.compat.v1.train.Saver.save
Function
Saves a model.
Prototype
1 | def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False, is_incremental_checkpoint=False, save_delta=False) |
Parameters
Parameter |
Type |
Mandatory/Optional |
Description |
|---|---|---|---|
sess |
Session |
Mandatory |
Model session to be saved. |
save_path |
str |
Mandatory |
Model save path. The local file system and HDFS file system are supported. The length range is [1, 1024].
|
global_step |
int, np.int64 |
Optional |
Adds the number of training steps to the checkpoint file name. The default value is None, and the value range is [0, 2147483647]. |
latest_filename |
str |
Optional |
Optional name of the protocol buffer file. The file contains the latest checkpoint list. The default value is None. Value range: [1, 50]. |
meta_graph_suffix |
str |
Optional |
Suffix of the MetaGraphDef file. The default value is meta, and the suffix length range is [1, 50]. |
write_meta_graph |
bool |
Optional |
Whether to write data to the MetaGraph file. The default value is True. Value:
|
write_state |
bool |
Optional |
Whether to write data to the CheckpointStateProto file. The default value is True. Value:
|
strip_default_attrs |
bool |
Optional |
Whether to delete the default attribute from NodeDefs when saving a model file. The default value is False.
|
save_debug_info |
bool |
Optional |
Whether to save the debugging information. The default value is False.
|
is_incremental_checkpoint |
bool |
Optional |
Whether to save and load the incremental model. The default value is False.
|
save_delta |
bool |
Optional |
Whether to save the incremental model.
|
Return Value
- Success: model_checkpoint_path, that is, the path for saving the model.
- Failure: An exception is thrown.
Example
The following briefly describes the usage. For details, see the little demo in the Rec SDK code repository.
1 2 3 4 5 6 7 8 9 10 11 12 13 | #1. Import required libraries. import tensorflow as tf from mx_rec.util.initialize import init, get_rank_id # 2. Build a computational graph. # ... # 3. Create a saver. saver = tf.compat.v1.train.Saver() # 4. Obtain the rank ID. rank_id = get_rank_id() # 5. Set the number of training steps when the model needs to be saved. global_step = 200 with tf.compat.v1.Session() as sess: saver.save(sess, f"./saved-model/model-{rank_id}", global_step=global_step) |
See Also
For details about the API call sequence and example, see Porting and Training.