cd FastChat-76f0424d1add61aadc8e5bdeed5ebe540f266ba3 vi fastchat/train/train_mem.py
修改前:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. from fastchat.train.llama_flash_attn_monkey_patch import ( replace_llama_attn_with_flash_attn, ) replace_llama_attn_with_flash_attn() from fastchat.train.train import train if __name__ == "__main__": train()
修改后:
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. # from fastchat.train.llama_flash_attn_monkey_patch import ( # replace_llama_attn_with_flash_attn, #) (关闭 Flash Attention) # replace_llama_attn_with_flash_attn() import os import torch import torch_npu # (使torch支持NPU) import deepspeed import deepspeed_npu # (使deepspeed支持NPU) from torch_npu.contrib import transfer_to_npu from fastchat.train.train import train os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DISABLED"] = "TRUE" # 关闭WANDB训练日志功能 if __name__ == "__main__": torch.npu.set_compile_mode(jit_compile=True) # 设置静态模式 deepspeed.init_distributed('hccl') # 初始化deepspeed train()
vi fastchat/serve/inference.py
修改后:
if (device == "cuda" and num_gpus == 1) or device == "mps": model.npu(device) # .npu 替换 .cuda
@torch.inference_mode() def generate_stream( model, tokenizer, params: Dict, device: str, context_len: int, stream_interval: int = 2, judge_sent_end: bool = False, ):
修改后:
@torch.no_grad() # 等价替换 def generate_stream( model, tokenizer, params: Dict, device: str, context_len: int, stream_interval: int = 2, judge_sent_end: bool = False, ):
修改后::
out = model(torch.as_tensor(input_ids).npu(device),use_cache=True) # .npu 替换 .cuda
修改后:
input_ids=torch.as_tensor([[token]]).npu(device), # .npu 替换 .cuda
注释def __post_init__(self):函数中的部分版本检测代码,即注释掉如下部分代码。
if ( self.framework == "pt" and is_torch_available() and (self.device.type != "cuda") and (get_xla_device_type(self.device) != "GPU") and (self.fp16 or self.fp16_full_eval) ): raise ValueError( "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" " (`--fp16_full_eval`) can only be used on CUDA devices." )
注释def _compare_versions函数中版本检测代码,即注释掉如下部分代码。
if not ops[op](version.parse(got_ver), version.parse(want_ver)): raise ImportError( f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" )
if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.cuda.random.get_rng_state_all() else: rng_states["cuda"] = torch.cuda.random.get_rng_state()
if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.npu.random.get_rng_state_all() # .npu 替换 .cuda else: rng_states["cuda"] = torch.npu.random.get_rng_state() # .npu 替换 .cuda
if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) else: try: torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) except Exception as e:
修改后:
if torch.cuda.is_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: torch.npu.random.set_rng_state_all(checkpoint_rng_state["cuda"]) # .npu 替换 .cuda else: try: torch.npu.random.set_rng_state(checkpoint_rng_state["cuda"]) # .npu 替换 .cuda except Exception as e:
import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
修改后:
import torch import torch_npu import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) )
if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min).npu(torch.npu.current_device()) # .npu 替换 .cuda )
class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq)
修改后:
class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (torch.tensor(base).double() ** (torch.arange(0, dim, 2).float().to(device) / dim).double()) # 不修改arm会计算错误 self.register_buffer("inv_freq", inv_freq)
def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype)
修改后:
def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True).half() hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return (self.weight * hidden_states).to(input_dtype)
def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype)
修改后:
def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.zeros((tgt_len, tgt_len), device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1) mask = mask.to(dtype)
修改def _expand_mask函数。
修改前:
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
修改后:
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) # inverted_mask = 1.0 - expanded_mask return expanded_mask
past_key_value = (key_states, value_states) if use_cache else None attn_weights = torch.matmul(query_states key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
修改后:
past_key_value = (key_states, value_states) if use_cache else None attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) attn_weights = torch_npu.npu_scaled_masked_softmax(attn_weights, attention_mask, (1 / math.sqrt(self.head_dim)))
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask
修改后:
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, device=inputs_embeds.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( inputs_embeds.device ) combined_attention_mask = expanded_attn_mask+ combined_attention_mask return combined_attention_mask <= 1.