Flash{model}ForCausalLM类适配

Flash{model}ForCausalLM类实现了模型的初始化、权重加载以及前向推理,位于“{llm_path}/atb_llm/models/{model}/flash_causal_{model}.py”文件。

以InternVL2.5为例如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from ..base.flash_causal_multimodal import MultiModalLLm
class FlashInternvlForCausalLM(FlashForCausalLM):
    def __init__(self, config, weights, **kwargs):
        super().__init__(config, weights, **kwargs)
        self.config = config
        self.trust_remote_code = kwargs.get('trust_remote_code', False)
        self.weights = weights # id相等,引用传递
        self.dtype = weights.dtype
        self.vision_config = config.vision_config
        enable_vit_dp = kwargs.get('enable_vit_dp', True)
        setattr(self.vision_config, 'enable_vit_dp', enable_vit_dp)
        self.llm_config = config.llm_config
        self.llm_config.quantize = None
        # 图片处理相关参数
        self.downsample_ratio = config.downsample_ratio
        self.vit_hidden_size = self.vision_config.hidden_size
        self.llm_hidden_size = self.llm_config.hidden_size
        self.image_size = config.force_image_size or self.vision_config.image_size
        self.patch_size = self.vision_config.patch_size
        self.select_layer = config.select_layer
        self.num_image_token = int((self.image_size // self.patch_size) ** 2 * (self.downsample_ratio ** 2))
        self.neftune_alpha = None
        self.im_mask = None
        self.template = config.template
        self.ps_version = config.ps_version
        if self.template not in ['Hermes-2', 'internlm2-chat', 'phi3-chat', 'internvl2_5']:
            raise ValueError(
                f"Unsupported template {self.template}, supported templates are `Hermes-2`, "
                "`internlm2-chat`, `phi3-chat`, `internvl2_5`. Please check the value of 'template' in config.json"
            )
        if self.ps_version not in ['v1', 'v2']:
            raise ValueError(
                f"Unsupported ps_version {self.ps_version}, supported templates are `v1` and `v2`."
                "Please check the value of 'ps_version' in config.json"
            )
        self.npu_id = weights.device.index
        self.process_group, self.device = initialize_torch_distributed(self.tp_rank, self.npu_id, self.tp_world_size)
        self.init_llm_model_type()
        self.init_vision_model()
        self.init_mlp_projector()
        self.init_language_model()
        self.init_normalizer()
        if self.llm_model_type == ACTION_INTERNLM2:
            if self.dtype != torch.float16:
                logger.error(
                    f"{self.dtype} is unsupported, supported dtypes are float16."
                    "Please check the value of 'torch_dtype' in config.json",
                    ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID
                )
                raise ValueError(
                    f"{self.dtype} is unsupported, supported dtypes are float16."
                    "Please check the value of 'torch_dtype' in config.json"
                )
            self.llm_embedding_layer = self.language_model.get_embedding_layer()
            self.processor = safe_get_tokenizer_from_pretrained(
                config.model_name_or_path, trust_remote_code=self.trust_remote_code
            )
            self.config.eos_token_id = INTERNLM2_EOS_TOKEN_ID
        elif self.llm_model_type == ACTION_LLAMA:
            self.llm_embedding_layer = self.language_model.model.embed_tokens
            self.processor = safe_get_tokenizer_from_pretrained(
                self.config.model_name_or_path, trust_remote_code=self.trust_remote_code, use_fast=False
            )
            self.config.eos_token_id = self.llm_config.eos_token_id
        elif self.llm_model_type == ACTION_QWEN2:
            self.llm_embedding_layer = self.language_model.transformer.wte
            self.processor = safe_get_tokenizer_from_pretrained(
                self.config.model_name_or_path, padding_side="left", trust_remote_code=self.trust_remote_code,
            )
            self.config.eos_token_id = self.llm_config.eos_token_id
        self.img_begin_id = self.processor.encode("<img>")[-1]
        self.img_end_id = self.processor.encode("</img>")[-1]
        self.img_context_token_id = self.processor.encode("<IMG_CONTEXT>")[-1]
    def init_module_weight(self, module, weights, prefix="model", prefixskip=None):
        model_weights = [model_weight for model_weight in module.state_dict().keys()]
        for model_weight in model_weights:
            if prefixskip and prefixskip in model_weight:
                continue
            saved_weight = torch.nn.Parameter(
                    weights.get_tensor(f"{prefix}.{model_weight}"), requires_grad=False
                )
            if not self.vision_config.enable_vit_dp:
                saved_weight = internvl_tensor_parallel_split(model_weight, prefix, \
                    self.tp_rank, self.tp_world_size, saved_weight)
            model_weight_list = model_weight.split(".")
            target_module = module
            for nxt_module in model_weight_list[:-1]:
                target_module = getattr(target_module, nxt_module)
            setattr(target_module, model_weight_list[-1], saved_weight)
    def init_llm_model_type(self):
        llm_model_architectures = self.llm_config.architectures[0]
        if llm_model_architectures == INTERNLM2_ARCHITECTURE:
            self.llm_model_type = ACTION_INTERNLM2 # internlm: VL2-2B、VL2-8B、VL2-20B
        elif llm_model_architectures == LLAMA_ARCHITECTURE:
            self.llm_model_type = ACTION_LLAMA # llama, yi: VL2-40B、VL2-76B
        elif llm_model_architectures == QWEN2_ARCHITECTURE:
            self.llm_model_type = ACTION_QWEN2 # qwen: VL2-1B
        else:
            logger.error("Currently only InternVL-Chat-V1-2、InternVL-Chat-V1-5、InternVL2 are supported. "
                         "Please check `config.json`.",
                         ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID)
            raise KeyError("Currently only InternVL-Chat-V1-2、InternVL-Chat-V1-5、InternVL2 are supported. "
                           "Please check `config.json`.")
    def init_normalizer(self):
        weight, bias = create_standardization_params(IMAGENET_MEAN, IMAGENET_STD, RESCALE_FACTOR, CONV_CHANNELS)
        self.normalizer = nn.Conv2d(in_channels=CONV_CHANNELS, out_channels=CONV_CHANNELS, kernel_size=1, \
            groups=CONV_GROUPS)
        self.normalizer.weight = nn.Parameter(data=weight, requires_grad=False)
        self.normalizer.bias = nn.Parameter(data=bias, requires_grad=False)
        self.normalizer.npu()
        # Normalizer warmup
        self.normalizer(torch.randn(MAX_NUM_PATCHES, CONV_CHANNELS, IMAGE_SIZE, IMAGE_SIZE, device='npu'))
    def init_vision_model(self):
        self.vision_model = InternVisionModel(self.vision_config, self.process_group).to(dtype=self.dtype)
        self.init_module_weight(self.vision_model, self.weights, prefix="vision_model")
        self.vision_model = self.vision_model.to(self.device)
    def init_mlp_projector(self):
        if self.downsample_ratio == 0:
            raise ZeroDivisionError("Downsample ratio will be zero")
        input_dim = self.vit_hidden_size * int(np.divide(1, self.downsample_ratio)) ** 2
        if self.vision_config.enable_vit_dp:
            self.mlp1 = nn.Sequential(
                nn.LayerNorm(input_dim),
                nn.Linear(input_dim, self.llm_hidden_size),
                nn.GELU(),
                nn.Linear(self.llm_hidden_size, self.llm_hidden_size)
            ).to(dtype=self.dtype)
        else:
            self.mlp1 = nn.Sequential(
                nn.LayerNorm(input_dim),
                ColumnLinear(input_dim, self.llm_hidden_size, gather_output=False, process_group=self.process_group),
                nn.GELU(),
                RowLinear(self.llm_hidden_size, self.llm_hidden_size, process_group=self.process_group)
            ).to(dtype=self.dtype)
        self.init_module_weight(self.mlp1, self.weights, prefix="mlp1")
        self.mlp1 = self.mlp1.to(self.device)
    def init_language_model(self):
        model_type = self.llm_model_type
        if model_type == ACTION_INTERNLM2:
            self.language_model = FlashInternlm2ForCausalLM(
                self.config,
                self.weights,
                lmhead_prefix=LMHEAD_PREFIX.replace('lm_head', 'output'),
                model_prefix=MODEL_PREFIX,
            )
        elif model_type == ACTION_LLAMA:
            self.language_model = LlamaForCausalLM(
                self.llm_config,
                self.weights,
                lmhead_prefix=LMHEAD_PREFIX,
                model_prefix=MODEL_PREFIX,
            )
        elif model_type == ACTION_QWEN2:
            self.language_model = FlashQwen2ForCausalLM(
                self.llm_config,
                self.weights,
                lmhead_prefix=LMHEAD_PREFIX,
                model_prefix=MODEL_PREFIX,
                transformer_wte_parallel=False,
            )
        else:
            logger.error(f"Currently only {LLAMA_ARCHITECTURE}{INTERNLM2_ARCHITECTURE}{QWEN2_ARCHITECTURE} "
                         "are supported. Please check `config.json`.",
                         ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID)
            raise KeyError(f"Currently only {LLAMA_ARCHITECTURE}{INTERNLM2_ARCHITECTURE}{QWEN2_ARCHITECTURE} "
                           "are supported. Please check `config.json`.")
        self.language_model.skip_word_embedding = True
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        if scale_factor == 0:
            raise ZeroDivisionError("Scale factor will be zero")
        x = x.view(n, w, int(h * scale_factor), int(np.divide(c, scale_factor)))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        if scale_factor == 0:
            raise ZeroDivisionError("Scale factor will be zero")
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(np.divide(c, scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            print_log(self.tp_rank, logger.warning, 
                      "In ps_version 'v1', the height and width have not been swapped back, "
                      "which results in a transposed image.")
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x
    def noised_embed(self, vit_embeds, noise_alpha=5):
        dims = torch.tensor(vit_embeds.size(1) * vit_embeds.size(2))
        if dims == 0:
            raise ZeroDivisionError("Dim of the tensor is zero")
        mag_norm = np.divide(noise_alpha, torch.sqrt(dims))
        noise = torch.zeros_like(vit_embeds).uniform_(-mag_norm, mag_norm)
        return vit_embeds + noise
    def extract_feature(self, pixel_values):
        if self.select_layer == -1:
            vit_embeds = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=False,
                return_dict=True).last_hidden_state
        else:
            vit_embeds = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=True,
                return_dict=True).hidden_states[self.select_layer]
        vit_embeds = vit_embeds[:, 1:, :]
        if self.training and self.neftune_alpha is not None:
            vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

除上述模型初始化的相关方法外,还需要适配prepare_prefill_token和forward两种重要方法以用于前向推理。

  1. prepare_prefill_token,用于得到encoder和tokenizer融合后的input_embeds、前处理、encoder、tokenizer等。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    def prepare_prefill_token(self, multimodalinputs, processor):
            text = multimodalinputs.text
            image = multimodalinputs.image
            video = multimodalinputs.video
            current_query = ""
            if image is not None:
                use_dynamic_prepro = False if self.ps_version == "v1" else True
                pixel_values = load_and_preprocess_image(image, normalizer=self.normalizer, \
                    use_dynamic_prepro=use_dynamic_prepro).to(self.dtype).to(self.device)
                vit_embeds = self.extract_feature(pixel_values).to(self.dtype).to(self.device)
                image_tokens_num = self.num_image_token * vit_embeds.shape[0]
                current_query = (f'<img>{"<IMG_CONTEXT>" * image_tokens_num}</img>\n')
            elif video is not None:
                pixel_values, num_patches_list = load_video(video)
                pixel_values = pixel_values.to(self.dtype).to(self.device)
                vit_embeds = self.extract_feature(pixel_values).to(self.dtype).to(self.device)
                for i, num_patch in enumerate(num_patches_list):
                    current_query += (f'Frame{i+1}: '
                        f'<img>{"<IMG_CONTEXT>" * num_patch * self.num_image_token}</img>\n')
            
            system_prompt = INTERNVL_SYSTEM_PROMPTS[self.ps_version][self.template]
            texts = ('<|im_start|>system\n'
                    f'{system_prompt}<|im_end|><|im_start|>user\n')
            texts += current_query
            texts += (f'{text}<|im_end|><|im_start|>assistant\n')
            input_ids = processor.encode(texts)
            input_ids = torch.tensor(input_ids, requires_grad=False).to(self.device)
            input_embeds = self.llm_embedding_layer(input_ids)
            sequence_length, embedding_size = input_embeds.shape
            input_ids = input_ids.reshape(sequence_length)
            vit_embeds = vit_embeds.reshape(-1, embedding_size)
            selected = (input_ids == self.img_context_token_id)
            
            try:
                input_embeds[selected] = input_embeds[selected] * torch.zeros(1, dtype=self.dtype,
                                                            device=self.device) + vit_embeds.reshape(-1, embedding_size)
            except Exception as e:
                error_msg = f'{e} \ninput_embeds[selected].shape is {input_embeds[selected].shape}, '\
                            f'vit_embeds.shape is {vit_embeds.shape}\n'\
                            f'Please check whether shape of input_embeds[selected] matches the shape of vit_embeds.\n'\
                            f'If not, please check whether self.img_context_token_id '\
                            f'and the token-id of "<IMG_CONTEXT>" in processor are the same.'
                logger.error(error_msg,
                             ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
                raise ValueError(error_msg) from e
            input_embeds = input_embeds.reshape(-1, embedding_size)
            return input_embeds
    
  2. forward,由于底座大语言模型已适配,直接self.language_model.forward即可。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    def forward(
            self,
            input_ids: torch.Tensor,
            position_ids: torch.Tensor,
            is_prefill: bool,
            kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
            block_tables: torch.Tensor,
            slots: torch.Tensor,
            input_lengths: torch.Tensor,
            max_seq_len: int,
            lm_head_indices: Optional[torch.Tensor] = None,
            **kwargs):
            if is_prefill and input_ids.dim() == 1:
                input_ids = self.prepare_prefill_token_service(input_ids)
            return self.language_model.forward(input_ids,
                                              position_ids,
                                              is_prefill,
                                              kv_cache,
                                              block_tables,
                                              slots,
                                              input_lengths,
                                              max_seq_len,
                                              lm_head_indices)