回调实现
- 从框架获取必要的用于故障修复的参数及方法。
@dataclass class BuildDataArgs: model_type_ = None model_provider_ = None train_valid_test_datasets_provider_ = None def set_build_data_args(model_type, model_provider, train_valid_test_datasets_provider): BuildDataArgs.model_type_ = model_type BuildDataArgs.model_provider_ = model_provider BuildDataArgs.train_valid_test_datasets_provider_ = train_valid_test_datasets_provider
以Megatron为例,获取Megatron框架中用以建立模型优化器及数据集的必要参数,客户需根据自身框架决定获取的必要参数。
- MindIO TTP功能相关回调。
def save_callback(step: int, save_info: list, train_args=None, ctx=None): model = train_args[TRAIN_PARAM][MODEL_INDEX] optimizer = train_args[TRAIN_PARAM][OPTIM_INDEX] opt_param_scheduler = train_args[TRAIN_PARAM][SCHEDULER_INDEX] global_args = get_args() # Update learning rate. if global_args.train_samples is None: global_args.consumed_train_samples = step * global_args.global_batch_size if train_args[TRAIN_PARAM][SCHEDULER_INDEX].num_steps != global_args.consumed_train_samples: train_args[TRAIN_PARAM][SCHEDULER_INDEX].step(global_args.global_batch_size) if hasattr(optimizer, 'optim_nums'): for _, info_dict in enumerate(save_info): optim_idx = info_dict.get("type", 0) rank_list = info_dict.get("ranks", None) save_rank = rank_list[0] optimizer.set_dump_args(optim_idx, save_rank, step, rank_list) else: rank_list = save_info[0].get("ranks", None) save_rank = rank_list[0] optimizer.set_dump_args(save_rank, step, rank_list) save_checkpoint(step, model, optimizer, opt_param_scheduler, global_args.num_floating_point_operations_so_far) def rename_callback(step: int, ctx=None): iteration = step rank = torch.distributed.get_rank() args = get_args() tmp_dir = 'iter_{:07d}_tmp'.format(iteration) fin_dir = 'iter_{:07d}'.format(iteration) src_path = os.path.join(args.save, tmp_dir) dst_path = os.path.join(args.save, fin_dir) try: src_check_ret, err_msg, src_abs_path = FileUtils.regular_file_path(src_path, args.save, False) dst_check_ret, err_msg, dst_abs_path = FileUtils.regular_file_path(dst_path, args.save, False) if (not src_check_ret) or (not dst_check_ret): raise FileNotFoundError(f"rank:{rank} rename path error. {err_msg}") os.rename(src_abs_path, dst_abs_path) except FileNotFoundError as e: print(f"{e}") return False # And update the latest iteration tracker_filename = get_checkpoint_tracker_filename(args.save) is_path_valid, err_msg, tracker_filename = FileUtils.regular_file_path(tracker_filename, "/", False) if not is_path_valid: print_rank_0(err_msg) raise Exception(" tracker_filename is not valid") with open(tracker_filename, 'w') as f: f.write(str(iteration)) return True def get_checkpoint_name(checkpoints_path, iteration, release=False, pipeline_parallel=None, tensor_rank=None, pipeline_rank=None, expert_parallel=None, expert_rank=None, return_base_dir=False, tmp=False): if release: directory = 'release' else: directory = 'iter_{:07d}'.format(iteration) if tmp: directory = directory + "_tmp" if return_base_dir: common_path = os.path.join(checkpoints_path, directory) return common_path # Use both the tensor and pipeline MP rank. if pipeline_parallel is None: pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) if tensor_rank is None: tensor_rank = mpu.get_tensor_model_parallel_rank() if pipeline_rank is None: pipeline_rank = mpu.get_pipeline_model_parallel_rank() if expert_parallel is None: expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1) if expert_rank is None: expert_rank = mpu.get_expert_model_parallel_rank() # Use both the tensor and pipeline MP rank. If using the distributed # optimizer, then the optimizer's path must additionally include the # data parallel rank. if not pipeline_parallel: common_path = os.path.join(checkpoints_path, directory, f'mp_rank_{tensor_rank:02d}') else: common_path = os.path.join(checkpoints_path, directory, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') if expert_parallel: common_path = common_path + f'_{expert_rank:03d}' return os.path.join(common_path, "model_optim_rng.pt") def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far): """Save a model checkpoint.""" args = get_args() # Only rank zero of the data parallel writes to the disk. model = unwrap_model(model) ckpt_format = args.dist_ckpt_format if args.use_dist_ckpt else 'torch' print('rank {} is saving checkpoint at iteration {:7d} to {} in {} format' .format(args.rank, iteration, args.save, ckpt_format)) # Collect rng state across data parallel ranks. rng_state = get_rng_state(args.use_dist_ckpt) tmp = optimizer.get_error_dump() # Checkpoint name. checkpoint_name = get_checkpoint_name(args.save, iteration, return_base_dir=args.use_dist_ckpt, tmp=tmp) check_ret, err_msg, checkpoint_name = FileUtils.regular_file_path(checkpoint_name, '/', False) if not check_ret: raise Exception(f"rank {args.rank} get checkpoint name error, {err_msg}") # Save distributed optimizer's custom parameter state. if args.use_distributed_optimizer and not args.no_save_optim and optimizer is not None and not args.use_dist_ckpt: optim_checkpoint_name = \ get_distributed_optimizer_checkpoint_name(checkpoint_name) ensure_directory_exists(optim_checkpoint_name) optimizer.save_parameter_state(optim_checkpoint_name) if not args.use_distributed_optimizer and (args.fp16 or args.bf16): if not hasattr(optimizer.config, 'reuse_fp32_param'): optimizer._copy_main_params_to_model_params() cur_rank = torch.distributed.get_rank() save_flag = optimizer.need_write_file() # Collect args, model, RNG. if not torch.distributed.is_initialized() \ or save_flag \ or args.use_dist_ckpt: optim_sd_kwargs = {} if args.use_dist_ckpt and args.use_distributed_optimizer: optim_sd_kwargs['sharding_type'] = ('fully_sharded_bucket_space' if args.ckpt_fully_parallel_save else 'dp_zero_gather_scatter') print_rank_0(f'Storing distributed optimizer sharded state of type {optim_sd_kwargs["sharding_type"]}') state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state, args.use_dist_ckpt, iteration, optim_sd_kwargs=optim_sd_kwargs) state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far # Save. ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name)
MindIO TTP功能沿用了Megatron框架中保存CheckPoint的主要逻辑,但是由于需要参与CheckPoint保存的worker不再是有序的,因此需要对CheckPoint文件的命名以及是否进行CheckPoint保存的判断逻辑进行修改,在临终保存之后文件名可能与Megatron框架中保存CheckPoint的文件名不同,由rename_callback实现文件改名以和框架中的命名规则一致。
- MindIO UCE、MindIO ARF相关回调。
def stop_callback(train_args=None, ctx=None): # stop and clean device device = get_current_device() torch.npu.set_device(device) torch_npu.npu.stop_device(device) def clean_callback(is_uce_error=False, train_args=None, ctx=None): """ this function do: 1) get UCE check result from torch_npu 2) do some clear before rebuild (avoid OOM) when the check result is UCE_HIGH_LEVEL 3) HCCL resume and restart device """ device = get_device() rank = torch.distributed.get_rank() torch.npu.set_device(device) ret = RET_OK if is_uce_error: check_memory_result = torch_npu.npu.check_uce_in_memory(device) if check_memory_result == UCE_LOW_LEVEL: # no need rebuild ret = RET_NO_REBUILD elif check_memory_result == UCE_HIGH_LEVEL: # need rebuild clear_memory(train_args[TRAIN_PARAM][MODEL_INDEX], train_args[TRAIN_PARAM][OPTIM_INDEX], train_args[TRAIN_PARAM][CONFIG_INDEX]) train_args[TRAIN_PARAM][MODEL_INDEX] = None train_args[TRAIN_PARAM][OPTIM_INDEX] = None train_args[TRAIN_PARAM][SCHEDULER_INDEX] = None train_args[TRAIN_PARAM][CONFIG_INDEX] = None ret = RET_OK else: # exit ret = RET_ERROR if hasattr(torch.distributed, 'reinit_process_group'): torch.distributed.reinit_process_group(group=None, rebuild_link=False) torch.npu.restart_device(device) return ret def repair_callback(step: int, need_rebuild: bool, error_ranks: list, repair_info: dict, train_args=None, ctx=None): from mindio_ttp.framework_ttp import RepairType, OptimizerType torch.npu.set_device(get_current_device()) optim_idxs = repair_info.get("type", OptimizerType.ATTENTION.value) repair_type = repair_info.get("repair_type", None) src_ranks = repair_info.get("src", []) dest_ranks = repair_info.get("dst", []) rank_list = repair_info.get("rank_list", []) build_repair_group(rank_list) if repair_type == RepairType.RT_SEND.value: # send return send_rank_repair(src_ranks, dest_ranks, optim_idxs, step, rank_list, train_args[TRAIN_PARAM]) elif repair_type in [RepairType.RT_UCE_HIGHLEVEL.value, RepairType.RT_UCE_LOWLEVEL.value, RepairType.RT_RECV_REPAIR.value]: return recv_rank_repair(src_ranks, dest_ranks, optim_idxs, step, need_rebuild, rank_list, train_args[TRAIN_PARAM]) else: return RET_ERROR def rollback_callback(step: int, train_args=None, ctx=None): args = get_args() torch.npu.set_device(get_current_device()) global temp_memory_ckpt rank = torch.distributed.get_rank() # Update consumed train samples and learning rate. if args.train_samples is None: args.consumed_train_samples = step * args.global_batch_size if train_args[TRAIN_PARAM][SCHEDULER_INDEX].num_steps != args.consumed_train_samples: train_args[TRAIN_PARAM][SCHEDULER_INDEX].step(args.global_batch_size) feature_rollback() gather_model_params_from_optimizer(train_args[TRAIN_PARAM][OPTIM_INDEX], step) rebuild_dataset(train_args) set_memory_ckpt(None) temp_memory_ckpt = None destroy_repair_group() return RET_OK def feature_rollback(): # Other features re-initialize or roll back data after the fault is rectified. if hasattr(args, "num_experts") and args.num_experts: mpu._MOE_AUX_LOSSES_LOGGING_TRACKER = {} # 重新初始化全局Tensor def send_rank_repair(src_ranks: list, dest_ranks: list, optim_idxs: list, step: int, rank_list: list, train_args): rank_list.sort() rank_ = torch.distributed.get_rank() repair_group = get_repair_group() for idx, src_rank in enumerate(src_ranks): dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx] if src_rank != rank_: return RET_ERROR if src_rank == dest_rank: return RET_ERROR save_and_send_ckpt(src_rank, dest_rank, step, optim_idx, train_args) for idx, _ in enumerate(src_ranks): dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx] train_args[OPTIM_INDEX].send_optim_param_state(dest_rank, repair_group, step, optim_idx) # repair other global tensor for idx, src_rank in enumerate(src_ranks): torch.distributed.send(global_tensor, dst=dest_rank, group=repair_group) return RET_OK def save_and_send_ckpt(src_rank, dest_rank, step, optim_idx, train_args): """ Save memory checkpoint and send to dest rank. """ repair_group = get_repair_group() state_dict = save_memory_ckpt(train_args[OPTIM_INDEX], train_args[SCHEDULER_INDEX], step, optim_idx) buffer = io.BytesIO() torch.save(state_dict, buffer) state_dict_bytes = buffer.getvalue() state_dict_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(state_dict_bytes)).to('npu') # Send tensor size first size_tensor = torch.tensor([state_dict_tensor.numel()], dtype=torch.long).to('npu') torch.distributed.send(size_tensor, dst=dest_rank, group=repair_group) # Send the serialized state_dict tensor torch.distributed.send(state_dict_tensor, dst=dest_rank, group=repair_group) return RET_OK def recv_rank_repair(src_ranks: list, dest_ranks: list, optim_idxs: list, step: int, need_rebuild: bool, rank_list: list, train_args): # low level and only rollback case if src_ranks == dest_ranks: return RET_OK rank_ = torch.distributed.get_rank() rank_list.sort() if need_rebuild: build_local_embedding_group() model, optimizer, lr_scheduler, config = rebuild_model_and_optimizer( BuildDataArgs.model_provider_, BuildDataArgs.model_type_) train_args[MODEL_INDEX] = model train_args[OPTIM_INDEX] = optimizer train_args[SCHEDULER_INDEX] = lr_scheduler train_args[CONFIG_INDEX] = config for idx, src_rank in enumerate(src_ranks): dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx] if dest_rank != rank_: return RET_ERROR if src_rank != dest_rank: recv_ckpt_from_peer(src_rank, dest_rank, step, rank_list) # combine state_dict and once load,fix precision problem state_dict = get_memory_ckpt() load_memory_ckpt(train_args[MODEL_INDEX], train_args[OPTIM_INDEX], train_args[SCHEDULER_INDEX], state_dict, None) repair_group = get_repair_group() for idx, src_rank in enumerate(src_ranks): dest_rank, optim_idx = dest_ranks[idx], optim_idxs[idx] train_args[OPTIM_INDEX].recv_and_load_optim_param_state(src_rank, repair_group, step, optim_idx) # repair other global tensor for idx, src_rank in enumerate(src_ranks): recv_tensor = torch.empty(size, dtype=type, device="npu") torch.distributed.recv(recv_tensor, src=src_rank, group=repair_group) global_tesor.data.copy_(recv_tensor) return RET_OK def recv_ckpt_from_peer(src_rank, dest_rank, step, rank_list: list): """ receive memory checkpoint and repair train() param. """ repair_group = get_repair_group() # Receive tensor size first size_tensor = torch.tensor([0], dtype=torch.long, device='npu') torch.distributed.recv(size_tensor, src=src_rank, group=repair_group) size = size_tensor.item() # Receive the serialized state_dict tensor state_dict_tensor = torch.empty(size, dtype=torch.uint8, device='npu') torch.distributed.recv(state_dict_tensor, src=src_rank, group=repair_group) # Deserialize the state_dict state_dict_bytes = state_dict_tensor.cpu().numpy().tobytes() buffer = io.BytesIO(state_dict_bytes) device_count = torch.npu.device_count() if device_count == 0: raise ValueError(f"torch.npu.device_count return 0!") map_location = { 'npu:' + str(src_rank % device_count): 'npu:' + str(dest_rank % device_count) } loaded_state_dict = torch.load(buffer, map_location=map_location) set_memory_ckpt(loaded_state_dict) def rebuild_model_and_optimizer(model_provider, model_type): args = get_args() load = args.load args.load = None ori_embedding_group = mpu._EMBEDDING_GROUP mpu._EMBEDDING_GROUP = get_local_embedding_group() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type) mpu._EMBEDDING_GROUP = ori_embedding_group args.load = load config = get_model_config(model[0]) return model, optimizer, lr_scheduler, config def rebuild_dataset(args): # repair data iterator train_data_iterator, valid_data_iterator, test_data_iterator \ = build_data_iterator(args[TRAIN_PARAM][MODEL_INDEX]) args[TRAIN_PARAM][TRAIN_DATA_INDEX] = train_data_iterator args[TRAIN_PARAM][VALID_DATA_INDEX] = valid_data_iterator args[TEST_DATA_ITER][0] = test_data_iterator return RET_OK def build_data_iterator(model): args = get_args() if args.virtual_pipeline_model_parallel_size is not None: train_data_iterator = [] valid_data_iterator = [] test_data_iterator = [] for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) iterators = build_train_valid_test_data_iterators( BuildDataArgs.train_valid_test_datasets_provider_) train_data_iterator.append(iterators[0]) valid_data_iterator.append(iterators[1]) test_data_iterator.append(iterators[2]) else: train_data_iterator, valid_data_iterator, test_data_iterator \ = build_train_valid_test_data_iterators( BuildDataArgs.train_valid_test_datasets_provider_) return train_data_iterator, valid_data_iterator, test_data_iterator def update_arf_reboot_flag(new_state): # 对于MindIO ARF中增量启动的节点,选择修复过程中接收的优化器数据而不是加载故障前已保存的CheckPoint数据 global ARF_REBOOT_FLAG, LOAD, PRETRAINED_CHECKPOINT args = get_args() if new_state: LOAD = args.load PRETRAINED_CHECKPOINT = args.pretrained_checkpoint args.load = None args.pretrained_checkpoint = None elif ARF_REBOOT_FLAG: args.load = LOAD args.pretrained_checkpoint = PRETRAINED_CHECKPOINT ARF_REBOOT_FLAG = new_state def arf_rebuild_process_group_callback(fault_ranks: list, train_args, ctx): models, optimizer = train_args[TRAIN_PARAM][MODEL_INDEX], train_args[TRAIN_PARAM][OPTIM_INDEX] args = get_args() update_arf_reboot_flag(False) # 1.1 destroy_all_process_group torch.distributed.destroy_process_group() # 1.2 init_all_process_group torch.distributed.init_process_group( backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, timeout=timedelta(minutes=args.distributed_timeout_minutes), ) #2.1 destroy mpu group delete_mpu_group(mpu) # 2.2 destroy global memory buffer mpu.destroy_global_memory_buffer() # 3.1 init_mpu mpu.initialize_model_parallel( context_parallel_size=args.context_parallel_size, tensor_model_parallel_size=args.tensor_model_parallel_size, expert_model_parallel_size=args.expert_model_parallel_size, distributed_timeout_minutes=args.distributed_timeout_minutes, pipeline_model_parallel_size=args.pipeline_model_parallel_size, nccl_communicator_config_path=args.nccl_communicator_config_path, pipeline_model_parallel_split_rank=args.pipeline_model_parallel_split_rank, virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size, ) update_model_and_optim_related_group(models, optimizer) def delete_mpu_group(mpu_module): group_space = [mpu_group for mpu_group in dir(mpu_module) if 'GROUP' in mpu_group] for group_name in group_space: setattr(mpu_module, group_name, None) def update_model_and_optim_related_group(models, optimizer): for model in models: # dense for buffer in model.buffers: # fix ParamAndGradBuffer attributes buffer.data_parallel_group = mpu._DATA_PARALLEL_GROUP_WITH_CP buffer.data_parallel_world_size = torch.distributed.get_world_size(group=mpu._DATA_PARALLEL_GROUP_WITH_CP) for bucket in buffer.buckets: # fix Bucket attributes bucket.data_parallel_group = mpu._DATA_PARALLEL_GROUP_WITH_CP bucket.data_parallel_world_size = torch.distributed.get_world_size( group=mpu._DATA_PARALLEL_GROUP_WITH_CP) bucket.data_parallel_rank = torch.distributed.get_rank(mpu._DATA_PARALLEL_GROUP_WITH_CP) # moe for expert_buffer in model.expert_parallel_buffers: # fix ParamAndGradBuffer attributes expert_buffer.data_parallel_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP expert_buffer.data_parallel_world_size = \ torch.distributed.get_world_size(group=mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP) for bucket in expert_buffer.buckets: # fix Bucket attributes bucket.data_parallel_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP bucket.data_parallel_world_size = torch.distributed.get_world_size( group=mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP) bucket.data_parallel_rank = torch.distributed.get_rank(mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP) if hasattr(optimizer, 'optim_nums'): # 根据需求更新优化器中持有的通信组 optimizer.chained_optimizers[0].ori_dp_group = mpu._DATA_PARALLEL_GROUP_WITH_CP optimizer.chained_optimizers[1].ori_dp_group = mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP else: optimizer.ori_dp_group = mpu._DATA_PARALLEL_GROUP_WITH_CP if not get_args().use_distributed_optimizer: # 非分布式优化器不持有副本通信组 return # fix optimizer attributes if hasattr(optimizer, 'optim_nums'): # 分布式优化器会持有副本组用于副本优化器相关的通信 optimizer.chained_optimizers[0].data_parallel_group = ttp_get_dp_cp_replica_group() optimizer.chained_optimizers[0].data_parallel_group_gloo = ttp_get_dp_cp_replica_group_gloo() optimizer.chained_optimizers[0].ori_dp_list = torch.distributed.get_process_group_ranks( mpu._DATA_PARALLEL_GROUP_WITH_CP) optimizer.chained_optimizers[1].data_parallel_group = ttp_get_dp_ep_replica_group() optimizer.chained_optimizers[1].data_parallel_group_gloo = ttp_get_dp_ep_replica_group_gloo() optimizer.chained_optimizers[1].ori_dp_list = torch.distributed.get_process_group_ranks( mpu._DATA_MODULO_EXPERT_PARALLEL_GROUP) else: optimizer.data_parallel_group = ttp_get_dp_cp_replica_group() optimizer.data_parallel_group_gloo = ttp_get_dp_cp_replica_group_gloo() optimizer.ori_dp_list = torch.distributed.get_process_group_ranks(mpu._DATA_PARALLEL_GROUP_WITH_CP) return #(可选)若训练中涉及切流操作,重置训练中使用的流,确保调度正确 class CustomFunction(torch.autograd.Function): @staticmethod def forward(ctx, input): torch.cuda.set_stream(torch.cuda.default_stream()) return input @staticmethod def backward(ctx, grad): torch.cuda.set_stream(torch.cuda.default_stream()) return grad def recover_set_stream(): input_tensor = torch.empty(1, dtype=torch.float32, device="npu", requires_grad=True) grad_tensor = torch.empty(1, dtype=torch.float32, device="npu", requires_grad=True) output_tensor = CustomFunction.apply(input_tensor) output_tensor.backward(grad_tensor)
stop&clean callback为MindIO UCE和MindIO ARF必要的回调函数,stop_callback用于通过torch_npu.npu.stop_device接口令其他worker暂停训练,待修复完成后所有worker共同重新进入训练;clean_callback中,如果当前故障为UCE故障,则需要通过torch_npu.npu.check_uce_in_memory接口检测UCE故障类型,若为UCE_HIGH_LEVEL类型,则本次修复需要重建模型优化器以重新申请内存,刷新Tensor状态以完全修复UCE故障;之后通过torch.npu.restart_device接口重新启动device。
repair_callback用于实现互为副本的worker进行数据修复,故障worker重建模型优化器并接收并加载正常副本worker的CheckPoint数据,正常副本worker保存并发送CheckPoint数据。
在repair完成后,所有worker共同进入rollback_callback中,在rollback阶段,将优化器中的参数取出到模型,还需要将其他训练中用到的参数回滚至正确的位置,如学习率、数据集。
MindIO ARF功能依赖MindIO UCE修复中用到的数据传输功能,arf_rebuild_process_group_callback用于MindIO ARF功能在增量重启worker后重建通信组,并替换模型优化器中持有的通信组,确保在故障修复后的训练中使用重建的通信组。
此外,对于其他需要修复的全局Tensor,若其与训练迭代无依赖关系,则可在rollback阶段对其重新初始化,将其释放;若其与训练迭代有依赖关系,且与副本优化器映射关系一致时,则需要在repair阶段对其进行重建并通过点对点通信修复该Tensor的数据。