Training Recovery Principles

After a fault is rectified, the training process is restarted. The started training process needs to save and load model weights to return to the training state when a job is interrupted. During normal training, checkpoint files of the training model weights are saved at regular intervals. The training process restarted after training termination can load the previously saved checkpoint file to restore the model weight state at a certain checkpoint, thus reducing training time. The methods of saving and loading checkpoints vary according to frameworks. The following lists some examples of saving and loading checkpoints for TensorFlow, PyTorch, and MindSpore. You can modify your training model script based on these examples.

PyTorch

  1. Save a checkpoint.
    def save_checkpoint(state, is_best, args, filename='checkpoint.pth.tar'):
        filename2 = os.path.join(args.save_ckpt_path, filename)
        torch.save(state, filename2)
        if is_best:
            shutil.copyfile(filename2, os.path.join(args.save_ckpt_path, 'model_best.pth.tar'))
  2. Load the checkpoint.
    checkpoint = torch.load(args.checkpoint_path, map_location=loc)
                args.start_epoch = checkpoint['epoch']
                best_acc1 = checkpoint['best_acc1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])

MindSpore

  1. Save a checkpoint.
    ms.save_checkpoint(net, "./lenet.ckpt",
                       choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
  2. Load the checkpoint.
    param_dict = ms.load_checkpoint("./lenet.ckpt")

TensorFlow

  1. Use the tf.compat.v1.train.CheckpointManager interface to manage checkpoints.
      checkpoint_manager = tf.train.CheckpointManager(
          runnable.checkpoint,
          directory=flags_obj.model_dir,
          max_to_keep=10,
          step_counter=runnable.global_step,
          checkpoint_interval=checkpoint_interval)
  2. Save a checkpoint (create a checkpoint).
    Save(
        Checkpoint_number=None, check_internal=True, options=None
    )
  3. Load the saved checkpoint (try to load the latest checkpoint in the directory).
    Restore_or_initialize()