Script Adaptation
This section uses ResNet-50 and Pangu-alpha models as examples to describe how to enable the resumable training function in the ModelArts and MindSpore scenario. If a fault occurs in a foundation model and the dying gasp checkpoint is saved using OBS, the file may fail to be saved or be incomplete due to the limitations of current ModelArts. In this case, the periodic checkpoint is loaded to resume the training. For the solution to this problem, see 2.d in Step 3 of the pangu-alpha script adaptation.
Resnet50
- Download the ResNet code of the r1.7 branch in MindSpore repository as the training code. The following uses CIFAR-10 data as an example to describe how to train a model.
- Modify the config.py file in official/cv/resnet/src/model_utils/ in the resnet directory.
# Modify the get_config function and set the value of config_path as follows. parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../config/resnet50_cifar10_config.yaml"), help="Config file path") - Modify the resnet50_cifar10_config.yaml file in official/cv/resnet/config/ of the resnet directory, or configure the following parameters in the Running Parameter window on the Create Training Job page.
enable_modelarts: True # Execute a training job on ModelArts. run_distribute: True # Multi-node training model device_num: 8 # Number of processors used by a single node checkpoint_path: "checkpoint/" # Sub-directory where the model is saved
- Modify the moxing_adapter.py file in official/cv/resnet/src/model_utils/ in the resnet directory.
# Comment out the following code. # if config.train_url: # print("Start to copy output directory") # sync_data(config.output_path, config.train_url) - Modify the train.py file in official/cv/resnet/ in the resnet directory. Perform different script adaptation based on the model file storage type.
- For the EFS-stored model file:
- Add a module to save the dying gasp checkpoint for resumable training.
# ForMA: Save the dying gasp checkpoint. from mindx_elastic.terminating_message import ExceptionCheckpoint
- Add functions.
# ForMA: Add the following functions. def adaption_for_modelarts_efs(): # ForMA: Use the EFS path configured for the new training job on ModelArts as the path to save and load the model. config.output_path = os.environ.get("CHECKPOINT_PATH") config.pre_trained = os.environ.get("CHECKPOINT_PATH") # ForMA: ModelArts automatically synchronizes data from OBS to data_url. config.data_path = config.data_url if not os.path.exists(config.output_path): os.makedirs(config.output_path, 0o755, exist_ok=True) - Add the following code to the train_net function.
# Delete the decorator moxing_wrapper of the train_net function. ... set_parameter() # ForMA: Configure parameters. adaption_for_modelarts_efs() ckpt_param_dict = load_pre_trained_checkpoint() .... if config.save_checkpoint: # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions. ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}] config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, keep_checkpoint_max=config.keep_checkpoint_max, append_info=ckpt_append_info, exception_save=True) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) excetion_ckpt_cb = ExceptionCheckpoint( prefix="resnet", directory=ckpt_save_dir, config=config_ck ) cb += [ckpt_cb, excetion_ckpt_cb] - Modify the load_pre_trained_checkpoint function.
def load_pre_trained_checkpoint(): """ Load checkpoint according to pre_trained path. """ param_dict = None if config.pre_trained: if os.path.isdir(config.pre_trained): ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, "ckpt_0") ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt") ckpt_files = glob.glob(ckpt_pattern) # ForMA: exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt") exception_ckpt_files = glob.glob(exception_ckpt_pattern) ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files if not ckpt_files: logger.warning(f"There is no ckpt file in {ckpt_save_dir}, " f"pre_trained is unsupported.") else: ckpt_files.sort(key=os.path.getmtime, reverse=True) time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}" f" pre trained ckpt model {ckpt_files[0]} loading", flush=True) param_dict = ms.load_checkpoint(ckpt_files[0]) elif os.path.isfile(config.pre_trained): param_dict = ms.load_checkpoint(config.pre_trained) else: print(f"Invalid pre_trained {config.pre_trained} parameter.") return param_dict
- Add a module to save the dying gasp checkpoint for resumable training.
- For the OBS-stored model file:
- Add a module to save the dying gasp checkpoint for resumable training.
# ForMA: Save the dying gasp checkpoint. from mindx_elastic.terminating_message import ExceptionCheckpoint
- Add the following code to the train_net function.
# Add the decorator moxing_wrapper to the train_net function. ... set_parameter() # ForMA: Configure parameters. config.pre_trained = config.checkpoint_url config.output_path = config.train_url ckpt_param_dict = load_pre_trained_checkpoint() .... if config.save_checkpoint: # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions. ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}] config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, keep_checkpoint_max=config.keep_checkpoint_max, append_info=ckpt_append_info, exception_save=True) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) excetion_ckpt_cb = ExceptionCheckpoint( prefix="resnet", directory=ckpt_save_dir, config=config_ck ) cb += [ckpt_cb, excetion_ckpt_cb] - Modify the load_pre_trained_checkpoint function.
def load_pre_trained_checkpoint(): """ Load checkpoint according to pre_trained path. """ param_dict = None if config.pre_trained: if os.path.isdir(config.pre_trained): # ForMA: ckpt_save_dir = os.path.join(config.pre_trained, config.checkpoint_path, "ckpt_0") ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt") ckpt_files = glob.glob(ckpt_pattern) exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt") exception_ckpt_files = glob.glob(exception_ckpt_pattern) ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files if not ckpt_files: logger.warning(f"There is no ckpt file in {ckpt_save_dir}, " f"pre_trained is unsupported.") else: ckpt_files.sort(key=os.path.getmtime, reverse=True) time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}" f" pre trained ckpt model {ckpt_files[0]} loading", flush=True) param_dict = ms.load_checkpoint(ckpt_files[0]) elif os.path.isfile(config.pre_trained): param_dict = ms.load_checkpoint(config.pre_trained) else: print(f"Invalid pre_trained {config.pre_trained} parameter.") return param_dict
- Add a module to save the dying gasp checkpoint for resumable training.
- For the EFS-stored model file:
Pangu-alpha
- Download the pangu-alpha code of the r1.7 branch from the MindSpore code repository as the training code.
- Modify training parameters. Modify the utils.py file in the pangu_alpha directory official/nlp/pangu_alpha/src/. Change the default values of the following parameters or add new parameters.
per_batch_size: 8 save_checkpoint: True device_num: 8 run_type: train strategy_load_ckpt_path: "./strategy.ckpt" strategy_save_ckpt_path: "./strategy.ckpt" # Add this parameter. param_init_type: "fp16" checkpoint_url: "" # Add the path for storing the OBS-based model file.
- Modify the train.py file in official/nlp/pangu_alpha/ in the pangu_alpha directory. Perform different script adaptation based on the model file storage type.
- For the EFS-stored model file:
- Add a module to save the dying gasp checkpoint for resumable training.
# ForMA: Save the dying gasp checkpoint. from mindx_elastic.terminating_message import ExceptionCheckpoint
- Configure the path for saving the model.
# Add the following code to the script main entry: ... set_parse(opt) # ForMA: Use the EFS path configured for the new training job on ModelArts as the path to save and load the model. opt.save_checkpoint_path = os.environ.get("CHECKPOINT_PATH") opt.pre_trained = os.environ.get("CHECKPOINT_PATH") if not os.path.exists(opt.save_checkpoint_path): os.makedirs(opt.save_checkpoint_path, 0o755, exist_ok=True) ... - Modify the run_train function. (If pipeline parallelism is enabled, that is, opt.stage_num > 1, modify the run_train_pipeline function.)
... if args_opt.parallel_mode == "data_parallel": # in avoid of the loop call depth context.set_context(max_call_depth=10000) # ForMA: Comment out the following code. # env variable prepare # group_info_file = os.getenv("GROUP_INFO_FILE") # if group_info_file: # with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile: # outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n") ... # loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved) callback = [TimeMonitor(args_opt.sink_size)] ... if args_opt.pre_trained: restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps, micro_size=micro_batch_interleaved)] add_checkpoint_callback_policy(args_opt, callback, rank) ... - Modify the restore_checkpoint function.
... ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*_breakpoint.ckpt") ckpt_exp_files = glob.glob(ckpt_exp_pattern) # ForMA: Comment out the following code. # ckpt_files = [] # for file in ckpt_all_files: # if file not in ckpt_exp_files: # ckpt_files.append(file) # ForMA: If the dying gasp model file exists, use it. Otherwise, use the model file that is periodically saved. ckpt_files = ckpt_exp_files if ckpt_exp_files else ckpt_all_files ... - Modify the add_checkpoint_callback_policy function.
if args_param.save_checkpoint: # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions. # checkpoint store epoch_num and step_num info ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, integrated_save=False, append_info=ckpt_append_info, exception_save=True ) # save checkpoint into rank directory ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), config=ckpt_config) ckpoint_exp = ExceptionCheckpoint( prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), config=ckpt_config) callback.append(ckpoint_cb) callback.append(ckpoint_exp) - Modify the set_parallel_context function.
def set_parallel_context(args_opt): r"""Set parallel context""" D.init() device_num = D.get_group_size() rank = D.get_rank() print("rank_id is {}, device_num is {}".format(rank, device_num)) if not os.path.exists(args_opt.strategy_load_ckpt_path): args_opt.strategy_load_ckpt_path = "" if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path: args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new" context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num
- Add a module to save the dying gasp checkpoint for resumable training.
- For the OBS-stored model file:
- Add a module to save the dying gasp checkpoint for resumable training.
# ForMA: Save the dying gasp checkpoint. from mindx_elastic.terminating_message import ExceptionCheckpoint
- Configure the path for saving the model.
# Add the following code to the script main entry: ... set_parse(opt) # ForMA: Add checkpoint_url to Training Input on ModelArts. The platform automatically downloads the model file to a local path. opt.save_checkpoint_path = opt.checkpoint_url opt.pre_trained = opt.checkpoint_url ...
- Modify the run_train function. (If pipeline parallelism is enabled, that is, opt.stage_num > 1, modify the run_train_pipeline function.)
... if args_opt.parallel_mode == "data_parallel": # in avoid of the loop call depth context.set_context(max_call_depth=10000) # ForMA: Comment out the following code. # env variable prepare # group_info_file = os.getenv("GROUP_INFO_FILE") # if group_info_file: # with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile: # outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n") ... # loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved) callback = [TimeMonitor(args_opt.sink_size)] ... if args_opt.pre_trained: restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps, micro_size=micro_batch_interleaved)] add_checkpoint_callback_policy(args_opt, callback, rank) ... - Modify the restore_checkpoint function and add the checkpoint_file_sort_key function.
import re def checkpoint_file_sort_key(file_path): # ForMA: When the OBS files are synchronized to the local host, the mtime of each file is the same and cannot be used for file sorting. Use the epoch and step information contained in the file name to sort files in descending order. pattern = re.compile("([0-9]+)_([0-9]+)") _, file_name = os.path.split(file_path) epoch_step_suffix = file_name.split("-")[-1] result = pattern.search(epoch_step_suffix) if result is None: return -os.path.getmtime(file_path) epoch, step = result.groups() return -int(epoch), -int(step) def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): r""" Load checkpoint process. """ # ForMa: Check whether the dying gasp checkpoint file is complete. is_complete = True lastest_exp_ckpt_size_list = [] for dirpath, dirnames, filenames in os.walk(args_param.save_checkpoint_path): if not is_complete: break for subdir in dirnames: candi_dirs = os.path.join(args_param.save_checkpoint_path, subdir) ckpt_exp_pattern = os.path.join(candi_dirs, f"*_breakpoint.ckpt") ckpt_pattern = os.path.join(candi_dirs, f"*.ckpt") ckpt_all_files = glob.glob(ckpt_pattern) ckpt_exp_files = glob.glob(ckpt_exp_pattern) if not ckpt_exp_files: is_complete = False break ckpt_exp_files.sort(key=checkpoint_file_sort_key) lastest_exp_ckpt = ckpt_exp_files[0] lastest_exp_ckpt_size = os.path.getsize(lastest_exp_ckpt) lastest_exp_ckpt_size_list.append(lastest_exp_ckpt_size) if ckpt_all_files and ckpt_exp_files: ckpt_all_files.sort(key=checkpoint_file_sort_key) oldest_ckpt = ckpt_all_files[-1] if lastest_exp_ckpt != oldest_ckpt and lastest_exp_ckpt_size != os.path.getsize(oldest_ckpt): is_complete = False if is_complete: if len(set(lastest_exp_ckpt_size_list)) != 1: is_complete = False print("======start single checkpoint", flush=True) ckpt_name = args_param.ckpt_name_prefix ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*.ckpt") ckpt_all_files = glob.glob(ckpt_pattern) if not ckpt_all_files: print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " f"current ckpt_files found is {ckpt_all_files} " f"with pattern {ckpt_pattern}, so skip the loading.") return ckpt_all_files.sort(key=checkpoint_file_sort_key) ckpt_file = "" for ckpt in ckpt_all_files: if is_complete and ckpt.endswith("breakpoint.ckpt"): ckpt_file = ckpt break if not is_complete and "breakpoint" not in ckpt: ckpt_file = ckpt break if not ckpt_file: print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " f"current ckpt_files found is {ckpt_all_files} " f"with pattern {ckpt_pattern}, so skip the loading.") return time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_file} loading", flush=True) # Load checkpoint files latest file print(f'Start to load from {ckpt_file}') param_dict = load_checkpoint(ckpt_file) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_param.has_trained_step = int(param_dict["step_num"].data.asnumpy()) model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch) load_param_into_net(network, param_dict) - Modify the add_checkpoint_callback_policy function.
if args_param.save_checkpoint: # ForMA: Add a callback to save the dying gasp checkpoint and enable MindSpore's saving function upon exceptions. # checkpoint store epoch_num and step_num info ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, integrated_save=False, append_info=ckpt_append_info, exception_save=True ) # save checkpoint into rank directory ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.train_url, f"rank_{rank_id}"), config=ckpt_config) ckpoint_exp = ExceptionCheckpoint( prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.train_url, f"rank_{rank_id}"), config=ckpt_config) callback.append(ckpoint_cb) callback.append(ckpoint_exp) - Modify the set_parallel_context function.
def set_parallel_context(args_opt): r"""Set parallel context""" D.init() device_num = D.get_group_size() rank = D.get_rank() print("rank_id is {}, device_num is {}".format(rank, device_num)) if not os.path.exists(args_opt.strategy_load_ckpt_path): args_opt.strategy_load_ckpt_path = "" if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path: args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new" context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num
- Add a module to save the dying gasp checkpoint for resumable training.
- For the EFS-stored model file:
Parent topic: ModelArts Scenario