副本优化器功能实现
- 新增优化器属性。
class TFTReplicaOptimizer(DistributedOptimizer): def __init__(self, optimizer: torch.optim.Optimizer, config: OptimizerConfig, grad_scaler: MegatronGradScaler, init_state_fn: Optional[Callable], per_model_buffers: Dict[int, List[ParamAndGradBuffer]], data_parallel_group: torch.distributed.ProcessGroup, data_parallel_group_gloo: torch.distributed.ProcessGroup, data_parallel_group_idx: int, ori_dp_group=None): self.args = get_args() if hasattr(self.args, 'optimizer_replica_num') and self.args.optimizer_replica_num > 1: self.replica_num = self.args.optimizer_replica_num else: self.replica_num = 2 if torch.distributed.get_world_size(group=ori_dp_group) == 1: raise ValueError('High availability do not support data parallel world size is 1!') if (torch.distributed.get_world_size(group=ori_dp_group) % self.replica_num) != 0: raise ValueError('High availability do not support data parallel world size is undivided by replica_num!') if self.data_parallel_group is None: raise ValueError('High availability do not support data parallel group is None!') super().__init__(optimizer, config, grad_scaler, init_state_fn, per_model_buffers, data_parallel_group, data_parallel_group_gloo, data_parallel_group_idx) self.error_dump = False self.save_args = {} self.current_step = 0 self.ori_dp_group = ori_dp_group self.ori_dp_list = torch.distributed.get_process_group_ranks(ori_dp_group)
其中error_dump作为MindIO TTP功能临终保存时的标记用以区分框架实现的周期CheckPoint保存,save_args用以记录在临终保存时所需的必要信息,current_step用以记录当前优化器状态与迭代步数的对应关系,ori_dp_group为PTD切分时的原始DP组,ori_dp_list为ori_dp_group对应的rank_list。
- 优化器中,MindIO TTP功能相关方法。
@staticmethod def get_index_map(dp_ranks, save_ranks_list, replica_num: int): dp_size = len(dp_ranks) replica_size = dp_size // replica_num dp_ranks_tmp = [dp_ranks[i:i + replica_size] for i in range(0, dp_size, replica_size)] dp_ranks_maps = {} for data_parallel_ranks in dp_ranks_tmp: for i in range(replica_size): dp_ranks_maps[data_parallel_ranks[i]] = i tup = [(rank, si) for si, rank in enumerate(save_ranks_list)] tup.sort(key=lambda x: dp_ranks_maps.get(x[0])) ti_to_si = {} for ti, (rank, si) in enumerate(tup): ti_to_si[ti] = si return ti_to_si def get_error_dump(self): return self.error_dump def set_current_step(self, step): self.current_step = step def set_update_successful(self, flag): self.update_successful = flag def set_dump_args(self, rank, step, rank_list): self.save_args['step'] = step self.save_args['rank'] = rank self.save_args['rank_list'] = rank_list self.error_dump = True dp_size = len(self.ori_dp_list) replica_size = dp_size // self.replica_num dp_ranks_tmp = [self.ori_dp_list[i:i + replica_size] for i in range(0, dp_size, replica_size)] dp_ranks_maps = {} for data_parallel_ranks in dp_ranks_tmp: for i in range(replica_size): dp_ranks_maps[data_parallel_ranks[i]] = i for save_rank in rank_list: if dp_ranks_maps.get(save_rank) == 0: self.save_args['rank'] = save_rank break def need_write_file(self): cur_rank = torch.distributed.get_rank() if self.error_dump and self.save_args['rank'] == cur_rank: return True elif not self.error_dump and torch.distributed.get_rank(group=self.ori_dp_group) == 0: return True else: return False def get_parameter_state_dp_zero_for_ttp(self): global_rank = torch.distributed.get_rank() save_rank = self.save_args['rank'] save_step = self.save_args['step'] save_rank_list = self.save_args['rank_list'] if save_step != self.current_step: raise RuntimeError(f"rank {global_rank} save step is not right, please check. \n") merge_states = self.save_local_parameter_state(use_npu=False) sorted_save_rank_list = sorted(save_rank_list) # torch内部按照这种方式保存 ti_to_si = self.get_index_map(self.ori_dp_list, sorted_save_rank_list, self.replica_num) save_group_gloo = torch.distributed.new_group(sorted_save_rank_list, backend="gloo", use_local_synchronization=True) final_state = {} for gbuf_idx, gbuf_states in merge_states.items(): final_state[gbuf_idx] = {} for dtype, all_buckets_states in gbuf_states.items(): world_tensors = {} final_state[gbuf_idx][dtype] = {} for key, all_buckets_state in all_buckets_states.items(): for bucket_tensor in all_buckets_state: send_tensor = bucket_tensor.cpu() # Gather tensor list. if global_rank == save_rank: recv_tensors = [ torch.empty((send_tensor.numel(),), dtype=torch.float32, device="cpu") for _ in range(len(save_rank_list)) ] else: recv_tensors = None # Gather. torch.distributed.gather(send_tensor, recv_tensors, save_rank, save_group_gloo,) if global_rank == save_rank: res = [] for i in range(len(save_rank_list)): res.append(recv_tensors[ti_to_si.get(i)]) if len(res) != len(recv_tensors): raise ValueError( "The length of received doesn`t match the expected number of receive tensors.") if key not in world_tensors: world_tensors[key] = [] world_tensors[key].append(torch.cat(res)) final_state[gbuf_idx][dtype] = world_tensors final_state["per_bucket_numel"] = self.per_bucket_numel final_state["per_bucket_numel_unpadded"] = self.per_bucket_numel_unpadded return final_state def copy_all_main_param_to_model_param(self, param_state): for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges): for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items(): for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): main_param = param_state[gbuf_idx][dtype]["param"][bucket_idx] bucket_param = self.buffers[gbuf_idx].buckets[bucket_idx].param_data offset = self.buffers[gbuf_idx].buckets[bucket_idx].offset if main_param.nelement() != bucket_param.nelement(): raise RuntimeError("The number of elements in main_param and bucket_param must be the same.") for model_param, _ in gbuf_range_map["param_map"].items(): param_world_start, param_world_end, _ = self.buffers[gbuf_idx].param_index_map[model_param] param_bucket_start = param_world_start - offset param_bucket_end = param_world_end - offset main_param_shard = main_param[param_bucket_start : param_bucket_end] model_param.view(-1).detach().copy_(main_param_shard) def save_parameters_state_ttp(self): cur_rank = torch.distributed.get_rank() save_rank = self.save_args['rank'] if cur_rank not in self.save_args['rank_list']: return None state_dict = self.get_parameter_state_dp_zero_for_ttp() if cur_rank == save_rank: # if reuse fp32 param, then not copy all if not hasattr(self.config, 'reuse_fp32_param') or not self.config.reuse_fp32_param: self.copy_all_main_param_to_model_param(state_dict) return state_dict return None def save_parameter_state_impl(self): if self.error_dump: state_dict = self.save_parameters_state_ttp() else: state_dict = self.get_parameter_state_dp_zero() return state_dict def save_parameter_state(self, filename: str): cur_rank = torch.distributed.get_rank() state_dict = self.save_parameter_state_impl() if self.error_dump: save_rank = self.save_args['rank'] if cur_rank == save_rank: torch.save(state_dict, filename) print(f"errdump rank: {cur_rank} successfully saved parameters") else: if torch.distributed.get_rank(self.ori_dp_group) == 0: torch.save(state_dict, filename) print(f"normal rank: {cur_rank} successfully saved parameters")
参照Megatron原生方法,加入error_dump逻辑分支。由于临终保存时仅由被选择的worker进行,因此需要与周期CheckPoint保存区分,同时参与保存的rank号也与周期CheckPoint保存不同,需要额外处理。
- 优化器中,MindIO UCE、MindIO ARF相关方法。
def send_param_state(self, dst, group): for _, gbuf_range_maps in enumerate(self.gbuf_ranges): for _, gbuf_range_map_for_all_buckets in gbuf_range_maps.items(): for _, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): for model_param, _ in gbuf_range_map["param_map"].items(): group_index, group_order = self.model_param_group_index_map[model_param] main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] tensors = { "param": main_param, **optim_state, } torch.distributed.send(tensors["param"].detach().npu(), dst=dst, group=group) torch.distributed.send(tensors["exp_avg"].detach().npu(), dst=dst, group=group) torch.distributed.send(tensors["exp_avg_sq"].detach().npu(), dst=dst, group=group) def recv_param_state(self, src, group): for _, gbuf_range_maps in enumerate(self.gbuf_ranges): for _, gbuf_range_map_for_all_buckets in gbuf_range_maps.items(): for _, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): for model_param, _ in gbuf_range_map["param_map"].items(): # Main param & optimizer states. group_index, group_order = self.model_param_group_index_map[model_param] main_param = self.optimizer.param_groups[group_index]["params"][group_order] optim_state = self.optimizer.state[main_param] tensors = { "param": main_param, **optim_state, } torch.distributed.recv(tensors["param"].data, src=src, group=group) torch.distributed.recv(tensors["exp_avg"].data, src=src, group=group) torch.distributed.recv(tensors["exp_avg_sq"].data, src=src, group=group) def state_dict_memory(self): return self.state_dict() def load_state_dict_memory(self, state_dict): self.load_state_dict(state_dict) def remove_hook_handle(self): if self.remove_pre_hook_handle: self.remove_pre_hook_handle.remove() self.remove_pre_hook_handle = None
在优化器中加入参数的点对点收发功能,用以在MindIO UCE故障修复及MindIO ARF增量重启时将副本优化器的数据传输给其被修复的副本worker,在优化器类方法中传输由优化器持有的参数及一阶矩、二阶矩估计,其他需要传输的参数将在回调函数中具体实现,用户还可自行实现其他数据的传输。
- 建立副本优化器使用的副本DP组。
def build_dp_cp_replica_group(dp_cp_ranks: list): if len(dp_cp_ranks) % REPLICA_NUM != 0: raise ValueError(f"High availability do not support the size of dp_cp_ranks:{dp_cp_ranks} " f"is undivided by replica_num:{REPLICA_NUM} !") global DP_CP_REPLICA_GROUP, DP_CP_REPLICA_GROUP_GLOO cur_rank = torch.distributed.get_rank() replica_group_size = len(dp_cp_ranks) // REPLICA_NUM replica_lists = [dp_cp_ranks[i * replica_group_size : (i + 1) * replica_group_size] for i in range(0, REPLICA_NUM)] for replica_list in replica_lists: if cur_rank in replica_list: replica_group = torch.distributed.new_group(replica_list, use_local_synchronization=True) replica_group_gloo = torch.distributed.new_group(replica_list, backend="gloo", use_local_synchronization=True) DP_CP_REPLICA_GROUP = replica_group DP_CP_REPLICA_GROUP_GLOO = replica_group_gloo def build_dp_ep_replica_group(dp_ep_ranks: list): if len(dp_ep_ranks) % REPLICA_NUM != 0: raise ValueError(f"High availability do not support the size of dp_ep_ranks:{dp_ep_ranks} " f"is undivided by replica_num:{REPLICA_NUM}") global DP_EP_REPLICA_GROUP, DP_EP_REPLICA_GROUP_GLOO cur_rank = torch.distributed.get_rank() replica_group_size = len(dp_ep_ranks) // REPLICA_NUM replica_lists = [dp_ep_ranks[i * replica_group_size: (i + 1) * replica_group_size] for i in range(0, REPLICA_NUM)] for replica_list in replica_lists: if cur_rank in replica_list: replica_group = torch.distributed.new_group(replica_list, use_local_synchronization=True) replica_group_gloo = torch.distributed.new_group(replica_list, backend="gloo", use_local_synchronization=True) DP_EP_REPLICA_GROUP = replica_group DP_EP_REPLICA_GROUP_GLOO = replica_group_gloo
在获得dp_ep或dp_cp组的rank_list之后,作为入参新建副本组用以在副本组内同步模型优化器数据。
- 框架使用副本优化器。
- 修改获取优化器时使用的优化器类。
def tft_get_megatron_optimizer_based_on_param_groups( config: OptimizerConfig, param_groups: List, per_model_buffers: Optional[Dict[int, List[ParamAndGradBuffer]]] = None, data_parallel_group: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None, data_parallel_group_idx: Optional[int] = None, ori_dp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> MegatronOptimizer: if config.optimizer == 'adam': optimizer = Adam( param_groups, lr=config.lr, weight_decay=config.weight_decay, betas=(config.adam_beta1, config.adam_beta2), eps=config.adam_eps, ) def init_state_fn(opt): for group in opt.param_groups: for p in group['params']: if len(opt.state[p]) == 0: opt.state[p]['exp_avg'] = torch.zeros_like(p.data) opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) elif config.optimizer == 'sgd': optimizer = SGD( param_groups, lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum, ) init_state_fn = None else: raise Exception('{} optimizer is not supported.'.format(config.optimizer)) if config.fp16 or config.bf16 or config.use_distributed_optimizer: grad_scaler = None # Constant loss scale. if config.loss_scale: grad_scaler = ConstantGradScaler(config.loss_scale) # Dynamic loss scale. else: if config.fp16: grad_scaler = DynamicGradScaler( initial_scale=config.initial_loss_scale, min_scale=config.min_loss_scale, growth_factor=2.0, backoff_factor=0.5, growth_interval=config.loss_scale_window, hysteresis=config.hysteresis, ) optimizer_args = [ optimizer, config, grad_scaler, init_state_fn, ] if config.use_distributed_optimizer: optimizer = TFTReplicaOptimizer( *optimizer_args, per_model_buffers=per_model_buffers, data_parallel_group=data_parallel_group, data_parallel_group_gloo=data_parallel_group_gloo, data_parallel_group_idx=data_parallel_group_idx, ori_dp_group=ori_dp_group ) else: optimizer = TFTFP16ReplicaOptimizer(*optimizer_args, ori_dp_group=ori_dp_group) return optimizer # FP32. return TFTFP32ReplicaOptimizer(optimizer, config, init_state_fn, ori_dp_group=ori_dp_group)
如果封装新的优化器类,需要在框架侧替换所使用的优化器,需要注意的是如果入参改变也请一并修改。
- 对于分布式优化器,还需要在框架流程中额外修改。
def tft_clip_grad_norm_fp32_wrapper(clip_grad_norm_fp32): @wraps(clip_grad_norm_fp32) def wrapper(parameters, grads_for_norm, *args, **kwargs): if get_args().use_distributed_optimizer: norm_type = kwargs.get('norm_type', 2) for idx, grad in enumerate(grads_for_norm): grads_for_norm[idx] = grad / get_replica_dp_num() ** (1/norm_type) return clip_grad_norm_fp32(parameters, grads_for_norm, *args, **kwargs) return wrapper def tft_build_model_gbuf_range_wrapper(_build_model_gbuf_range): @wraps(_build_model_gbuf_range) def wrapper(param_and_grad_buffer, bucket_index): ori_dp_group = param_and_grad_buffer.data_parallel_group is_expert_parallel = [] for param in param_and_grad_buffer.buckets[0].params_list: is_expert_parallel.append(not getattr(param, 'allreduce', True)) if all(is_expert_parallel): param_and_grad_buffer.data_parallel_group = get_dp_ep_replica_group() elif not any(is_expert_parallel): param_and_grad_buffer.data_parallel_group = get_dp_cp_replica_group() else: raise ValueError("Mixed boolean values found.") data = _build_model_gbuf_range(param_and_grad_buffer, bucket_index) param_and_grad_buffer.data_parallel_group = ori_dp_group return data return wrapper
clip_grad_norm_fp32方法用以梯度裁剪,当使用分布式优化器时会在原始DP组内梯度积累,但是由于副本优化器改变了分布式优化器的切分方式,需要在此处对梯度做修正;build_model_gbuf_range为Megatron分布式优化器的类方法,用以在DP组内切分优化器,此时需要在切分优化器时使用副本DP组,需要注意的是在使用MoE特性时需要区分模型的稠密层和稀疏层。
- 修改获取优化器时使用的优化器类。
父主题: 非MindSpeed用户对接指导