本节以Resnet50和Pangu-alpha模型为例,展示如何在ModelArts平台+MindSpore框架下开启断点续训功能。受限于ModelArts平台当前版本设计,大模型场景下发生故障时,若通过OBS保存临终checkpoint,可能存在文件保存失败或不完整的情况。在该情况下,会加载周期性checkpoint继续训练。该问题的解决方案可以参考pangu-alpha脚本适配步骤3中的2.d部分。
# 修改get_config函数,为config_path参数添加如下默认值 parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../config/resnet50_cifar10_config.yaml"), help="Config file path")
enable_modelarts: True # 在ModelArts平台下执行训练任务 run_distribute:True # 多节点训练模型 device_num: 8 # 单节点使用的芯片数量 checkpoint_path: "checkpoint/" # 模型保存子目录
# 对以下代码进行注释 # if config.train_url: # print("Start to copy output directory") # sync_data(config.output_path, config.train_url)
# ForMA: 引入临终checkpoint保存 from mindx_elastic.terminating_message import ExceptionCheckpoint
# ForMA: 新建如下函数 def adaption_for_modelarts_efs(): # ForMA: 使用ModelArts新建训练任务所配置的efs路径作为模型保存与加载路径 config.output_path = os.environ.get("CHECKPOINT_PATH") config.pre_trained = os.environ.get("CHECKPOINT_PATH") # ForMA: ModelArts会自动将数据从obs同步到data_url config.data_path = config.data_url if not os.path.exists(config.output_path): os.makedirs(config.output_path, 0o755, exist_ok=True)
# 删除train_net函数的装饰器“moxing_wrapper” … set_parameter() # ForMA: 参数配置 adaption_for_modelarts_efs() ckpt_param_dict = load_pre_trained_checkpoint() …. if config.save_checkpoint: # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能 ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}] config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, keep_checkpoint_max=config.keep_checkpoint_max, append_info=ckpt_append_info, exception_save=True) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) excetion_ckpt_cb = ExceptionCheckpoint( prefix="resnet", directory=ckpt_save_dir, config=config_ck ) cb += [ckpt_cb, excetion_ckpt_cb]
def load_pre_trained_checkpoint(): """ Load checkpoint according to pre_trained path. """ param_dict = None if config.pre_trained: if os.path.isdir(config.pre_trained): ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, "ckpt_0") ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt") ckpt_files = glob.glob(ckpt_pattern) # ForMA: exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt") exception_ckpt_files = glob.glob(exception_ckpt_pattern) ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files if not ckpt_files: logger.warning(f"There is no ckpt file in {ckpt_save_dir}, " f"pre_trained is unsupported.") else: ckpt_files.sort(key=os.path.getmtime, reverse=True) time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}" f" pre trained ckpt model {ckpt_files[0]} loading", flush=True) param_dict = ms.load_checkpoint(ckpt_files[0]) elif os.path.isfile(config.pre_trained): param_dict = ms.load_checkpoint(config.pre_trained) else: print(f"Invalid pre_trained {config.pre_trained} parameter.") return param_dict
# ForMA: 引入临终checkpoint保存 from mindx_elastic.terminating_message import ExceptionCheckpoint
# 为train_net函数添加装饰器“moxing_wrapper” … set_parameter() # ForMA: 参数配置 config.pre_trained = config.checkpoint_url config.output_path = config.train_url ckpt_param_dict = load_pre_trained_checkpoint() …. if config.save_checkpoint: # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能 ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}] config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, keep_checkpoint_max=config.keep_checkpoint_max, append_info=ckpt_append_info, exception_save=True) ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck) excetion_ckpt_cb = ExceptionCheckpoint( prefix="resnet", directory=ckpt_save_dir, config=config_ck ) cb += [ckpt_cb, excetion_ckpt_cb]
def load_pre_trained_checkpoint(): """ Load checkpoint according to pre_trained path. """ param_dict = None if config.pre_trained: if os.path.isdir(config.pre_trained): # ForMA: ckpt_save_dir = os.path.join(config.pre_trained, config.checkpoint_path, "ckpt_0") ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt") ckpt_files = glob.glob(ckpt_pattern) exception_ckpt_pattern = os.path.join(ckpt_save_dir, "*_breakpoint.ckpt") exception_ckpt_files = glob.glob(exception_ckpt_pattern) ckpt_files = exception_ckpt_files if exception_ckpt_files else ckpt_files if not ckpt_files: logger.warning(f"There is no ckpt file in {ckpt_save_dir}, " f"pre_trained is unsupported.") else: ckpt_files.sort(key=os.path.getmtime, reverse=True) time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}" f" pre trained ckpt model {ckpt_files[0]} loading", flush=True) param_dict = ms.load_checkpoint(ckpt_files[0]) elif os.path.isfile(config.pre_trained): param_dict = ms.load_checkpoint(config.pre_trained) else: print(f"Invalid pre_trained {config.pre_trained} parameter.") return param_dict
per_batch_size: 8 save_checkpoint: True device_num: 8 run_type: train strategy_load_ckpt_path: "./strategy.ckpt" strategy_save_ckpt_path: "./strategy.ckpt" # 新增该参数 param_init_type: "fp16" checkpoint_url: "" # 新增该参数 obs存储的模型文件路径
# ForMA: 引入临终checkpoint保存 from mindx_elastic.terminating_message import ExceptionCheckpoint
# 脚本main入口添加如下代码 … set_parse(opt) # ForMA: 使用ModelArts新建训练任务所配置的efs路径作为模型保存与加载路径 opt.save_checkpoint_path = os.environ.get("CHECKPOINT_PATH") opt.pre_trained = os.environ.get("CHECKPOINT_PATH") if not os.path.exists(opt.save_checkpoint_path): os.makedirs(opt.save_checkpoint_path, 0o755, exist_ok=True) …
… if args_opt.parallel_mode == "data_parallel": # in avoid of the loop call depth context.set_context(max_call_depth=10000) # ForMA:注释如下代码 # env variable prepare # group_info_file = os.getenv("GROUP_INFO_FILE") # if group_info_file: # with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile: # outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n") … # loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved) callback = [TimeMonitor(args_opt.sink_size)] … if args_opt.pre_trained: restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps, micro_size=micro_batch_interleaved)] add_checkpoint_callback_policy(args_opt, callback, rank) …
… ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*_breakpoint.ckpt") ckpt_exp_files = glob.glob(ckpt_exp_pattern) # ForMA: 注释如下代码 # ckpt_files = [] # for file in ckpt_all_files: # if file not in ckpt_exp_files: # ckpt_files.append(file) # ForMA: 如果存在临终模型文件,则使用临终模型文件,否则使用周期性保存模型文件 ckpt_files = ckpt_exp_files if ckpt_exp_files else ckpt_all_files …
if args_param.save_checkpoint: # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能 # checkpoint store epoch_num and step_num info ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, integrated_save=False, append_info=ckpt_append_info, exception_save=True ) # save checkpoint into rank directory ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), config=ckpt_config) ckpoint_exp = ExceptionCheckpoint( prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"), config=ckpt_config) callback.append(ckpoint_cb) callback.append(ckpoint_exp)
def set_parallel_context(args_opt): r"""Set parallel context""" D.init() device_num = D.get_group_size() rank = D.get_rank() print("rank_id is {}, device_num is {}".format(rank, device_num)) if not os.path.exists(args_opt.strategy_load_ckpt_path): args_opt.strategy_load_ckpt_path = "" if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path: args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new" context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num
# ForMA: 引入临终checkpoint保存 from mindx_elastic.terminating_message import ExceptionCheckpoint
# 脚本main入口添加如下代码 … set_parse(opt) # ForMA: ModelArts添加“训练输入”参数checkpoint_url,平台会自动下载模型文件至本地路径 opt.save_checkpoint_path = opt.checkpoint_url opt.pre_trained = opt.checkpoint_url …
… if args_opt.parallel_mode == "data_parallel": # in avoid of the loop call depth context.set_context(max_call_depth=10000) # ForMA:注释如下代码 # env variable prepare # group_info_file = os.getenv("GROUP_INFO_FILE") # if group_info_file: # with open(os.path.expanduser("job/code/group_info_env"), "a") as outfile: # outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n") … # loss_callback = LossCallBack(step_per_epoch, rank, 0, 0, micro_size=micro_batch_interleaved) callback = [TimeMonitor(args_opt.sink_size)] … if args_opt.pre_trained: restore_checkpoint(args_opt, args_opt.sink_size, ds, model, pangu_alpha_with_grads, epoch=actual_epoch_num) callback += [LossCallBack(step_per_epoch, rank, args_opt.has_trained_epoches, args_opt.has_trained_steps, micro_size=micro_batch_interleaved)] add_checkpoint_callback_policy(args_opt, callback, rank) …
import re def checkpoint_file_sort_key(file_path): # ForMA:obs文件同步至本地,文件mtime是相同的,无法用于文件排序。使用文件名包含的epoch,step信息降序排序 pattern = re.compile("([0-9]+)_([0-9]+)") _, file_name = os.path.split(file_path) epoch_step_suffix = file_name.split("-")[-1] result = pattern.search(epoch_step_suffix) if result is None: return -os.path.getmtime(file_path) epoch, step = result.groups() return -int(epoch), -int(step) def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch): r""" Load checkpoint process. """ # ForMa: 检查临终checkpoint文件是否完整 is_complete = True lastest_exp_ckpt_size_list = [] for dirpath, dirnames, filenames in os.walk(args_param.save_checkpoint_path): if not is_complete: break for subdir in dirnames: candi_dirs = os.path.join(args_param.save_checkpoint_path, subdir) ckpt_exp_pattern = os.path.join(candi_dirs, f"*_breakpoint.ckpt") ckpt_pattern = os.path.join(candi_dirs, f"*.ckpt") ckpt_all_files = glob.glob(ckpt_pattern) ckpt_exp_files = glob.glob(ckpt_exp_pattern) if not ckpt_exp_files: is_complete = False break ckpt_exp_files.sort(key=checkpoint_file_sort_key) lastest_exp_ckpt = ckpt_exp_files[0] lastest_exp_ckpt_size = os.path.getsize(lastest_exp_ckpt) lastest_exp_ckpt_size_list.append(lastest_exp_ckpt_size) if ckpt_all_files and ckpt_exp_files: ckpt_all_files.sort(key=checkpoint_file_sort_key) oldest_ckpt = ckpt_all_files[-1] if lastest_exp_ckpt != oldest_ckpt and lastest_exp_ckpt_size != os.path.getsize(oldest_ckpt): is_complete = False if is_complete: if len(set(lastest_exp_ckpt_size_list)) != 1: is_complete = False print("======start single checkpoint", flush=True) ckpt_name = args_param.ckpt_name_prefix ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()), f"{ckpt_name}*.ckpt") ckpt_all_files = glob.glob(ckpt_pattern) if not ckpt_all_files: print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " f"current ckpt_files found is {ckpt_all_files} " f"with pattern {ckpt_pattern}, so skip the loading.") return ckpt_all_files.sort(key=checkpoint_file_sort_key) ckpt_file = "" for ckpt in ckpt_all_files: if is_complete and ckpt.endswith("breakpoint.ckpt"): ckpt_file = ckpt break if not is_complete and "breakpoint" not in ckpt: ckpt_file = ckpt break if not ckpt_file: print(f"There is no ckpt file in {args_param.save_checkpoint_path}, " f"current ckpt_files found is {ckpt_all_files} " f"with pattern {ckpt_pattern}, so skip the loading.") return time_stamp = datetime.datetime.now() print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_file} loading", flush=True) # Load checkpoint files latest file print(f'Start to load from {ckpt_file}') param_dict = load_checkpoint(ckpt_file) if param_dict.get("epoch_num") and param_dict.get("step_num"): args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy()) args_param.has_trained_step = int(param_dict["step_num"].data.asnumpy()) model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch) load_param_into_net(network, param_dict)
if args_param.save_checkpoint: # ForMA: 添加临终checkpoint保存回调,开启mindspore异常保存功能 # checkpoint store epoch_num and step_num info ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}] ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps, keep_checkpoint_max=args_param.keep_checkpoint_max, integrated_save=False, append_info=ckpt_append_info, exception_save=True ) # save checkpoint into rank directory ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.train_url, f"rank_{rank_id}"), config=ckpt_config) ckpoint_exp = ExceptionCheckpoint( prefix=args_param.ckpt_name_prefix + str(rank_id), directory=os.path.join(args_param.train_url, f"rank_{rank_id}"), config=ckpt_config) callback.append(ckpoint_cb) callback.append(ckpoint_exp)
def set_parallel_context(args_opt): r"""Set parallel context""" D.init() device_num = D.get_group_size() rank = D.get_rank() print("rank_id is {}, device_num is {}".format(rank, device_num)) if not os.path.exists(args_opt.strategy_load_ckpt_path): args_opt.strategy_load_ckpt_path = "" if args_opt.strategy_load_ckpt_path == args_opt.strategy_save_ckpt_path: args_opt.strategy_save_ckpt_path = f"{args_opt.strategy_save_ckpt_path}_new" context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path, enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file=args_opt.strategy_save_ckpt_path) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() return rank, device_num