昇腾社区首页
中文
注册

副本优化器功能实现

  1. 新增优化器属性。
    class TFTReplicaOptimizer(DistributedOptimizer):
        def __init__(self,
                     optimizer: torch.optim.Optimizer,
                     config: OptimizerConfig,
                     grad_scaler: MegatronGradScaler,
                     init_state_fn: Optional[Callable],
                     per_model_buffers: Dict[int, List[ParamAndGradBuffer]],
                     data_parallel_group: torch.distributed.ProcessGroup,
                     data_parallel_group_gloo: torch.distributed.ProcessGroup,
                     data_parallel_group_idx: int,
                     ori_dp_group=None):
            self.args = get_args()
            if hasattr(self.args, 'optimizer_replica_num') and self.args.optimizer_replica_num > 1:
                self.replica_num = self.args.optimizer_replica_num
            else:
                self.replica_num = 2
            if torch.distributed.get_world_size(group=ori_dp_group) == 1:
                raise ValueError('High availability do not support data parallel world size is 1!')
            if (torch.distributed.get_world_size(group=ori_dp_group) % self.replica_num) != 0:
                raise ValueError('High availability do not support data parallel world size is undivided by replica_num!')
            if self.data_parallel_group is None:
                raise ValueError('High availability do not support data parallel group is None!')
    
            super().__init__(optimizer, config, grad_scaler, init_state_fn,
                             per_model_buffers, data_parallel_group, data_parallel_group_gloo, data_parallel_group_idx)
            self.error_dump = False
            self.save_args = {}
            self.current_step = 0
            self.ori_dp_group = ori_dp_group
            self.ori_dp_list = torch.distributed.get_process_group_ranks(ori_dp_group)

    其中error_dump作为MindIO TTP功能临终保存时的标记用以区分框架实现的周期CheckPoint保存,save_args用以记录在临终保存时所需的必要信息,current_step用以记录当前优化器状态与迭代步数的对应关系,ori_dp_group为PTD切分时的原始DP组,ori_dp_list为ori_dp_group对应的rank_list。

  2. 优化器中,MindIO TTP功能相关方法。
    @staticmethod
    def get_index_map(dp_ranks, save_ranks_list, replica_num: int):
        dp_size = len(dp_ranks)
        replica_size = dp_size // replica_num
        dp_ranks_tmp = [dp_ranks[i:i + replica_size] for i in range(0, dp_size, replica_size)]
        dp_ranks_maps = {}
        for data_parallel_ranks in dp_ranks_tmp:
            for i in range(replica_size):
                dp_ranks_maps[data_parallel_ranks[i]] = i
        tup = [(rank, si) for si, rank in enumerate(save_ranks_list)]
        tup.sort(key=lambda x: dp_ranks_maps.get(x[0]))
        ti_to_si = {}
        for ti, (rank, si) in enumerate(tup):
            ti_to_si[ti] = si
    
        return ti_to_si
    
    def get_error_dump(self):
        return self.error_dump
    
    def set_current_step(self, step):
        self.current_step = step
    
    def set_update_successful(self, flag):
        self.update_successful = flag
    
    def set_dump_args(self, rank, step, rank_list):
        self.save_args['step'] = step
        self.save_args['rank'] = rank
        self.save_args['rank_list'] = rank_list
        self.error_dump = True
    
        dp_size = len(self.ori_dp_list)
        replica_size = dp_size // self.replica_num
        dp_ranks_tmp = [self.ori_dp_list[i:i + replica_size] for i in range(0, dp_size, replica_size)]
        dp_ranks_maps = {}
        for data_parallel_ranks in dp_ranks_tmp:
            for i in range(replica_size):
                dp_ranks_maps[data_parallel_ranks[i]] = i
    
        for save_rank in rank_list:
            if dp_ranks_maps.get(save_rank) == 0:
                self.save_args['rank'] = save_rank
                break
    
    def need_write_file(self):
        cur_rank = torch.distributed.get_rank()
        if self.error_dump and self.save_args['rank'] == cur_rank:
            return True
        elif not self.error_dump and torch.distributed.get_rank(group=self.ori_dp_group) == 0:
            return True
        else:
            return False
    
    def get_parameter_state_dp_zero_for_ttp(self):
        global_rank = torch.distributed.get_rank()
        save_rank = self.save_args['rank']
        save_step = self.save_args['step']
        save_rank_list = self.save_args['rank_list']
        if save_step != self.current_step:
            raise RuntimeError(f"rank {global_rank} save step is not right, please check.  \n")
        merge_states = self.save_local_parameter_state(use_npu=False)
        sorted_save_rank_list = sorted(save_rank_list)  # torch内部按照这种方式保存
        ti_to_si = self.get_index_map(self.ori_dp_list, sorted_save_rank_list, self.replica_num)
        save_group_gloo = torch.distributed.new_group(sorted_save_rank_list, backend="gloo",
                                                      use_local_synchronization=True)
        final_state = {}
        for gbuf_idx, gbuf_states in merge_states.items():
            final_state[gbuf_idx] = {}
            for dtype, all_buckets_states in gbuf_states.items():
                world_tensors = {}
                final_state[gbuf_idx][dtype] = {}
                for key, all_buckets_state in all_buckets_states.items():
                    for bucket_tensor in all_buckets_state:
                        send_tensor = bucket_tensor.cpu()
                        # Gather tensor list.
                        if global_rank == save_rank:
                            recv_tensors = [
                                torch.empty((send_tensor.numel(),), dtype=torch.float32, device="cpu")
                                for _ in range(len(save_rank_list))
                            ]
                        else:
                            recv_tensors = None
    
                        # Gather.
                        torch.distributed.gather(send_tensor, recv_tensors, save_rank, save_group_gloo,)
                        if global_rank == save_rank:
                            res = []
                            for i in range(len(save_rank_list)):
                                res.append(recv_tensors[ti_to_si.get(i)])
                            if len(res) != len(recv_tensors):
                                raise ValueError(
                                    "The length of received doesn`t match the expected number of receive tensors.")
                            if key not in world_tensors:
                                world_tensors[key] = []
                            world_tensors[key].append(torch.cat(res))
                final_state[gbuf_idx][dtype] = world_tensors
    
        final_state["per_bucket_numel"] = self.per_bucket_numel
        final_state["per_bucket_numel_unpadded"] = self.per_bucket_numel_unpadded
        return final_state
    
    def copy_all_main_param_to_model_param(self, param_state):
        for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
            for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
                for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
                    main_param = param_state[gbuf_idx][dtype]["param"][bucket_idx]
                    bucket_param = self.buffers[gbuf_idx].buckets[bucket_idx].param_data
                    offset = self.buffers[gbuf_idx].buckets[bucket_idx].offset
                    if main_param.nelement() != bucket_param.nelement():
                        raise RuntimeError("The number of elements in main_param and bucket_param must be the same.")
                    for model_param, _ in gbuf_range_map["param_map"].items():
                        param_world_start, param_world_end, _ = self.buffers[gbuf_idx].param_index_map[model_param]
                        param_bucket_start = param_world_start - offset
                        param_bucket_end = param_world_end - offset
                        main_param_shard = main_param[param_bucket_start : param_bucket_end]
                        model_param.view(-1).detach().copy_(main_param_shard)
    
    def save_parameters_state_ttp(self):
        cur_rank = torch.distributed.get_rank()
        save_rank = self.save_args['rank']
        if cur_rank not in self.save_args['rank_list']:
            return None
    
        state_dict = self.get_parameter_state_dp_zero_for_ttp()
        if cur_rank == save_rank:  # if reuse fp32 param, then not copy all
            if not hasattr(self.config, 'reuse_fp32_param') or not self.config.reuse_fp32_param:
                self.copy_all_main_param_to_model_param(state_dict)
            return state_dict
    
        return None
    
    def save_parameter_state_impl(self):
        if self.error_dump:
            state_dict = self.save_parameters_state_ttp()
        else:
            state_dict = self.get_parameter_state_dp_zero()
        return state_dict
    
    def save_parameter_state(self, filename: str):
        cur_rank = torch.distributed.get_rank()
        state_dict = self.save_parameter_state_impl()
        if self.error_dump:
            save_rank = self.save_args['rank']
            if cur_rank == save_rank:
                torch.save(state_dict, filename)
                print(f"errdump rank: {cur_rank} successfully saved parameters")
        else:
            if torch.distributed.get_rank(self.ori_dp_group) == 0:
                torch.save(state_dict, filename)
                print(f"normal rank: {cur_rank} successfully saved parameters")

    参照Megatron原生方法,加入error_dump逻辑分支。由于临终保存时仅由被选择的worker进行,因此需要与周期CheckPoint保存区分,同时参与保存的rank号也与周期CheckPoint保存不同,需要额外处理。

  3. 优化器中,MindIO UCEMindIO ARF相关方法。
    def send_param_state(self, dst, group):
        for _, gbuf_range_maps in enumerate(self.gbuf_ranges):
            for _, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
                for _, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
                    for model_param, _ in gbuf_range_map["param_map"].items():
                        group_index, group_order = self.model_param_group_index_map[model_param]
                        main_param = self.optimizer.param_groups[group_index]["params"][group_order]
                        optim_state = self.optimizer.state[main_param]
                        tensors = {
                            "param": main_param,
                            **optim_state,
                        }
                        torch.distributed.send(tensors["param"].detach().npu(), dst=dst, group=group)
                        torch.distributed.send(tensors["exp_avg"].detach().npu(), dst=dst, group=group)
                        torch.distributed.send(tensors["exp_avg_sq"].detach().npu(), dst=dst, group=group)
    
    def recv_param_state(self, src, group):
        for _, gbuf_range_maps in enumerate(self.gbuf_ranges):
            for _, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
                for _, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
                    for model_param, _ in gbuf_range_map["param_map"].items():
                        # Main param & optimizer states.
                        group_index, group_order = self.model_param_group_index_map[model_param]
                        main_param = self.optimizer.param_groups[group_index]["params"][group_order]
                        optim_state = self.optimizer.state[main_param]
                        tensors = {
                            "param": main_param,
                            **optim_state,
                        }
                        torch.distributed.recv(tensors["param"].data, src=src, group=group)
                        torch.distributed.recv(tensors["exp_avg"].data, src=src, group=group)
                        torch.distributed.recv(tensors["exp_avg_sq"].data, src=src, group=group)
    
    def state_dict_memory(self):
        return self.state_dict()
    
    def load_state_dict_memory(self, state_dict):
        self.load_state_dict(state_dict)
    
    def remove_hook_handle(self):
        if self.remove_pre_hook_handle:
            self.remove_pre_hook_handle.remove()
            self.remove_pre_hook_handle = None

    在优化器中加入参数的点对点收发功能,用以在MindIO UCE故障修复及MindIO ARF增量重启时将副本优化器的数据传输给其被修复的副本worker,在优化器类方法中传输由优化器持有的参数及一阶矩、二阶矩估计,其他需要传输的参数将在回调函数中具体实现,用户还可自行实现其他数据的传输。

  4. 建立副本优化器使用的副本DP组。
    def build_dp_cp_replica_group(dp_cp_ranks: list):
        if len(dp_cp_ranks) % REPLICA_NUM != 0:
            raise ValueError(f"High availability do not support the size of dp_cp_ranks:{dp_cp_ranks} "
                             f"is undivided by replica_num:{REPLICA_NUM} !")
        global DP_CP_REPLICA_GROUP, DP_CP_REPLICA_GROUP_GLOO
        cur_rank = torch.distributed.get_rank()
        replica_group_size = len(dp_cp_ranks) // REPLICA_NUM
        replica_lists = [dp_cp_ranks[i * replica_group_size : (i + 1) * replica_group_size] for i in range(0, REPLICA_NUM)]
        for replica_list in replica_lists:
            if cur_rank in replica_list:
                replica_group = torch.distributed.new_group(replica_list, use_local_synchronization=True)
                replica_group_gloo = torch.distributed.new_group(replica_list, backend="gloo",
                                                                 use_local_synchronization=True)
                DP_CP_REPLICA_GROUP = replica_group
                DP_CP_REPLICA_GROUP_GLOO = replica_group_gloo
    
    
    def build_dp_ep_replica_group(dp_ep_ranks: list):
        if len(dp_ep_ranks) % REPLICA_NUM != 0:
            raise ValueError(f"High availability do not support the size of dp_ep_ranks:{dp_ep_ranks} "
                             f"is undivided by replica_num:{REPLICA_NUM}")
        global DP_EP_REPLICA_GROUP, DP_EP_REPLICA_GROUP_GLOO
        cur_rank = torch.distributed.get_rank()
        replica_group_size = len(dp_ep_ranks) // REPLICA_NUM
        replica_lists = [dp_ep_ranks[i * replica_group_size: (i + 1) * replica_group_size] for i in range(0, REPLICA_NUM)]
        for replica_list in replica_lists:
            if cur_rank in replica_list:
                replica_group = torch.distributed.new_group(replica_list, use_local_synchronization=True)
                replica_group_gloo = torch.distributed.new_group(replica_list, backend="gloo",
                                                                 use_local_synchronization=True)
                DP_EP_REPLICA_GROUP = replica_group
                DP_EP_REPLICA_GROUP_GLOO = replica_group_gloo

    在获得dp_ep或dp_cp组的rank_list之后,作为入参新建副本组用以在副本组内同步模型优化器数据。

  5. 框架使用副本优化器。
    1. 修改获取优化器时使用的优化器类。
      def tft_get_megatron_optimizer_based_on_param_groups(
          config: OptimizerConfig,
          param_groups: List,
          per_model_buffers: Optional[Dict[int, List[ParamAndGradBuffer]]] = None,
          data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
          data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None,
          data_parallel_group_idx: Optional[int] = None,
          ori_dp_group: Optional[torch.distributed.ProcessGroup] = None,
      ) -> MegatronOptimizer:
          if config.optimizer == 'adam':
              optimizer = Adam(
                  param_groups,
                  lr=config.lr,
                  weight_decay=config.weight_decay,
                  betas=(config.adam_beta1, config.adam_beta2),
                  eps=config.adam_eps,
              )
              def init_state_fn(opt):
                  for group in opt.param_groups:
                      for p in group['params']:
                          if len(opt.state[p]) == 0:
                              opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
                              opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
          elif config.optimizer == 'sgd':
              optimizer = SGD(
                  param_groups,
                  lr=config.lr,
                  weight_decay=config.weight_decay,
                  momentum=config.sgd_momentum,
              )
              init_state_fn = None
          else:
              raise Exception('{} optimizer is not supported.'.format(config.optimizer))
      
          if config.fp16 or config.bf16 or config.use_distributed_optimizer:
              grad_scaler = None
      
              # Constant loss scale.
              if config.loss_scale:
                  grad_scaler = ConstantGradScaler(config.loss_scale)
              # Dynamic loss scale.
              else:
                  if config.fp16:
                      grad_scaler = DynamicGradScaler(
                          initial_scale=config.initial_loss_scale,
                          min_scale=config.min_loss_scale,
                          growth_factor=2.0,
                          backoff_factor=0.5,
                          growth_interval=config.loss_scale_window,
                          hysteresis=config.hysteresis,
                      )
              optimizer_args = [
                  optimizer,
                  config,
                  grad_scaler,
                  init_state_fn,
              ]
              if config.use_distributed_optimizer:
                  optimizer = TFTReplicaOptimizer(
                      *optimizer_args,
                      per_model_buffers=per_model_buffers,
                      data_parallel_group=data_parallel_group,
                      data_parallel_group_gloo=data_parallel_group_gloo,
                      data_parallel_group_idx=data_parallel_group_idx,
                      ori_dp_group=ori_dp_group
                  )
              else:
                  optimizer = TFTFP16ReplicaOptimizer(*optimizer_args, ori_dp_group=ori_dp_group)
              return optimizer
      
          # FP32.
          return TFTFP32ReplicaOptimizer(optimizer, config, init_state_fn, ori_dp_group=ori_dp_group)

      如果封装新的优化器类,需要在框架侧替换所使用的优化器,需要注意的是如果入参改变也请一并修改。

    2. 对于分布式优化器,还需要在框架流程中额外修改。
      def tft_clip_grad_norm_fp32_wrapper(clip_grad_norm_fp32):
          @wraps(clip_grad_norm_fp32)
          def wrapper(parameters, grads_for_norm, *args, **kwargs):
              if get_args().use_distributed_optimizer:
                  norm_type = kwargs.get('norm_type', 2)
                  for idx, grad in enumerate(grads_for_norm):
                      grads_for_norm[idx] = grad / get_replica_dp_num() ** (1/norm_type)
              return clip_grad_norm_fp32(parameters, grads_for_norm, *args, **kwargs)
          return wrapper
      
      def tft_build_model_gbuf_range_wrapper(_build_model_gbuf_range):
          @wraps(_build_model_gbuf_range)
          def wrapper(param_and_grad_buffer, bucket_index):
              ori_dp_group = param_and_grad_buffer.data_parallel_group
              is_expert_parallel = []
              for param in param_and_grad_buffer.buckets[0].params_list:
                  is_expert_parallel.append(not getattr(param, 'allreduce', True))
              if all(is_expert_parallel):
                  param_and_grad_buffer.data_parallel_group = get_dp_ep_replica_group()
              elif not any(is_expert_parallel):
                  param_and_grad_buffer.data_parallel_group = get_dp_cp_replica_group()
              else:
                  raise ValueError("Mixed boolean values found.")
              data = _build_model_gbuf_range(param_and_grad_buffer, bucket_index)
              param_and_grad_buffer.data_parallel_group = ori_dp_group
              return data
          return wrapper

      clip_grad_norm_fp32方法用以梯度裁剪,当使用分布式优化器时会在原始DP组内梯度积累,但是由于副本优化器改变了分布式优化器的切分方式,需要在此处对梯度做修正;build_model_gbuf_range为Megatron分布式优化器的类方法,用以在DP组内切分优化器,此时需要在切分优化器时使用副本DP组,需要注意的是在使用MoE特性时需要区分模型的稠密层和稀疏层。