昇腾社区首页
中文
注册

基础修改

  1. 编辑“torchrun”文件。
    1. 查找环境中的“torchrun”文件。
      which torchrun
    2. 打开以上命令显示路径下的“torchrun”文件。
      vim {torchrun文件路径}/torchrun
    3. 按“i”进入编辑模式,修改以下内容。
      # 增加加粗内容
      import re
      import sys
      import mindio_ttp.framework_ttp
      from torch.distributed.run import main as torch_main
    4. 按“Esc”键,输入:wq!,按“Enter”保存并退出编辑。
  2. 编辑“training.py”文件。
    1. 打开Megatron框架中的“training.py”文件。
      vim megatron/training/training.py
    2. 按“i”进入编辑模式,修改pretrain方法的结构。
      • 将train方法使用的入参封装为从另一方法获取的list变量,在train方法调用前后都有使用的变量也建议加入MindIO TFT的管理,以确保对齐所有对象的生命周期,避免在故障修复时发生内存泄露。
        def build_train_args(*input_args):
            args, timers, train_valid_test_dataset_provider, model_provider, model_type, forward_step_func, process_non_loss_data_func = input_args
            from megatron.training.training import setup_model_and_optimizer
            # Model, optimizer, and learning rate.
            timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
            model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
                model_provider, model_type)
            timers('model-and-optimizer-setup').stop()
            print_datetime('after model, optimizer, and learning rate, scheduler are built')
            config = get_model_config(model[0])
            # Data stuff.
            timers('train/valid/test-data-iterators-setup', log_level=0).start(
                barrier=True)
            if args.virtual_pipeline_model_parallel_size is not None:
                train_data_iterator = []
                valid_data_iterator = []
                test_data_iterator = []
                for i in range(len(model)):
                    mpu.set_virtual_pipeline_model_parallel_rank(i)
                    iterators = build_train_valid_test_data_iterators(
                        train_valid_test_dataset_provider)
                    train_data_iterator.append(iterators[0])
                    valid_data_iterator.append(iterators[1])
                    test_data_iterator.append(iterators[2])
            else:
                train_data_iterator, valid_data_iterator, test_data_iterator \
                    = build_train_valid_test_data_iterators(
                    train_valid_test_dataset_provider)
            timers('train/valid/test-data-iterators-setup').stop()
            print_datetime('after dataloaders are built')
            # Print setup timing.
            print_rank_0('done with setup ...')
            timers.log(['model-and-optimizer-setup',
                        'train/valid/test-data-iterators-setup'], barrier=True)
            train_args = [forward_step_func,
                          model, optimizer, opt_param_scheduler,
                          train_data_iterator, valid_data_iterator, process_non_loss_data_func, config]
            test_data_iterator_list = [test_data_iterator]
            return train_args, test_data_iterator_list
      • 在pretrain中修改train函数的使用方法。
        def pretrain(*args, **kwargs):
            ##############YOUR CODE##############
            train_args, test_data_iterator_list = build_train_args(args, timers, train_valid_test_dataset_provider,
                                                                   model_provider,
                                                                   model_type, forward_step_func, process_non_loss_data_func)
            iteration, num_floating_point_operations_so_far = train(train_args, test_data_iterator_list)
            test_data_iterator = test_data_iterator_list[0]
            forward_step_func, model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, process_non_loss_data_func, config = train_args
            ##############YOUR CODE##############

        请确保train函数传入的参数为列表包装的有序参数列表,以便参数对象可以动态改变,达到故障修复的目的,在训练结束后将参数从列表中取出,以备后续流程使用。

    3. 按“Esc”键,输入:wq!,按“Enter”保存并退出编辑。
  3. 与框架交互的其他修改。
    • 使能MindIO ARF功能时,对框架的修改。
      def tft_reboot_skip_wrapper(fn):
          @wraps(fn)
          def wrapper(*args, **kwargs):
              if is_arf_reboot_node():
                  return None
              res = fn(*args, **kwargs)
              return res
          return wrapper
      
      def tft_reboot_new_group_wrapper(fn):
          @wraps(fn)
          def wrapper(*args, **kwargs):
              if is_arf_reboot_node() and "backend" in kwargs and kwargs["backend"] == "gloo":
                  return None
              res = fn(*args, **kwargs)
              return res
          return wrapper
      
      def tft_reboot_build_train_valid_test_data_iterators_wrapper(fn):
          @wraps(fn)
          def wrapper(*args, **kwargs):
              res = fn(*args, **kwargs)
              if is_arf_reboot_node():
                  get_args().do_train = True
              return res
          return wrapper

      MindIO ARF功能进行进程级恢复时,将进行增量worker重启,此时训练进程初始化阶段存在跨节点通信,因此会遇到通信不匹配的情况,建议对通信跳过处理,如tft_reboot_skip_wrapper所示,对初始化中遇到的通信操作进行patch或在代码中做侵入式修改。

      此外在gloo组建组时也会进行一次通信,建议跳过gloo组创建,如tft_reboot_new_group_wrapper。

      在Megatron框架中,有flag变量do_train控制当前worker是否参与训练,也需要对该flag变量进行修改,使训练进程进入train函数并上报MindIO ARF初始化完成的信号。

    • CheckPoint中需要保存的args在每个迭代中更新。
      def tft_training_log_wrapper(training_log):
          @wraps(training_log)
          def wrapper(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
                       loss_scale, report_memory_flag, skipped_iter,
                       grad_norm, params_norm, num_zeros_in_grad):
              report_memory_flag = training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
                       loss_scale, report_memory_flag, skipped_iter,
                       grad_norm, params_norm, num_zeros_in_grad)
              arguments = get_args()
              batch_size = mpu.get_data_parallel_world_size() * \
                           arguments.micro_batch_size * \
                           get_num_microbatches()
              num_floating_point_operations_so_far = num_floating_point_operations(arguments, batch_size)
              arguments.num_floating_point_operations_so_far += num_floating_point_operations_so_far
              arguments.iteration = iteration
              return report_memory_flag
          return wrapper

      在Megatron框架中,iteration和num_floating_point_operations_so_far也会被保存到CheckPoint中,为保证临终遗言所保存的CheckPoint文件完整,需要在每个迭代完成后更新相应变量。