大模型迁移

  1. 修改LLAMA模型套件训练脚本“fastchat/train/train_mem.py”

    cd FastChat-76f0424d1add61aadc8e5bdeed5ebe540f266ba3
    vi fastchat/train/train_mem.py 

    修改前:

    # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
    
    # Need to call this before importing transformers.
    from fastchat.train.llama_flash_attn_monkey_patch import (
        replace_llama_attn_with_flash_attn,
    )
    
    replace_llama_attn_with_flash_attn()
    
    from fastchat.train.train import train
    
    if __name__ == "__main__":
        train()

    修改后:

    # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
    
    # Need to call this before importing transformers.
    # from fastchat.train.llama_flash_attn_monkey_patch import (
    #    replace_llama_attn_with_flash_attn,
    #)        (关闭 Flash Attention)
    
    # replace_llama_attn_with_flash_attn()
    import os
    import torch
    import torch_npu       #  (使torch支持NPU)
    import deepspeed
    import deepspeed_npu       #  (使deepspeed支持NPU)
    from torch_npu.contrib import transfer_to_npu
    
    from fastchat.train.train import train
    
    os.environ["WANDB_MODE"] = "offline"
    os.environ["WANDB_DISABLED"] = "TRUE"                  # 关闭WANDB训练日志功能
    
    if __name__ == "__main__":
        torch.npu.set_compile_mode(jit_compile=True)         #  设置静态模式
        deepspeed.init_distributed('hccl')                    #  初始化deepspeed
    
        train()

  2. 修改LLAMA模型套件脚本“fastchat/serve/inference.py”

    vi fastchat/serve/inference.py
    1. 第133行。
      修改前:
          if (device == "cuda" and num_gpus == 1) or device == "mps":
              model.to(device)

      修改后:

          if (device == "cuda" and num_gpus == 1) or device == "mps":
              model.npu(device)                                  #  .npu 替换 .cuda
    2. 第142行。

      修改前:

      @torch.inference_mode()
      def generate_stream(
          model,
          tokenizer,
          params: Dict,
          device: str,
          context_len: int,
          stream_interval: int = 2,
          judge_sent_end: bool = False,
      ):

      修改后:

      @torch.no_grad()                                            #  等价替换
      def generate_stream(
          model,
          tokenizer,
          params: Dict,
          device: str,
          context_len: int,
          stream_interval: int = 2,
          judge_sent_end: bool = False,
      ):
    3. 第177行。
      修改前:
      out = model(torch.as_tensor([input_ids], device=device), use_cache=True)    

      修改后::

      out = model(torch.as_tensor(input_ids).npu(device),use_cache=True)   #  .npu 替换 .cuda
    4. 第193行。
      修改前:
      input_ids=torch.as_tensor([[token]], device=device),                

      修改后:

      input_ids=torch.as_tensor([[token]]).npu(device),                  #  .npu 替换 .cuda

  3. 修改transformers模型套件脚本。

    1. 修改“/root/miniconda3/envs/fs37/lib/python3.7/site-packages/transformers/training_args.py”文件,此处目录请根据实际使用的python路径确认。

      注释def __post_init__(self):函数中的部分版本检测代码,即注释掉如下部分代码。

              if (
                  self.framework == "pt"
                  and is_torch_available()
                  and (self.device.type != "cuda")
                  and (get_xla_device_type(self.device) != "GPU")
                  and (self.fp16 or self.fp16_full_eval)
              ):
                  raise ValueError(
                      "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
                      " (`--fp16_full_eval`) can only be used on CUDA devices."
                  )
    2. 修改“/root/miniconda3/envs/py37/lib/python3.7/site-packages/transformers/utils/versions.py”

      注释def _compare_versions函数中版本检测代码,即注释掉如下部分代码。

          if not ops[op](version.parse(got_ver), version.parse(want_ver)):
              raise ImportError(
                  f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
              )
    3. 修改“/root/miniconda3/envs/test/lib/python3.7/site-packages/transformers/trainer.py”
      • 修改一。

        修改前:

            if torch.cuda.is_available():
                if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                    # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                    rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
                else:
                    rng_states["cuda"] = torch.cuda.random.get_rng_state()
        修改后:
            if torch.cuda.is_available():
                if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                    # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
                    rng_states["cuda"] = torch.npu.random.get_rng_state_all()     #  .npu 替换 .cuda
                else:
                    rng_states["cuda"] = torch.npu.random.get_rng_state()          #  .npu 替换 .cuda
      • 修改二。

        修改前:

            if torch.cuda.is_available():
                if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                    torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
                else:
                    try:
                        torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
                    except Exception as e:

        修改后:

            if torch.cuda.is_available():
                if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                    torch.npu.random.set_rng_state_all(checkpoint_rng_state["cuda"])    #  .npu 替换 .cuda
                else:
                    try:
                        torch.npu.random.set_rng_state(checkpoint_rng_state["cuda"])     #  .npu 替换 .cuda
                    except Exception as e:
    4. 修改“/root/miniconda3/envs/test/lib/python3.7/site-packages/transformers/models/llama/modeling_llama.py”
      • 修改一,引入torch_npu。

        修改前:

        import torch
        import torch.utils.checkpoint
        from torch import nn
        from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

        修改后:

        import torch
        import torch_npu
        import torch.utils.checkpoint
        from torch import nn
        from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
      • 修改二。
        修改前:
            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                    )
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)     
                )
        修改后:
            if attention_mask is not None:
                if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                    raise ValueError(
                        f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                    )
                attn_weights = attn_weights + attention_mask
                attn_weights = torch.max(
                    attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min).npu(torch.npu.current_device())   #  .npu 替换 .cuda
                )
      • 修改三,使用double格式规避torch在arm架构下的bug。

        修改前:

        class LlamaRotaryEmbedding(torch.nn.Module):
            def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
                super().__init__()
                inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))    
                self.register_buffer("inv_freq", inv_freq)

        修改后:

        class LlamaRotaryEmbedding(torch.nn.Module):
            def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
                super().__init__()
                inv_freq = 1.0 / (torch.tensor(base).double() ** (torch.arange(0, dim, 2).float().to(device) / dim).double())    # 不修改arm会计算错误
                self.register_buffer("inv_freq", inv_freq)
        
      • 修改四,float32转float16加速计算。

        修改前:

            def forward(self, hidden_states):
                input_dtype = hidden_states.dtype
                variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
                hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        
                return (self.weight * hidden_states).to(input_dtype)

        修改后:

            def forward(self, hidden_states):
                input_dtype = hidden_states.dtype
                variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True).half()
                hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        
                return (self.weight * hidden_states).to(input_dtype)
      • 修改五,Attentionmask计算简化,基本逻辑是将0与-65504的mask,修改为False和True的形式,计算少一些,并使用scalemasksoftmax融合算子。

        修改前:

        def _make_causal_mask(
            input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
        ):
            """
            Make causal mask used for bi-directional self-attention.
            """
            bsz, tgt_len = input_ids_shape
            mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
            mask_cond = torch.arange(mask.size(-1), device=device)
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)               
            mask = mask.to(dtype)
        

        修改后:

        def _make_causal_mask(
            input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
        ):
            """
            Make causal mask used for bi-directional self-attention.
            """
            bsz, tgt_len = input_ids_shape
            mask = torch.zeros((tgt_len, tgt_len), device=device)                               
            mask_cond = torch.arange(mask.size(-1), device=device)
            mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1)
            mask = mask.to(dtype)

        修改def _expand_mask函数。

        修改前:

        def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
            """
            Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
            """
            bsz, src_len = mask.size()
            tgt_len = tgt_len if tgt_len is not None else src_len
        
            expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        
            inverted_mask = 1.0 - expanded_mask
        
            return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

        修改后:

        def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
            """
            Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
            """
            bsz, src_len = mask.size()
            tgt_len = tgt_len if tgt_len is not None else src_len
        
            expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
        
            # inverted_mask = 1.0 - expanded_mask
        
            return expanded_mask
      • 修改六,引入softmax融合算子,替换scale、mask及softmax算子。

        修改前:

            past_key_value = (key_states, value_states) if use_cache else None
        
            attn_weights = torch.matmul(query_states key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
            attn_weights = attn_weights + attention_mask
            attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        修改后:

            past_key_value = (key_states, value_states) if use_cache else None
        
            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
            attn_weights = torch_npu.npu_scaled_masked_softmax(attn_weights, attention_mask, (1 / math.sqrt(self.head_dim)))
      • 修改七,修改def _prepare_decoder_attention_mask函数。
        修改前:
            def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        
                combined_attention_mask = None
                if input_shape[-1] > 1:
                    combined_attention_mask = _make_causal_mask(
                        input_shape,
                        inputs_embeds.dtype,
                        device=inputs_embeds.device,
                        past_key_values_length=past_key_values_length,
                    )
        
                if attention_mask is not None:
                    expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                        inputs_embeds.device
                    )
                    combined_attention_mask = (
                        expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
                    )
        
                return combined_attention_mask

        修改后:

            def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
                combined_attention_mask = None
                if input_shape[-1] > 1:
                    combined_attention_mask = _make_causal_mask(
                        input_shape,
                        inputs_embeds.dtype,
                        device=inputs_embeds.device,
                        past_key_values_length=past_key_values_length,
                    )
        
                if attention_mask is not None:
                    expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                        inputs_embeds.device
                    )
                    combined_attention_mask = expanded_attn_mask+ combined_attention_mask
        
                return combined_attention_mask <= 1.