仅在某个Device上保存Checkpoint数据
2022/09/20
235
问题信息
| 问题来源 | 产品大类 | 产品子类 | 关键字 |
|---|---|---|---|
| 官方 | 模型训练 | TensorFlow | 分布式训练、checkpoint数据 |
问题现象描述
分布式训练场景下,如果用户只希望在某个device上保存checkpoint数据,而不希望在其他device上保存checkpoint数据。
原因分析
修改训练脚本。
解决措施
按照如下方法修改训练脚本:
TensorFlow原始代码:
self._classifier=tf.estimator.Estimator(
model_fn=cnn_model_fn,
model_dir=self._model_dir,
config=tf.estimator.RunConfig(
save_checkpoints_steps=50 if hvd.rank() == 0 else None,
keep_checkpoint_max=1))迁移后的代码:
self._classifier=NPUEstimator(
model_fn=cnn_model_fn,
model_dir=self._model_dir,
config=tf.estimator.NPURunConfig(
save_checkpoints_steps=50 if get_rank_id() == 0 else 0,
keep_checkpoint_max=1))


