脚本适配

本节以Resnet50和Pangu-alpha模型为例,展示如何在ModelArts平台+MindSpore框架下开启断点续训功能。受限于ModelArts平台当前版本设计,大模型场景下发生故障时,若通过OBS保存临终checkpoint,可能存在文件保存失败或不完整的情况。在该情况下,会加载周期性checkpoint继续训练。该问题的解决方案可以参考pangu-alpha脚本适配步骤3中的2.d部分。

Resnet50

  1. 下载MindSpore代码仓中r1.7分支的resnet代码作为训练代码,以cifar10数据进行模型训练为例。
  2. 修改resnet目录“official/cv/resnet/src/model_utils/”中的config.py文件。

    # 修改get_config函数,为config_path参数添加如下默认值
    parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../config/resnet50_cifar10_config.yaml"), help="Config file path")

  3. 修改resnet目录“official/cv/resnet/config/”中的“resnet50_cifar10_config.yaml”文件,或者直接在“创建训练作业”的“运行参数”中配置如下参数。

    enable_modelarts: True # 在ModelArts平台下执行训练任务
    run_distribute:True  # 多节点训练模型
    device_num: 8  # 单节点使用的芯片数量
    checkpoint_path: "checkpoint/"  # 模型保存子目录

  4. 修改resnet目录“official/cv/resnet/src/model_utils/”中的“moxing_adapter.py”文件。

    # 对以下代码进行注释
    # if config.train_url:
    #     print("Start to copy output directory")
    #     sync_data(config.output_path, config.train_url)

  5. 修改resnet目录“official/cv/resnet/”中的“train.py”文件。针对efs存储模型文件和obs存储模型文件的场景,分别进行此脚本内容适配。

    1. 使用efs进行模型文件存储。
      1. 添加断点续训临终checkpoint保存模块。
        # ForMA: 引入临终checkpoint保存
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. 新增函数。
        # ForMA: 新建如下函数
        def adaption_for_modelarts_efs():
            # ForMA: 使用ModelArts新建训练任务所配置的efs路径作为模型保存与加载路径
            config.output_path = os.environ.get("CHECKPOINT_PATH")
            config.pre_trained = os.environ.get("CHECKPOINT_PATH")
            # ForMA: ModelArts会自动将数据从obs同步到data_url
            config.data_path = config.data_url
            if not os.path.exists(config.output_path):
                os.makedirs(config.output_path, 0o755, exist_ok=True)
      3. 在train_net函数中添加如下代码。
        # 删除train_net函数的装饰器“moxing_wrapper”
        …
        set_parameter()
         
        # ForMA: 参数配置
        adaption_for_modelarts_efs()
         
        ckpt_param_dict = load_pre_trained_checkpoint()
        ….
            if config.save_checkpoint:
                # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能
         
                ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
                config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                             keep_checkpoint_max=config.keep_checkpoint_max,
                                             append_info=ckpt_append_info,
                                             exception_save=True)
                ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
                
                excetion_ckpt_cb = ExceptionCheckpoint(
                    prefix="resnet", directory=ckpt_save_dir, config=config_ck
                )
                
                cb += [ckpt_cb, excetion_ckpt_cb]
      4. 修改load_pre_trained_checkpoint函数。
        def load_pre_trained_checkpoint():
            """
            Load checkpoint according to pre_trained path.
            """
            param_dict = None
            if config.pre_trained:
                if os.path.isdir(config.pre_trained):
                    ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, "ckpt_0")
                    ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt")
                    ckpt_files = glob.glob(ckpt_pattern)
         
                    # ForMA: 
                    exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt")
                   exception_ckpt_files = glob.glob(exception_ckpt_pattern)
                   ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files
         
                    if not ckpt_files:
                        logger.warning(f"There is no ckpt file in {ckpt_save_dir}, "
                                       f"pre_trained is unsupported.")
                    else:
                        ckpt_files.sort(key=os.path.getmtime, reverse=True)
                        time_stamp = datetime.datetime.now()
                        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
                              f" pre trained ckpt model {ckpt_files[0]} loading",
                              flush=True)
                        param_dict = ms.load_checkpoint(ckpt_files[0])
                elif os.path.isfile(config.pre_trained):
                    param_dict = ms.load_checkpoint(config.pre_trained)
                else:
                    print(f"Invalid pre_trained {config.pre_trained} parameter.")
            return param_dict
    2. 使用obs进行模型文件存储。
      1. 添加断点续训临终checkpoint保存模块。
        # ForMA: 引入临终checkpoint保存
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. 在train_net函数中添加如下代码。
        # 为train_net函数添加装饰器“moxing_wrapper”
        …
        set_parameter()
         
        # ForMA: 参数配置
        config.pre_trained = config.checkpoint_url
        config.output_path = config.train_url
         
        ckpt_param_dict = load_pre_trained_checkpoint()
         
        ….
            if config.save_checkpoint:
                # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能
         
                ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
                config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                             keep_checkpoint_max=config.keep_checkpoint_max,
                                             append_info=ckpt_append_info,
                                             exception_save=True)
                ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
                
                excetion_ckpt_cb = ExceptionCheckpoint(
                    prefix="resnet", directory=ckpt_save_dir, config=config_ck
                )
                
                cb += [ckpt_cb, excetion_ckpt_cb]
      3. 修改load_pre_trained_checkpoint函数。
        def load_pre_trained_checkpoint():
            """
            Load checkpoint according to pre_trained path.
            """
            param_dict = None
            if config.pre_trained:
                if os.path.isdir(config.pre_trained):
                    # ForMA:
         
                    ckpt_save_dir = os.path.join(config.pre_trained, config.checkpoint_path, "ckpt_0")
                    ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt")
                    ckpt_files = glob.glob(ckpt_pattern)
         
                    exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt")
                   exception_ckpt_files = glob.glob(exception_ckpt_pattern)
                   ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files
         
                    if not ckpt_files:
                        logger.warning(f"There is no ckpt file in {ckpt_save_dir}, "
                                       f"pre_trained is unsupported.")
                    else:
                        ckpt_files.sort(key=os.path.getmtime, reverse=True)
                        time_stamp = datetime.datetime.now()
                        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
                              f" pre trained ckpt model {ckpt_files[0]} loading",
                              flush=True)
                        param_dict = ms.load_checkpoint(ckpt_files[0])
                elif os.path.isfile(config.pre_trained):
                    param_dict = ms.load_checkpoint(config.pre_trained)
                else:
                    print(f"Invalid pre_trained {config.pre_trained} parameter.")
            return param_dict

