Script Adaptation

This section provides script adaptation examples for fault recovery. Select a script adaptation example based on your actual requirements.

The sample model code provided below may differ from the actual version. Please use the actual version code.

PyTorch-based Fault Recovery Example

  1. Download ResNet50_ID4149_for_PyTorch from the master branch in the PyTorch code repository and use it as the training code.
  2. Prepare a dataset corresponding to ResNet-50, and comply with corresponding specifications when using the dataset.
  3. Upload the dataset to the storage node as an administrator.
    1. Go to the /data/atlas_dls/public directory and upload the dataset to any directory, for example, /data/atlas_dls/public/dataset/resnet50/imagenet.
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet# pwd
      Command output:
      1
      /data/atlas_dls/public/dataset/resnet50/imagenet
      
    2. Run the du -sh command to check the dataset size.
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet# du -sh
      Command output:
      1
      11G
      
  4. Decompress the training code downloaded in 1 to the local host, and upload the ModelZoo-PyTorch/PyTorch/built-in/cv/classification/ResNet50_ID4149_for_PyTorch directory in the decompressed training code to a directory in the environment, for example, /data/atlas_dls/public/code/.
  5. Go to the mindcluster-deploy repository, select a branch based on mindcluster-deploy Version Description, and obtain the train_start.sh, utils.sh, and rank_table.sh files in the samples/train/resumable-training/fault-rescheduling/withRanktable/pytorch/resnet50 directory. Then, create a scripts directory in the training code and construct the following directory structure on the management node:
    root@ubuntu:/data/atlas_dls/public/code/ResNet50_ID4149_for_PyTorch/scripts/#
    scripts/
    ├── rank_table.sh
    ├── utils.sh
    └── train_start.sh
  6. In /data/atlas_dls/public/code/ResNet50_ID4149_for_PyTorch, modify the following information in bold in the main.py code. The modification involves adjusting the logic of model saving and loading.
    import argparse
    import glob
    import os
    ...
        if args.resume:
            candidate_ckpt_path = ""
            for p in glob.glob(f"./rank*"):
                best_ckpt_path = os.path.join(p, "model_best.pth.tar")
                if os.path.exists(best_ckpt_path):
                    candidate_ckpt_path = best_ckpt_path
                    break
            if candidate_ckpt_path:
                print("[gpu id:", args.gpu, "]", "=> loading checkpoint '{}'".format(candidate_ckpt_path))
                # Map model to be loaded to specified single npu.
                loc = 'npu:{}'.format(args.gpu)
                checkpoint = torch.load(candidate_ckpt_path, map_location=loc)
                print(f"load checkpoint to : {loc}")
                args.start_epoch = checkpoint['epoch']
                best_acc1 = checkpoint['best_acc1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("[gpu id:", args.gpu, "]", "=> loaded checkpoint '{}' (epoch {})".format(candidate_ckpt_path, checkpoint['epoch']))
            else:
                print("no valid ckpt found to resume.")
    ...
            if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
                save_path = f"./rank_{args.rank}"
                if not os.path.exists(save_path):
                    os.makedirs(save_path, exist_ok=True)
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, save_path=save_path)
    ...
    ...
    # Modify the original save_checkpoint function.
    def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_path="./"):
        if is_best:
            target_path = os.path.join(save_path, 'model_best.pth.tar')
            torch.save(state, target_path)
            print(f"save ckpt to {target_path} done. Best epoch for now is :{state['epoch']}")

