脚本适配

本章节提供了故障恢复脚本适配示例。用户请根据实际情况选择对应的脚本适配示例。

下文中模型示例代码可能与实际版本存在差异,请以实际版本代码为准。

PyTorch的故障恢复示例

  1. 下载PyTorch代码仓中master分支的“ResNet50_ID4149_for_PyTorch”作为训练代码。
  2. 自行准备ResNet-50对应的数据集,使用时请遵守对应规范。
  3. 管理员用户上传数据集到存储节点。

    1. 进入“/data/atlas_dls/public”目录,将数据集上传到任意位置,如“/data/atlas_dls/public/dataset/resnet50/imagenet”
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet# pwd
      回显示例如下:
      1
      /data/atlas_dls/public/dataset/resnet50/imagenet
      
    2. 执行du -sh命令,查看数据集大小。
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet# du -sh
      回显示例如下:
      1
      11G
      

  4. 1中下载的训练代码解压到本地,将解压后的训练代码中“ModelZoo-PyTorch/PyTorch/built-in/cv/classification/ResNet50_ID4149_for_PyTorch”目录上传至环境,如“/data/atlas_dls/public/code/”目录。
  5. 进入“MindXDL-deploy”仓库,根据MindXDL-deploy开源仓版本说明进入版本对应分支,获取“samples/train/resumable-training/fault-rescheduling/withRanktable/pytorch/resnet50”目录中的train_start.sh、utils.sh和rank_table.sh文件,在训练代码中创建“scripts”目录,在管理节点构造成如下的目录结构。

    root@ubuntu:/data/atlas_dls/public/code/ResNet50_ID4149_for_PyTorch/scripts/#
    scripts/
    ├── rank_table.sh
    ├── utils.sh
    └── train_start.sh

  6. “/data/atlas_dls/public/code/ResNet50_ID4149_for_PyTorch”路径下修改main.py代码,修改以下加粗内容,改动内容涉及模型保存和加载的逻辑调整。

    import argparse
    import glob
    import os
    ...
        if args.resume:
            candidate_ckpt_path = ""
            for p in glob.glob(f"./rank*"):
                best_ckpt_path = os.path.join(p, "model_best.pth.tar")
                if os.path.exists(best_ckpt_path):
                    candidate_ckpt_path = best_ckpt_path
                    break
            if candidate_ckpt_path:
                print("[gpu id:", args.gpu, "]", "=> loading checkpoint '{}'".format(candidate_ckpt_path))
                # Map model to be loaded to specified single gpu.
                loc = 'npu:{}'.format(args.gpu)
                checkpoint = torch.load(candidate_ckpt_path, map_location=loc)
                print(f"load checkpoint to : {loc}")
                args.start_epoch = checkpoint['epoch']
                best_acc1 = checkpoint['best_acc1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("[gpu id:", args.gpu, "]", "=> loaded checkpoint '{}' (epoch {})".format(candidate_ckpt_path, checkpoint['epoch']))
            else:
                print("no valid ckpt found to resume.")
    ...
            if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                save_path = f"./rank_{args.rank}"
                if not os.path.exists(save_path):
                    os.makedirs(save_path, exist_ok=True)
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, save_path=save_path)
    ...
    ...
    # 修改原有save_checkpoint函数
    def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_path="./"):
        if is_best:
            target_path = os.path.join(save_path, 'model_best.pth.tar')
            torch.save(state, target_path)
            print(f"save ckpt to {target_path} done. Best epoch for now is :{state['epoch']}")

