脚本适配

适配说明

本章节提供了故障恢复、临终遗言和混合并行模型的脚本适配示例。用户请根据实际情况选择对应的脚本适配示例,示例结构如下所示。部分示例中所用的任务yaml参数说明可参见yaml参数说明

基于Tensorflow的故障恢复代码适配示例

  1. 下载TensorFlow代码仓中master分支中的“ResNet50_ID0360_for_TensorFlow2.X”作为训练代码,请根据该模型代码TensorFlow版本选择训练镜像中的TensorFlow版本包。
  2. 管理员用户上传数据集到存储节点。

    1. 进入“/data/atlas_dls/public”目录,将数据集上传到任意位置,如“/data/atlas_dls/public/dataset/resnet50/imagenet_TF”
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet_TF# pwd
      /data/atlas_dls/public/dataset/resnet50/imagenet_TF
    2. 执行du -sh命令,查看数据集大小。
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet_TF# du -sh
      42G

  3. 在本地解压1中下载的训练代码,将“ModelZoo-TensorFlow-master/TensorFlow2/built-in/cv/image_classification/”下的“ResNet50_ID0360_for_TensorFlow2.X”目录重命名为“ResNet50_for_TensorFlow_2.6_code/”目录。
  4. 进入“MindXDL-deploy”仓库,选择“5.0.RC1”分支。获取“samples/train”目录中的“train_start.sh”“rank_table.sh”“utils.sh”文件,结合3中的“ResNet50_for_TensorFlow_2.6_code”目录,在host的“/data/atlas_dls/public/code”路径下,构造如下的目录结构。

    /data/atlas_dls/public/code/ResNet50_for_TensorFlow_2.6_code/
    ├──  scripts
    │   ├──  train_start.sh
    │   ├──  utils.sh
    │   ├──  rank_table.sh
    │    ...
    │        ...
    ├──  tensorflow
    │   ├──  resnet_ctl_imagenet_main.py
    │   ├──  resnet_model.py
    │   ├──  resnet_runnable.py
    │    ...
    │        ...
    ├──  benchmark.sh
    ├──  modelzoo_level.txt
     ...
    └──  requirements.txt

  5. 创建yaml文件,该yaml会作为任务启动的yaml。参考tensorflow的yaml模板,启动命令替换为:

    ...
    command:
    - "/bin/bash"
    - "-c"
    - "cd /job/code/ResNet50_for_TensorFlow_2.6_code/scripts;chmod +x train_start.sh;bash train_start.sh /job/code/ResNet50_for_TensorFlow_2.6_code/ /job/output/ tensorflow/resnet_ctl_imagenet_main.py --data_dir=/job/data/resnet50/imagenet_TF/  --distribution_strategy=one_device --use_tf_while_loop=true  --steps_per_loop=1 --enable_checkpoint_and_export ..." # 此处省略部分参数
    ...

