Script Adaptation

This section uses ResNet-50 and Pangu-alpha models as examples to describe how to enable the resumable training function in the ModelArts and MindSpore scenario. If a fault occurs in a foundation model and the dying gasp checkpoint is saved using OBS, the file may fail to be saved or be incomplete due to the limitations of current ModelArts. In this case, the periodic checkpoint is loaded to resume the training. For the solution to this problem, see 2.d in Step 3 of the pangu-alpha script adaptation.

Resnet50

  1. Download the ResNet code of the r1.7 branch in MindSpore repository as the training code. The following uses CIFAR-10 data as an example to describe how to train a model.
  2. Modify the config.py file in official/cv/resnet/src/model_utils/ in the resnet directory.
    # Modify the get_config function and set the value of config_path as follows.
    parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../config/resnet50_cifar10_config.yaml"), help="Config file path")
  3. Modify the resnet50_cifar10_config.yaml file in official/cv/resnet/config/ of the resnet directory, or configure the following parameters in the Running Parameter window on the Create Training Job page.
    enable_modelarts: True # Execute a training job on ModelArts.
    run_distribute: True  # Multi-node training model
    device_num: 8  # Number of processors used by a single node
    checkpoint_path: "checkpoint/"  # Sub-directory where the model is saved
  4. Modify the moxing_adapter.py file in official/cv/resnet/src/model_utils/ in the resnet directory.
    # Comment out the following code.
    # if config.train_url:
    #     print("Start to copy output directory")
    #     sync_data(config.output_path, config.train_url)
  5. Modify the train.py file in official/cv/resnet/ in the resnet directory. Perform different script adaptation based on the model file storage type.
    1. For the EFS-stored model file:
      1. Add a module to save the dying gasp checkpoint for resumable training.
        # ForMA: Save the dying gasp checkpoint.
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. Add functions.
        # ForMA: Add the following functions.
        def adaption_for_modelarts_efs():
            # ForMA: Use the EFS path configured for the new training job on ModelArts as the path to save and load the model.
            config.output_path = os.environ.get("CHECKPOINT_PATH")
            config.pre_trained = os.environ.get("CHECKPOINT_PATH")
            # ForMA: ModelArts automatically synchronizes data from OBS to data_url.
            config.data_path = config.data_url
            if not os.path.exists(config.output_path):
                os.makedirs(config.output_path, 0o755, exist_ok=True)
      3. Add the following code to the train_net function.
        # Delete the decorator moxing_wrapper of the train_net function.
        ...
        set_parameter()
         
        # ForMA: Configure parameters.
        adaption_for_modelarts_efs()
         
        ckpt_param_dict = load_pre_trained_checkpoint()
        ....
            if config.save_checkpoint:
                # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions.
         
                ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
                config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                             keep_checkpoint_max=config.keep_checkpoint_max,
                                             append_info=ckpt_append_info,
                                             exception_save=True)
                ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
                
                excetion_ckpt_cb = ExceptionCheckpoint(
                    prefix="resnet", directory=ckpt_save_dir, config=config_ck
                )
                
                cb += [ckpt_cb, excetion_ckpt_cb]
      4. Modify the load_pre_trained_checkpoint function.
        def load_pre_trained_checkpoint():
            """
            Load checkpoint according to pre_trained path.
            """
            param_dict = None
            if config.pre_trained:
                if os.path.isdir(config.pre_trained):
                    ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, "ckpt_0")
                    ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt")
                    ckpt_files = glob.glob(ckpt_pattern)
         
                    # ForMA: 
                    exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt")
                   exception_ckpt_files = glob.glob(exception_ckpt_pattern)
                   ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files
         
                    if not ckpt_files:
                        logger.warning(f"There is no ckpt file in {ckpt_save_dir}, "
                                       f"pre_trained is unsupported.")
                    else:
                        ckpt_files.sort(key=os.path.getmtime, reverse=True)
                        time_stamp = datetime.datetime.now()
                        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
                              f" pre trained ckpt model {ckpt_files[0]} loading",
                              flush=True)
                        param_dict = ms.load_checkpoint(ckpt_files[0])
                elif os.path.isfile(config.pre_trained):
                    param_dict = ms.load_checkpoint(config.pre_trained)
                else:
                    print(f"Invalid pre_trained {config.pre_trained} parameter.")
            return param_dict
    2. For the OBS-stored model file:
      1. Add a module to save the dying gasp checkpoint for resumable training.
        # ForMA: Save the dying gasp checkpoint.
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. Add the following code to the train_net function.
        # Add the decorator moxing_wrapper to the train_net function.
        ...
        set_parameter()
         
        # ForMA: Configure parameters.
        config.pre_trained = config.checkpoint_url
        config.output_path = config.train_url
         
        ckpt_param_dict = load_pre_trained_checkpoint()
         
        ....
            if config.save_checkpoint:
                # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions.
         
                ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
                config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                             keep_checkpoint_max=config.keep_checkpoint_max,
                                             append_info=ckpt_append_info,
                                             exception_save=True)
                ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
                
                excetion_ckpt_cb = ExceptionCheckpoint(
                    prefix="resnet", directory=ckpt_save_dir, config=config_ck
                )
                
                cb += [ckpt_cb, excetion_ckpt_cb]
      3. Modify the load_pre_trained_checkpoint function.
        def load_pre_trained_checkpoint():
            """
            Load checkpoint according to pre_trained path.
            """
            param_dict = None
            if config.pre_trained:
                if os.path.isdir(config.pre_trained):
                    # ForMA:
         
                    ckpt_save_dir = os.path.join(config.pre_trained, config.checkpoint_path, "ckpt_0")
                    ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt")
                    ckpt_files = glob.glob(ckpt_pattern)
         
                    exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt")
                   exception_ckpt_files = glob.glob(exception_ckpt_pattern)
                   ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files
         
                    if not ckpt_files:
                        logger.warning(f"There is no ckpt file in {ckpt_save_dir}, "
                                       f"pre_trained is unsupported.")
                    else:
                        ckpt_files.sort(key=os.path.getmtime, reverse=True)
                        time_stamp = datetime.datetime.now()
                        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
                              f" pre trained ckpt model {ckpt_files[0]} loading",
                              flush=True)
                        param_dict = ms.load_checkpoint(ckpt_files[0])
                elif os.path.isfile(config.pre_trained):
                    param_dict = ms.load_checkpoint(config.pre_trained)
                else:
                    print(f"Invalid pre_trained {config.pre_trained} parameter.")
            return param_dict