MindSpore-based Fault Recovery Example

  1. Download the master branch code from the MindSpore code repository, rename the models/official/cv/ResNet directory resnet, and use it as the training code.
  2. Run the following command to create a code directory on the management node and upload the training code to the directory:
    mkdir /data/atlas_dls/code
  3. Go to the mindcluster-deploy repository, select a branch based on mindcluster-deploy Version Description, and obtain the train_start.sh and main.sh files in the samples/train/resumable-training/fault-rescheduling/withRanktable/mindspore/resnet50 directory. Then, combine the files with the resnet/scripts directory in the training code to construct the following directory structure on the management node:
    root@ubuntu:/data/atlas_dls/public/code/resnet/scripts/#
    scripts/
    ├── main.sh
     ...
    ├── run_distribute_train.sh
    ├── run_distribute_train_gpu.sh
    └── train_start.sh
  4. Modify the train_start.sh file in the /data/atlas_dls/public/code/resnet/scripts directory.
    1. Change dataset_path to the actual dataset directory in the container.
    2. Change config_yaml_path to the actual configuration file path in the container.
    Modify the parameters based on your needs (global configuration parameters: dataset path and configuration file path). For other model adaptation, add or delete parameters based on your needs.
    dataset_path=/job/data/imagenet/train
    config_yaml_path=/job/code/resnet/resnet50_imagenet2012_config.yaml

    The train_start.sh script calls the main.sh script to start a training job. When adapting to other models, adjust the environment variable configuration, startup script path, and startup script parameters in the main.sh script based on the usage guide of the training startup script (train.py in this example).

    # main.sh: For this example (ResNet-50 model), you do not need to modify the script. For adaptation of other models, add, delete, or modify environment variables based on your needs, and then modify the training startup script path and corresponding parameters, that is, the part invoked by the Python command in the main.sh script.
    # In this example, the Python command for single-server single-processor training is as follows:
    python ${ROOT_PATH}/../train.py --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} 
    # In this example, the command for single-server multi-processor and distributed training is as follows:
    python ${ROOT_PATH}/../train.py --run_distribute=True --device_num=${RANK_SIZE} --data_path=${DATA_PATH} --config_path=${CONFIG_PATH} 
  5. Modify the resnet50_imagenet2012_config.yaml file in /data/atlas_dls/public/code/resnet/config/. Configure model and graph build saving and loading functions.
    ...
    run_distribute: False
    enable_profiling: False
    data_path: "/cache/data"
    output_dir: "/job/code/output" # Checkpoint saving path. Change the path as required.
    load_path: "/cache/checkpoint_path/"
    device_target: "Ascend"
    checkpoint_path: "./checkpoint/"
    checkpoint_file_path: ""
    ...
    net_name: "resnet50"
    dataset: "imagenet2012"
    device_num: 1
    pre_trained: "/job/code/output/resnet50/imagenet2012/ckpt" # Pre-trained model loading path (directory and file) in the container. Fuzzy search for .ckpt files in the specified path is supported. The latest .ckpt file found will be loaded. You can modify the path based on the actual requirements according to the training YAML.
    run_eval: False
    eval_dataset_path: ""
    parameter_server: False
    filter_weight: False
    save_best_ckpt: True
    eval_start_epoch: 40
    ...
    network_dataset: "resnet50_imagenet2012"
    
    
    # Retraining options
    save_graphs: False  # Whether to save the graph build result.
    save_graphs_path: "./graphs" # Path for saving the graph build result.
    has_trained_epoch: 0 # Epoch for model pre-training. The default value is 0.
    has_trained_step: 0 # Model pre-training step. The default value is 0.
    ---
    # Help information of each option
    enable_modelarts: "Whether training on modelarts, default: False"
    ...
    batch_size: "Batch size for training and evaluation"
    epoch_size: "Total training epochs."
    checkpoint_path: "The location of the checkpoint file."
    checkpoint_file_path: "The location of the checkpoint file."
    save_graphs: "Whether save graphs during training, default: False."
    save_graphs_path: "Path to save graphs."
  6. The startup script of the ResNet code is train.py. Check whether train.py contains code for saving checkpoints.
    • If the code exists, skip this step.
    • If the code does not exist, add the following sample code for saving checkpoints. You need to define and set required parameters in the configuration file. Refer to the following snippet for adapting other models, and add the code for saving checkpoints based on the specific content of the startup script. For details, see the tutorial on the official website of MindSpore.
    ...
        # Model saving code
        if config.save_checkpoint:
            ckpt_append_info = [{"epoch_num": 0, "step_num": 0}]
            config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
                                         keep_checkpoint_max=config.keep_checkpoint_max,
                                         append_info=ckpt_append_info)
            ckpt_cb = ModelCheckpoint(prefix=config.net_name, directory=config.save_ckpt_dir+"_"+str(config.rank_id), config=config_ck)
            cb += [ckpt_cb]
    ...
  7. The startup script of the ResNet code is train.py. Check whether the code for loading checkpoints exists in train.py. If the code exists, the configuration is complete and go to the next section. If the code does not exist, go to 8.
  8. Add the code for loading checkpoints to train.py. The following is a checkpoint loading example. You need to define and set required parameters in the configuration file. Refer to the following snippet for adapting other models, and add the code for loading checkpoints based on the specific content of the startup script. For details, see the tutorial on the official website of MindSpore.
    1. Modify src/utils.py to add the code for reading epochs. After a checkpoint is loaded, the training log will be printed from the epoch where the checkpoint is saved.
      ...
      def init_weight(net, cfg):
          """init_weight"""
          if cfg.pre_trained:
              if not os.path.isfile(cfg.pre_trained):
                  cfg.logger.warning("There is not ckpt file: %s", cfg.pre_trained)
              else:
                  param_dict = ms.load_checkpoint(cfg.pre_trained)
                  if cfg.filter_weight:
                      filter_list = [x.name for x in net.end_point.get_parameters()]
                      filter_checkpoint_parameter_by_list(param_dict, filter_list)
                  ms.load_param_into_net(net, param_dict)
                  cfg.start_epoch = int(param_dict.get('epoch_num', ms.Tensor(0, ms.int32)).asnumpy().item())
                  cfg.logger.info("Pre trained ckpt mode: %s loading", cfg.pre_trained)
      ...
    2. Modify train.py by using _try_to_init_weight to replace the original init_weight function to load the .ckpt file. This prevents training errors caused by incomplete checkpoint loading.
      import glob
      ...
      # Search for the latest .ckpt file in the pre_trained directory.
      def _find_latest_ckpt():
          ckpt_files = glob.glob(config.pre_trained+"*/*.ckpt")
          if ckpt_files:
              ckpt_files.sort(key=os.path.getmtime, reverse=True)
          return ckpt_files
      
      # Attempt to load the .ckpt file for INIT_WEIGHT_MAX_ATTEMPTS times.
      def _try_to_init_weight(net, config):
          if os.path.isfile(config.pre_trained):
              latest_ckpt = [config.pre_trained]
          else:
              latest_ckpt = _find_latest_ckpt()
      
          if not latest_ckpt:
              config.logger.warning("There is not ckpt file: %s", config.pre_trained)
              return
      
          init_weight_attempts = 0
          INIT_WEIGHT_MAX_ATTEMPTS = 5
          while(latest_ckpt and init_weight_attempts < INIT_WEIGHT_MAX_ATTEMPTS): 
              try:
                  config.pre_trained = latest_ckpt[0]
                  init_weight(net, config)
                  break
              except Exception:
                  config.logger.warning("Pre trained ckpt %s format is incorrect, try to load the last most recent ckpt", config.pre_trained)
                  if latest_ckpt[1:]:
                      latest_ckpt = latest_ckpt[1:]
                      init_weight_attempts+=1
                      continue
                  else:
                      config.logger.error("no more ckpt to load", config.pre_trained)
                      raise ValueError("ckpt format is incorrect, no more ckpt to load, load ckpt failed.")
      
      ...
      @moxing_wrapper()
      def train_net():
          """train net"""
          target = config.device_target
          set_parameter()
          set_output_dir(config)
          config.logger = get_logger(config.log_dir, config.rank_id, config.parameter_server)
          dataset = create_dataset(dataset_path=config.data_path, do_train=True,
                                   batch_size=config.batch_size, train_image_size=config.train_image_size,
                                   eval_image_size=config.eval_image_size, target=target,
                                   distribute=config.run_distribute)
          step_size = dataset.get_dataset_size()
          net = resnet(class_num=config.class_num)
          if config.parameter_server:
              net.set_param_ps()
         # Use _try_to_init_weight, instead of the original init_weight function, to load the .ckpt file. This prevents training errors caused by incomplete checkpoint loading.
          _try_to_init_weight(net, config)
      
          if config.resume_ckpt:
              resume_param = ms.load_checkpoint(config.resume_ckpt,
                                                choice_func=lambda x: not x.startswith(('learning_rate', 'global_step')))
              config.start_epoch = int(resume_param.get('epoch_num', ms.Tensor(0, ms.int32)).asnumpy().item())
          lr = ms.Tensor(init_lr(step_size=step_size))
      ...