基于Pytorch的故障恢复代码适配示例

  1. 下载PyTorch代码仓中master分支的“ResNet50_for_PyTorch”作为训练代码。
  2. 自行准备ResNet-50对应的数据集,使用时请遵守对应规范。
  3. 管理员用户上传数据集到存储节点。

    1. 进入“/data/atlas_dls/public”目录,将数据集上传到任意位置,如“/data/atlas_dls/public/dataset/resnet50/imagenet”
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet# pwd
      /data/atlas_dls/public/dataset/resnet50/imagenet
    2. 执行du -sh命令,查看数据集大小。
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet# du -sh
      11G 

  4. 1中下载的训练代码解压到本地,将解压后的训练代码中“ModelZoo-PyTorch/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch”目录重命名为“ResNet50_for_PyTorch_1.8_code/”目录。
  5. 进入“MindXDL-deploy”仓库,选择“5.0.RC1”分支,获取“samples/pytorch/resnet50”目录中的“train_start.sh”“utils.sh”“rank_table.sh”文件,在训练代码中创建“scripts”目录,在管理节点构造成如下的目录结构。

    root@ubuntu:/data/atlas_dls/public/code/ResNet50_for_PyTorch_1.8_code/scripts/#
    scripts/
    ├── rank_table.sh
    ├── utils.sh
    └── train_start.sh

  6. 修改训练代码。分为单卡场景和分布式场景。

    • 如涉及单卡训练,需修改训练代码目录下的“pytorch_resnet50_apex.py”文件,改动内容涉及模型保存和加载的逻辑调整。
      import argparse
      import glob
      import os
      ...
          if args.resume:
              candidate_ckpt_path = ""
              for p in glob.glob(f"./rank*"):
                  best_ckpt_path = os.path.join(p, "model_best.pth.tar")
                  if os.path.exists(best_ckpt_path):
                      candidate_ckpt_path = best_ckpt_path
                      break
      
              if candidate_ckpt_path:
                  print("[gpu id:", args.gpu, "]", "=> loading checkpoint '{}'".format(candidate_ckpt_path))
                  loc = ""
                  if args.gpu is None:
                      checkpoint = torch.load(candidate_ckpt_path)
                  else:
                      # Map model to be loaded to specified single gpu.
                      if args.device == 'npu':
                          loc = 'npu:{}'.format(args.gpu)
                      else:
                          loc = 'cuda:{}'.format(args.gpu)
                      checkpoint = torch.load(candidate_ckpt_path, map_location=loc)
                  print(f"load checkpoint to : {loc}")
      
                  args.start_epoch = checkpoint['epoch']
                  best_acc1 = checkpoint['best_acc1']
                  if args.gpu is not None:
                      # best_acc1 may be from a checkpoint from a different GPU
                      best_acc1 = best_acc1.to(args.gpu)
      
                  model.load_state_dict(checkpoint['state_dict'])
                  optimizer.load_state_dict(checkpoint['optimizer'])
                  print("[gpu id:", args.gpu, "]",
                        "=> loaded checkpoint '{}' (epoch {})".format(candidate_ckpt_path, checkpoint['epoch']))
              else:
                  print("no valid ckpt found to resume.")
          cudnn.benchmark = True
      ...
              # remember best acc@1 and save checkpoint
              is_best = acc1 > best_acc1
              best_acc1 = max(acc1, best_acc1)  
              file_name = "checkpoint_npu{}".format(args.npu)
      
              save_path = f"./rank_{args.rank}"
             if not os.path.exists(save_path):
                  os.makedirs(save_path, exist_ok=True)
      
              modeltmp = model.cpu()
              save_checkpoint({
                  'epoch': epoch + 1,
                  'arch': args.arch,
                  'state_dict': modeltmp.state_dict(),
                  'best_acc1': best_acc1,
                  'optimizer' : optimizer.state_dict(),
              }, is_best, save_path=save_path)
              modeltmp.to(CALCULATE_DEVICE)
      ...
      # 修改原有save_checkpoint函数
      def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_path="./"):
          if is_best:
              target_path = os.path.join(save_path, 'model_best.pth.tar')
              torch.save(state, target_path)
              print(f"save ckpt to {target_path} done. Best epoch for now is :{state['epoch']}")
    • 如涉及分布式训练,需修改“DistributedResnet50”目录下的“main_apex_d76_npu.py”文件,改动内容涉及模型保存和加载的逻辑调整。
      import argparse
      import glob
      import os
      ...
      if args.resume:
          candidate_ckpt_path = ""
          for p in glob.glob(f"./rank*"):
              best_ckpt_path = os.path.join(p, "model_best.pth.tar")
              if os.path.exists(best_ckpt_path):
                  candidate_ckpt_path = best_ckpt_path
                  break
      
          if candidate_ckpt_path:
              print("[gpu id:", args.gpu, "]", "=> loading checkpoint '{}'".format(candidate_ckpt_path))
              loc = ""
              if args.gpu is None:
                  checkpoint = torch.load(candidate_ckpt_path)
              else:
                  # Map model to be loaded to specified single gpu.
                  if args.device == 'npu':
                      loc = 'npu:{}'.format(args.gpu)
                  else:
                      loc = 'cuda:{}'.format(args.gpu)
                  checkpoint = torch.load(candidate_ckpt_path, map_location=loc)
              print(f"load checkpoint to : {loc}")
      
              args.start_epoch = checkpoint['epoch']
              best_acc1 = checkpoint['best_acc1']
              if args.gpu is not None:
                  # best_acc1 may be from a checkpoint from a different GPU
                  best_acc1 = best_acc1.to(args.gpu)
      
              model.load_state_dict(checkpoint['state_dict'])
              optimizer.load_state_dict(checkpoint['optimizer'])
              print("[gpu id:", args.gpu, "]",
                    "=> loaded checkpoint '{}' (epoch {})".format(candidate_ckpt_path, checkpoint['epoch']))
          else:
              print("no valid ckpt found to resume.")
      cudnn.benchmark = True
      ...
      if args.device == 'gpu':
          if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                      and args.rank % ngpus_per_node == 0):
              save_checkpoint({
                  'epoch': epoch + 1,
                  'arch': args.arch,
                  'state_dict': model.state_dict(),
                  'best_acc1': best_acc1,
                  'optimizer': optimizer.state_dict(),
              }, is_best)
      elif args.device == 'npu':
          # 每个节点的“0卡”保存模型文件;当ACC提升时才保存模型
          if args.rank % ngpus_per_node == 0:
              save_path = f"./rank_{args.rank}"
              if not os.path.exists(save_path):
                  os.makedirs(save_path, exist_ok=True)
      
              modeltmp = model.cpu()
              save_checkpoint({
                  'epoch': epoch + 1,
                  'arch': args.arch,
                  'state_dict': modeltmp.state_dict(),
                  'best_acc1': best_acc1,
                  'optimizer': optimizer.state_dict(),
              }, is_best, save_path=save_path)
      
              loc = 'npu:{}'.format(args.gpu)
              modeltmp.to(loc)
      ...
      # 修改原有save_checkpoint函数
      def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_path="./"):
          if is_best:
              target_path = os.path.join(save_path, 'model_best.pth.tar')
              torch.save(state, target_path)
              print(f"save ckpt to {target_path} done. Best epoch for now is :{state['epoch']}")

  7. 创建yaml文件,该yaml会作为任务启动的yaml。参考PyTorch的yaml模板,启动命令添加--resume参数。

    ...
          command:
          - "/bin/bash"
          - "-c"
          - "cd /job/code/ResNet50_for_PyTorch_1.8_code/scripts;chmod +x train_start.sh;bash train_start.sh /job/code/ResNet50_for_PyTorch_1.8_code/ /job/output/ --data=/job/data/resnet50/imagenet --seed=49 --worker=128 --learning-rate=1.6 --warmup=8 --label-smoothing=0.1 --mom=0.9 --weight-decay=1.0e-04 --static-loss-scale=128 --print-freq=1 --dist-url='tcp://127.0.0.1:50000' --dist-backend='hccl' --multiprocessing-distributed --benchmark=0 --device='npu' --epoch=90 --batch-size=1024 --resume=true;"
    ...