Pangu-alpha

  1. Download the pangu-alpha code of the r1.7 branch from the MindSpore code repository as the training code.
  2. Modify training parameters. Modify the utils.py file in the pangu_alpha directory official/nlp/pangu_alpha/src/. Change the default values of the following parameters or add new parameters.
    per_batch_size: 8
    save_checkpoint: True
    device_num: 8
    run_type: train
    strategy_load_ckpt_path: "./strategy.ckpt"
    strategy_save_ckpt_path: "./strategy.ckpt"  # Add this parameter.
    param_init_type: "fp16"
    checkpoint_url: ""  # Add the path for storing the OBS-based model file.
  3. Modify the train.py file in official/nlp/pangu_alpha/ in the pangu_alpha directory. Perform different script adaptation based on the model file storage type.
    1. For the EFS-stored model file:
      1. Add a module to save the dying gasp checkpoint for resumable training.
        # ForMA: Save the dying gasp checkpoint.
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. Configure the path for saving the model.
        # Add the following code to the script main entry:
        ...
        set_parse(opt)
         
        # ForMA: Use the EFS path configured for the new training job on ModelArts as the path to save and load the model.
        opt.save_checkpoint_path = os.environ.get("CHECKPOINT_PATH")
        opt.pre_trained = os.environ.get("CHECKPOINT_PATH")
        if not os.path.exists(opt.save_checkpoint_path):
            os.makedirs(opt.save_checkpoint_path, 0o755, exist_ok=True)
        ...
      3. Modify the run_train function. (If pipeline parallelism is enabled, that is, opt.stage_num > 1, modify the run_train_pipeline function.)
        ...
        if args_opt.parallel_mode == "data_parallel":
           # in avoid of the loop call depth
           context.set_context(max_call_depth=10000)
         
        # ForMA: Comment out the following code.
        # env variable prepare
        # group_info_file = os.getenv("GROUP_INFO_FILE")
        # if group_info_file:
        #     with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile:
        #         outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")
        ...
        # loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved)
        callback = [TimeMonitor(args_opt.sink_size)]
        ...
        if args_opt.pre_trained:
           restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
         
        callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps,  micro_size=micro_batch_interleaved)]
         
        add_checkpoint_callback_policy(args_opt, callback, rank)
        ...
      4. Modify the restore_checkpoint function.
        ...
        ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*_breakpoint.ckpt")
        ckpt_exp_files = glob.glob(ckpt_exp_pattern)
         
        # ForMA: Comment out the following code.
        # ckpt_files = []
        # for file in ckpt_all_files:
        #     if file not in ckpt_exp_files:
        #         ckpt_files.append(file)
         
        # ForMA: If the dying gasp model file exists, use it. Otherwise, use the model file that is periodically saved.
        ckpt_files = ckpt_exp_files if ckpt_exp_files else ckpt_all_files
        ...
      5. Modify the add_checkpoint_callback_policy function.
        if args_param.save_checkpoint:
                # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions.
         
                # checkpoint store epoch_num and step_num info
                ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
                ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                               keep_checkpoint_max=args_param.keep_checkpoint_max,
                                               integrated_save=False,
                                               append_info=ckpt_append_info,
                                               exception_save=True
                                               )
         
                # save checkpoint into rank directory
                ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                             directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
                                             config=ckpt_config)
         
                ckpoint_exp = ExceptionCheckpoint(
                    prefix=args_param.ckpt_name_prefix + str(rank_id),
                    directory=os.path.join(args_param.save_checkpoint_path,
                                           f"rank_{rank_id}"),
                    config=ckpt_config)
         
                callback.append(ckpoint_cb)
                callback.append(ckpoint_exp)
      6. Modify the set_parallel_context function.
        def set_parallel_context(args_opt):
            r"""Set parallel context"""
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            print("rank_id is {}, device_num is {}".format(rank, device_num))
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
         
            if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path:
                args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new"
         
            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(
                parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
                full_batch=bool(args_opt.full_batch),
                strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
                enable_parallel_optimizer=bool(args_opt.optimizer_shard),
                strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path)
            set_algo_parameters(elementwise_op_strategy_follow=True)
            _set_multi_subgraphs()
            return rank, device_num
    2. For the OBS-stored model file:
      1. Add a module to save the dying gasp checkpoint for resumable training.
        # ForMA: Save the dying gasp checkpoint.
        from mindx_elastic.terminating_message import ExceptionCheckpoint
      2. Configure the path for saving the model.
        # Add the following code to the script main entry:
        ...
        set_parse(opt)
         
        # ForMA: Add checkpoint_url to Training Input on ModelArts. The platform automatically downloads the model file to a local path.
        opt.save_checkpoint_path = opt.checkpoint_url
        opt.pre_trained = opt.checkpoint_url
        ...
      3. Modify the run_train function. (If pipeline parallelism is enabled, that is, opt.stage_num > 1, modify the run_train_pipeline function.)
        ...
        if args_opt.parallel_mode == "data_parallel":
           # in avoid of the loop call depth
           context.set_context(max_call_depth=10000)
         
        # ForMA: Comment out the following code.
        # env variable prepare
        # group_info_file = os.getenv("GROUP_INFO_FILE")
        # if group_info_file:
        #     with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile:
        #         outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")
        ...
        # loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved)
        callback = [TimeMonitor(args_opt.sink_size)]
        ...
        if args_opt.pre_trained:
           restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num)
         
        callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps,  micro_size=micro_batch_interleaved)]
         
        add_checkpoint_callback_policy(args_opt, callback, rank)
        ...
      4. Modify the restore_checkpoint function and add the checkpoint_file_sort_key function.
        import re
        def checkpoint_file_sort_key(file_path):
            # ForMA: When the OBS files are synchronized to the local host, the mtime of each file is the same and cannot be used for file sorting. Use the epoch and step information contained in the file name to sort files in descending order.
            pattern = re.compile("([0-9]+)_([0-9]+)")
            _, file_name = os.path.split(file_path)
            epoch_step_suffix = file_name.split("-")[-1]
            result = pattern.search(epoch_step_suffix)
        
            if result is None:
                return -os.path.getmtime(file_path)
        
            epoch, step = result.groups()
            return -int(epoch), -int(step)
        
        def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
            r"""
            Load checkpoint process.
            """
            # ForMa: Check whether the dying gasp checkpoint file is complete.
            is_complete = True
            lastest_exp_ckpt_size_list = []
        
            for dirpath, dirnames, filenames in os.walk(args_param.save_checkpoint_path):
                if not is_complete:
                    break
                for subdir in dirnames:
                    candi_dirs = os.path.join(args_param.save_checkpoint_path, subdir)
                    ckpt_exp_pattern = os.path.join(candi_dirs, f"*_breakpoint.ckpt")
                    ckpt_pattern = os.path.join(candi_dirs, f"*.ckpt")
                    ckpt_all_files = glob.glob(ckpt_pattern)
                    ckpt_exp_files = glob.glob(ckpt_exp_pattern)
                    if not ckpt_exp_files:
                        is_complete = False
                        break
        
                    ckpt_exp_files.sort(key=checkpoint_file_sort_key)
                    lastest_exp_ckpt = ckpt_exp_files[0]
                    lastest_exp_ckpt_size = os.path.getsize(lastest_exp_ckpt)
                    lastest_exp_ckpt_size_list.append(lastest_exp_ckpt_size)
        
                    if ckpt_all_files and ckpt_exp_files:
                        ckpt_all_files.sort(key=checkpoint_file_sort_key)
                        oldest_ckpt = ckpt_all_files[-1]
                        if lastest_exp_ckpt != oldest_ckpt and lastest_exp_ckpt_size != os.path.getsize(oldest_ckpt):
                            is_complete = False
        
            if is_complete:
                if len(set(lastest_exp_ckpt_size_list)) != 1:
                    is_complete = False
        
        
            print("======start single checkpoint", flush=True)
            ckpt_name = args_param.ckpt_name_prefix
            ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*.ckpt")
            ckpt_all_files = glob.glob(ckpt_pattern)
        
            if not ckpt_all_files:
                print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                      f"current ckpt_files found is {ckpt_all_files} "
                      f"with pattern {ckpt_pattern}, so skip the loading.")
                return
        
            ckpt_all_files.sort(key=checkpoint_file_sort_key)
            ckpt_file = ""
            for ckpt in ckpt_all_files:
                if is_complete and ckpt.endswith("breakpoint.ckpt"):
                    ckpt_file = ckpt
                    break
                if not is_complete and "breakpoint" not in ckpt:
                    ckpt_file = ckpt
                    break
        
            if not ckpt_file:
                print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                      f"current ckpt_files found is {ckpt_all_files} "
                      f"with pattern {ckpt_pattern}, so skip the loading.")
                return
        
            time_stamp = datetime.datetime.now()
            print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_file} loading", flush=True)
            # Load checkpoint files latest file
            print(f'Start to load from {ckpt_file}')
            param_dict = load_checkpoint(ckpt_file)
            if param_dict.get("epoch_num") and param_dict.get("step_num"):
                args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
                args_param.has_trained_step = int(param_dict["step_num"].data.asnumpy())
            model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
            load_param_into_net(network, param_dict)
      5. Modify the add_checkpoint_callback_policy function.
            if args_param.save_checkpoint:
                # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions.
         
                # checkpoint store epoch_num and step_num info
                ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
                ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                               keep_checkpoint_max=args_param.keep_checkpoint_max,
                                               integrated_save=False,
                                               append_info=ckpt_append_info,
                                               exception_save=True
                                               )
         
                # save checkpoint into rank directory
                ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                             directory=os.path.join(args_param.train_url, f"rank_{rank_id}"),
                                             config=ckpt_config)
         
                ckpoint_exp = ExceptionCheckpoint(
                    prefix=args_param.ckpt_name_prefix + str(rank_id),
                    directory=os.path.join(args_param.train_url,
                                           f"rank_{rank_id}"),
                    config=ckpt_config)
         
                callback.append(ckpoint_cb)
                callback.append(ckpoint_exp)
      6. Modify the set_parallel_context function.
        def set_parallel_context(args_opt):
            r"""Set parallel context"""
            D.init()
            device_num = D.get_group_size()
            rank = D.get_rank()
            print("rank_id is {}, device_num is {}".format(rank, device_num))
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
         
            if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path:
                args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new"
         
            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(
                parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
                full_batch=bool(args_opt.full_batch),
                strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
                enable_parallel_optimizer=bool(args_opt.optimizer_shard),
                strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path)
            set_algo_parameters(elementwise_op_strategy_follow=True)
            _set_multi_subgraphs()
            return rank, device_num