接口调用

  1. 在初始化Torch全局组之前,初始化并启动MindIO TFT的Controller和Processor模块。

    在Megatron框架中初始化分布式方法中侵入式修改或以装饰器形式修改,调用MindIO TFT初始化的接口;

    MindIO TFT初始化需要在torch.distributed.init_process_group之前,torch.cuda.set_device(device)之后进行;

    如果开启TLS认证,需要修改tls_option_top_path,填写证书配置根目录。

    # 侵入式修改megatron.training.initialize._initialize_distributed
    def _initialize_distributed():
        """Initialize torch.distributed and core model parallel."""
        args = get_args()
        device_count = torch.cuda.device_count()
        if torch.distributed.is_initialized():
            if args.rank == 0:
                print(
                    "torch distributed is already initialized, "
                    "skipping initialization ...",
                    flush=True,
                )
            args.rank = torch.distributed.get_rank()
            args.world_size = torch.distributed.get_world_size()
        else:
            if args.rank == 0:
                print("> initializing torch distributed ...", flush=True)
            # Manually set the device ids.
            if device_count > 0:
                device = args.rank % device_count
                if args.local_rank is not None:
                    assert (
                        args.local_rank == device
                    ), "expected local-rank to be the same as rank % device-count."
                else:
                    args.local_rank = device
                torch.cuda.set_device(device)
    
            #  adapt MindIO TFT here
            from mindio_ttp.framework_ttp import (tft_register_logger_handler, tft_init_controller, tft_start_controller,
                                          tft_init_processor, tft_start_processor, tft_is_reboot_node, adapting_logger)
            tls_option_top_path='{证书配置根目录}'
            args = get_args()
            default_ip = '127.0.0.1'
            ttp_ip = os.getenv('TTP_ADDR', default_ip)
            controller_ip = os.getenv('CONTROLLER_ADDR', default_ip)
            if controller_ip == default_ip:
                controller_ip = ttp_ip
            processor_ip = os.getenv('PROCESSOR_ADDR', default_ip)
            if processor_ip == default_ip:
                processor_ip = ttp_ip
            port = 8000
     
            cur_rank = args.rank
            world_size = args.world_size
            tft_register_logger_handler(adapting_logger) # 框架有自己的logger时,跳过此步
            enable_worker_reboot = args.enable_worker_reboot if hasattr(args, 'enable_worker_reboot') else False
            enable_ucefault_repair = args.enable_ucefault_repair if hasattr(args, 'enable_ucefault_repair') else False
            enable_mindx = os.getenv('MINDX_TASK_ID', None)  # 对接MindCluster时根据环境变量判断controller启动时机
            if cur_rank == 0 and enable_mindx is None:  # 不使用MindCluster时,在rank0上启动MindIO TFT Controller服务
                tft_init_controller(cur_rank, world_size, False,
                                    enable_arf=enable_worker_reboot)
                tft_start_controller(controller_ip, port, enable_tls=enable_tls, tls_option_top_path=tls_option_top_path)
            tft_init_processor(cur_rank, world_size, False, enable_tls=enable_tls, 
                               tls_option_top_path=tls_option_top_path, enable_uce=enable_ucefault_repair, 
                               enable_arf=enable_worker_reboot)
            tft_start_processor(processor_ip, port)  # 在所有worker上启动MindIO TFT Processor服务
            if tft_is_reboot_node():  # 使能MindIO ARF功能时,判断此节点是否为增量重启的节点
                update_arf_reboot_flag(True)  # 记录当前节点为MindIO ARF增量重启节点
    
            # Call the init process
            torch.distributed.init_process_group(
                backend=args.distributed_backend,
                world_size=args.world_size,
                rank=args.rank,
                timeout=timedelta(minutes=args.distributed_timeout_minutes),
            )

  2. 在适当位置调用tft_register接口,注册修复时需要的回调函数,并获取修复时重建模型优化器及数据集所必要的方法及参数。

    在之前封装的build_train_args函数中,其入参包含了修复时重建模型优化器及数据集所必要的方法及参数,建议对其使用patch方法,集中管理注册回调、设置副本关系、获取必要参数。

    def tft_build_train_args_wrapper(build_train_args):
        @wraps(build_train_args)
        def wrapper(*input_args):
            args, timers, train_valid_test_dataset_provider, model_provider, model_type, forward_step_func, process_non_loss_data_func = input_args
            from mindio_ttp.framework_ttp import (tft_register_rename_handler, tft_register_save_ckpt_handler, tft_set_optimizer_replica,
                                                  tft_register_stop_handler, tft_register_clean_handler, tft_register_repair_handler,
                                                  tft_register_rollback_handler, OptimizerType, tft_register_rebuild_group_handler, 
                                                        tft_register_set_stream_handler)
    
            args = get_args()
            replica_info = []
            replica_offset = REPLICA_OFFSET
            moe_flag = args.expert_model_parallel_size > 1
            cur_rank = args.rank
            dp_cp_ranks = get_dp_cp_ranks()
            dense_replica_cnt = get_replica_dp_num() if args.use_distributed_optimizer else len(dp_cp_ranks)
            replica_dict = {
                "rank_list": dp_cp_ranks,
                "replica_cnt": dense_replica_cnt,
                "replica_shift": replica_offset
            }
            replica_info.append(replica_dict)
            if moe_flag:
                dp_ep_ranks = get_dp_ep_ranks()
                moe_replica_cnt = get_replica_dp_num() if args.use_distributed_optimizer else len(dp_ep_ranks)
                replica_dict = {
                    "rank_list": dp_ep_ranks,
                    "replica_cnt": moe_replica_cnt,
                    "replica_shift": replica_offset
                }
                replica_info.append(replica_dict)
            set_build_data_args(model_type, model_provider, train_valid_test_dataset_provider)
            tft_set_optimizer_replica(cur_rank, replica_info)
            tft_register_save_ckpt_handler(save_callback)
            tft_register_rename_handler(rename_callback)
            tft_register_stop_handler(stop_callback)
            tft_register_clean_handler(clean_callback)
            tft_register_repair_handler(repair_callback)
            tft_register_rollback_handler(rollback_callback)
            tft_register_rebuild_group_handler(arf_rebuild_process_group_callback)
            tft_register_set_stream_handle(recover_set_stream)
            return build_train_args(*input_args)
        return wrapper

    示例代码中的回调函数在回调实现进行讲解。

  3. 对train函数使用tft_exception_handler。

    使用装饰器的形式调用tft_exception_handler,处理训练中的报错。对于MindIO TFT已经适配的训练框架,当train函数中的iteration循环结束时训练结束,装饰器销毁Controller和Processor线程;装饰器无法感知框架的区别,当被装饰的train函数中有epoch循环时,需要进行适配。

    # 在框架中实现对train函数的装饰器结构,集中实现相关功能的必要修改
    from funtools import wraps
    def tft_train_wrapper(train):
        @wraps(train)
        def wrapper(train_args, test_data_iterator_list):
            from mindio_ttp.framework_ttp import tft_exception_handler
            @tft_exception_handler
            def tft_train(train_args, test_data_iterator_list):
                if is_arf_reboot_node():  # 读取ARF FLAG,判断次节点是否为增量重启的节点,由用户自行实现
                    raise RuntimeError("ARF FINISH")  # 使能MindIO ARF功能时,增量重启的节点在此处向MindIO TFT上报重启完成
                train_args[OPTIM_INDEX].set_current_step(get_args().iteration)  # 设置副本优化器的属性,优化器中记录当前迭代步数
                return train(*train_args)
            return tft_train(train_args, test_data_iterator_list)
        return wrapper
    
    # 在pretrain中对train函数的调用做如下修改,其中train_args, test_data_iterator_list为build_train_args的返回值。
    def pretrain(*args, **kwargs):
        iteration, num_floating_point_operations_so_far = tft_train_wrapper(train)(train_args, test_data_iterator_list)
    
    # 或通过patch方法在进入训练之前对train函数进行处理
    def patch_manager():
        train = tft_train_wrapper(train)

  4. 使用tft_start_updating_os和tft_end_updating_os对优化器更新进行保护。

    为确保优化器正确更新,需要对该过程进行保护,可通过侵入式修改或重载step方法调用tft_start_updating_os和tft_end_updating_os:

    def step(self):
     
        timers = self.config.timers
     
        # Copy gradients from model params to main params.
        if timers is not None:
            timers('optimizer-copy-to-main-grad', log_level=1).start(
                barrier=self.config.barrier_with_L1_time
            )
        self._copy_model_grads_to_main_grads()
        if timers is not None:
            timers('optimizer-copy-to-main-grad').stop()
     
        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:
     
            # Unscale and check for inf/nan.
            if timers is not None:
                timers('optimizer-unscale-and-check-inf', log_level=1).start(
                    barrier=self.config.barrier_with_L1_time
                )
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            if timers is not None:
                timers('optimizer-unscale-and-check-inf').stop()
     
            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)
     
            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None
     
        # Clip the main gradients.
        if timers is not None:
            timers('optimizer-clip-main-grad', log_level=1).start(
                barrier=self.config.barrier_with_L1_time
            )
        grad_norm = None
        if self.config.clip_grad > 0.0:
            grad_norm = self.clip_grad_norm(self.config.clip_grad)
        if timers is not None:
            timers('optimizer-clip-main-grad').stop()
     
        # Count the zeros in the grads.
        if timers is not None:
            timers('optimizer-count-zeros', log_level=1).start(
                barrier=self.config.barrier_with_L1_time
            )
        num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None
        if timers is not None:
            timers('optimizer-count-zeros').stop()
     
        # Step the optimizer.
        if timers is not None:
            timers('optimizer-inner-step', log_level=1).start(
                barrier=self.config.barrier_with_L1_time
            )
    
        from mindio_ttp.framework_ttp import tft_start_updating_os, tft_end_updating_os
        torch.distributed.barrier()
        if get_args().expert_model_parallel_size == 1
            tft_start_updating_os(-1)
        self.current_step = args.iteration
        self.optimizer.step()
        torch.cuda.synchronize()
        self.current_step += 1
        if get_args().expert_model_parallel_size == 1
            tft_end_updating_os(self.current_step)
    
        if timers is not None:
            timers('optimizer-inner-step').stop()
     
        # Update params from main params.
        if timers is not None:
            timers('optimizer-copy-main-to-model-params', log_level=1).start(
                barrier=self.config.barrier_with_L1_time
            )
        self._copy_main_params_to_model_params()
        if timers is not None:
            timers('optimizer-copy-main-to-model-params').stop()
     
        # Successful update.
        return True, grad_norm, num_zeros_in_grad

    其中,self.current_step为新增的优化器对象的属性,调用tft_start_updating_ostft_end_updating_os对优化器数据进行标脏,当优化器更新时向MindIO TFT上报start_updating状态,当优化器更新结束时向MindIO TFT上报end_updating状态,并上报当前优化器对应的迭代步数。

    在MoE场景下,Megatron使用链式优化器(ChainedOptimizer),其中包含了分布式优化器或混合精度优化器,此时需在链式优化器中加入tft_start_updating_ostft_end_updating_os对优化器数据进行标脏,并且禁用分布式优化器或混合精度优化器中的状态上报,正确调用MindIO TFT接口确保在每个训练迭代中的优化器更新阶段仅上报一次优化器开始更新及结束更新,以便MindIO TFT正确感知优化器状态。

其他接口均已由MindIO TFT在模块内部使用,如有特殊需求,需自行根据说明使用。