基于MindSpore的故障恢复代码适配示例

  1. 下载MindSpore代码仓中r2.0.0-alpha分支代码,将“models/official/cv/ResNet”目录重命名为“resnet”并作为训练代码。
  2. 执行以下命令,在管理节点创建代码目录,并上传训练代码到该目录。

    mkdir /data/atlas_dls/code

  3. 进入“MindXDL-deploy”仓库,选择“5.0.RC1”分支,获取“samples/mindspore/resnet50”目录中的“train_start.sh”“main.sh”“pre_stop.sh”文件,结合训练代码中“resnet/scripts”目录,在管理节点构造成如下的目录结构。

    root@ubuntu:/data/atlas_dls/public/code/resnet/scripts/#
    scripts/
    ├── pre_stop.sh
    ├── main.sh
     ...
    ├── run_distribute_train.sh
    ├── run_distribute_train_gpu.sh
    └── train_start.sh

  4. 修改“/data/atlas_dls/public/code/resnet/scripts”目录下的“train_start.sh”文件。

    1. “dataset_path”修改为容器内实际的数据集目录。
    2. “conig_yaml_path”修改为容器内实际的配置文件路径。
    # train_start.sh: 根据实际情况进行修改,全局配置参数:数据集路径,配置参数文件路径;其他模型适配,请根据实际情况增删参数。
    dataset_path=/job/data/resnet50/imagenet/train
    config_yaml_path=/job/code/resnet/config/resnet50_imagenet2012_config.yaml
    
    # main.sh: 针对本示例(Resnet50模型),用户不需要再修改此脚本;其他模型适配,请根据实际情况,增、删或修改环境变量配置,然后修改训练启动脚本路径和对应的参数,即main.sh脚本中Python命令调用的部分。
    # 本例中,单机单卡的Python命令如下:
    python ${ROOT_PATH}/../train.py --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} --output_path=${OUTPUT_PATH} --pre_trained=${OUTPUT_PATH}
    # 本例中,单机多卡和分布式的命令如下:
    python ${ROOT_PATH}/../train.py --run_distribute=True --device_num=${RANK_SIZE} --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} --output_path=${OUTPUT_PATH} --pre_trained=${OUTPUT_PATH}

    “train_start.sh”脚本通过调用“main.sh”脚本启动训练任务。在适配其他模型时,请根据其训练启动脚本(本示例为“train.py”)的使用指导,调整“main.sh”脚本中的环境变量配置、启动脚本路径、启动脚本参数。

  5. 修改“/data/atlas_dls/public/code/resnet/config/”目录的配置文件“resnet50_imagenet2012_config.yaml”。模型保存和加载设置,图编译保存和加载设置。

    ...
    run_distribute: False
    enable_profiling: False
    data_path: "/cache/data"
    output_path: "/cache/train" # 修改checkpoint保存路径,请用户根据实际情况进行修改
    load_path: "/cache/checkpoint_path/"
    device_target: "Ascend"
    checkpoint_path: "./checkpoint/"
    checkpoint_file_path: ""
    ...
    net_name: "resnet50"
    dataset: "imagenet2012"
    device_num: 1
    pre_trained: "/job/code/output/checkpoint/ckpt_0" # 容器内预训练模型加载路径(支持目录和文件),请用户参考训练yaml根据实际情况进行修改
    run_eval: False
    eval_dataset_path: ""
    parameter_server: False
    filter_weight: False
    save_best_ckpt: True
    eval_start_epoch: 40
    ...
    network_dataset: "resnet50_imagenet2012"
    
    
    # 再训练选项 
    save_graphs: False  # 是否开启图编译结果保存
    save_graphs_path: "./graphs" # 图编译结果保存路径
    has_trained_epoch: 0 # 模型预训练的epoch,默认是0
    has_trained_step: 0 # 模型预训练的step,默认是0
    ---
    # 每项配置的帮助说明
    enable_modelarts: "Whether training on modelarts, default: False"
    ...
    batch_size: "Batch size for training and evaluation"
    epoch_size: "Total training epochs."
    checkpoint_path: "The location of the checkpoint file."
    checkpoint_file_path: "The location of the checkpoint file."
    save_graphs: "Whether save graphs during training, default: False."
    save_graphs_path: "Path to save graphs."

  6. resnet代码的启动脚本为“train.py”,检查“train.py”中是否存在保存checkpoint的代码,如果存在,则执行8;否则执行7
  7. 补充保存checkpoint的代码。以下为checkpoint保存样例,其中所用参数需要用户在配置文件中定义和设置。其他模型适配,请参考如下片段,根据启动脚本具体内容,添加保存checkpoint的代码。如有需要,请参考MindSpore官网教程进行修改。

    ...
        # 模型保存代码
        if config.save_checkpoint:
            ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
            config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                         keep_checkpoint_max=config.keep_checkpoint_max,
                                         append_info=ckpt_append_info)
            ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
            cb += [ckpt_cb]
    ...

  8. resnet代码的启动脚本为“train.py”,检查“train.py”中是否存在加载checkpoint的代码,如果存在,则执行10;否则执行9
  9. 补充加载checkpoint的代码。以下为checkpoint加载样例,其中所用参数需要用户在配置文件中定义和设置。其他模型适配,请参考如下片段,根据启动脚本具体内容,添加加载checkpoint的代码。如有需要,请参考MindSpore官网教程进行修改。

    ...
    def load_pre_trained_checkpoint():
        """
        Load checkpoint according to pre_trained path.
        """
        param_dict = None
        if config.pre_trained:
            if os.path.isdir(config.pre_trained):
                # 为了文档简洁性, 此处省略了config.output_path等配置参数的校验, 请用户自行添加相关校验
                ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, "ckpt_0")
                ckpt_pattern = os.path.join(ckpt_save_dir, "*.ckpt")
                ckpt_files = glob.glob(ckpt_pattern)
                if not ckpt_files:
                    logger.warning(f"There is no ckpt file in {ckpt_save_dir}, "
                                   f"pre_trained is unsupported.")
                else:
                    ckpt_files.sort(key=os.path.getmtime, reverse=True)
                    time_stamp = datetime.datetime.now()
                    print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
                          f" pre trained ckpt model {ckpt_files[0]} loading",
                          flush=True)
                    param_dict = load_checkpoint(ckpt_files[0])
            elif os.path.isfile(config.pre_trained):
                # 调用checkpoint加载代码
                param_dict = ms.load_checkpoint(config.pre_trained)
            else:
                print(f"Invalid pre_trained {config.pre_trained} parameter.")
        return param_dict
    ...

  10. 创建下发任务的yaml文件。

    yaml参数修改可参考yaml参数说明