TensorFlow-based Fault Recovery Example

  1. Download ResNet50_ID0360_for_TensorFlow2.X from the master branch of the TensorFlow code repository and use it as the training code. Select the TensorFlow version package in the training image based on the TensorFlow version of model code.
  2. Upload the dataset to the storage node as an administrator.
    1. Go to the /data/atlas_dls/public directory and upload the dataset to any directory, for example, /data/atlas_dls/public/dataset/resnet50/imagenet_TF.
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet_TF# pwd
      /data/atlas_dls/public/dataset/resnet50/imagenet_TF
    2. Run the du -sh command to check the dataset size.
      root@ubuntu:/data/atlas_dls/public/dataset/resnet50/imagenet_TF# du -sh
      42G
  3. Decompress the training code downloaded in 1 to the local host and rename the ResNet50_ID0360_for_TensorFlow2.X directory in ModelZoo-TensorFlow-master/TensorFlow2/built-in/cv/image_classification/ to the ResNet50_for_TensorFlow_2.6_code/ directory.
  4. Go to the mindcluster-deploy repository, select a branch based on mindcluster-deploy Version Description, and obtain the train_start.sh, utils.sh, and rank_table.sh files in the samples/train/basic-training/ranktable directory. Then, create a scripts directory in the training code and construct the following directory structure on the management node:
    /data/atlas_dls/public/code/ResNet50_for_TensorFlow_2.6_code/
    ├──  scripts
    │   ├──  train_start.sh
    │   ├──  utils.sh
    │   ├──  rank_table.sh
    │    ...
  5. Modify the training code and add the code for printing logs recorded when the checkpoint file is loaded. Modify the tensorflow/tf2_common/training/controller.py file.
    class Controller(object):
      """Class that facilitates training and evaluation of models."""
      def __init__(
        ...
        # Restore Model if needed.
        if self.checkpoint_manager is not None:
          model_restored = self._restore_model()
          logging.info("loading checkpoint %s", model_restored)
          if not model_restored and self.checkpoint_manager.checkpoint_interval:
            # If the model is not restored from a checkpoint, save an initial
            # checkpoint.
            ckpt_path = self.checkpoint_manager.save(
                checkpoint_number=self.global_step)
            logging.info("Saved checkpoints in %s", ckpt_path)
        # Create and initialize the interval triggers.
        self.eval_trigger = utils.IntervalTrigger(self.eval_interval,
                                                  self.eval_offset)

Pangu_alpha Model Adaptation Example

  1. Download the master branch code from the MindSpore code repository, rename the models/official/nlp/Pangu_alpha directory pangu_alpha, and use it as the training code. To use the model script of this version, ensure that the MindSpore version installed in the image is 2.0.0 or later and the MindFormers component has been installed.
  2. Run the following command to create a code directory on the master node:
    mkdir /data/atlas_dls/code
  3. Go to the mindcluster-deploy repository, select a branch based on mindcluster-deploy Version Description, and obtain the train_start.sh and main.sh files in the samples/train/resumable-training/fault-rescheduling/withRanktable/mindspore/pangu_alpha directory. Then, combine the files with the pangu_alpha/scripts directory in the training code to construct the following directory structure on the management node: For the Pangu model running over 10 billions of parameters, use the corresponding files in the samples/train/resumable-training/fault-rescheduling/withRanktable/mindspore/pangu_alpha_13B directory.
    root@ubuntu:/data/atlas_dls/code/pangu_alpha/scripts/# 
    scripts/
    ├── main.sh
    ├── run_cluster_export.sh
    ├── run_distribute_eval_gpu.sh
    ├── run_distribute_eval.sh
     ...
    ├── run_distribute_train.sh
    ├── run_standalone_eval.sh
    ├── run_standalone_export.sh
    ├── run_standalone_predict.sh
    └── train_start.sh
  4. Modify the train_start.sh file in the /data/atlas_dls/code/pangu_alpha/scripts directory, and change dataset to the actual dataset directory in the container.
    ...
    # Training dataset path. Change it as required.
    # Security tip. The verification of paths and input parameters is involved.
    dataset="/job/data/train_data"
    
    # Set training environment variables.
    set_env
    
    # Single-server training scenario
    if [[ "$server_count" == "1" ]]; then
        server_id=0
        if [ ${device_count} -lt 8 ]; then
            echo "Less than 8 card training is not supported for pangu alpha model." | tee log
        fi
        if [ ${device_count} -eq 8 ]; then
            bash main.sh ${device_count} ${server_count} ${RANK_TABLE_FILE} ${server_id} ${dataset}
        fi
    
    # Distributed training scenario
    else
        server_id=$(get_server_id)
        if [ $? -eq 1 ];then
            echo "get server id failed."
            exit 1
        fi
        echo "server id is: "${server_id}
        bash main.sh ${device_count} ${server_count} ${RANK_TABLE_FILE} ${server_id} ${dataset}
    
  5. Skip this step for models with 10 billion or fewer parameters. To train a model with hundreds of billions of parameters and recover it within 5 minutes, additional script adaptation is required. The following uses the master branch of Pangu_alpha code in the MindSpore code repository as an example. (The elastic training job configuration and script adaptation have been completed.)
    1. Modify the args_opt.num_layers, args_opt.stage_num and args_opt.micro_size parameters in the src/pangu_alpha_config.py file.
      def set_parse_200B(args_opt):
          """
              Set config for 200B mode
          """
          args_opt.embedding_size = 16384
          args_opt.num_layers = 32                 # Model layers
          args_opt.num_heads = 128
          if args_opt.per_batch_size == 0:
              args_opt.per_batch_size = 1
          args_opt.word_emb_dp = 0
          if args_opt.run_type == "train":
              args_opt.start_lr = 6e-5
              args_opt.end_lr = 6e-6
             args_opt.stage_num = 8               # Number of pipeline stages
             args_opt.micro_size = 16             # Microbatch size in pipeline parallelism mode. The value must be greater than args_opt.stage_num.
              args_opt.op_level_model_parallel_num = 16
              if args_opt.optimizer_shard = 1:
                  args_opt.op_level_model_parallel_num = 8
          elif args_opt.run_type == "predict":
              args_opt.stage_num = 4
              args_opt.micro_size = 1
              args_opt.op_level_model_parallel_num = 16
              if args_opt.optimizer_shard == 1:
                  args_opt.op_level_model_parallel_num = 8
    2. In addition, you need to specify or directly change the value of micro_batch_interleaved in src/utils.py to 1. For details, see the calculation relationship among stage_device_num, data_parallel_num, batch_size, and micro_batch_interleaved in the run_train_pipeline function of the train.py script. The final result must meet the following condition: The value of batch_size of PanguAlphaConfig is a multiple of the value of data_parallel of TransformerOpParallelConfig.
  6. The startup script of the Pangu code is train.py. Check whether train.py contains code for saving checkpoints.
    • If the code exists, skip this step.
    • If the code does not exist, add the following sample code for saving the checkpoints. You can define and set required parameters in the src/utils.py file according to 9.
    ...
    
        # Call the code for saving checkpoints.
        add_checkpoint_callback_policy(args_opt, callback, rank)
    ...
    # Define the checkpoint saving code.
    def add_checkpoint_callback_policy(args_param, callback, rank_id):
        r"""
        Add checkpoint policy to callback.
        """
        # Security tip. The verification of paths and input parameters is involved.
        if args_param.save_checkpoint:
            # The epoch_num and step_num info information is saved in checkpoints.
            ckpt_append_info = [{"epoch_num": args_param.has_trained_epoches, "step_num": args_param.has_trained_steps}]
            ckpt_config = CheckpointConfig(save_checkpoint_steps=args_param.save_checkpoint_steps,
                                           keep_checkpoint_max=args_param.keep_checkpoint_max,
                                           integrated_save=False,
                                           append_info=ckpt_append_info
                                           )
    
    
            ckpoint_cb = ModelCheckpoint(prefix=args_param.ckpt_name_prefix + str(rank_id),
                                         directory=os.path.join(args_param.save_checkpoint_path, f"rank_{rank_id}"),
                                         config=ckpt_config)
    
    
            callback.append(ckpoint_cb)
    ...
  7. The startup script of the Pangu code is train.py. Check whether the code for loading checkpoints exists in train.py. If the code exists, go to 10. If the code does not exist, go to 8.
  8. Add the code for loading checkpoints to train.py. The following is a checkpoint loading example, containing some checkpoint loading code exists. You need to add the checkpoint loading code related to elastic training. You can define and set required parameters in the src/utils.py configuration file according to 9.
    ...
    # If pipeline parallelism is not enabled for the running model, modify the following function:
    def set_parallel_context(args_opt):
    # If pipeline parallelism is enabled for the running model, modify the following function:
    # Security tip. The verification of paths and input parameters is involved.
    def set_pipeline_parallel_context(args_opt):
    # Add the following code before mindspore.set_auto_parallel_context. For details about how to use the set_auto_parallel_context parameter, refer to MindSpore Parallel Distributed Training Mode.
            
             
            # Content added to elastic training.
            if not os.path.exists(args_opt.strategy_load_ckpt_path):
                args_opt.strategy_load_ckpt_path = ""
    
            # Content added to elastic training. The strategy_ckpt_save_file_path parameter can be specified based on the path in the container.
            strategy_ckpt_save_file_path = '/job/data/code/fault_torlence/pangu_alpha/strategy.ckpt' 
            if args_opt.strategy_load_ckpt_path == strategy_ckpt_save_file_path:
                 strategy_ckpt_save_file_path = '/job/data/code/fault_torlence/pangu_alpha/strategy_new.ckpt'
     
            # Change strategy_ckpt_save_file=strategy.ckpt to strategy_ckpt_save_file=strategy_ckpt_save_file_path. If the strategy_ckpt_save_file parameter is not specified in set_auto_parallel_context, manually add strategy_ckpt_save_file=strategy_ckpt_save_file_path, as shown in the following information in bold:
            mindspore.set_auto_parallel_context(
                parallel_mode=args_opt.parallel_mode, gradients_mean=False, search_mode=args_opt.search_mode,
                full_batch=bool(args_opt.full_batch), loss_repeated_mean=True,
                device_num=device_num, enable_parallel_optimizer=bool(args_opt.optimizer_shard),
                pipeline_stages=args_opt.stage_num, enable_alltoall=bool(args_opt.enable_alltoall),
                strategy_ckpt_save_file=strategy_ckpt_save_file_path)
           
    ...
    ...
    # Define the checkpoint loading code.
    # Security tip. The verification of paths and input parameters is involved.
    def restore_checkpoint(args_param, sink_size, dataset, model, network, epoch):
        r"""
        Load checkpoint process.
        """
        print("======start single checkpoint", flush=True)
        ckpt_name = args_param.ckpt_name_prefix
        # To simplify the document, the verification of the command line parameters save_checkpoint_path and ckpt_name is omitted. Do not forget to verify them.
        ckpt_pattern = os.path.join(args_param.save_checkpoint_path, "rank_{}".format(D.get_rank()),
                                    f"{ckpt_name}*.ckpt")
        ckpt_all_files = glob.glob(ckpt_pattern)
        if not ckpt_all_files:
            print(f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                  f"current ckpt_files found is {ckpt_files} "
                  f"with pattern {ckpt_pattern}, so skip the loading.")
            return
        ckpt_exp_pattern = os.path.join(
            args_param.save_checkpoint_path,
            "rank_{}".format(D.get_rank()),
            f"{ckpt_name}*_breakpoint.ckpt",
        )
        ckpt_exp_files = glob.glob(ckpt_exp_pattern)
        ckpt_files = []
        for file in ckpt_all_files:
            if file not in ckpt_exp_files:
                ckpt_files.append(file)
    
        if not ckpt_files:
            print(
                f"There is no ckpt file in {args_param.save_checkpoint_path}, "
                f"current ckpt_files found is {ckpt_files} "
                f"with pattern {ckpt_pattern}, so skip the loading."
            )
            return
        ckpt_files.sort(key=os.path.getmtime, reverse=True)
        time_stamp = datetime.datetime.now()
        print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_files} loading",
              flush=True)
        # Load the latest checkpoint file.
        print(f'Start to load from {ckpt_files[0]}')
        param_dict = load_checkpoint(ckpt_files[0])
        if param_dict.get("epoch_num") and param_dict.get("step_num"):
            args_param.has_trained_epoches = int(param_dict["epoch_num"].data.asnumpy())
            args_param.has_trained_steps = int(param_dict["step_num"].data.asnumpy())
        model.build(train_dataset=dataset, sink_size=sink_size, epoch=epoch)
        load_param_into_net(network, param_dict)
    ...
  9. Modify the parameters in the src/utils.py script.
    ...
        opt.add_argument("--vocab_size",
                          type=int,
                          default=50304, # Change the value based on the training dataset. Here, the value has been changed to the value of the sample dataset.
                          help="vocabulary size, default is 40000.")
    ...
        opt.add_argument("--data_column_name",
                         type=str,
                         default="text", # Change the value based on the field defined by the dataset. Here, the value has been changed to the value of the sample dataset.
                         help="Column name of datasets")
    ...
        parser.add_argument("--strategy_load_ckpt_path",
                            type=str,
                            default="/job/data/code/fault_torlence/pangu_alpha/strategy/strategy.ckpt", # Specify the paths in the container during elastic training based on user habits. The paths will not be overwritten by training.
                            help="The training prallel strategy for the model.")
        parser.add_argument("--tokenizer_path",
                            type=str,
                            default="./tokenizer_path",
                            help="The path where stores vocab and vocab model file")
    ...
    def add_retrain_params(opt):
        """
        Add parameters about retrain.
        """
        opt.add_argument("--pre_trained",
                         type=str,
                         default="/job/data/code/fault_torlence/pangu_alpha/8p", # Path of the pre-trained model.
                         help="Pretrained checkpoint path.")
        opt.add_argument("--save_checkpoint_path",  
                         type=str,
                         default="/job/data/code/fault_torlence/pangu_alpha/8p",   # Path for saving the model.
                         help="Save checkpoint path.")
        opt.add_argument("--keep_checkpoint_max", # Model saving policy: maximum quantity.
                         type=int,
                         default=1,
                         help="Max checkpoint save number.")
        opt.add_argument("--save_checkpoint_steps", # Model saving policy: saving interval.
                         type=int,
                         default=20,
                         help="Save checkpoint step number.")
        opt.add_argument("--save_checkpoint", # Whether to save the model in the current training.
                         type=ast.literal_eval,
                         default=True,
                         help="Whether save checkpoint in local disk.")
        opt.add_argument("--ckpt_name_prefix", # Model saving policy: file name prefix.
                         type=str,
                         default="pangu",
                         help="Saving checkpoint name prefix.")
    ...
  10. Create an empty file group_info_env in the /data/atlas_dls/code/pangu_alpha directory.
    root@ubuntu:/data/atlas_dls/code/pangu_alpha/# 
    pangu_alpha/
    ├── README.md
    ├── README_CN.md
    ├── group_info_env
     ...
    ├── scripts
    ├── serving_increment
    ├── src
    ├── tasks.py
    └── train.py
  11. Change the group_info_env path in the train.py file.
    ...
        # env variable prepare
        group_info_file = os.getenv("GROUP_INFO_FILE")
        if group_info_file:
            with open(os.path.expanduser("/job/code/group_info_env"), "a") as outfile:
                outfile.write(f"export GROUP_INFO_FILE_REFLECT={group_info_file}\n")
    ...