def __init__(self,
checkpoint_dir,
save_secs=None,
save_steps=None,
saver=None,
checkpoint_basename="model.ckpt",
scaffold=None,
listeners=None)
NPUCheckpointSaverHook类的构造函数,用于保存checkpoint文件。NPUCheckpointSaverHook类继承了CheckpointSaverHook类,可以调用基类的原生接口。
在使用NPUEstimator并且配置iteration_per_loop>1时,该Hook可能不生效。
参数名 |
输入/输出 |
描述 |
---|---|---|
checkpoint_dir |
输入 |
保存checkpoint文件的路径。 |
save_secs |
输入 |
每隔多少秒保存一次。 |
save_steps |
输入 |
每隔多少step保存一次。 |
saver |
输入 |
Saver对象。 |
checkpoint_basename |
输入 |
checkpoint文件的basename。 |
scaffold |
输入 |
获取saver对象的Scaffold。 |
listeners |
输入 |
CheckpointSaverListener子类示例,用于保存checkpoint。 |
返回NPUCheckpointSaverHook类对象。
from npu_bridge.npu_init import * checkpoint_hook = NPUCheckpointSaverHook(checkpoint_dir='./ckpt', save_steps=2000) ... mnist_classifier.train( input_fn=train_input_fn, steps=2000, hooks=[checkpoint_hook])