基于Pangu模型的故障恢复代码适配示例

  1. 下载MindSpore代码仓中master分支代码,将“models/official/nlp/Pangu_alpha”目录重命名为“pangu_alpha”并作为训练代码,使用该版本模型脚本需保证在镜像中安装的mindspore版本不低于2.0.0,并且安装mindformers组件。
  2. 执行以下命令,在管理节点创建代码目录。

    mkdir /data/atlas_dls/code

  3. 进入“MindXDL-deploy”仓库,选择“5.0.RC1”分支,获取“samples/mindspore/pangu_alpha”目录中的“train_start.sh”“main.sh”“pre_stop.sh”文件,结合训练代码中“pangu_alpha/scripts”目录,在管理节点构造成如下的目录结构。对于盘古百亿模型,使用“samples/mindspore/pangu_alpha_13B”目录中的对应文件。

    root@ubuntu:/data/atlas_dls/code/pangu_alpha/scripts/# 
    scripts/
    ├── main.sh
    ├── pre_stop.sh
    ├── run_cluster_export.sh
    ├── run_distribute_eval_gpu.sh
    ├── run_distribute_eval.sh
     ...
    ├── run_distribute_train.sh
    ├── run_standalone_eval.sh
    ├── run_standalone_export.sh
    ├── run_standalone_predict.sh
    └── train_start.sh

  4. 修改“/data/atlas_dls/code/pangu_alpha/scripts”目录下的“train_start.sh”文件,将“dataset”修改为容器内实际的数据集目录。

    ...
    # 训练数据集路径,根据实际情况修改
    # 安全提示,涉及对路径和输入参数的校验
    dataset="/job/data/dataset/train_data"
    
    # 设置训练环境变量
    set_env
    
    # 单节点训练场景
    if [[ "$server_count" == "1" ]]; then
        server_id=0
        if [ ${device_count} -lt 8 ]; then
            echo "Less than 8 card training is not supported for pangu alpha model." | tee log
        fi
        if [ ${device_count} -eq 8 ]; then
            bash main.sh ${device_count} ${server_count} ${RANK_TABLE_FILE} ${server_id} ${dataset}
        fi
    
    # 分布式训练场景
    else
        server_id=$(get_server_id)
        if [ $? -eq 1 ];then
            echo "get server id failed."
            exit 1
        fi
        echo "server id is: "${server_id}
        bash main.sh ${device_count} ${server_count} ${RANK_TABLE_FILE} ${server_id} ${dataset}
    

  5. 百亿及以下模型可跳过该步骤。训练千亿模型时,期望恢复时间小于5min,需要进行额外脚本适配。下文以MindSpore代码仓中pangu_alpha的master分支为例(已完成断点续训任务配置和脚本适配)。

    1. 修改“src/pangu_alpha_config.py”文件,主要涉及三个参数的更改:args_opt.num_layers、args_opt.stage_num、args_opt.micro_size。
      def set_parse_200B(args_opt):
          r"""
              Set config for 200B mode
          """
          args_opt.embedding_size = 16384
          args_opt.num_layers = 32
          args_opt.num_heads = 128
          if args_opt.per_batch_size == 0:
              args_opt.per_batch_size = 1
          args_opt.word_emb_dp = 0
          if args_opt.run_type == "train":
              args_opt.start_lr = 6e-5
              args_opt.end_lr = 6e-6
             args_opt.stage_num = 8
             args_opt.micro_size = 16
              args_opt.op_level_model_parallel_num = 16
              if args_opt.optimizer_shard = 1:
                  args_opt.op_level_model_parallel_num = 8
          elif args_opt.run_type == "predict":
              args_opt.stage_num = 4
              args_opt.micro_size = 1
              args_opt.op_level_model_parallel_num = 16
              if args_opt.optimizer_shard == 1:
                  args_opt.op_level_model_parallel_num = 8
    2. 此外,需要指定或者直接修改“src/utils.py”中的“micro_batch_interleaved”参数为“1”(请参考“train.py”脚本的“run_train_pipeline”函数中“stage_device_num”、“data_parallel_num”、“batch_size”、“micro_batch_interleaved”之间的计算关系。最终结果需要满足“PanguAlphaConfig”的“batch_size”值是“TransformerOpParallelConfig”的“data_parallel”的倍数)。

  6. pangu代码的启动脚本为“train.py”,检查“train.py”中是否存在保存checkpoint的代码,如果存在,则执行8;否则执行7
  7. 补充保存checkpoint的代码。以下为checkpoint保存样例,其中所用参数可参照10在配置文件“src/utils.py”中定义和设置。

    ...
    
        # 保存checkpoint的代码调用
        add_checkpoint_callback_policy(args_opt, callback, rank)
    ...
    # 保存checkpoint代码定义
    def add_checkpoint_callback_policy(args_param, callback, rank_id):
        r"""
        Add checkpoint policy to callback.
        """
        # 安全提示,涉及对路径和输入参数的校验
        if args_param.save_checkpoint:
            # checkpoint保存epoch_num和step_num info信息
            ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
            ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                           keep_checkpoint_max=args_param.keep_checkpoint_max,
                                           integrated_save=False,
                                           append_info=ckpt_append_info
                                           )
    
    
            ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                         directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
                                         config=ckpt_config)
    
    
            callback.append(ckpoint_cb)
    ...

  8. pangu代码的启动脚本为“train.py”,检查“train.py”中是否存在加载checkpoint的代码,如果存在,则执行13;否则执行9
  9. 补充加载checkpoint的代码。以下为checkpoint加载样例,存在部分加载checkpoint的代码,需要添加断点续训特性相关checkpoint加载代码,其中所用参数可参照10在配置文件“src/utils.py”中定义和设置。

    ...
    # 如果运行的模型没有开启pipeline并行,则修改在以下函数
    def set_parallel_context(args_opt):
    # 如果运行的模型开启pipeline并行,则修改在以下函数
    # 安全提示,涉及对路径和输入参数的校验
    def set_pipeline_parallel_context(args_opt):
    # 在context.set_auto_parallel_context前添加以下代码,请参考MindSpore文档分布式并行接口说明“set_auto_parallel_context”参数的使用介绍...
            
             
            # 断点续训中增加内容
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
    
            # 断点续训增加内容,strategy_ckpt_save_file_path参数可以根据容器内路径指定
            strategy_ckpt_save_file_path = '/job/data/code/fault_torlence/pangu_alpha/strategy.ckpt' 
            if args_opt.strategy_load_ckpt_path == strategy_ckpt_save_file_path:
                 strategy_ckpt_save_file_path = '/job/data/code/fault_torlence/pangu_alpha/strategy_new.ckpt'
    
            # strategy_ckpt_save_file='strategy.ckpt'修改成strategy_ckpt_save_file=strategy_ckpt_save_file_path
            context.set_auto_parallel_context(
                parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
                full_batch=bool(args_opt.full_batch), strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path,
                enable_parallel_optimizer=bool(args_opt.optimizer_shard), strategy_ckpt_save_file='strategy.ckpt')
            set_algo_parameters(elementwise_op_strategy_follow=True)
            _set_multi_subgraphs()
    ...
    ...
    # checkpoint加载代码定义
    # 安全提示,涉及对路径和输入参数的校验
    def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
        r"""
        Load checkpoint process.
        """
        print("======start single checkpoint", flush=True)
        ckpt_name = args_param.ckpt_name_prefix
        # 为了文档简洁易读, 此处省略了命令行参数save_checkpoint_path和ckpt_name的校验, 请用户自行添加相关校验
        ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()),
                                    f"{ckpt_name}*.ckpt")
        ckpt_all_files = glob.glob(ckpt_pattern)
        if not ckpt_all_files:
            print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                  f"current ckpt_files found is {ckpt_files} "
                  f"with pattern {ckpt_pattern}, so skip the loading.")
            return
        ckpt_exp_pattern = os.path.join(
            args_param.save_checkpoint_path,
            "rank_{}".format(D.get_rank()),
            f"{ckpt_name}*_breakpoint.ckpt",
        )
        ckpt_exp_files = glob.glob(ckpt_exp_pattern)
        ckpt_files = []
        for file in ckpt_all_files:
            if file not in ckpt_exp_files:
                ckpt_files.append(file)
    
        if not ckpt_files:
            print(
                f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                f"current ckpt_files found is {ckpt_files} "
                f"with pattern {ckpt_pattern}, so skip the loading."
            )
            return
        ckpt_files.sort(key=os.path.getmtime, reverse=True)
        time_stamp = datetime.datetime.now()
        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading",
              flush=True)
        # 加载checkpoint最新文件
        print(f'Start to load from {ckpt_files[0]}')
        param_dict = load_checkpoint(ckpt_files[0])
        if param_dict.get("epoch_num") and param_dict.get("step_num"):
            args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
            args_param.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
        model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
        load_param_into_net(network, param_dict)
    ...

  10. 修改“src/utils.py”文件中的参数。

    ...
        opt.add_argument("--vocab_size",
                          type=int,
                          default=50304, # 根据训练数据集进行修改,此处已修改为样例数据集的取值
                          help="vocabulary size, default is 40000.")
    ...
        opt.add_argument("--data_column_name",
                         type=str,
                         default="text", # 根据数据集定义的字段进行修改,此处已修改为样例数据集的取值
                         help="Column name of datasets")
    ...
        parser.add_argument("--strategy_load_ckpt_path",
                            type=str,
                            default="/job/data/code/fault_torlence/pangu_alpha/strategy/strategy.ckpt", # 断点续训中,根据用户习惯指定容器内路径,且路径不会被训练覆盖。
                            help="The training prallel strategy for the model.")
        parser.add_argument("--tokenizer_path",
                            type=str,
                            default="./tokenizer_path",
                            help="The path where stores vocab and vocab model file")
    ...
    def add_retrain_params(opt):
        """
        Add parameters about retrain.
        """
        opt.add_argument("--pre_trained",
                         type=str,
                         default="/job/data/code/fault_torlence/pangu_alpha/8p", # 指定预训练模型路径,
                         help="Pretrained checkpoint path.")
        opt.add_argument("--save_checkpoint_path",  # 指定模型保存路径
                         type=str,
                         default="/job/data/code/fault_torlence/pangu_alpha/8p",
                         help="Save checkpoint path.")
        opt.add_argument("--keep_checkpoint_max", # 指定模型保存策略:最大数量
                         type=int,
                         default=1,
                         help="Max checkpoint save number.")
        opt.add_argument("--save_checkpoint_steps", # 指定模型保存策略:保存间隔
                         type=int,
                         default=20,
                         help="Save checkpoint step number.")
        opt.add_argument("--save_checkpoint", # 指定当次训练是否保存模型
                         type=ast.literal_eval,
                         default=True,
                         help="Whether save checkpoint in local disk.")
        opt.add_argument("--ckpt_name_prefix", # 指定模型保存策略:文件名前缀
                         type=str,
                         default="pangu",
                         help="Saving checkpoint name prefix.")
    ...

  11. “/data/atlas_dls/code/pangu_alpha”目录下构建空文件“group_info_env”

    root@ubuntu:/data/atlas_dls/code/pangu_alpha/# 
    pangu_alpha/
    ├── README.md
    ├── README_CN.md
    ├── group_info_env
     ...
    ├── scripts
    ├── serving_increment
    ├── src
    ├── tasks.py
    └── train.py

  12. 修改“train.py”文件中的“group_info_env”路径。

    ...
         # env variable prepare
         group_info_file = os.getenv("GROUP_INFO_FILE")
         if group_info_file:
             with open(os.path.expanduser("/job/code/group_info_env"), "a") as outfile:
                 outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")
    ...

  13. 使用MindXDL-deploy中a800_vcjob.yaml文件运行任务。yaml参数修改可参考yaml参数说明

