MultimodalPARunner调用“{llm_path}/atb_llm/models/{model}/router_{model}.py”中的{model}Router进行模型的初始化及配置文件的加载,其中{model}为模型名,需要严格与模型配置文件中的“model_type”保持一致。
{model}Router类起到路由的作用,告诉模型加载哪里的模型和对应的配置文件。
{model}Router类继承于基类BaseRouter,对于模型迁移适配,该类需要实现的方法有:“get_config”和“get_tokenizer”。
示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | from ..base.router import BaseRouter @dataclass class InternvlRouter(BaseRouter): def get_config(self): config = InternvlConfig.from_pretrained(self.model_name_or_path) if self.max_position_embeddings: config.max_position_embeddings = self.max_position_embeddings config.model_name_or_path = self.model_name_or_path super().check_config(config) return config def get_tokenizer(self): try: llm_model_architectures = self.config_dict['llm_config']['architectures'][0] except KeyError as e: logger.error("`llm_config.architectures` does not exist! Check `config.json`.", ErrorCode.ATB_MODELS_MODEL_PARAM_JSON_INVALID) raise ValueError("`llm_config.architectures` does not exist! Check `config.json`.") from e if llm_model_architectures == INTERNLM2_ARCHITECTURE: tokenizer = safe_get_tokenizer_from_pretrained( self.model_name_or_path, trust_remote_code=self.trust_remote_code ) elif llm_model_architectures == LLAMA_ARCHITECTURE: tokenizer = safe_get_tokenizer_from_pretrained( self.model_name_or_path, revision=self.revision, padding_side="left", trust_remote_code=self.trust_remote_code, use_fast=False ) elif llm_model_architectures == QWEN2_ARCHITECTURE: tokenizer = safe_get_tokenizer_from_pretrained( self.model_name_or_path, padding_side="left", trust_remote_code=self.trust_remote_code, ) else: logger.error( "`llm_config.architectures` must in " f"[{LLAMA_ARCHITECTURE}, {INTERNLM2_ARCHITECTURE}, {QWEN2_ARCHITECTURE}], " f"got {llm_model_architectures}.", ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE) raise ValueError( "`llm_config.architectures` must in " f"[{LLAMA_ARCHITECTURE}, {INTERNLM2_ARCHITECTURE}, {QWEN2_ARCHITECTURE}], " f"got {llm_model_architectures}.") return tokenizer def get_input_builder(self): return InternvlInputBuilder(self.tokenizer, self.config) def tokenize(self, inputs, **kwargs): img_begin_id = self.tokenizer.encode("<img>")[-1] img_end_id = self.tokenizer.encode("</img>")[-1] shm_name_save_path = kwargs.get("shm_name_save_path", None) image_size = self.config.force_image_size or self.config.vision_config.image_size patch_size = self.config.vision_config.patch_size if patch_size == 0: logger.error('The vision patch_size of config can not be 0.', ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE) raise ValueError('The vision patch_size of config can not be 0.') num_image_token = int((image_size // patch_size) ** 2 * (self.config.downsample_ratio ** 2)) use_dynamic_prepro = False if self.config.ps_version == "v1" else True system_prompt = INTERNVL_SYSTEM_PROMPTS[self.config.ps_version][self.config.template] query = ('<|im_start|>system\n' f'{system_prompt}<|im_end|><|im_start|>user\n') text = "" image_index = 1 shm_name_list = [] shape_value_list = [] image_num = sum(1 for d in inputs if _IMAGE in d) for single_input in inputs: if _TEXT in single_input: text += single_input.get(_TEXT) continue if _IMAGE in single_input: current_query, shm_name_value, shape_value = process_image_input( single_input, image_num, image_index, use_dynamic_prepro, num_image_token, shm_name_save_path ) query += current_query image_index += 1 shm_name_list.append(shm_name_value) shape_value_list.append(shape_value) elif _VIDEO in single_input: current_query, shm_name_value, shape_value = process_video_input( single_input, use_dynamic_prepro, num_image_token, shm_name_save_path ) query += current_query shm_name_list += shm_name_value shape_value_list += shape_value else: logger.error("Unsupported media type, it should be a video or image, please check your input.", ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE) raise KeyError("Unsupported media type, it should be a video or image, please check your input.") query += f'{text}<|im_end|><|im_start|>assistant\n' query_ids = torch.tensor(self.tokenizer.encode(query)) bos_pos_set = torch.nonzero(query_ids == img_begin_id).view(-1) eos_pos_set = torch.nonzero(query_ids == img_end_id).view(-1) for i, (bos_pos, eos_pos) in enumerate(zip(bos_pos_set, eos_pos_set)): if eos_pos - bos_pos < 3: logger.error("tokenize input error.", ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE) raise ValueError("tokenize input error.") query_ids[bos_pos + 1] = shm_name_list[i] query_ids[bos_pos + 2] = shape_value_list[i] return query_ids |