训练重启

在完成故障处理后,训练进程会被重新拉起,拉起的训练进程需要回到断点时的训练状态,因此需要完成模型权重的保存和加载。在正常训练中,每隔一段时间保存训练模型权重的checkpoint文件,在断点后新拉起的进程可以加载之前保存的checkpoint文件,从而恢复到之前保存点的模型权重状态,减少训练时间。对于不同框架,保存和加载checkpoint的方法不一样,以下给出了TensorFlowPyTorchMindSpore保存和加载checkpoint的示例,用户需按照示例修改自己的训练模型脚本

PyTorch

  1. 保存checkpoint。
    def save_checkpoint(state, is_best, args, filename='checkpoint.pth.tar'):
        filename2 = os.path.join(args.save_ckpt_path, filename)
        torch.save(state, filename2)
        if is_best:
            shutil.copyfile(filename2, os.path.join(args.save_ckpt_path, 'model_best.pth.tar'))
  2. 加载checkpoint。
    checkpoint = torch.load(args.checkpoint_path, map_location=loc)
                args.start_epoch = checkpoint['epoch']
                best_acc1 = checkpoint['best_acc1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])

MindSpore

  1. 保存checkpoint。
    ms.save_checkpoint(net, "./lenet.ckpt",
                       choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
  2. 加载checkpoint。
    param_dict = ms.load_checkpoint("./lenet.ckpt")

TensorFlow

  1. 使用tf.compat.v1.train.CheckpointManager接口进行checkpoint管理。
      checkpoint_manager = tf.train.CheckpointManager(
          runnable.checkpoint,
          directory=flags_obj.model_dir,
          max_to_keep=10,
          step_counter=runnable.global_step,
          checkpoint_interval=checkpoint_interval)
  2. 保存checkpoint(创建一个新的checkpoint)。
    Save(
        Checkpoint_number=None, check_internal=True, options=None
    )
  3. 加载保存的checkpoint(尝试加载从目录中的最新的checkpoint)。
    Restore_or_initialize()