基于Resnet50模型的临终遗言代码适配示例

临终遗言功能目前只支持MindSpore框架,需要参见和学习MindSpore回调机制“ModelCheckpoint”内容)中的方法和样例,再对训练启动脚本进行适配。

集群调度组件也对临终遗言功能进行了增强。以resnet模型r2.0.0-alpha分支为例,在“train.py”文件中添加以下加粗内容。

from mindx_elastic.terminating_message import ExceptionCheckpoint
import os # 如果之前不存在,则引入
import datetime # 如果之前不存在,则引入
...
def _is_time_interval_valid():
     # 安全提示,涉及对路径和输入参数的校验
    ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, "ckpt_0")
    ckpt_pattern = os.path.join(ckpt_save_dir, "*breakpoint.ckpt")
    ckpt_files = glob.glob(ckpt_pattern)
    if not ckpt_files:
        return True
    else:
        ckpt_files.sort(key=os.path.getmtime, reverse=True)
        last_breakpoint_ckpt = ckpt_files[0]
        last_breakpoint_ckpt_timestamp = os.path.getmtime(last_breakpoint_ckpt)
        if int((datetime.datetime.now() - datetime.timedelta(minutes=1)).timestamp()) > int(last_breakpoint_ckpt_timestamp):
            return True
        return False


