Flash{model}ForCausalLM类实现了模型的初始化、权重加载以及前向推理,位于“{llm_path}/atb_llm/models/{model}/flash_causal_{model}.py”文件。
以InternVL2.5为例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | from ..base.flash_causal_multimodal import MultiModalLLm class FlashInternvlForCausalLM(FlashForCausalLM): def __init__(self, config, weights, **kwargs): super().__init__(config, weights, **kwargs) self.config = config self.trust_remote_code = kwargs.get('trust_remote_code', False) self.weights = weights # id相等,引用传递 self.dtype = weights.dtype self.vision_config = config.vision_config enable_vit_dp = kwargs.get('enable_vit_dp', True) setattr(self.vision_config, 'enable_vit_dp', enable_vit_dp) self.llm_config = config.llm_config self.llm_config.quantize = None # 图片处理相关参数 self.downsample_ratio = config.downsample_ratio self.vit_hidden_size = self.vision_config.hidden_size self.llm_hidden_size = self.llm_config.hidden_size self.image_size = config.force_image_size or self.vision_config.image_size self.patch_size = self.vision_config.patch_size self.select_layer = config.select_layer self.num_image_token = int((self.image_size // self.patch_size) ** 2 * (self.downsample_ratio ** 2)) self.neftune_alpha = None self.im_mask = None self.template = config.template self.ps_version = config.ps_version if self.template not in ['Hermes-2', 'internlm2-chat', 'phi3-chat', 'internvl2_5']: raise ValueError( f"Unsupported template {self.template}, supported templates are `Hermes-2`, " "`internlm2-chat`, `phi3-chat`, `internvl2_5`. Please check the value of 'template' in config.json" ) if self.ps_version not in ['v1', 'v2']: raise ValueError( f"Unsupported ps_version {self.ps_version}, supported templates are `v1` and `v2`." "Please check the value of 'ps_version' in config.json" ) self.npu_id = weights.device.index self.process_group, self.device = initialize_torch_distributed(self.tp_rank, self.npu_id, self.tp_world_size) self.init_llm_model_type() self.init_vision_model() self.init_mlp_projector() self.init_language_model() self.init_normalizer() if self.llm_model_type == ACTION_INTERNLM2: if self.dtype != torch.float16: logger.error( f"{self.dtype} is unsupported, supported dtypes are float16." "Please check the value of 'torch_dtype' in config.json", ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID ) raise ValueError( f"{self.dtype} is unsupported, supported dtypes are float16." "Please check the value of 'torch_dtype' in config.json" ) self.llm_embedding_layer = self.language_model.get_embedding_layer() self.processor = safe_get_tokenizer_from_pretrained( config.model_name_or_path, trust_remote_code=self.trust_remote_code ) self.config.eos_token_id = INTERNLM2_EOS_TOKEN_ID elif self.llm_model_type == ACTION_LLAMA: self.llm_embedding_layer = self.language_model.model.embed_tokens self.processor = safe_get_tokenizer_from_pretrained( self.config.model_name_or_path, trust_remote_code=self.trust_remote_code, use_fast=False ) self.config.eos_token_id = self.llm_config.eos_token_id elif self.llm_model_type == ACTION_QWEN2: self.llm_embedding_layer = self.language_model.transformer.wte self.processor = safe_get_tokenizer_from_pretrained( self.config.model_name_or_path, padding_side="left", trust_remote_code=self.trust_remote_code, ) self.config.eos_token_id = self.llm_config.eos_token_id self.img_begin_id = self.processor.encode("<img>")[-1] self.img_end_id = self.processor.encode("</img>")[-1] self.img_context_token_id = self.processor.encode("<IMG_CONTEXT>")[-1] def init_module_weight(self, module, weights, prefix="model", prefixskip=None): model_weights = [model_weight for model_weight in module.state_dict().keys()] for model_weight in model_weights: if prefixskip and prefixskip in model_weight: continue saved_weight = torch.nn.Parameter( weights.get_tensor(f"{prefix}.{model_weight}"), requires_grad=False ) if not self.vision_config.enable_vit_dp: saved_weight = internvl_tensor_parallel_split(model_weight, prefix, \ self.tp_rank, self.tp_world_size, saved_weight) model_weight_list = model_weight.split(".") target_module = module for nxt_module in model_weight_list[:-1]: target_module = getattr(target_module, nxt_module) setattr(target_module, model_weight_list[-1], saved_weight) def init_llm_model_type(self): llm_model_architectures = self.llm_config.architectures[0] if llm_model_architectures == INTERNLM2_ARCHITECTURE: self.llm_model_type = ACTION_INTERNLM2 # internlm: VL2-2B、VL2-8B、VL2-20B elif llm_model_architectures == LLAMA_ARCHITECTURE: self.llm_model_type = ACTION_LLAMA # llama, yi: VL2-40B、VL2-76B elif llm_model_architectures == QWEN2_ARCHITECTURE: self.llm_model_type = ACTION_QWEN2 # qwen: VL2-1B else: logger.error("Currently only InternVL-Chat-V1-2、InternVL-Chat-V1-5、InternVL2 are supported. " "Please check `config.json`.", ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID) raise KeyError("Currently only InternVL-Chat-V1-2、InternVL-Chat-V1-5、InternVL2 are supported. " "Please check `config.json`.") def init_normalizer(self): weight, bias = create_standardization_params(IMAGENET_MEAN, IMAGENET_STD, RESCALE_FACTOR, CONV_CHANNELS) self.normalizer = nn.Conv2d(in_channels=CONV_CHANNELS, out_channels=CONV_CHANNELS, kernel_size=1, \ groups=CONV_GROUPS) self.normalizer.weight = nn.Parameter(data=weight, requires_grad=False) self.normalizer.bias = nn.Parameter(data=bias, requires_grad=False) self.normalizer.npu() # Normalizer warmup self.normalizer(torch.randn(MAX_NUM_PATCHES, CONV_CHANNELS, IMAGE_SIZE, IMAGE_SIZE, device='npu')) def init_vision_model(self): self.vision_model = InternVisionModel(self.vision_config, self.process_group).to(dtype=self.dtype) self.init_module_weight(self.vision_model, self.weights, prefix="vision_model") self.vision_model = self.vision_model.to(self.device) def init_mlp_projector(self): if self.downsample_ratio == 0: raise ZeroDivisionError("Downsample ratio will be zero") input_dim = self.vit_hidden_size * int(np.divide(1, self.downsample_ratio)) ** 2 if self.vision_config.enable_vit_dp: self.mlp1 = nn.Sequential( nn.LayerNorm(input_dim), nn.Linear(input_dim, self.llm_hidden_size), nn.GELU(), nn.Linear(self.llm_hidden_size, self.llm_hidden_size) ).to(dtype=self.dtype) else: self.mlp1 = nn.Sequential( nn.LayerNorm(input_dim), ColumnLinear(input_dim, self.llm_hidden_size, gather_output=False, process_group=self.process_group), nn.GELU(), RowLinear(self.llm_hidden_size, self.llm_hidden_size, process_group=self.process_group) ).to(dtype=self.dtype) self.init_module_weight(self.mlp1, self.weights, prefix="mlp1") self.mlp1 = self.mlp1.to(self.device) def init_language_model(self): model_type = self.llm_model_type if model_type == ACTION_INTERNLM2: self.language_model = FlashInternlm2ForCausalLM( self.config, self.weights, lmhead_prefix=LMHEAD_PREFIX.replace('lm_head', 'output'), model_prefix=MODEL_PREFIX, ) elif model_type == ACTION_LLAMA: self.language_model = LlamaForCausalLM( self.llm_config, self.weights, lmhead_prefix=LMHEAD_PREFIX, model_prefix=MODEL_PREFIX, ) elif model_type == ACTION_QWEN2: self.language_model = FlashQwen2ForCausalLM( self.llm_config, self.weights, lmhead_prefix=LMHEAD_PREFIX, model_prefix=MODEL_PREFIX, transformer_wte_parallel=False, ) else: logger.error(f"Currently only {LLAMA_ARCHITECTURE}、{INTERNLM2_ARCHITECTURE}、{QWEN2_ARCHITECTURE} " "are supported. Please check `config.json`.", ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID) raise KeyError(f"Currently only {LLAMA_ARCHITECTURE}、{INTERNLM2_ARCHITECTURE}、{QWEN2_ARCHITECTURE} " "are supported. Please check `config.json`.") self.language_model.skip_word_embedding = True def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale if scale_factor == 0: raise ZeroDivisionError("Scale factor will be zero") x = x.view(n, w, int(h * scale_factor), int(np.divide(c, scale_factor))) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) if scale_factor == 0: raise ZeroDivisionError("Scale factor will be zero") x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(np.divide(c, scale_factor * scale_factor))) if self.ps_version == 'v1': print_log(self.tp_rank, logger.warning, "In ps_version 'v1', the height and width have not been swapped back, " "which results in a transposed image.") else: x = x.permute(0, 2, 1, 3).contiguous() return x def noised_embed(self, vit_embeds, noise_alpha=5): dims = torch.tensor(vit_embeds.size(1) * vit_embeds.size(2)) if dims == 0: raise ZeroDivisionError("Dim of the tensor is zero") mag_norm = np.divide(noise_alpha, torch.sqrt(dims)) noise = torch.zeros_like(vit_embeds).uniform_(-mag_norm, mag_norm) return vit_embeds + noise def extract_feature(self, pixel_values): if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = vit_embeds[:, 1:, :] if self.training and self.neftune_alpha is not None: vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha) h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds |
除上述模型初始化的相关方法外,还需要适配prepare_prefill_token和forward两种重要方法以用于前向推理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | def prepare_prefill_token(self, multimodalinputs, processor): text = multimodalinputs.text image = multimodalinputs.image video = multimodalinputs.video current_query = "" if image is not None: use_dynamic_prepro = False if self.ps_version == "v1" else True pixel_values = load_and_preprocess_image(image, normalizer=self.normalizer, \ use_dynamic_prepro=use_dynamic_prepro).to(self.dtype).to(self.device) vit_embeds = self.extract_feature(pixel_values).to(self.dtype).to(self.device) image_tokens_num = self.num_image_token * vit_embeds.shape[0] current_query = (f'<img>{"<IMG_CONTEXT>" * image_tokens_num}</img>\n') elif video is not None: pixel_values, num_patches_list = load_video(video) pixel_values = pixel_values.to(self.dtype).to(self.device) vit_embeds = self.extract_feature(pixel_values).to(self.dtype).to(self.device) for i, num_patch in enumerate(num_patches_list): current_query += (f'Frame{i+1}: ' f'<img>{"<IMG_CONTEXT>" * num_patch * self.num_image_token}</img>\n') system_prompt = INTERNVL_SYSTEM_PROMPTS[self.ps_version][self.template] texts = ('<|im_start|>system\n' f'{system_prompt}<|im_end|><|im_start|>user\n') texts += current_query texts += (f'{text}<|im_end|><|im_start|>assistant\n') input_ids = processor.encode(texts) input_ids = torch.tensor(input_ids, requires_grad=False).to(self.device) input_embeds = self.llm_embedding_layer(input_ids) sequence_length, embedding_size = input_embeds.shape input_ids = input_ids.reshape(sequence_length) vit_embeds = vit_embeds.reshape(-1, embedding_size) selected = (input_ids == self.img_context_token_id) try: input_embeds[selected] = input_embeds[selected] * torch.zeros(1, dtype=self.dtype, device=self.device) + vit_embeds.reshape(-1, embedding_size) except Exception as e: error_msg = f'{e} \ninput_embeds[selected].shape is {input_embeds[selected].shape}, '\ f'vit_embeds.shape is {vit_embeds.shape}\n'\ f'Please check whether shape of input_embeds[selected] matches the shape of vit_embeds.\n'\ f'If not, please check whether self.img_context_token_id '\ f'and the token-id of "<IMG_CONTEXT>" in processor are the same.' logger.error(error_msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE) raise ValueError(error_msg) from e input_embeds = input_embeds.reshape(-1, embedding_size) return input_embeds |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, is_prefill: bool, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, input_lengths: torch.Tensor, max_seq_len: int, lm_head_indices: Optional[torch.Tensor] = None, **kwargs): if is_prefill and input_ids.dim() == 1: input_ids = self.prepare_prefill_token_service(input_ids) return self.language_model.forward(input_ids, position_ids, is_prefill, kv_cache, block_tables, slots, input_lengths, max_seq_len, lm_head_indices) |