在完成故障处理后,训练进程会被重新拉起,拉起的训练进程需要回到断点时的训练状态,因此需要完成模型权重的保存和加载。在正常训练中,每隔一段时间保存训练模型权重的checkpoint文件,在断点后新拉起的进程可以加载之前保存的checkpoint文件,从而恢复到之前保存点的模型权重状态,减少训练时间。对于不同框架,保存和加载checkpoint的方法不一样,以下给出了TensorFlow、PyTorch、MindSpore保存和加载checkpoint的示例,用户需按照示例修改自己的训练模型脚本。
def save_checkpoint(state, is_best, args, filename='checkpoint.pth.tar'): filename2 = os.path.join(args.save_ckpt_path, filename) torch.save(state, filename2) if is_best: shutil.copyfile(filename2, os.path.join(args.save_ckpt_path, 'model_best.pth.tar'))
checkpoint = torch.load(args.checkpoint_path, map_location=loc) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer'])
ms.save_checkpoint(net, "./lenet.ckpt", choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
param_dict = ms.load_checkpoint("./lenet.ckpt")
checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval)
Save( Checkpoint_number=None, check_internal=True, options=None )
Restore_or_initialize()