def train_net():
...
    # define callbacks
    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossCallBack(config.has_trained_epoch)
    cb = [time_cb, loss_cb]
    ckpt_save_dir = set_save_ckpt_dir()
    if config.save_checkpoint:
        ckpt_append_info = [{"epoch_num": config.has_trained_epoch, "step_num": config.has_trained_step}]
        config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                     keep_checkpoint_max=config.keep_checkpoint_max,
                                     append_info=ckpt_append_info,
                                     exception_save=_is_time_interval_valid())
        ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)

        cb += [ckpt_cb]
       if _is_time_interval_valid():
           ckpoint_exp = ExceptionCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)
             cb += [ckpoint_exp]
    run_eval(target, model, ckpt_save_dir, cb)
...

其余模型适配是类似的,理解ExceptionCheckpoint的使用方法,和ModelCheckpoint用法类似,将定义的ExceptionCheckpoint添加到Callback列表中发挥作用。

使用MindXDL-deploy中a800_vcjob.yaml文件运行任务,yaml参数修改可参考yaml参数说明

基于Pangu_alpha模型的临终遗言代码适配示例

临终遗言功能目前只支持MindSpore框架,需要参见和学习MindSpore回调机制“ModelCheckpoint”内容)中的方法和样例,再对训练启动脚本进行适配。

集群调度组件也对临终遗言功能进行了增强,以pangu_alpha模型master分支为例,在“train.py”文件中添加以下加粗内容。其中mindx_elastic需要通过下载mindx_elastic-{version}-py3x-none-linux_{arch}.whl软件包获取,并且在容器内进行安装。

from mindx_elastic.terminating_message import ExceptionCheckpoint
...
def add_checkpoint_callback_policy(args_param, callback, rank_id):
...
        ckpoint_cb = ModelCheckpoint(
            prefix=args_param.ckpt_name_prefix + str(rank_id),
            directory=os.path.join(args_param.save_checkpoint_path,
                                   f"rank_{rank_id}"),
            config=ckpt_config)

        # 异常回调
         # 安全提示,涉及对路径和输入参数的校验
        ckpoint_exp = ExceptionCheckpoint(
            prefix=args_param.ckpt_name_prefix + str(rank_id),
            directory=os.path.join(args_param.save_checkpoint_path,
                                   f"rank_{rank_id}"), config=ckpt_config)
        callback.append(ckpoint_cb)
        callback.append(ckpoint_exp)
...

# 在未搭配使用恢复策略的情况下,使用临终遗言而不是周期性ckpt进行断点续训,需要修改restore_exception_checkpoint函数(当某个rank对应的临终遗言ckpt缺失时,加载的模型不一致,可能导致续训异常):
def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch):
    """
    Restore exception checkpoint to training model.
    Args:
        args_param: model training parameters
        sink_size: model training sink size
        dataset: dataset used for training
        model: model
        network: pangu_alpha network
        epoch: training epoch

    Returns: load exception checkpont success or not.

    """
    ckpt_name = args_param.ckpt_name_prefix

    try:
        ckpt_pattern = os.path.join(
            args_param.save_checkpoint_path,
            f"rank_{D.get_rank()}",
            f"{ckpt_name}*breakpoint.ckpt",
        )
        ckpt_files = glob.glob(ckpt_pattern)
        if not ckpt_files:
            return False
        ckpt_files.sort(key=os.path.getmtime, reverse=True)
        print(f" checkpoint files {ckpt_files[0]}")
        param_dict = load_checkpoint(ckpt_files[0])
        print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}")
        if param_dict.get("epoch_num") and param_dict.get("step_num"):
            args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
            args_param.has_trained_steps = int(param_dict["step_num"].data.asnumpy())

        # Load checkpoint files
        model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
        load_param_into_net(network, param_dict)
    except TypeError:
        return False
    else:
        return True

其余模型适配是类似的,理解ExceptionCheckpoint的使用方法,和ModelCheckpoint用法类似,将定义的ExceptionCheckpoint添加到Callback列表中发挥作用。

使用MindXDL-deploy中的a800_vcjob.yaml文件运行任务,yaml参数修改可参考yaml参数说明

