cd ascendspeed/model/ vi internlm_model.py
from ascendspeed.model import llama_model RotaryEmbedding = llama_model.RotaryEmbedding apply_rotary_pos_emb = llama_model.apply_rotary_pos_emb llama_model.LlamaParallelAttention = InternLMParallelAttention InternModel = llama_model.LlamaModel InternModelPipe = llama_model.LlamaModelPipe
class MHA(nn.Module): def __init__( self, embed_dim: int, num_heads: int, process_group: Optional[torch.distributed.ProcessGroup], dropout: float = 0.0, softmax_scale: float = None, causal: bool = False, layer_idx: int = None, rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal self.layer_idx = layer_idx self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads if self.rotary_emb_dim > 0: self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True self.Wqkv = ColumnParallelLinearTorch( embed_dim, 3 * embed_dim, process_group, bias=True, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) # according to https://spaces.ac.cn/archives/9577 inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) # output projection always have the bias (for now) self.out_proj = RowParallelLinearTorch( embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) # need to assign tp attribute so that internlm know it is tensor parallel module if gpc.get_world_size(ParallelMode.TENSOR) > 1: for name in ["out_proj", "Wqkv"]: for param in getattr(self, name).parameters(): setattr(param, IS_TENSOR_PARALLEL, True) def forward(self, x, seqlen=None, inference_params=None, **kwargs): if kwargs.get("indexes", None) is not None: return self._packed_forward(x=x, inference_params=inference_params, **kwargs) else: return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) def _forward(self, x, seqlen=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ qkv = self.Wqkv(x) if seqlen is None: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) else: qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) if self.rotary_emb_dim > 0: kwargs["inference_params"] = inference_params qkv = self.rotary_emb(qkv, **kwargs) if inference_params is None: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: qkv = qkv.to(torch.bfloat16) context = self.inner_attn(qkv).to(x.dtype) else: context = self.inner_attn(qkv) else: q = qkv[:, :, 0] assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) # If we're processing the prompt, causal=None (use self.causal). # If we're decoding, then causal=False. causal = None if inference_params.sequence_len_offset == 0 else False context = self.inner_cross_attn(q, kv, causal=causal) if seqlen is None: context = rearrange(context, "b s h d -> b s (h d)") else: context = rearrange(context, "b s h d -> (b s) (h d)") out = self.out_proj(context) return out def _packed_forward(self, x, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ qkv = self.Wqkv(x) # total x hsz' qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d qkv = self.rotary_emb(qkv, **kwargs) kwargs.pop("indexes") if inference_params is None: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: qkv = qkv.to(torch.bfloat16) context = self.inner_attn(qkv, **kwargs).to(x.dtype) else: context = self.inner_attn(qkv, **kwargs) else: raise RuntimeError("Not support this right now") context = rearrange(context, "b h d -> b (h d)") # recover the shape out = self.out_proj(context) return out
修改后如下:
class InternLMParallelAttention(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.causal): super(InternLMParallelAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 self.sequence_parallel = args.sequence_parallel self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.init_method = init_method self.output_layer_init_method = output_layer_init_method self.num_attention_heads = args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = utils.divide(projection_size, world_size) self.hidden_size_per_attention_head = utils.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = utils.divide( args.num_attention_heads, world_size) if attention_type == AttnType.self_attn: self.query_key_value = mpu.ColumnParallelLinear( args.hidden_size, 3*projection_size, bias=True, gather_output=False, init_method=self.init_method, sequence_parallel_enabled=self.sequence_parallel) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = NPUFusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff * (1.0 / self.norm_factor)) self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head) self.use_triangle_attn = args.triangle_attn if self.use_triangle_attn: self.triangle_attn = TriangleAttention(block_size=1024, masked_softmax_func=self.scale_mask_softmax) self.dense = mpu.RowParallelLinear( projection_size, args.hidden_size, bias=True, input_is_parallel=True, init_method=self.output_layer_init_method, skip_bias_add=False, sequence_parallel_enabled=self.sequence_parallel) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False): if self.attention_type == AttnType.self_attn: mixed_x_layer, _ = self.query_key_value(hidden_states) new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) (query_layer, key_layer, value_layer) = utils.split_tensor_along_last_dim(mixed_x_layer, 3) query_layer = query_layer.permute(1, 2, 0, 3).contiguous() key_layer = key_layer.permute(1, 2, 0, 3).contiguous() value_layer = value_layer.permute(1, 2, 0, 3).contiguous() cos, sin = self.rotary_emb(value_layer, seq_len=new_tensor_shape[0]) query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, offset=0) if layer_past is not None: past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if get_key_value: present = (key_layer, value_layer) if self.use_triangle_attn and layer_past is None: context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask) output, _ = self.dense(context_layer) return output attention_scores = torch.matmul(query_layer, key_layer.transpose(3, 2)) if get_key_value: with torch.no_grad(): if layer_past is not None: attention_mask = attention_mask[ ..., attention_scores.size(3) - 1, :attention_scores.size(3)].unsqueeze(2) else: attention_mask = attention_mask[ ..., :attention_scores.size(3), :attention_scores.size(3)] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) context_layer = torch.matmul(attention_probs, value_layer) bs, nh, sq, hd = context_layer.shape context_layer = context_layer.permute(2, 0, 1, 3).contiguous() context_layer = context_layer.view(sq, bs, nh * hd) output, _ = self.dense(context_layer) if get_key_value: output = [output, present] return output
from ascendspeed.training import pretrain from pretrain_llama import ( data_post_process, forward_step, train_valid_test_datasets_provider )
def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0('building InternLM model ...') see_memory_usage(f"Before Building Model", force=True) args = get_args() with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=None if args.remote_device == 'none' else args.remote_device, config_dict_or_path=args.deepspeed_config, enabled=args.zero_stage == 3, mpu=mpu): if args.deepspeed and not args.no_pipeline_parallel: model = InternModelPipe(parallel_output=True) attention_mask = torch.tril(torch.ones( (1, args.seq_length, args.seq_length), device=get_accelerator().current_device_name())).view( 1, 1, args.seq_length, args.seq_length) # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) if args.fp16: attention_mask = attention_mask.half() elif args.bf16: attention_mask = attention_mask.bfloat16() # Attention mask must be bool. args.attn_mask = attention_mask.to(torch.bool) else: model = InternModel( parallel_output=True, add_pooler=False, pre_process=pre_process, post_process=post_process ) see_memory_usage(f"After Building Model", force=True) return model
if __name__ == "__main__": torch.npu.set_compile_mode(jit_compile=True) pretrain(train_valid_test_datasets_provider, model_provider, forward_step, args_defaults={'tokenizer_type': 'PretrainedFromHF'}, data_post_process=data_post_process)