OpenAI接口适配

服务侧接收到请求之后会调用tokenize(),当请求是OpenAI的格式时,将调用模型侧的InputBuilder类的make_context()接口。

  1. InputBulder类

    适配新的多模态模型对接服务化时,需要创建一个新的子类“XXXInputBuilder”继承基类“InputBuilder”,并重写make_context()这个类方法。

    以Qwen-VL为例 ,下面是QwenVlInputBuilder的类图。完成该类重写后,需要在模型的Router的get_input_builder()中完成示例化。对应文件路径位于“/usr/local/Ascend/atb-models/atb_llm/models/qwen/router_qwen.py”

  2. make_context() 函数输入

    OpenAI格式的请求,在输入上变为了List[Dict[str, Dict]]格式,可以支持多轮对话的输入。每一轮对话是Dict,其中多出了两个字段“role”和“content”,“role”表示这一轮对话的角色,“content”表示这一轮对话的内容,其格式与1一致。

    代码示例如下:

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    [
        {
            "role": "user",
            "content": [
                {"image": "/XXX/XXX/image.png"},
                {"video": "/XXX/XXX/video.mp4"},
                {"audio": "/XXX/XXX/audio.mp3"},
                {"text": "What is in the image?"}
            ]
        },
        {
            "role": "assistant",
            "content": [
                {"text": "A cute panda."}
            ]
        }
        ...
    ]
    
  3. make_context() 函数实现

    这个函数实现的目的与tokenize()是一致的,都是将输入转换为input_ids。不过这个函数通常是以模型支持的chat template来安排整个内容。

    此函数的实现步骤与tokenize()函数的一致,主要是多了下面的步骤b.。下面的实现步骤顺序不强制,可以按照实际实现调整。

    1. 将输入转换为Str类型的query,并且用特殊的token作为分割。
    2. 按照template拼接内容。
    3. 对转换后的query进行encode,得到token_ids。
    4. 遍历输入,加载并处理多媒体数据,计算input_ids的大小,进行padding。
    5. 将处理好的pixel_value数据存入共享内存。
    6. 将共享内存的name和存入数据的shape进行编码。
    7. 将编码好的name和shape嵌入input_ids中,返回一维的torch.Tensor(device=cpu)类型的input_ids。

    代码示例如下:

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    def make_context(
            self, 
            rank: int,
            conversation: List[Dict[str, List[Dict]]], 
            system: str = "You are a helpful assistant.",
            **kwargs):
            if self.generation_config["chat_format"] != 'chatml':
                raise ValueError(_ERROR_BAD_CHAT_FORMAT)
            if not isinstance(conversation[0]["content"], list):
                raise ValueError("The conversation \"content\" should be a List[Dict].")
    
            shm_name_save_path = kwargs.get('shm_name_save_path', None)
            self.rank = rank
            max_window_size = kwargs.get('max_window_size', None)
            if max_window_size is None:
                max_window_size = self.generation_config["max_window_size"]
    
            context_tokens = self._apply_chat_template(
                conversation,
                system=system,
                max_window_size=max_window_size,
                shm_name_save_path=shm_name_save_path,
                )
            return context_tokens
    
        def _apply_chat_template(
            self,
            conversation: List[Dict[str, List[Dict]]],
            system: str = "",
            max_window_size: int = 6144,
            shm_name_save_path: str = None,
            **kwargs):
    
            #  1. 获取特殊 Token
            im_start_tokens = [self.tokenizer.im_start_id]
            im_end_tokens = [self.tokenizer.im_end_id]
            nl_tokens = self.tokenizer.encode("\n")
    
            system_tokens_part = self._tokenize_str("system", system, nl_tokens)
            system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
    
            shm_name_list = []
            shape_value_list = []
            content_key = "content"
            image_key = "image"
            for message in conversation:
                for single_input in message[content_key]:
                    if image_key not in single_input.keys():
                        continue
                    #  4. 遍历输入,加载并处理多媒体数据,计算`input_ids`的大小,进行`padding`
                    image_pixel = _image_preprocess(single_input[image_key])
                    image_pixel = image_pixel[None, :]
                    if shm_name_save_path is None:
                        shm_name_save_dir = os.path.dirname(os.path.dirname(single_input[image_key]))
                        shm_name_save_path = os.path.join(shm_name_save_dir, "shm_name.txt")
                    shm = shm_utils.create_shm(image_pixel.nbytes, shm_name_save_path)
                    shared_array = np.ndarray(image_pixel.shape, dtype=np.float32, buffer=shm.buf)
                    shared_array[:] = image_pixel
    
                    #  5. 将处理好的`pixel_value`数据存入共享内存
                    #  6. 将共享内存`name`和存入数据的`shape`编码
                    shm_name = shm_utils.encode_shm_name_to_int64(shm.name)
                    shape_value = shm_utils.encode_shape_to_int64(image_pixel.shape)
                    shm_name_list.append(shm_name)
                    shape_value_list.append(shape_value)
            
             #  1. 将输入转换为`Str`类型的`query`,并且用特殊的`token`作为分割
            context_tokens = system_tokens
            query = self.tokenizer.from_list_format(conversation.pop()[content_key])
    
            for message in conversation[::-1]:
                turn_query = self.tokenizer.from_list_format(message[content_key])
                if message["role"] == self.user_role_name:
                    query_tokens = nl_tokens + im_start_tokens + \
                        self._tokenize_str(self.user_role_name, turn_query, nl_tokens) + im_end_tokens + nl_tokens
                elif message["role"] == self.system_role_name:
                    query_tokens = im_start_tokens + \
                        self._tokenize_str(self.system_role_name, turn_query, nl_tokens) + im_end_tokens
                else:
                    raise ValueError(f"message role not supported yet: {message['role']}")
    
                current_context_size = (
                    len(system_tokens) + len(query_tokens) + len(context_tokens)
                )
                if current_context_size < max_window_size:
                    context_tokens = query_tokens + context_tokens
                else:
                    break
             #  2. 按照`template`拼接内容
            context_tokens += (
                nl_tokens
                + im_start_tokens
                + self._tokenize_str(self.user_role_name, query, nl_tokens) #  对转换后的`query`进行`encode`,得到`token_ids`
                + im_end_tokens
                + nl_tokens
                + im_start_tokens
                + self.tokenizer.encode(self.system_role_name)
                + nl_tokens
            )
    
            #  7. 将编码好的`name`和`shape`嵌入`input_ids`中,返回一维的`torch.Tensor`(cpu)类型的`input_ids`
            context_tokens_tensor = torch.tensor(context_tokens)
            bos_pos = torch.where(torch.eq(context_tokens_tensor, self.image_start_id))[0]
            image_num = bos_pos.shape[0]
            for i in range(image_num):
                context_tokens[bos_pos[i] + 1] = shm_name_list[i]
                context_tokens[bos_pos[i] + 2] = shape_value_list[i]
    
            return context_tokens
    
  4. Router 中重写 get_input_builder()

    服务侧会通过模型的Router的get_input_builder()接口获取每个模型侧InputBuilder,因此请务必重写该函数。