Pangu-alpha

  1. 下载MindSpore代码仓中r1.7分支的pangu-alpha代码作为训练代码。
  2. 修改训练参数。修改pangu_alpha目录“official/nlp/pangu_alpha/src/”中的“utils.py”文件,修改下列参数的默认值,或者添加新的参数。

    per_batch_size: 8
    save_checkpoint: True
    device_num: 8
    run_type: train
    strategy_load_ckpt_path: "./strategy.ckpt"
    strategy_save_ckpt_path: "./strategy.ckpt"  # 新增该参数
    param_init_type: "fp16"
    checkpoint_url: ""  #  新增该参数 obs存储的模型文件路径

  3. 修改pangu_alpha目录“official/nlp/pangu_alpha/”中的“train.py”文件。针对efs存储模型文件和obs存储模型文件的场景,分别进行此脚本内容适配。

    1. 使用efs进行模型文件存储。
      1. 添加断点续训临终checkpoint保存模块。
        # ForMA: 引入临终checkpoint保存
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. 配置模型保存地址。
        # 脚本main入口添加如下代码
        …
        set_parse(opt)
         
        # ForMA: 使用ModelArts新建训练任务所配置的efs路径作为模型保存与加载路径
        opt.save_checkpoint_path = os.environ.get("CHECKPOINT_PATH")
        opt.pre_trained = os.environ.get("CHECKPOINT_PATH")
        if not os.path.exists(opt.save_checkpoint_path):
            os.makedirs(opt.save_checkpoint_path, 0o755, exist_ok=True)
      3. run_train函数按照下方内容进行修改(如果开启pipeline并行,即opt.stage_num>1,按照如下方式修改run_train_pipeline函数)。
        …
        if args_opt.parallel_mode == "data_parallel":
           # in avoid of the loop call depth
           context.set_context(max_call_depth=10000)
         
        # ForMA:注释如下代码
        # env variable prepare
        # group_info_file = os.getenv("GROUP_INFO_FILE")
        # if group_info_file:
        #     with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile:
        #         outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")# loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved)
        callback = [TimeMonitor(args_opt.sink_size)]if args_opt.pre_trained:
           restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
         
        callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps,  micro_size=micro_batch_interleaved)]
         
        add_checkpoint_callback_policy(args_opt, callback, rank)
        …
      4. 修改restore_checkpoint函数。
        ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*_breakpoint.ckpt")
        ckpt_exp_files = glob.glob(ckpt_exp_pattern)
         
        # ForMA: 注释如下代码
        # ckpt_files = []
        # for file in ckpt_all_files:
        #     if file not in ckpt_exp_files:
        #         ckpt_files.append(file)
         
        # ForMA: 如果存在临终模型文件,则使用临终模型文件,否则使用周期性保存模型文件
        ckpt_files = ckpt_exp_files if ckpt_exp_files else ckpt_all_files
      5. 修改add_checkpoint_callback_policy函数。
        if args_param.save_checkpoint:
                # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能
         
                # checkpoint store epoch_num and step_num info
                ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
                ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                               keep_checkpoint_max=args_param.keep_checkpoint_max,
                                               integrated_save=False,
                                               append_info=ckpt_append_info,
                                               exception_save=True
                                               )
         
                # save checkpoint into rank directory
                ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                             directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
                                             config=ckpt_config)
         
                ckpoint_exp = ExceptionCheckpoint(
                    prefix=args_param.ckpt_name_prefix + str(rank_id),
                    directory=os.path.join(args_param.save_checkpoint_path,
                                           f"rank_{rank_id}"),
                    config=ckpt_config)
         
                callback.append(ckpoint_cb)
                callback.append(ckpoint_exp)
      6. 修改set_parallel_context函数。
        def set_parallel_context(args_opt):
            r"""Set parallel context"""
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            print("rank_id is {}, device_num is {}".format(rank, device_num))
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
         
            if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path:
                args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new"
         
            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(
                parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
                full_batch=bool(args_opt.full_batch),
                strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
                enable_parallel_optimizer=bool(args_opt.optimizer_shard),
                strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path)
            set_algo_parameters(elementwise_op_strategy_follow=True)
            _set_multi_subgraphs()
            return rank, device_num
    2. 使用obs进行模型文件存储。
      1. 添加断点续训临终checkpoint保存模块。
        # ForMA: 引入临终checkpoint保存
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. 配置模型保存地址。
        # 脚本main入口添加如下代码
        …
        set_parse(opt)
         
        # ForMA: ModelArts添加“训练输入”参数checkpoint_url,平台会自动下载模型文件至本地路径
        opt.save_checkpoint_path = opt.checkpoint_url
        opt.pre_trained = opt.checkpoint_url
      3. run_train函数按照下方内容进行修改(如果开启pipeline并行,即opt.stage_num>1,按照如下方式修改run_train_pipeline函数)。
        …
        if args_opt.parallel_mode == "data_parallel":
           # in avoid of the loop call depth
           context.set_context(max_call_depth=10000)
         
        # ForMA:注释如下代码
        # env variable prepare
        # group_info_file = os.getenv("GROUP_INFO_FILE")
        # if group_info_file:
        #     with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile:
        #         outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")# loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved)
        callback = [TimeMonitor(args_opt.sink_size)]if args_opt.pre_trained:
           restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
         
        callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps,  micro_size=micro_batch_interleaved)]
         
        add_checkpoint_callback_policy(args_opt, callback, rank)
        …
      4. 修改restore_checkpoint函数,添加checkpoint_file_sort_key函数。
        import re
        def checkpoint_file_sort_key(file_path):
            # ForMA:obs文件同步至本地,文件mtime是相同的,无法用于文件排序。使用文件名包含的epoch,step信息降序排序
            pattern = re.compile("([0-9]+)_([0-9]+)")
            _, file_name = os.path.split(file_path)
            epoch_step_suffix = file_name.split("-")[-1]
            result = pattern.search(epoch_step_suffix)
        
            if result is None:
                return -os.path.getmtime(file_path)
        
            epoch, step = result.groups()
            return -int(epoch), -int(step)
        
        def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
            r"""
            Load checkpoint process.
            """
            # ForMa: 检查临终checkpoint文件是否完整
            is_complete = True
            lastest_exp_ckpt_size_list = []
        
            for dirpath, dirnames, filenames in os.walk(args_param.save_checkpoint_path):
                if not is_complete:
                    break
                for subdir in dirnames:
                    candi_dirs = os.path.join(args_param.save_checkpoint_path, subdir)
                    ckpt_exp_pattern = os.path.join(candi_dirs, f"*_breakpoint.ckpt")
                    ckpt_pattern = os.path.join(candi_dirs, f"*.ckpt")
                    ckpt_all_files = glob.glob(ckpt_pattern)
                    ckpt_exp_files = glob.glob(ckpt_exp_pattern)
                    if not ckpt_exp_files:
                        is_complete = False
                        break
        
                    ckpt_exp_files.sort(key=checkpoint_file_sort_key)
                    lastest_exp_ckpt = ckpt_exp_files[0]
                    lastest_exp_ckpt_size = os.path.getsize(lastest_exp_ckpt)
                    lastest_exp_ckpt_size_list.append(lastest_exp_ckpt_size)
        
                    if ckpt_all_files and ckpt_exp_files:
                        ckpt_all_files.sort(key=checkpoint_file_sort_key)
                        oldest_ckpt = ckpt_all_files[-1]
                        if lastest_exp_ckpt != oldest_ckpt and lastest_exp_ckpt_size != os.path.getsize(oldest_ckpt):
                            is_complete = False
        
            if is_complete:
                if len(set(lastest_exp_ckpt_size_list)) != 1:
                    is_complete = False
        
        
            print("======start single checkpoint", flush=True)
            ckpt_name = args_param.ckpt_name_prefix
            ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*.ckpt")
            ckpt_all_files = glob.glob(ckpt_pattern)
        
            if not ckpt_all_files:
                print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                      f"current ckpt_files found is {ckpt_all_files} "
                      f"with pattern {ckpt_pattern}, so skip the loading.")
                return
        
            ckpt_all_files.sort(key=checkpoint_file_sort_key)
            ckpt_file = ""
            for ckpt in ckpt_all_files:
                if is_complete and ckpt.endswith("breakpoint.ckpt"):
                    ckpt_file = ckpt
                    break
                if not is_complete and "breakpoint" not in ckpt:
                    ckpt_file = ckpt
                    break
        
            if not ckpt_file:
                print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                      f"current ckpt_files found is {ckpt_all_files} "
                      f"with pattern {ckpt_pattern}, so skip the loading.")
                return
        
            time_stamp = datetime.datetime.now()
            print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_file} loading", flush=True)
            # Load checkpoint files latest file
            print(f'Start to load from {ckpt_file}')
            param_dict = load_checkpoint(ckpt_file)
            if param_dict.get("epoch_num") and param_dict.get("step_num"):
                args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
                args_param.has_trained_step = int(param_dict["step_num"].data.asnumpy())
            model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
            load_param_into_net(network, param_dict)
      5. 修改add_checkpoint_callback_policy函数。
            if args_param.save_checkpoint:
                # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能
         
                # checkpoint store epoch_num and step_num info
                ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
                ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                               keep_checkpoint_max=args_param.keep_checkpoint_max,
                                               integrated_save=False,
                                               append_info=ckpt_append_info,
                                               exception_save=True
                                               )
         
                # save checkpoint into rank directory
                ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                             directory=os.path.join(args_param.train_url, f"rank_{rank_id}"),
                                             config=ckpt_config)
         
                ckpoint_exp = ExceptionCheckpoint(
                    prefix=args_param.ckpt_name_prefix + str(rank_id),
                    directory=os.path.join(args_param.train_url,
                                           f"rank_{rank_id}"),
                    config=ckpt_config)
         
                callback.append(ckpoint_cb)
                callback.append(ckpoint_exp)
      6. 修改set_parallel_context函数。
        def set_parallel_context(args_opt):
            r"""Set parallel context"""
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            print("rank_id is {}, device_num is {}".format(rank, device_num))
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
         
            if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path:
                args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new"
         
            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(
                parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
                full_batch=bool(args_opt.full_batch),
                strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
                enable_parallel_optimizer=bool(args_opt.optimizer_shard),
                strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path)
            set_algo_parameters(elementwise_op_strategy_follow=True)
            _set_multi_subgraphs()
            return rank, device_num