基于Pangu_alpha模型的混合并行模型代码适配示例

  1. 下载MindSpore代码仓中master分支的pangu_alpha代码,以pangu_alpha模型为例介绍混合并行模型恢复策略适配方法。
  2. 配置混合并行模型恢复策略,参考MindSpore文档了解“GROUP_INFO_FILE”变量的使用方法,以pangu_alpha模型为例,在DL组件启动脚本main.sh中增加变量的示例如加粗内容所示。

    ...
            rankid=$((rank_start + i))
            export DEVICE_ID=${i}
            export RANK_ID=${rankid}
            mkdir -p ${ROOT_PATH}/../device${rankid}
            cd ${ROOT_PATH}/../device${rankid} || exit
            group_info_dir=./group_info.pb
            group_info_file_tmp=$(realpath $group_info_dir)
            export GROUP_INFO_FILE=${group_info_file_tmp}
            echo "start training for rank ${RANK_ID}, device ${DEVICE_ID}"
    ...

  3. 确保在可利用计算资源内开启混合并行模型恢复策略。以MindSpore pangu_alpha 2.6B模型为例,在“src/pangu_alpha_config.py”文件中确认“args_opt.optimizer_shard”参数修改为“0”

    # 确保optimizer_shard参数设置为0
        elif args_opt.mode == "2.6B":
            args_opt.embedding_size = 2560
            args_opt.num_layers = 32
            args_opt.num_heads = 32
            args_opt.op_level_model_parallel_num = 8
            if args_opt.run_type == "train":
                args_opt.start_lr = 1e-4
                args_opt.end_lr = 1e-6
                args_opt.optimizer_shard = 0

  4. 根据恢复策略加载临终checkpoint,以MindSpore pangu_alpha 2.6B模型为例,在“train.py”文件中核实如下代码,具体适配流程如下:

    1. 导入Python依赖包,其中mindx_elastic需要通过下载mindx_elastic-{version}-py3x-none-linux_{arch}.whl软件包获取,并且在容器内进行安装。
    2. 增加对并行策略环境变量的处理。
    3. 新增临终checkpoint加载方法。
    4. 对原checkpoint加载方法进行检查和适配。
    # 导入依赖
    import json
    from mindx_elastic.restore_module import RestoreStrategyGenerator
    ...
    # 如果运行的模型没有开启pipeline并行,则修改在以下函数
    def run_train(args_opt):
    # 如果运行的模型开启pipeline并行,则修改在以下函数
    def run_train_pipeline(args_opt):
    # 增加并行策略环境变量处理
    ...
        device_num = 1
        if args_opt.distribute == "true":
            rank, device_num = set_parallel_context(args_opt)
        context.set_context(save_graphs=False, save_graphs_path="./graphs_of_device_id_" + str(rank))
        # env variable prepare
        # 安全提示,涉及对路径、输入参数和环境变量的校验
        group_info_file = os.getenv("GROUP_INFO_FILE")
        if group_info_file:
            os.environ["GROUP_INFO_FILE_REFLECT"] = group_info_file
        if group_info_file:
            # 为了文档简洁易读, 省略了对group_info_file的校验, 用户使用时根据需要进行相关校验
            with open(os.path.expanduser("/job/code/group_info_env"), "a") as outfile:
                outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")
    ...
        if args_opt.pre_trained:
           
            flag = restore_exception_checkpoint(args_opt, args_opt.sink_size, ds, model,
                                                pangu_alpha_with_grads, epoch=actual_epoch_num)
            if not flag:
                restore_checkpoint(args_opt, args_opt.sink_size, ds, model,
                                   pangu_alpha_with_grads, epoch=actual_epoch_num)
    ...
    # 修改原checkpoint文件加载方法
    # 安全提示,涉及对路径、输入参数和环境变量的校验
    def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
        r"""
        Load checkpoint process.
        """
        print("======start single checkpoint", flush=True)
        ckpt_name = args_param.ckpt_name_prefix
        ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()),
                                    f"{ckpt_name}*.ckpt")
        ckpt_all_files = glob.glob(ckpt_pattern)
    
    
        if not ckpt_all_files:
            print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                  f"current ckpt_files found is {ckpt_all_files} "
                  f"with pattern {ckpt_pattern}, so skip the loading.")
    
    
        ckpt_exp_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()),
                                        f"{ckpt_name}*_breakpoint.ckpt")
        ckpt_exp_files = glob.glob(ckpt_exp_pattern)
        ckpt_files = []
        for file in ckpt_all_files:
            if file not in ckpt_exp_files:
                ckpt_files.append(file)
    
    
        if not ckpt_files:
            print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                  f"current ckpt_files found is {ckpt_files} "
                  f"with pattern {ckpt_pattern}, so skip the loading.")
            return
        ckpt_files.sort(key=os.path.getmtime, reverse=True)
    ...
    # 定义临终checkpoint加载方法
    # 安全提示,涉及对路径、输入参数和环境变量的校验
    def get_exception_checkpoints(args_param):
        r"""
        Load checkpoint process.
        """
        
        print("======start exception checkpoint", flush=True)
        restore_ranks = os.getenv("RESTORE_RANKS")
        if not restore_ranks:
            return None
    
        restore_rank_list = list(map(int, restore_ranks.split(",")))
        ckpt_file_list = []
        ckpt_name = args_param.ckpt_name_prefix
        for ckpt_rank in restore_rank_list:
            ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
                                        f"rank_{ckpt_rank}",
                                        f"{ckpt_name}*_breakpoint.ckpt")
            ckpt_files = glob.glob(ckpt_pattern)
            if not ckpt_files:
                print(
                    f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                    f"current ckpt_files found is {ckpt_files} "
                    f"with pattern {ckpt_pattern}, so skip the loading.")
                return None
            ckpt_files.sort(key=os.path.getmtime, reverse=True)
            ckpt_file_list.append(ckpt_files[0])
        print(f"checkpoint file {ckpt_file_list}")
        return ckpt_file_list
    
    # 安全提示,涉及对路径、输入参数和环境变量的校验
    def check_exception_checkpoints(ckpt_file_list):
        """
        Check exception checkpoints size.
        Args:
            ckpt_file_list: exception checkpoints
        Returns: result of exception checkpoints size check.
    
        """
        ckpt_size_list = []
        for ckpt_file in ckpt_file_list:
            ckpt_size_list.append(os.path.getsize(ckpt_file))
    
        if len(set(ckpt_size_list)) > 1:
            return False
        return True
    
    # 安全提示,涉及对路径、输入参数和环境变量的校验
    def restore_exception_checkpoint(args_param, sink_size, dataset, model, network, epoch):
        """
        Restore exception checkpoint to training model.
        Args:
            args_param: model training parameters
            sink_size: model training sink size
            dataset: dataset used for training
            model: model
            network: pangu_alpha network
            epoch: training epoch
    
    
        Returns: load exception checkpont success or not.
    
    
        """
        restore_strategy_generator = RestoreStrategyGenerator()
        res_query = restore_strategy_generator.gen_fault_tolerance_strategy()
        if not res_query:
            return False
    
        restore_ranks, restore_dict = res_query
        print(f"restore ranks: {restore_ranks}, restore dict: {restore_dict}")
        if not restore_ranks:
            return False
    
        if not restore_dict:
               return False
    
        os.environ["RESTORE_RANKS"] = restore_ranks
        os.environ["RESTORE_RANKS_MAP"] = str(restore_dict)
    
        if os.getenv("RESTORE_RANKS") == "-1":
            return False
    
    
        ckpt_file_list = get_exception_checkpoints(args_param)
    
    
        restore_flag = False
        if ckpt_file_list:
            restore_flag = check_exception_checkpoints(ckpt_file_list)
    
    
        if not restore_flag:
            return False
    
    
        ckpt_name = args_param.ckpt_name_prefix
        restore_ranks_map = os.getenv("RESTORE_RANKS_MAP")
        if not restore_ranks_map:
            return False
    
    
        try:
            print("whether run into load process")
            restore_ranks_map_json = json.loads(restore_ranks_map)
            map_rank_id = D.get_rank()
            for key in restore_ranks_map_json.keys():
                key_list = list(key.split(","))
                if str(D.get_rank()) in key_list:
                    map_rank_id = restore_ranks_map_json.get(key)
    
    
            print(f"loading map rank id {map_rank_id}")
            ckpt_pattern = os.path.join(args_param.save_checkpoint_path,
                                        f"rank_{map_rank_id}",
                                        f"{ckpt_name}*breakpoint.ckpt")
            ckpt_files = glob.glob(ckpt_pattern)
            if not ckpt_files:
                 return False
            ckpt_files.sort(key=os.path.getmtime, reverse=True)
            print(f" checkpoint files {ckpt_files[0]}")
            param_dict = load_checkpoint(ckpt_files[0])
            print(f" checkpoint param dict epoch num {param_dict.get('epoch_num')}")
            if param_dict.get("epoch_num") and param_dict.get("step_num"):
                args_param.has_trained_epoches = int(
                    param_dict["epoch_num"].data.asnumpy())
                args_param.has_trained_steps = int(
                    param_dict["step_num"].data.asnumpy())
    
    
            # 加载checkpoint文件
            model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
            load_param_into_net(network, param_dict)
        except TypeError:
            return False
        else:
            return True

  5. 使用MindXDL-deploy中的a800_vcjob.yaml文件运行任务,yaml参数修改可参考yaml参数说明。下发任务yaml中“metadata.name”对应的环境变量名称为“mindx-dls-test”和任务名保持一致,如以下代码中加粗内容所示。

    ...
    apiVersion: batch.volcano.sh/v1alpha1
    kind: Job
    metadata:
      name: mindx-dls-test                 
      namespace:  xxx                     
      labels:
        ring-controller.atlas: ascend-910   
    ...
          spec:
            terminationGracePeriodSeconds: 600 # 参考表 临终遗言vcjob任务“terminationGracePeriodSeconds”配置项值列表
            containers:
            - image: mindspore:b035        
              imagePullPolicy: IfNotPresent
              name: mindspore
              env:
              - name: mindx-dls-test        
                valueFrom:
                  fieldRef:
                    fieldPath: metadata.name
              - name: XDL_IP               
                valueFrom:
            ...