MindSpore的故障恢复示例

  1. 下载MindSpore代码仓中master分支代码,将“models/official/cv/ResNet”目录重命名为“resnet”并作为训练代码。
  2. 执行以下命令,在管理节点创建代码目录,并上传训练代码到该目录。

    mkdir /data/atlas_dls/code

  3. 进入“MindXDL-deploy”仓库,根据MindXDL-deploy开源仓版本说明进入版本对应分支,获取“samples/train/resumable-training/fault-rescheduling/withRanktable/mindspore/resnet50”目录中的“train_start.sh”“main.sh”文件,结合训练代码中“resnet/scripts”目录,在管理节点构造成如下的目录结构。

    root@ubuntu:/data/atlas_dls/public/code/resnet/scripts/#
    scripts/
    ├── main.sh
     ...
    ├── run_distribute_train.sh
    ├── run_distribute_train_gpu.sh
    └── train_start.sh

  4. 修改“/data/atlas_dls/public/code/resnet/scripts”目录下的“train_start.sh”文件。

    1. “dataset_path”修改为容器内实际的数据集目录。
    2. “config_yaml_path”修改为容器内实际的配置文件路径。
    根据实际情况进行修改,全局配置参数:数据集路径,配置参数文件路径;其他模型适配,请根据实际情况增删参数。
    dataset_path=/job/data/imagenet/train
    config_yaml_path=/job/code/resnet/resnet50_imagenet2012_config.yaml

    train_start.sh脚本通过调用main.sh脚本启动训练任务。在适配其他模型时,请根据其训练启动脚本(本示例为train.py)的使用指导,调整main.sh脚本中的环境变量配置、启动脚本路径、启动脚本参数。

    # main.sh: 针对本示例(Resnet50模型),用户不需要再修改此脚本;其他模型适配,请根据实际情况,增、删或修改环境变量配置,然后修改训练启动脚本路径和对应的参数,即main.sh脚本中Python命令调用的部分。
    # 本例中,单机单卡的Python命令如下:
    python ${ROOT_PATH}/../train.py --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} 
    # 本例中,单机多卡和分布式的命令如下:
    python ${ROOT_PATH}/../train.py --run_distribute=True --device_num=${RANK_SIZE} --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} 

  5. 修改“/data/atlas_dls/public/code/resnet/config/”目录的配置文件“resnet50_imagenet2012_config.yaml”。模型保存和加载设置,图编译保存和加载设置。

    ...
    run_distribute: False
    enable_profiling: False
    data_path: "/cache/data"
    output_dir: "/job/code/output" # 修改checkpoint保存路径,请用户根据实际情况进行修改
    load_path: "/cache/checkpoint_path/"
    device_target: "Ascend"
    checkpoint_path: "./checkpoint/"
    checkpoint_file_path: ""
    ...
    net_name: "resnet50"
    dataset: "imagenet2012"
    device_num: 1
    pre_trained: "/job/code/output/resnet50/imagenet2012/ckpt" # 容器内预训练模型加载路径(支持目录和文件),支持在指定路径下对.ckpt文件进行模糊查找,将搜寻最新的.ckpt文件进行加载,请用户参考训练YAML根据实际情况进行修改
    run_eval: False
    eval_dataset_path: ""
    parameter_server: False
    filter_weight: False
    save_best_ckpt: True
    eval_start_epoch: 40
    ...
    network_dataset: "resnet50_imagenet2012"
    
    
    # 再训练选项 
    save_graphs: False  # 是否开启图编译结果保存
    save_graphs_path: "./graphs" # 图编译结果保存路径
    has_trained_epoch: 0 # 模型预训练的epoch,默认是0
    has_trained_step: 0 # 模型预训练的step,默认是0
    ---
    # 每项配置的帮助说明
    enable_modelarts: "Whether training on modelarts, default: False"
    ...
    batch_size: "Batch size for training and evaluation"
    epoch_size: "Total training epochs."
    checkpoint_path: "The location of the checkpoint file."
    checkpoint_file_path: "The location of the checkpoint file."
    save_graphs: "Whether save graphs during training, default: False."
    save_graphs_path: "Path to save graphs."

  6. resnet代码的启动脚本为“train.py”,检查“train.py”中是否存在保存checkpoint的代码,示例代码如下。

    • 如果存在,则跳过本步骤。
    • 如果不存在,则补充以下保存checkpoint的代码样例,其中所用参数需要用户在配置文件中定义和设置。其他模型适配,请参考如下片段,根据启动脚本具体内容,添加保存checkpoint的代码。如有需要,请参考MindSpore官网教程进行修改。
    ...
        # 模型保存代码
        if config.save_checkpoint:
            ckpt_append_info = [{"epoch_num": 0, "step_num": 0}]
            config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                         keep_checkpoint_max=config.keep_checkpoint_max,
                                         append_info=ckpt_append_info)
            ckpt_cb = ModelCheckpoint(prefix=config.net_name, directory=config.save_ckpt_dir+"_"+str(config.rank_id), config=config_ck)
            cb += [ckpt_cb]
    ...

  7. resnet代码的启动脚本为train.py,检查train.py中是否存在加载checkpoint的代码,如果存在,则执行配置完成,进行下一章节操作;否则执行8
  8. 在train.py中补充加载checkpoint的代码。以下为checkpoint加载样例,其中所用参数需要用户在配置文件中定义和设置。其他模型适配,请参考如下片段,根据启动脚本具体内容,添加加载checkpoint的代码。如有需要,请参考MindSpore官网教程进行修改。

    1. 修改“src/utils.py”,添加读取epoch代码,加载ckpt后,训练日志中将从ckpt保存时刻所处的epoch开始打印。
      ...
      def init_weight(net, cfg):
          """init_weight"""
          if cfg.pre_trained:
              if not os.path.isfile(cfg.pre_trained):
                  cfg.logger.warning("There is not ckpt file: %s", cfg.pre_trained)
              else:
                  param_dict = ms.load_checkpoint(cfg.pre_trained)
                  if cfg.filter_weight:
                      filter_list = [x.name for x in net.end_point.get_parameters()]
                      filter_checkpoint_parameter_by_list(param_dict, filter_list)
                  ms.load_param_into_net(net, param_dict)
                  cfg.start_epoch = int(param_dict.get('epoch_num', ms.Tensor(0, ms.int32)).asnumpy().item())
                  cfg.logger.info("Pre trained ckpt mode: %s loading", cfg.pre_trained)
      ...
    2. 修改train.py,替换原有的init_weight函数,使用_try_to_init_weight尝试加载ckpt文件,避免出现加载到不完整的ckpt,导致训练报错的问题。
      import glob
      ...
      # 找寻pre_trained目录下最新的*.ckpt文件
      def _find_latest_ckpt():
          ckpt_files = glob.glob(config.pre_trained+"*/*.ckpt")
          if ckpt_files:
              ckpt_files.sort(key=os.path.getmtime, reverse=True)
          return ckpt_files
      
      # 尝试加载ckpt文件,尝试次数为INIT_WEIGHT_MAX_ATTEMPTS次
      def _try_to_init_weight(net, config):
          if os.path.isfile(config.pre_trained):
              latest_ckpt = [config.pre_trained]
          else:
              latest_ckpt = _find_latest_ckpt()
      
          if not latest_ckpt:
              config.logger.warning("There is not ckpt file: %s", config.pre_trained)
              return
      
          init_weight_attempts = 0
          INIT_WEIGHT_MAX_ATTEMPTS = 5
          while(latest_ckpt and init_weight_attempts < INIT_WEIGHT_MAX_ATTEMPTS): 
              try:
                  config.pre_trained = latest_ckpt[0]
                  init_weight(net, config)
                  break
              except Exception:
                  config.logger.warning("Pre trained ckpt %s format is incorrect, try to load the last most recent ckpt", config.pre_trained)
                  if latest_ckpt[1:]:
                      latest_ckpt = latest_ckpt[1:]
                      init_weight_attempts+=1
                      continue
                  else:
                      config.logger.error("no more ckpt to load", config.pre_trained)
                      raise ValueError("ckpt format is incorrect, no more ckpt to load, load ckpt failed.")
      
      ...
      @moxing_wrapper()
      def train_net():
          """train net"""
          target = config.device_target
          set_parameter()
          set_output_dir(config)
          config.logger = get_logger(config.log_dir, config.rank_id, config.parameter_server)
          dataset = create_dataset(dataset_path=config.data_path, do_train=True,
                                   batch_size=config.batch_size, train_image_size=config.train_image_size,
                                   eval_image_size=config.eval_image_size, target=target,
                                   distribute=config.run_distribute)
          step_size = dataset.get_dataset_size()
          net = resnet(class_num=config.class_num)
          if config.parameter_server:
              net.set_param_ps()
          # 替换原有的init_weight函数,使用_try_to_init_weight尝试加载ckpt文件,避免加载到不完整的ckpt,导致训练报错
          _try_to_init_weight(net, config)
      
          if config.resume_ckpt:
              resume_param = ms.load_checkpoint(config.resume_ckpt,
                                                choice_func=lambda x: not x.startswith(('learning_rate', 'global_step')))
              config.start_epoch = int(resume_param.get('epoch_num', ms.Tensor(0, ms.int32)).asnumpy().item())
          lr = ms.Tensor(init_lr(step_size=step_size))
      ...

