大模型迁移

  1. 创建ascendspeed/model/internlm_model.py文件。

    cd ascendspeed/model/
    vi internlm_model.py

  2. 在internlm_model.py导入与llama模型共用结构如下。

    from ascendspeed.model import llama_model                                                      
    RotaryEmbedding = llama_model.RotaryEmbedding                                               
    apply_rotary_pos_emb = llama_model.apply_rotary_pos_emb                                    
    llama_model.LlamaParallelAttention = InternLMParallelAttention                           
    InternModel = llama_model.LlamaModel                                                          
    InternModelPipe = llama_model.LlamaModelPipe                                                                                       

  3. 迁移书生模型Attention模块代码。

    修改前如下:
    class MHA(nn.Module):                                                                                                                    
        def __init__(                                                                                         
            self,                                                                                           
            embed_dim: int,                                                                                   
            num_heads: int,                                                                                    
            process_group: Optional[torch.distributed.ProcessGroup],                                     
            dropout: float = 0.0,                                                                              
            softmax_scale: float = None,                                                                          
            causal: bool = False,                                                                         
            layer_idx: int = None,                                                                            
            rotary_emb_dim: int = 0,                                                                        
            rotary_emb_scale_base: int = 0,                                                               
            use_flash_attn: bool = True,                                                                
            device: Optional[torch.device] = None,                                                         
            dtype: Optional[torch.dtype] = None,                                                             
        ) -> None:                                                                                             
            factory_kwargs = {"device": device, "dtype": dtype}                                       
            super().__init__()                                                                                  
            self.embed_dim = embed_dim                                                                         
            self.causal = causal                                                                                                
            self.layer_idx = layer_idx                                                                  
            self.rotary_emb_dim = rotary_emb_dim                                                         
            self.use_flash_attn = use_flash_attn                                                     
            self.num_heads = num_heads                                                                                 
            assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"                                                                                           
            self.head_dim = self.embed_dim // num_heads                                              
            if self.rotary_emb_dim > 0:                                                                   
                self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)                                             
            # notice here should change bias=True                                                 
            self.Wqkv = ColumnParallelLinearTorch(                                                   
                embed_dim,                                                                              
                3 * embed_dim,                                                                         
                process_group,                                                                         
                bias=True,                                                                                  
                sequence_parallel=gpc.config.parallel.sequence_parallel,                      
                **factory_kwargs,                                                                      
            )  # according to https://spaces.ac.cn/archives/9577
            inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention      
            inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention                                                                                     
            self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)                                                                      
            self.inner_cross_attn = inner_cross_attn_cls(                                       
                causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout      
            )                                                                                              
            # output projection always have the bias (for now)                                 
            self.out_proj = RowParallelLinearTorch(                                               
                embed_dim,                                                                              
                embed_dim,                                                                              
                process_group,                                                                          
                sequence_parallel=gpc.config.parallel.sequence_parallel,                      
                **factory_kwargs,                                                                     
            )                                                                                             
            # need to assign tp attribute so that internlm know it is tensor parallel module                                                                                               
            if gpc.get_world_size(ParallelMode.TENSOR) > 1:                                     
                for name in ["out_proj", "Wqkv"]:                                                  
                    for param in getattr(self, name).parameters():                              
                        setattr(param, IS_TENSOR_PARALLEL, True)                                  
        def forward(self, x, seqlen=None, inference_params=None, **kwargs):                 
            if kwargs.get("indexes", None) is not None:                                          
                return self._packed_forward(x=x, inference_params=inference_params, **kwargs)                                                                                           
            else:                                                                                            
                return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)                                                 
        def _forward(self, x, seqlen=None, inference_params=None, **kwargs):               
            """                                                                                          
            Arguments:                                                                                   
                x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.                                                                                     
                    If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we                                                                                       
                    split x during sequence parallel, we split the batch * seqlen dimension                                                                                           
                    (in case batch is small).                                                        
            """                                                                                           
            qkv = self.Wqkv(x)                                                                         
            if seqlen is None:                                                                        
                qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)                                                                                   
            else:                                                                                        
                qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)                                                                           
            if self.rotary_emb_dim > 0:                                                             
                kwargs["inference_params"] = inference_params                                   
                qkv = self.rotary_emb(qkv, **kwargs)                                              
            if inference_params is None:                                                            
                if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:                                                               
                    with torch.cuda.amp.autocast(dtype=torch.bfloat16):                        
                        if qkv.dtype not in [torch.float16, torch.bfloat16]:                   
                            qkv = qkv.to(torch.bfloat16)                                               
                        context = self.inner_attn(qkv).to(x.dtype)                               
                else:                                                                                     
                    context = self.inner_attn(qkv)                                                   
            else:                                                                                            
                q = qkv[:, :, 0]                                                                       
                assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"                                                                                       
                kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)     
                # If we're processing the prompt, causal=None (use self.causal).            
                # If we're decoding, then causal=False.                                            
                causal = None if inference_params.sequence_len_offset == 0 else False        
                context = self.inner_cross_attn(q, kv, causal=causal)                            
            if seqlen is None:                                                                         
                context = rearrange(context, "b s h d -> b s (h d)")                           
            else:                                                                                           
                context = rearrange(context, "b s h d -> (b s) (h d)")                        
            out = self.out_proj(context)                                                             
            return out                                                                                  
        def _packed_forward(self, x, inference_params=None, **kwargs):                       
            """                                                                                          
            Arguments:                                                                                  
                x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.                                                                                      
                    If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we                                                                                              
                    split x during sequence parallel, we split the batch * seqlen dimension                                                                                            
                    (in case batch is small).                                                        
            """                                                                                          
            qkv = self.Wqkv(x)  # total x hsz'                                                     
            qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim)  # total x 3 x n_head x d                                                     
            qkv = self.rotary_emb(qkv, **kwargs)                                                   
            kwargs.pop("indexes")                                                                      
            if inference_params is None:                                                            
                if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:                                                                
                    with torch.cuda.amp.autocast(dtype=torch.bfloat16):                         
                        if qkv.dtype not in [torch.float16, torch.bfloat16]:                   
                            qkv = qkv.to(torch.bfloat16)                                              
                        context = self.inner_attn(qkv, **kwargs).to(x.dtype)                   
                else:                                                                                    
                    context = self.inner_attn(qkv, **kwargs)                                        
            else:                                                                                        
                raise RuntimeError("Not support this right now")                                   
            context = rearrange(context, "b h d -> b (h d)")  # recover the shape           
            out = self.out_proj(context)                                                            
            return out 

    修改后如下:

    class InternLMParallelAttention(MegatronModule):                                            
        def __init__(self, init_method,                                                           
                     output_layer_init_method, layer_number,                                  
                     attention_type=AttnType.self_attn,                                         
                     attn_mask_type=AttnMaskType.causal):                                      
            super(InternLMParallelAttention, self).__init__()                                 
            args = get_args()                                                                        
            self.fp16 = args.fp16                                                                                    
            self.bf16 = args.bf16                                                                   
            self.sequence_parallel = args.sequence_parallel                                   
            self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling      
            self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32                   
            if self.apply_query_key_layer_scaling:                                               
                self.attention_softmax_in_fp32 = True                                          
            self.layer_number = max(1, layer_number)                                            
            self.attention_type = attention_type                                                
            self.attn_mask_type = attn_mask_type                                                   
            self.init_method = init_method                                                        
            self.output_layer_init_method = output_layer_init_method                         
            self.num_attention_heads = args.num_attention_heads                               
            projection_size = args.kv_channels * args.num_attention_heads                  
            world_size = parallel_state.get_tensor_model_parallel_world_size()                
            self.hidden_size_per_partition = utils.divide(projection_size, world_size)   
            self.hidden_size_per_attention_head = utils.divide(                               
                projection_size, args.num_attention_heads)                                       
            self.num_attention_heads_per_partition = utils.divide(                           
                args.num_attention_heads, world_size)                                             
            if attention_type == AttnType.self_attn:                                              
                self.query_key_value = mpu.ColumnParallelLinear(                                  
                    args.hidden_size, 3*projection_size, bias=True, gather_output=False,                                                                                                      
                    init_method=self.init_method,                                                
                       sequence_parallel_enabled=self.sequence_parallel)                                                                                        
            coeff = None                                                                              
            self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)             
            if self.apply_query_key_layer_scaling:                                               
                coeff = self.layer_number                                                          
                self.norm_factor *= coeff                                                         
            self.scale_mask_softmax = NPUFusedScaleMaskSoftmax(                                 
                self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion,  
                attention_mask_func, self.attention_softmax_in_fp32, coeff * (1.0 / self.norm_factor))                                                                                    
            self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head)            
            self.use_triangle_attn = args.triangle_attn                                         
            if self.use_triangle_attn:                                                                 
                self.triangle_attn = TriangleAttention(block_size=1024,                           
                                                       masked_softmax_func=self.scale_mask_softmax)                                                                                        
            self.dense = mpu.RowParallelLinear(                                                     
                projection_size, args.hidden_size, bias=True, input_is_parallel=True,             
                init_method=self.output_layer_init_method, skip_bias_add=False,                 
                sequence_parallel_enabled=self.sequence_parallel)                             
        def forward(self, hidden_states, attention_mask, layer_past=None,                   
                    get_key_value=False):                                                              
            if self.attention_type == AttnType.self_attn:                                       
                mixed_x_layer, _ = self.query_key_value(hidden_states)                       
                new_tensor_shape = mixed_x_layer.size()[:-1] + \                             
                                   (self.num_attention_heads_per_partition,                
                                    3 * self.hidden_size_per_attention_head)                 
                mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)                        
                (query_layer,                                                                        
                 key_layer,                                                                            
                 value_layer) = utils.split_tensor_along_last_dim(mixed_x_layer, 3)       
            query_layer = query_layer.permute(1, 2, 0, 3).contiguous()                      
            key_layer = key_layer.permute(1, 2, 0, 3).contiguous()                           
            value_layer = value_layer.permute(1, 2, 0, 3).contiguous()                      
            cos, sin = self.rotary_emb(value_layer, seq_len=new_tensor_shape[0])            
            query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, offset=0)                                                                                     
            if layer_past is not None:                                                             
                past_key, past_value = layer_past                                                
                key_layer = torch.cat((past_key.type_as(key_layer),                          
                                       key_layer), dim=0)                                      
                value_layer = torch.cat((past_value.type_as(value_layer),                   
                                         value_layer), dim=0)                                 
            if get_key_value:                                                                         
                present = (key_layer, value_layer)                                              
    
            if self.use_triangle_attn and layer_past is None:                                  
                context_layer = self.triangle_attn(query_layer, key_layer, value_layer, attention_mask)                                                                                   
                output, _ = self.dense(context_layer)                                          
                return output                                                                           
            attention_scores = torch.matmul(query_layer, key_layer.transpose(3, 2))      
            if get_key_value:                                                                         
                with torch.no_grad():                                                                
                    if layer_past is not None:                                                   
                        attention_mask = attention_mask[                                       
                                         ...,                                                     
                                         attention_scores.size(3) - 1,                      
                                         :attention_scores.size(3)].unsqueeze(2)          
                    else:                                                                             
                        attention_mask = attention_mask[                                       
                                         ...,                                                    
                                         :attention_scores.size(3),                            
                                         :attention_scores.size(3)]                              
            attention_probs = self.scale_mask_softmax(attention_scores,                     
                                                      attention_mask)                        
            context_layer = torch.matmul(attention_probs, value_layer)                         
            bs, nh, sq, hd = context_layer.shape                                                    
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()                    
            context_layer = context_layer.view(sq, bs, nh * hd)                                
            output, _ = self.dense(context_layer)                                                
            if get_key_value:                                                                         
                output = [output, present]                                                           
            return output   

  4. 创建pretrain_intern.py主函数文件,导入与llama的共用数据和前向模块如下:

    from ascendspeed.training import pretrain                                                    
    from pretrain_llama import (                                                                    
        data_post_process,                                                                            
        forward_step,                                                                                   
        train_valid_test_datasets_provider                                                         
    )   

  5. 在pretrain_intern.py主函数文件中定义InternLM模型的实例化模块如下:

    def model_provider(pre_process=True, post_process=True):                                    
        """Build the model."""                                                                          
        print_rank_0('building InternLM model ...')                                               
        see_memory_usage(f"Before Building Model", force=True)                                     
        args = get_args()                                                                                   
        with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),       
                                 remote_device=None if args.remote_device == 'none' else args.remote_device,                                                                               
                                 config_dict_or_path=args.deepspeed_config,                       
                                 enabled=args.zero_stage == 3,                                        
                                 mpu=mpu):                                                              
            if args.deepspeed and not args.no_pipeline_parallel:                               
                model = InternModelPipe(parallel_output=True)                                     
                attention_mask = torch.tril(torch.ones(                                              
                    (1, args.seq_length, args.seq_length), device=get_accelerator().current_device_name())).view(                                     
                    1, 1, args.seq_length, args.seq_length)                                       
                # Convert attention mask to binary:                                                 
                attention_mask = (attention_mask < 0.5)                                            
                if args.fp16:                                                                          
                    attention_mask = attention_mask.half()                                        
                elif args.bf16:                                                                        
                    attention_mask = attention_mask.bfloat16()                                   
                # Attention mask must be bool.                                                      
                args.attn_mask = attention_mask.to(torch.bool)                                     
            else:                                                                                        
                model = InternModel(                                                                  
                    parallel_output=True,                                                              
                    add_pooler=False,                                                                   
                    pre_process=pre_process,                                                         
                    post_process=post_process                                                        
                )                                                                                         
        see_memory_usage(f"After Building Model", force=True)                                       
    return model                

  6. 在pretrain_intern.py主函数文件中定义预训练主函数如下:

     if __name__ == "__main__":                                                                      
        torch.npu.set_compile_mode(jit_compile=True)                                             
        pretrain(train_valid_test_datasets_provider, model_provider, forward_step,           
                 args_defaults={'tokenizer_type': 'PretrainedFromHF'},                          
                 data_post_process=data_post_process)