yaml参数说明

使用故障重调度需要在下发vcjob的yaml中增加“fault-scheduling”、“terminationGracePeriodSeconds”和“maxRetry”参数,具体说明如下所示。

表1 断点续训vcjob任务“fault-scheduling”配置项值列表

序号

含义

1

grace

任务使用重调度,并在过程中先优雅删除原Pod,15分钟(配置可参考Volcano配置)后若还未成功,使用强制删除原Pod。若使用临终遗言方案,则需要使用该项配置。

2

force

任务使用重调度,并在过程中强制删除原Pod

3

off

该任务不使用故障重调度特性,k8s的maxRetry仍然生效。

4

无(无fault-scheduling)

5

其他值

表2 临终遗言vcjob任务“terminationGracePeriodSeconds”配置项值列表

序号

含义

1

0< terminationGracePeriodSeconds <"grace-over-time”参数的值

容器收到停止信号到被K8s强制停止经历的时间,该时间需要大于0且小于volcano-*.yaml文件中“grace-over-time”参数的值,同时还需要保证能够保存完ckpt文件,请根据实际情况修改。具体说明请参考K8s官网容器生命周期回调

使用断点续训功能,需要扩展内存,请按注释添加参数。此外还要使用“maxRetry”机制,yaml模板示例如下:

apiVersion: v1
kind: ConfigMap
metadata:
  name: rings-config-mindx-dls-test     
  namespace: vcjob                    
  labels:
    ring-controller.atlas: ascend-910
data:
  hccl.json: |
    {
        "status":"initializing"
    }
---
apiVersion: v1
kind: ConfigMap
metadata:
  name: fault-config-mindx-dls-test
  namespace: vcjob 
data:
  fault-npus: |
    {
        "status":"initializing"
    }
---
apiVersion: batch.volcano.sh/v1alpha1
kind: Job
metadata:
  name: mindx-dls-test    # 注意和ConfigMap的name的对应关系
  namespace: vcjob        # 根据实际需要选择合适的namespace(ConfigMap, Job需要保持一致)
  labels:
    ring-controller.atlas: ascend-910  # hccl_controller根据该标签来区分配置Ascend910和非Ascend910的场景
    fault-scheduling: "force"
spec:
  minAvailable: 1
  schedulerName: volcano    # 使用volcano调度器调度任务
  policies:
    - event: PodEvicted
      action: RestartJob
  plugins:
    ssh: []
    env: []
    svc: []
  maxRetry: 3
...
      spec:
        containers:
        - image: mindspore:b035         # 训练框架镜像,可修改
          imagePullPolicy: IfNotPresent
          name: mindspore
...
          command:
          - "/bin/bash"
          - "-c"
          - "cd /job/code/resnet/scripts; chmod +x train_start.sh; ./train_start.sh;" # 训练脚本执行命令,确保Docker上存在相关命令和路径。
          #args: [ "while true; do sleep 30000; done;"  ]                            # 注释掉上一行并启用该行,用户可以在容器中手动运行训练脚本方便调试。
                                                                                     # 命令为'kubectl exec -it -n {namespace} {podname} bash'
          lifecycle:  # 使用临终遗言功能需要添加加粗代码
            preStop:
              exec:
                command: ["/bin/bash", "-c", "cd /job/code/resnet/scripts; bash pre_stop.sh"]
          resources:
            requests:
              huawei.com/Ascend910: 1                                                # 请求的NPU数量,最大值为8。用户可在下方添加行配置内存、CPU等资源
            limits:
              huawei.com/Ascend910: 1                                                # 数值与请求数量保持一致
...
       #断点续训扩容需要添加以下几行
          volumeMounts:
          - mountPath: /dev/shm
            name: shm
        volumes:
        - name: shm
          emptyDir:
            medium: Memory
            sizeLimit: 16Gi
...