TensorFlow的故障恢复示例

  1. 下载TensorFlow代码仓中master分支中的“ResNet50_ID0360_for_TensorFlow2.X”作为训练代码,请根据该模型代码TensorFlow版本选择训练镜像中的TensorFlow版本包。
  2. 管理员用户上传数据集到存储节点。

    1. 进入“/data/atlas_dls/public”目录,将数据集上传到任意位置,如“/data/atlas_dls/public/dataset/resnet50/imagenet_TF”
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet_TF# pwd
      /data/atlas_dls/public/dataset/resnet50/imagenet_TF
    2. 执行du -sh命令,查看数据集大小。
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet_TF# du -sh
      42G

  3. 在本地解压1中下载的训练代码,将“ModelZoo-TensorFlow-master/TensorFlow2/built-in/cv/image_classification/”下的“ResNet50_ID0360_for_TensorFlow2.X”目录重命名为“ResNet50_for_TensorFlow_2.6_code/”目录。
  4. 进入“MindXDL-deploy”仓库,根据MindXDL-deploy开源仓版本说明进入版本对应分支,获取“samples/train/basic-training/ranktable”目录中的train_start.sh、utils.sh和rank_table.sh文件,在训练代码中创建“scripts”目录,在管理节点构造成如下的目录结构。

    /data/atlas_dls/public/code/ResNet50_for_TensorFlow_2.6_code/
    ├──  scripts
    │   ├──  train_start.sh
    │   ├──  utils.sh
    │   ├──  rank_table.sh
    │    ...

  5. 修改训练代码。补充加载ckpt文件时的日志打印。修改"tensorflow/tf2_common/training/controller.py"。

    class Controller(object):
      """Class that facilitates training and evaluation of models."""
      def __init__(
        ...
        # Restore Model if needed.
        if self.checkpoint_manager is not None:
          model_restored = self._restore_model()
          logging.info("loading checkpoint %s", model_restored)
          if not model_restored and self.checkpoint_manager.checkpoint_interval:
            # If the model is not restored from a checkpoint, save an initial
            # checkpoint.
            ckpt_path = self.checkpoint_manager.save(
                checkpoint_number=self.global_step)
            logging.info("Saved checkpoints in %s", ckpt_path)
        # Create and initialize the interval triggers.
        self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
                                                  self.eval_offset)

