回调实现

  1. 从框架获取必要的用于故障修复的参数及方法。

    @dataclass
    class BuildDataArgs:
        model_type_ = None
        model_provider_ = None
        train_valid_test_datasets_provider_ = None
    
    def set_build_data_args(model_type, model_provider, train_valid_test_datasets_provider):
        BuildDataArgs.model_type_ = model_type
        BuildDataArgs.model_provider_ = model_provider
        BuildDataArgs.train_valid_test_datasets_provider_ = train_valid_test_datasets_provider

    以Megatron为例,获取Megatron框架中用以建立模型优化器及数据集的必要参数,客户需根据自身框架决定获取的必要参数。

  2. MindIO TTP功能相关回调。

    def save_callback(step: int, save_info: list, train_args=None, ctx=None):
        model = train_args[TRAIN_PARAM][MODEL_INDEX]
        optimizer = train_args[TRAIN_PARAM][OPTIM_INDEX]
        opt_param_scheduler = train_args[TRAIN_PARAM][SCHEDULER_INDEX]
        global_args = get_args()
    
        # Update learning rate.
        if global_args.train_samples is None:
            global_args.consumed_train_samples = step * global_args.global_batch_size
        if train_args[TRAIN_PARAM][SCHEDULER_INDEX].num_steps != global_args.consumed_train_samples:
            train_args[TRAIN_PARAM][SCHEDULER_INDEX].step(global_args.global_batch_size)
        if hasattr(optimizer, 'optim_nums'):
            for _, info_dict in enumerate(save_info):
                optim_idx = info_dict.get("type", 0)
                rank_list = info_dict.get("ranks", None)
                save_rank = rank_list[0]
                optimizer.set_dump_args(optim_idx, save_rank, step, rank_list)
        else:
            rank_list = save_info[0].get("ranks", None)
            save_rank = rank_list[0]
            optimizer.set_dump_args(save_rank, step, rank_list)
        save_checkpoint(step, model, optimizer, opt_param_scheduler,
                        global_args.num_floating_point_operations_so_far)
    
    def rename_callback(step: int, ctx=None):
        iteration = step
        rank = torch.distributed.get_rank()
        args = get_args()
        tmp_dir = 'iter_{:07d}_tmp'.format(iteration)
        fin_dir = 'iter_{:07d}'.format(iteration)
        src_path = os.path.join(args.save, tmp_dir)
        dst_path = os.path.join(args.save, fin_dir)
        try:
            src_check_ret, err_msg, src_abs_path = FileUtils.regular_file_path(src_path, args.save, False)
            dst_check_ret, err_msg, dst_abs_path = FileUtils.regular_file_path(dst_path, args.save, False)
            if (not src_check_ret) or (not dst_check_ret):
                raise FileNotFoundError(f"rank:{rank} rename path error. {err_msg}")
            os.rename(src_abs_path, dst_abs_path)
        except FileNotFoundError as e:
            print(f"{e}")
            return False
    
        # And update the latest iteration
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        is_path_valid, err_msg, tracker_filename = FileUtils.regular_file_path(tracker_filename, "/", False)
        if not is_path_valid:
            print_rank_0(err_msg)
            raise Exception("  tracker_filename is not valid")
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))
        return True
    
    def get_checkpoint_name(checkpoints_path, iteration, release=False,
                            pipeline_parallel=None,
                            tensor_rank=None, pipeline_rank=None,
                            expert_parallel=None, expert_rank=None,
                            return_base_dir=False, tmp=False):
        if release:
            directory = 'release'
        else:
            directory = 'iter_{:07d}'.format(iteration)
        if tmp:
            directory = directory + "_tmp"
        if return_base_dir:
            common_path = os.path.join(checkpoints_path, directory)
            return common_path
    
        # Use both the tensor and pipeline MP rank.
        if pipeline_parallel is None:
            pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
        if tensor_rank is None:
            tensor_rank = mpu.get_tensor_model_parallel_rank()
        if pipeline_rank is None:
            pipeline_rank = mpu.get_pipeline_model_parallel_rank()
        if expert_parallel is None:
            expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1)
        if expert_rank is None:
            expert_rank = mpu.get_expert_model_parallel_rank()
        # Use both the tensor and pipeline MP rank. If using the distributed
        # optimizer, then the optimizer's path must additionally include the
        # data parallel rank.
        if not pipeline_parallel:
            common_path = os.path.join(checkpoints_path, directory,
                                f'mp_rank_{tensor_rank:02d}')
        else:
            common_path = os.path.join(checkpoints_path, directory,
                    f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
        if expert_parallel:
            common_path = common_path + f'_{expert_rank:03d}'
    
        return os.path.join(common_path, "model_optim_rng.pt")
    
    def save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
                        num_floating_point_operations_so_far):
        """Save a model checkpoint."""
        args = get_args()
        # Only rank zero of the data parallel writes to the disk.
        model = unwrap_model(model)
        ckpt_format = args.dist_ckpt_format if args.use_dist_ckpt else 'torch'
        print('rank {} is saving checkpoint at iteration {:7d} to {} in {} format'
                               .format(args.rank, iteration, args.save, ckpt_format))
        # Collect rng state across data parallel ranks.
        rng_state = get_rng_state(args.use_dist_ckpt)
        tmp = optimizer.get_error_dump()
        # Checkpoint name.
        checkpoint_name = get_checkpoint_name(args.save, iteration, return_base_dir=args.use_dist_ckpt, tmp=tmp)
        check_ret, err_msg, checkpoint_name = FileUtils.regular_file_path(checkpoint_name, '/', False)
        if not check_ret:
            raise Exception(f"rank {args.rank} get checkpoint name error, {err_msg}")
        # Save distributed optimizer's custom parameter state.
        if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None and not args.use_dist_ckpt:
            optim_checkpoint_name = \
                get_distributed_optimizer_checkpoint_name(checkpoint_name)
            ensure_directory_exists(optim_checkpoint_name)
            optimizer.save_parameter_state(optim_checkpoint_name)
        if not args.use_distributed_optimizer and (args.fp16 or args.bf16):
            if not hasattr(optimizer.config, 'reuse_fp32_param'):
                optimizer._copy_main_params_to_model_params()
        cur_rank = torch.distributed.get_rank()
        save_flag = optimizer.need_write_file()
        # Collect args, model, RNG.
        if not torch.distributed.is_initialized() \
                or save_flag \
                or args.use_dist_ckpt:
            optim_sd_kwargs = {}
            if args.use_dist_ckpt and args.use_distributed_optimizer:
                optim_sd_kwargs['sharding_type'] = ('fully_sharded_bucket_space'
                                                    if args.ckpt_fully_parallel_save
                                                    else 'dp_zero_gather_scatter')
                print_rank_0(f'Storing distributed optimizer sharded state of type {optim_sd_kwargs["sharding_type"]}')
            state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state,
                                             args.use_dist_ckpt, iteration, optim_sd_kwargs=optim_sd_kwargs)
            state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
            # Save.
            ensure_directory_exists(checkpoint_name)
            torch.save(state_dict, checkpoint_name)

    MindIO TTP功能沿用了Megatron框架中保存CheckPoint的主要逻辑,但是由于需要参与CheckPoint保存的worker不再是有序的,因此需要对CheckPoint文件的命名以及是否进行CheckPoint保存的判断逻辑进行修改,在临终保存之后文件名可能与Megatron框架中保存CheckPoint的文件名不同,由rename_callback实现文件改名以和框架中的命名规则一致。

  3. MindIO UCEMindIO ARF相关回调。

    def stop_callback(train_args=None, ctx=None):
        # stop and clean device
        device = get_current_device()
        torch.npu.set_device(device)
        torch_npu.npu.stop_device(device)
    
    def clean_callback(is_uce_error=False, train_args=None, ctx=None):
        """
        this function do:
        1) get UCE check result from torch_npu
        2) do some clear before rebuild (avoid OOM) when the check result is UCE_HIGH_LEVEL
        3) HCCL resume and restart device
        """
        device = get_device()
        rank = torch.distributed.get_rank()
        torch.npu.set_device(device)
        ret = RET_OK
        if is_uce_error:
            check_memory_result = torch_npu.npu.check_uce_in_memory(device)
            if check_memory_result == UCE_LOW_LEVEL:  # no need rebuild
                ret = RET_NO_REBUILD
            elif check_memory_result == UCE_HIGH_LEVEL:  # need rebuild
                clear_memory(train_args[TRAIN_PARAM][MODEL_INDEX],
                                 train_args[TRAIN_PARAM][OPTIM_INDEX],
                                 train_args[TRAIN_PARAM][CONFIG_INDEX])
                train_args[TRAIN_PARAM][MODEL_INDEX] = None
                train_args[TRAIN_PARAM][OPTIM_INDEX] = None
                train_args[TRAIN_PARAM][SCHEDULER_INDEX] = None
                train_args[TRAIN_PARAM][CONFIG_INDEX] = None
                ret = RET_OK
            else:  # exit
                ret = RET_ERROR
        if hasattr(torch.distributed, 'reinit_process_group'):
            torch.distributed.reinit_process_group(group=None, rebuild_link=False)
        torch.npu.restart_device(device)
        return ret
    
    def repair_callback(step: int, need_rebuild: bool, error_ranks: list, repair_info: dict,
                        train_args=None, ctx=None):
        from mindio_ttp.framework_ttp import RepairType, OptimizerType
        torch.npu.set_device(get_current_device())
        optim_idxs = repair_info.get("type", OptimizerType.ATTENTION.value)
        repair_type = repair_info.get("repair_type", None)
        src_ranks = repair_info.get("src", [])
        dest_ranks = repair_info.get("dst", [])
        rank_list = repair_info.get("rank_list", [])
        build_repair_group(rank_list)
        if repair_type == RepairType.RT_SEND.value:  # send
            return send_rank_repair(src_ranks, dest_ranks, optim_idxs, step, rank_list, train_args[TRAIN_PARAM])
        elif repair_type in [RepairType.RT_UCE_HIGHLEVEL.value,
                             RepairType.RT_UCE_LOWLEVEL.value,
                             RepairType.RT_RECV_REPAIR.value]:
            return recv_rank_repair(src_ranks, dest_ranks, optim_idxs, step,
                                    need_rebuild, rank_list, train_args[TRAIN_PARAM])
        else:
            return RET_ERROR
    
    def rollback_callback(step: int, train_args=None, ctx=None):
        args = get_args()
        torch.npu.set_device(get_current_device())
        global temp_memory_ckpt
        rank = torch.distributed.get_rank()
        # Update consumed train samples and learning rate.
        if args.train_samples is None:
            args.consumed_train_samples = step * args.global_batch_size
        if train_args[TRAIN_PARAM][SCHEDULER_INDEX].num_steps != args.consumed_train_samples:
            train_args[TRAIN_PARAM][SCHEDULER_INDEX].step(args.global_batch_size)
        feature_rollback()
        gather_model_params_from_optimizer(train_args[TRAIN_PARAM][OPTIM_INDEX], step)
        rebuild_dataset(train_args)
        set_memory_ckpt(None)
        temp_memory_ckpt = None
        destroy_repair_group()
        return RET_OK
    
    def feature_rollback():
        # Other features re-initialize or roll back data after the fault is rectified.
        if hasattr(args, "num_experts") and args.num_experts:
            mpu._MOE_AUX_LOSSES_LOGGING_TRACKER = {}  # 重新初始化全局Tensor
    
    def send_rank_repair(src_ranks: list, dest_ranks: list, optim_idxs: list,
                         step: int, rank_list: list, train_args):
        rank_list.sort()
        rank_ = torch.distributed.get_rank()
        repair_group = get_repair_group()
        for idx, src_rank in enumerate(src_ranks):
            dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
            if src_rank != rank_:
                return RET_ERROR
            if src_rank == dest_rank:
                return RET_ERROR
            save_and_send_ckpt(src_rank, dest_rank, step, optim_idx, train_args)
        for idx, _ in enumerate(src_ranks):
            dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
            train_args[OPTIM_INDEX].send_optim_param_state(dest_rank, repair_group, step, optim_idx)
    
        # repair other global tensor
        for idx, src_rank in enumerate(src_ranks):    
            torch.distributed.send(global_tensor, dst=dest_rank, group=repair_group)
        return RET_OK
    
    def save_and_send_ckpt(src_rank, dest_rank, step, optim_idx, train_args):
        """
        Save memory checkpoint and send to dest rank.
        """
        repair_group = get_repair_group()
        state_dict = save_memory_ckpt(train_args[OPTIM_INDEX], train_args[SCHEDULER_INDEX],
                                          step, optim_idx)
        buffer = io.BytesIO()
        torch.save(state_dict, buffer)
        state_dict_bytes = buffer.getvalue()
        state_dict_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(state_dict_bytes)).to('npu')
        # Send tensor size first
        size_tensor = torch.tensor([state_dict_tensor.numel()], dtype=torch.long).to('npu')
        torch.distributed.send(size_tensor, dst=dest_rank, group=repair_group)
        # Send the serialized state_dict tensor
        torch.distributed.send(state_dict_tensor, dst=dest_rank, group=repair_group)
        return RET_OK
    
    def recv_rank_repair(src_ranks: list, dest_ranks: list, optim_idxs: list,
                         step: int, need_rebuild: bool, rank_list: list, train_args):
        # low level and only rollback case
        if src_ranks == dest_ranks:
            return RET_OK
        rank_ = torch.distributed.get_rank()
        rank_list.sort()
        if need_rebuild:
            build_local_embedding_group()
            model, optimizer, lr_scheduler, config = rebuild_model_and_optimizer(
                BuildDataArgs.model_provider_, BuildDataArgs.model_type_)
            train_args[MODEL_INDEX] = model
            train_args[OPTIM_INDEX] = optimizer
            train_args[SCHEDULER_INDEX] = lr_scheduler
            train_args[CONFIG_INDEX] = config
    
        for idx, src_rank in enumerate(src_ranks):
            dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
            if dest_rank != rank_:
                return RET_ERROR
            if src_rank != dest_rank:
                recv_ckpt_from_peer(src_rank, dest_rank, step, rank_list)
        # combine state_dict and once load,fix precision problem
        state_dict = get_memory_ckpt()
        load_memory_ckpt(train_args[MODEL_INDEX], train_args[OPTIM_INDEX], train_args[SCHEDULER_INDEX],
                             state_dict, None)
        repair_group = get_repair_group()
        for idx, src_rank in enumerate(src_ranks):
            dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx]
            train_args[OPTIM_INDEX].recv_and_load_optim_param_state(src_rank, repair_group, step, optim_idx)
    
        # repair other global tensor
        for idx, src_rank in enumerate(src_ranks):
            recv_tensor = torch.empty(size, dtype=type, device="npu")
            torch.distributed.recv(recv_tensor, src=src_rank, group=repair_group)
            global_tesor.data.copy_(recv_tensor)
        return RET_OK
    
    def recv_ckpt_from_peer(src_rank, dest_rank, step, rank_list: list):
        """
        receive memory checkpoint and repair train() param.
        """
        repair_group = get_repair_group()
        # Receive tensor size first
        size_tensor = torch.tensor([0], dtype=torch.long, device='npu')
        torch.distributed.recv(size_tensor, src=src_rank, group=repair_group)
        size = size_tensor.item()
        # Receive the serialized state_dict tensor
        state_dict_tensor = torch.empty(size, dtype=torch.uint8, device='npu')
        torch.distributed.recv(state_dict_tensor, src=src_rank, group=repair_group)
        # Deserialize the state_dict
        state_dict_bytes = state_dict_tensor.cpu().numpy().tobytes()
        buffer = io.BytesIO(state_dict_bytes)
        device_count = torch.npu.device_count()
        if device_count == 0:
            raise ValueError(f"torch.npu.device_count return 0!")
        map_location = {
            'npu:' + str(src_rank % device_count): 'npu:' + str(dest_rank % device_count)
        }
        loaded_state_dict = torch.load(buffer, map_location=map_location)
        set_memory_ckpt(loaded_state_dict)
    
    def rebuild_model_and_optimizer(model_provider, model_type):
        args = get_args()
        load = args.load
        args.load = None
    
        ori_embedding_group = mpu._EMBEDDING_GROUP
        mpu._EMBEDDING_GROUP = get_local_embedding_group()
        model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
        mpu._EMBEDDING_GROUP = ori_embedding_group
        args.load = load
        config = get_model_config(model[0])
        return model, optimizer, lr_scheduler, config
    
    def rebuild_dataset(args):
        # repair data iterator
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_data_iterator(args[TRAIN_PARAM][MODEL_INDEX])
        args[TRAIN_PARAM][TRAIN_DATA_INDEX] = train_data_iterator
        args[TRAIN_PARAM][VALID_DATA_INDEX] = valid_data_iterator
        args[TEST_DATA_ITER][0] = test_data_iterator
    
        return RET_OK
    
    def build_data_iterator(model):
        args = get_args()
        if args.virtual_pipeline_model_parallel_size is not None:
            train_data_iterator = []
            valid_data_iterator = []
            test_data_iterator = []
            for i in range(len(model)):
                mpu.set_virtual_pipeline_model_parallel_rank(i)
                iterators = build_train_valid_test_data_iterators(
                    BuildDataArgs.train_valid_test_datasets_provider_)
                train_data_iterator.append(iterators[0])
                valid_data_iterator.append(iterators[1])
                test_data_iterator.append(iterators[2])
        else:
            train_data_iterator, valid_data_iterator, test_data_iterator \
                = build_train_valid_test_data_iterators(
                    BuildDataArgs.train_valid_test_datasets_provider_)
        return train_data_iterator, valid_data_iterator, test_data_iterator
    
    def update_arf_reboot_flag(new_state): # 对于MindIO ARF中增量启动的节点,选择修复过程中接收的优化器数据而不是加载故障前已保存的CheckPoint数据
        global ARF_REBOOT_FLAG, LOAD, PRETRAINED_CHECKPOINT
        args = get_args()
        if new_state:
            LOAD = args.load
            PRETRAINED_CHECKPOINT = args.pretrained_checkpoint
            args.load = None
            args.pretrained_checkpoint = None
        elif ARF_REBOOT_FLAG:
            args.load = LOAD
            args.pretrained_checkpoint = PRETRAINED_CHECKPOINT
        ARF_REBOOT_FLAG = new_state
    
    def arf_rebuild_process_group_callback(fault_ranks: list, train_args, ctx):
        models, optimizer = train_args[TRAIN_PARAM][MODEL_INDEX], train_args[TRAIN_PARAM][OPTIM_INDEX]
        args = get_args()
        update_arf_reboot_flag(False)
        # 1.1 destroy_all_process_group
        torch.distributed.destroy_process_group()
        # 1.2 init_all_process_group
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            world_size=args.world_size,
            rank=args.rank,
            timeout=timedelta(minutes=args.distributed_timeout_minutes),
        )
        #2.1 destroy mpu group
        delete_mpu_group(mpu)
        # 2.2 destroy global memory buffer
        mpu.destroy_global_memory_buffer()
        # 3.1 init_mpu
        mpu.initialize_model_parallel(
            context_parallel_size=args.context_parallel_size,
            tensor_model_parallel_size=args.tensor_model_parallel_size,
            expert_model_parallel_size=args.expert_model_parallel_size,
            distributed_timeout_minutes=args.distributed_timeout_minutes,
            pipeline_model_parallel_size=args.pipeline_model_parallel_size,
            nccl_communicator_config_path=args.nccl_communicator_config_path,
            pipeline_model_parallel_split_rank=args.pipeline_model_parallel_split_rank,
            virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
        )
        update_model_and_optim_related_group(models, optimizer)
    
    def delete_mpu_group(mpu_module):
        group_space = [mpu_group for mpu_group in dir(mpu_module) if 'GROUP' in mpu_group]
        for group_name in group_space:
            setattr(mpu_module, group_name, None)
    
    def update_model_and_optim_related_group(models, optimizer):
        for model in models:
            # dense
            for buffer in model.buffers:
                # fix ParamAndGradBuffer attributes
                buffer.data_parallel_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
                buffer.data_parallel_world_size = torch.distributed.get_world_size(group=mpu._DATA_PARALLEL_GROUP_WITH_CP)
                for bucket in buffer.buckets:
                    # fix Bucket attributes
                    bucket.data_parallel_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
                    bucket.data_parallel_world_size = torch.distributed.get_world_size(
                        group=mpu._DATA_PARALLEL_GROUP_WITH_CP)
                    bucket.data_parallel_rank = torch.distributed.get_rank(mpu._DATA_PARALLEL_GROUP_WITH_CP)
            # moe
            for expert_buffer in model.expert_parallel_buffers:
                # fix ParamAndGradBuffer attributes
                expert_buffer.data_parallel_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP
                expert_buffer.data_parallel_world_size = \
                    torch.distributed.get_world_size(group=mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
                for bucket in expert_buffer.buckets:
                    # fix Bucket attributes
                    bucket.data_parallel_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP
                    bucket.data_parallel_world_size = torch.distributed.get_world_size(
                        group=mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
                    bucket.data_parallel_rank = torch.distributed.get_rank(mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
    
        if hasattr(optimizer, 'optim_nums'):  # 根据需求更新优化器中持有的通信组
            optimizer.chained_optimizers[0].ori_dp_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
            optimizer.chained_optimizers[1].ori_dp_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP
        else:
            optimizer.ori_dp_group = mpu._DATA_PARALLEL_GROUP_WITH_CP
    
        if not get_args().use_distributed_optimizer:  # 非分布式优化器不持有副本通信组
            return
    
        # fix optimizer attributes
        if hasattr(optimizer, 'optim_nums'):  # 分布式优化器会持有副本组用于副本优化器相关的通信
            optimizer.chained_optimizers[0].data_parallel_group = ttp_get_dp_cp_replica_group()
            optimizer.chained_optimizers[0].data_parallel_group_gloo = ttp_get_dp_cp_replica_group_gloo()
            optimizer.chained_optimizers[0].ori_dp_list = torch.distributed.get_process_group_ranks(
                mpu._DATA_PARALLEL_GROUP_WITH_CP)
            optimizer.chained_optimizers[1].data_parallel_group = ttp_get_dp_ep_replica_group()
            optimizer.chained_optimizers[1].data_parallel_group_gloo = ttp_get_dp_ep_replica_group_gloo()
            optimizer.chained_optimizers[1].ori_dp_list = torch.distributed.get_process_group_ranks(
                mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP)
        else:
            optimizer.data_parallel_group = ttp_get_dp_cp_replica_group()
            optimizer.data_parallel_group_gloo = ttp_get_dp_cp_replica_group_gloo()
            optimizer.ori_dp_list = torch.distributed.get_process_group_ranks(mpu._DATA_PARALLEL_GROUP_WITH_CP)
        return
    
    #(可选)若训练中涉及切流操作,重置训练中使用的流,确保调度正确
    class CustomFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            torch.cuda.set_stream(torch.cuda.default_stream())
            return input
        @staticmethod
        def backward(ctx, grad):
            torch.cuda.set_stream(torch.cuda.default_stream())
            return grad
    
    def recover_set_stream():
        input_tensor = torch.empty(1, dtype=torch.float32, device="npu", requires_grad=True)
        grad_tensor = torch.empty(1, dtype=torch.float32, device="npu", requires_grad=True)
        output_tensor = CustomFunction.apply(input_tensor)
        output_tensor.backward(grad_tensor)

    stop&clean callback为MindIO UCEMindIO ARF必要的回调函数,stop_callback用于通过torch_npu.npu.stop_device接口令其他worker暂停训练,待修复完成后所有worker共同重新进入训练;clean_callback中,如果当前故障为UCE故障,则需要通过torch_npu.npu.check_uce_in_memory接口检测UCE故障类型,若为UCE_HIGH_LEVEL类型,则本次修复需要重建模型优化器以重新申请内存,刷新Tensor状态以完全修复UCE故障;之后通过torch.npu.restart_device接口重新启动device。

    repair_callback用于实现互为副本的worker进行数据修复,故障worker重建模型优化器并接收并加载正常副本worker的CheckPoint数据,正常副本worker保存并发送CheckPoint数据。

    在repair完成后,所有worker共同进入rollback_callback中,在rollback阶段,将优化器中的参数取出到模型,还需要将其他训练中用到的参数回滚至正确的位置,如学习率、数据集。

    MindIO ARF功能依赖MindIO UCE修复中用到的数据传输功能,arf_rebuild_process_group_callback用于MindIO ARF功能在增量重启worker后重建通信组,并替换模型优化器中持有的通信组,确保在故障修复后的训练中使用重建的通信组。

    此外,对于其他需要修复的全局Tensor,若其与训练迭代无依赖关系,则可在rollback阶段对其重新初始化,将其释放;若其与训练迭代有依赖关系,且与副本优化器映射关系一致时,则需要在repair阶段对其进行重建并通过点对点通信修复该Tensor的数据。