Pangu_alpha模型适配示例

  1. 下载MindSpore代码仓中master分支代码,将“models/official/nlp/Pangu_alpha”目录重命名为“pangu_alpha”并作为训练代码,使用该版本模型脚本需保证在镜像中安装的MindSpore版本不低于2.0.0,并且安装mindformers组件。
  2. 执行以下命令,在管理节点创建代码目录。

    mkdir /data/atlas_dls/code

  3. 进入“MindXDL-deploy”仓库,根据MindXDL-deploy开源仓版本说明进入版本对应分支,获取“samples/train/resumable-training/fault-rescheduling/withRanktable/mindspore/pangu_alpha”目录中的“train_start.sh”“main.sh”文件,结合训练代码中“pangu_alpha/scripts”目录,在管理节点构造成如下的目录结构。对于盘古百亿模型,使用“samples/train/resumable-training/fault-rescheduling/withRanktable/mindspore/pangu_alpha_13B”目录中的对应文件。

    root@ubuntu:/data/atlas_dls/code/pangu_alpha/scripts/# 
    scripts/
    ├── main.sh
    ├── run_cluster_export.sh
    ├── run_distribute_eval_gpu.sh
    ├── run_distribute_eval.sh
     ...
    ├── run_distribute_train.sh
    ├── run_standalone_eval.sh
    ├── run_standalone_export.sh
    ├── run_standalone_predict.sh
    └── train_start.sh

  4. 修改“/data/atlas_dls/code/pangu_alpha/scripts”目录下的“train_start.sh”文件,将“dataset”修改为容器内实际的数据集目录。

    ...
    # 训练数据集路径,根据实际情况修改
    # 安全提示,涉及对路径和输入参数的校验
    dataset="/job/data/train_data"
    
    # 设置训练环境变量
    set_env
    
    # 单节点训练场景
    if [[ "$server_count" == "1" ]]; then
        server_id=0
        if [ ${device_count} -lt 8 ]; then
            echo "Less than 8 card training is not supported for pangu alpha model." | tee log
        fi
        if [ ${device_count} -eq 8 ]; then
            bash main.sh ${device_count} ${server_count} ${RANK_TABLE_FILE} ${server_id} ${dataset}
        fi
    
    # 分布式训练场景
    else
        server_id=$(get_server_id)
        if [ $? -eq 1 ];then
            echo "get server id failed."
            exit 1
        fi
        echo "server id is: "${server_id}
        bash main.sh ${device_count} ${server_count} ${RANK_TABLE_FILE} ${server_id} ${dataset}
    

  5. 百亿及以下模型可跳过该步骤。训练千亿模型时,期望恢复时间小于5min,需要进行额外脚本适配。下文以MindSpore代码仓中pangu_alpha的master分支为例(已完成弹性训练任务配置和脚本适配)。

    1. 修改“src/pangu_alpha_config.py”文件,主要涉及三个参数的更改:args_opt.num_layers、args_opt.stage_num、args_opt.micro_size。
      def set_parse_200B(args_opt):
          """
              Set config for 200B mode
          """
          args_opt.embedding_size = 16384
          args_opt.num_layers = 32                 # 模型层次
          args_opt.num_heads = 128
          if args_opt.per_batch_size == 0:
              args_opt.per_batch_size = 1
          args_opt.word_emb_dp = 0
          if args_opt.run_type == "train":
              args_opt.start_lr = 6e-5
              args_opt.end_lr = 6e-6
             args_opt.stage_num = 8               # 流水线阶段的数量
             args_opt.micro_size = 16             # 流水线并行模式下的微批次大小,其取值应大于args_opt.stage_num
              args_opt.op_level_model_parallel_num = 16
              if args_opt.optimizer_shard = 1:
                  args_opt.op_level_model_parallel_num = 8
          elif args_opt.run_type == "predict":
              args_opt.stage_num = 4
              args_opt.micro_size = 1
              args_opt.op_level_model_parallel_num = 16
              if args_opt.optimizer_shard == 1:
                  args_opt.op_level_model_parallel_num = 8
    2. 此外,需要指定或者直接修改“src/utils.py”中的“micro_batch_interleaved”参数为“1”(请参考“train.py”脚本的“run_train_pipeline”函数中“stage_device_num”、“data_parallel_num”、“batch_size”、“micro_batch_interleaved”之间的计算关系。最终结果需要满足“PanguAlphaConfig”的“batch_size”值是“TransformerOpParallelConfig”的“data_parallel”的倍数)。

  6. pangu代码的启动脚本为train.py,检查train.py中是否存在保存checkpoint的代码,代码示例如下。

    • 如果存在,则跳过本步骤。
    • 如果不存在,则补充以下保存checkpoint的代码样例,其中所用参数可参照9在配置文件“src/utils.py”中定义和设置。
    ...
    
        # 保存checkpoint的代码调用
        add_checkpoint_callback_policy(args_opt, callback, rank)
    ...
    # 保存checkpoint代码定义
    def add_checkpoint_callback_policy(args_param, callback, rank_id):
        r"""
        Add checkpoint policy to callback.
        """
        # 安全提示,涉及对路径和输入参数的校验
        if args_param.save_checkpoint:
            # checkpoint保存epoch_num和step_num info信息
            ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
            ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                           keep_checkpoint_max=args_param.keep_checkpoint_max,
                                           integrated_save=False,
                                           append_info=ckpt_append_info
                                           )
    
    
            ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                         directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
                                         config=ckpt_config)
    
    
            callback.append(ckpoint_cb)
    ...

  7. pangu代码的启动脚本为train.py,检查train.py中是否存在加载checkpoint的代码,如果存在,则执行10;否则执行8
  8. 在train.py中补充加载checkpoint的代码。以下为checkpoint加载样例,存在部分加载checkpoint的代码,需要添加弹性训练特性相关checkpoint加载代码,其中所用参数可参照9在配置文件“src/utils.py”中定义和设置。

    ...
    # 如果运行的模型没有开启pipeline并行,则修改在以下函数
    def set_parallel_context(args_opt):
    # 如果运行的模型开启pipeline并行,则修改在以下函数
    # 安全提示,涉及对路径和输入参数的校验
    def set_pipeline_parallel_context(args_opt):
    # 在mindspore.set_auto_parallel_context前添加以下代码,请参考MindSpore文档分布式并行接口说明对set_auto_parallel_context参数的使用说明
            
             
            # 弹性训练中增加内容
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
    
            # 弹性训练中增加内容,strategy_ckpt_save_file_path参数可以根据容器内路径指定
            strategy_ckpt_save_file_path = '/job/data/code/fault_torlence/pangu_alpha/strategy.ckpt' 
            if args_opt.strategy_load_ckpt_path == strategy_ckpt_save_file_path:
                 strategy_ckpt_save_file_path = '/job/data/code/fault_torlence/pangu_alpha/strategy_new.ckpt'
     
            # 将strategy_ckpt_save_file='strategy.ckpt'修改成strategy_ckpt_save_file=strategy_ckpt_save_file_path如果set_auto_parallel_context里没有指定strategy_ckpt_save_file参数,则需要手动添加strategy_ckpt_save_file=strategy_ckpt_save_file_path,如下粗体所示
            mindspore.set_auto_parallel_context(
                parallel_mode=args_opt.parallel_mode, gradients_mean=False, search_mode=args_opt.search_mode,
                full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
                device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
                pipeline_stages=args_opt.stage_num, enable_alltoall=bool(args_opt.enable_alltoall),
                strategy_ckpt_save_file=strategy_ckpt_save_file_path)
           
    ...
    ...
    # checkpoint加载代码定义
    # 安全提示,涉及对路径和输入参数的校验
    def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
        r"""
        Load checkpoint process.
        """
        print("======start single checkpoint", flush=True)
        ckpt_name = args_param.ckpt_name_prefix
        # 为了文档简洁易读, 此处省略了命令行参数save_checkpoint_path和ckpt_name的校验, 请用户自行添加相关校验
        ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()),
                                    f"{ckpt_name}*.ckpt")
        ckpt_all_files = glob.glob(ckpt_pattern)
        if not ckpt_all_files:
            print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                  f"current ckpt_files found is {ckpt_files} "
                  f"with pattern {ckpt_pattern}, so skip the loading.")
            return
        ckpt_exp_pattern = os.path.join(
            args_param.save_checkpoint_path,
            "rank_{}".format(D.get_rank()),
            f"{ckpt_name}*_breakpoint.ckpt",
        )
        ckpt_exp_files = glob.glob(ckpt_exp_pattern)
        ckpt_files = []
        for file in ckpt_all_files:
            if file not in ckpt_exp_files:
                ckpt_files.append(file)
    
        if not ckpt_files:
            print(
                f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                f"current ckpt_files found is {ckpt_files} "
                f"with pattern {ckpt_pattern}, so skip the loading."
            )
            return
        ckpt_files.sort(key=os.path.getmtime, reverse=True)
        time_stamp = datetime.datetime.now()
        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading",
              flush=True)
        # 加载checkpoint最新文件
        print(f'Start to load from {ckpt_files[0]}')
        param_dict = load_checkpoint(ckpt_files[0])
        if param_dict.get("epoch_num") and param_dict.get("step_num"):
            args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
            args_param.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
        model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
        load_param_into_net(network, param_dict)
    ...

  9. 修改“src/utils.py”文件中的参数。

    ...
        opt.add_argument("--vocab_size",
                          type=int,
                          default=50304, # 根据训练数据集进行修改,此处已修改为样例数据集的取值
                          help="vocabulary size, default is 40000.")
    ...
        opt.add_argument("--data_column_name",
                         type=str,
                         default="text", # 根据数据集定义的字段进行修改,此处已修改为样例数据集的取值
                         help="Column name of datasets")
    ...
        parser.add_argument("--strategy_load_ckpt_path",
                            type=str,
                            default="/job/data/code/fault_torlence/pangu_alpha/strategy/strategy.ckpt", # 弹性训练中,根据用户习惯指定容器内路径,且路径不会被训练覆盖。
                            help="The training prallel strategy for the model.")
        parser.add_argument("--tokenizer_path",
                            type=str,
                            default="./tokenizer_path",
                            help="The path where stores vocab and vocab model file")
    ...
    def add_retrain_params(opt):
        """
        Add parameters about retrain.
        """
        opt.add_argument("--pre_trained",
                         type=str,
                         default="/job/data/code/fault_torlence/pangu_alpha/8p", # 指定预训练模型路径,
                         help="Pretrained checkpoint path.")
        opt.add_argument("--save_checkpoint_path",  
                         type=str,
                         default="/job/data/code/fault_torlence/pangu_alpha/8p",   # 指定模型保存路径
                         help="Save checkpoint path.")
        opt.add_argument("--keep_checkpoint_max", # 指定模型保存策略:最大数量
                         type=int,
                         default=1,
                         help="Max checkpoint save number.")
        opt.add_argument("--save_checkpoint_steps", # 指定模型保存策略:保存间隔
                         type=int,
                         default=20,
                         help="Save checkpoint step number.")
        opt.add_argument("--save_checkpoint", # 指定当次训练是否保存模型
                         type=ast.literal_eval,
                         default=True,
                         help="Whether save checkpoint in local disk.")
        opt.add_argument("--ckpt_name_prefix", # 指定模型保存策略:文件名前缀
                         type=str,
                         default="pangu",
                         help="Saving checkpoint name prefix.")
    ...

  10. “/data/atlas_dls/code/pangu_alpha”目录下构建空文件“group_info_env”

    root@ubuntu:/data/atlas_dls/code/pangu_alpha/# 
    pangu_alpha/
    ├── README.md
    ├── README_CN.md
    ├── group_info_env
     ...
    ├── scripts
    ├── serving_increment
    ├── src
    ├── tasks.py
    └── train.py

  11. 修改train.py文件中的“group_info_env”路径。

    ...
        # env variable prepare
        group_info_file = os.getenv("GROUP_INFO_FILE")
        if group_info_file:
            with open(os.path.expanduser("/job/code/group_info_env"), "a") as outfile:
                outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")
    ...