Tgi-MindIE |______cover |______models |________init__.py |______cli.py |______server.py |______tgi_npu |______init__.py |____cache_manager.py |____info.py |____mind_model.py |____token_mindie.py |____vlm_mind_models.py |______pyproject.toml |______install.sh |______README.md
各源文件的含义和作用如下表所示:
源文件 |
含义及作用 |
---|---|
cover/models/__init__.py |
替换原仓中server/text_generation_server/models/__init__.py文件,将推理模型引导至MindIE LLM。 |
cover/cli.py |
替换原仓中server/text_generation_server/cli.py文件,添加tgi_npu模块的日志打印过滤。 |
cover/server.py |
替换原仓中server/text_generation_server/server.py文件,添加tgi_npu支持多模态模型入口功能。 |
tgi_npu/__init__.py |
针对NPU硬件环境进行必要的初始化。 |
tgi_npu/cache_manager.py |
KV Cache管理器,主要针对NPU进行KV Cache初始化。 |
tgi_npu/info.py |
NPU信息。 |
tgi_npu/mind_model.py |
定义了推理模型入口类MindModel以及对应的数据通信格式MindFlashCasualLMBatch,分别继承自原仓的FlashCasualLM以及FlashCasualLMBatch。在MindModel中,generate_token方法沿用了原版大部分代码,并结合MindIE LLM调用过程进行了修改。其中,Forward方法改为调用MindIE LLM提供的forward_tensor方法。warmup 结合NPU访存特点进行修改。 |
tgi_npu/token_mindie.py |
后采样代码。 |
tgi_npu/vlm_mind_models.py |
定义了多模态模型入口类VlmMindModel以及对应的数据通信格式。VlmMindFlashCasualLMBatch,分别继承自MindModel以及MindFlashCasualLMBatch。在VlmMindFlashCasualLMBatch中,batch_tokenized_inputs方法中使用了MindIE LLM模块提供的Tokenize方法,将输入编码为符合多模态模型输入要求的格式。 |
pyproject.toml |
适配安装包配置文件。 |
样例代码:
#!/usr/bin/env bash # install-origin if [ -d "./tgi_origin" ]; then echo "./tgi_origin directory has already exist!" exit 1 fi git clone -b v2.0.4 https://github.com/huggingface/text-generation-inference.git tgi_origin cp cover/cli.py tgi_origin/server/text_generation_server/ cp cover/models/__init__.py tgi_origin/server/text_generation_server/models sed -i "s/requires_padding, 16, window_size/requires_padding, 128, window_size/g" tgi_origin/router/src/infer.rs sed -i "s/prefill_logprobs: true/prefill_logprobs: false/g" tgi_origin/router/client/src/client.rs sed -i "s/bnb, accelerate, quantize, peft, outlines/accelerate, quantize, peft, outlines/g" tgi_origin/server/Makefile cd tgi_origin && make install-server && make install-router && make install-launcher cd .. && pip install -e .
# This file was copied from project[huggingface][text-generation-inference] from typing import Optional import torch from loguru import logger from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True # Disable gradients torch.set_grad_enabled(False) __all__ = [ "Model", "get_model", ] def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, ) -> Model: if speculate is not None: logger.warning("Speculate Decoding is not supported now!") set_speculate(0) try: import torch_npu from tgi_npu import MindModel npu_module_imported = True except (ImportError, NotImplementedError) as excp: npu_module_imported = False logger.error(f"Error catched: {str(excp)}") if npu_module_imported and torch.npu.is_available(): return MindModel(model_id) else: logger.error("NPU enviroment error!!!!!!!!!!!!") raise ValueError("NPU enviroment error!!!!!!!!!!!!")
import os import sys from pathlib import Path from typing import Optional from enum import Enum import typer from loguru import logger from huggingface_hub import hf_hub_download app = typer.Typer() MODEL_SUFFIX = ".safetensors" CONFIG_FILENAME = "config.json" class Quantization(str, Enum): bitsandbytes = "bitsandbytes" bitsandbytes_nf4 = "bitsandbytes-nf4" bitsandbytes_fp4 = "bitsandbytes-fp4" gptq = "gptq" awq = "awq" eetq = "eetq" fp8 = "fp8" class Dtype(str, Enum): float16 = "float16" bloat16 = "bfloat16" @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, ): if sharded: assert ( os.getenv("RANK", None) is not None ), "RANK must be set when sharded is True" assert ( os.getenv("WORLD_SIZE", None) is not None ), "WORLD_SIZE must be set when sharded is True" assert ( os.getenv("MASTER_ADDR", None) is not None ), "MASTER_ADDR must be set when sharded is True" assert ( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" # Remove default handler logger.remove() logger.add( sys.stdout, format="{message}", filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, ) logger.add( sys.stdout, format="{message}", filter="tgi_npu", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, )
import asyncio import os import torch import torch_npu import time import signal from grpc import aio from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path from typing import List, Optional from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model try: from tgi_npu.vlm_mind_models import VlmMindFlashCausalLMBatch VLM_BATCH_TYPES = {VlmMindFlashCausalLMBatch} except (ImportError, NotImplementedError): # These imports can fail on CPU/Non flash. VLM_BATCH_TYPES = set() from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.globals import set_model_id soc_version = torch_npu._C._npu_get_soc_version() if soc_version not in [104,220,221,222,223,224]: logger.info("Some ops do not support in this soc !") option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"} torch.npu.set_option(option) else: option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"} torch.npu.set_option(option) class SignalHandler: KEEP_PROCESSING = True def __init__(self): signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) def exit_gracefully(self, signum, frame): print(f"Exiting gracefully: Signal {signum}") self.KEEP_PROCESSING = False class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__( self, model: Model, cache: Cache, quantize: Optional[str], server_urls: List[str], ): self.cache = cache self.model = model self.quantize = quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU if model.device.type == "cuda" or model.device.type == "npu": # Force inference mode for the lifetime of TextGenerationService self._inference_mode_raii_guard = torch._C._InferenceMode(True) self.step = 0 async def Info(self, request, context): return self.model.info async def Health(self, request, context): if self.model.device.type == "cuda": torch.zeros((2, 2)).cuda() return generate_pb2.HealthResponse() async def ServiceDiscovery(self, request, context): return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) async def ClearCache(self, request, context): if request.HasField("id"): self.cache.delete(request.id) else: self.cache.clear() return generate_pb2.ClearCacheResponse() async def FilterBatch(self, request, context): batch = self.cache.pop(request.batch_id) if batch is None: raise ValueError(f"Batch ID {request.batch_id} not found in cache.") filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): for i, r in enumerate(request.batch.requests): r.parameters.typical_p = 1.0 r.prefill_logprobs = False if self.quantize == "gptq": try: # When using GPTQ, Exllama kernels need some global kernels # For which we have the finale shapes only after the model has loaded # This will allocate those buffers. from text_generation_server.layers.gptq import ( create_exllama_buffers, set_device, ) set_device(self.model.device) create_exllama_buffers(request.max_prefill_tokens) except ImportError: pass if ( self.model.batch_type in VLM_BATCH_TYPES ): # Hack, i would rather use kwargs in the `from_pb` call for i, r in enumerate(request.batch.requests): r.inputs = r.inputs.split('!')[0] batch = self.model.batch_type.from_pb_processor( request.batch, self.model.tokenizer, self.model.tokenize, self.model.dtype, self.model.device, ) else: batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) max_supported_total_tokens = self.model.warmup(batch) return generate_pb2.WarmupResponse( max_supported_total_tokens=max_supported_total_tokens ) async def Prefill(self, request, context): start = time.time_ns() if ( self.model.batch_type in VLM_BATCH_TYPES ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb_processor( request.batch, self.model.tokenizer, self.model.tokenize, self.model.dtype, self.model.device, ) else: batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) generations, next_batch, timings = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.PrefillResponse( generations=[generation.to_pb() for generation in generations], batch=next_batch.to_pb() if next_batch else None, forward_ns=timings[0], decode_ns=timings[1], total_ns=time.time_ns() - start, )
#!/usr/bin/env python3 # coding=utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. import torch import torch_npu from loguru import logger from tgi_npu.mind_models import MindModel def init(): torch._C._InferenceMode(True) soc_version = torch_npu._C._npu_get_soc_version() if soc_version not in [104, 220, 221, 222, 223, 224]: logger.info("Some op does not support for this soc!") option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"} else: option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"} try: torch.npu.set_option(option) logger.warning("Finish init for NPU device!") except Exception as e: logger.error(f"Failed to init for NPU device: {e}!") init()
# Part of codes in this file was copied from project[huggingface][text-generation-inference] import math from typing import Optional, List, Tuple import gc import torch from tgi_npu.info import NPUSocInfo BLOCK_SIZE: int = 128 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None class CacheManager: def __init__( self, num_blocks: int, num_layers: int, num_heads: int, head_size: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, ): self.block_size = BLOCK_SIZE self.num_blocks = num_blocks self.repeat_slots = repeat_slots # for NZ/ND data format display self.need_nz = NPUSocInfo().need_nz if self.need_nz: self.kv_cache = [ ( torch.empty( (num_blocks, num_heads * head_size // 16, self.block_size, 16), dtype=dtype, device=device, ), torch.empty( (num_blocks, num_heads * head_size // 16, self.block_size, 16), dtype=dtype, device=device, ), ) for _ in range(num_layers) ] else: self.kv_cache = [ ( torch.empty( (num_blocks, self.block_size, num_heads, head_size), dtype=dtype, device=device, ), torch.empty( (num_blocks, self.block_size, num_heads, head_size), dtype=dtype, device=device, ), ) for _ in range(num_layers) ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( 0, num_blocks * self.block_size, dtype=torch.int64 ).view(num_blocks, self.block_size) def __repr__(self): return (f"CacheManager: " f"num_blocks={self.num_blocks}," f"block_size={self.block_size}," f"free_block_mask={self.free_block_mask}," f"slots={self.slots}," f"k_cache shape={self.kv_cache[0][0].shape}," f"v_cache shape={self.kv_cache[0][1].shape}") def allocate( self, needed_blocks_slots: List[Tuple[int, int]], blocks: int, max_blocks: int, device: torch.device, ): # Get free blocks indices by finding values in mask that are not set to 0 free_block_indices = self.free_block_mask.nonzero() if blocks > len(free_block_indices): raise RuntimeError( f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" ) # Slice by the number of required blocks block_indices = free_block_indices[:blocks] block_indices = block_indices.flatten() # Padded block tables block_tables_tensor = torch.zeros( (len(needed_blocks_slots), max_blocks), dtype=torch.int32 ) # Allocate paged attention blocks cumulative_blocks = 0 slots = [] block_tables = [] for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): # Get allocated blocks for this sequence allocated_blocks = block_indices[ cumulative_blocks: cumulative_blocks + needed_blocks ] # Get slots for the allocated blocks all_slots = self.slots[allocated_blocks].flatten() # Repeat slots in the case of context sliding window if needed_slots > len(all_slots) and self.repeat_slots: repeats = math.ceil(needed_slots / len(all_slots)) all_slots = all_slots.repeat(repeats) allocated_slots = all_slots[:needed_slots] slots.append(allocated_slots) block_tables.append(allocated_blocks.tolist()) block_tables_tensor[i, :needed_blocks] = allocated_blocks cumulative_blocks += needed_blocks block_tables = block_tables block_tables_tensor = block_tables_tensor.to(device) slots = torch.concat(slots).to(device) # Allocate the required number of blocks by setting the mask to 0 self.free_block_mask[block_indices] = 0 return block_tables, block_tables_tensor, slots def free(self, block_indices: Optional[List[int]]): if block_indices is not None and block_indices: # Reset mask self.free_block_mask[block_indices] = 1 def set_cache_manager( num_blocks: int, num_layers: int, num_heads: int, head_size: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, ) -> CacheManager: global CACHE_MANAGER if CACHE_MANAGER is not None: del CACHE_MANAGER torch.npu.empty_cache() gc.collect() CACHE_MANAGER = CacheManager( num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device ) return CACHE_MANAGER def get_cache_manager() -> CacheManager: global CACHE_MANAGER if CACHE_MANAGER is None: raise RuntimeError("cache manager was not initialized") return CACHE_MANAGER
from dataclasses import dataclass import torch_npu @dataclass class NPUSocInfo: soc_name: str = "" soc_version: int = -1 need_nz: bool = False def __post_init__(self): self.soc_version = torch_npu._C._npu_get_soc_version() if self.soc_version in (100, 101, 102, 103, 104, 200, 201, 202, 203): self.need_nz = True
# Part of codes in this file was copied from project[huggingface][text-generation-inference] import math import time import os from typing import Optional, Tuple, List, Type from dataclasses import dataclass import torch_npu import torch from loguru import logger from opentelemetry import trace import numpy as np from text_generation_server.models.flash_causal_lm import FlashCausalLM, FlashCausalLMBatch from text_generation_server.models.types import ( Batch, Tokens, Generation, GeneratedText ) from text_generation_server.utils import StoppingCriteria from text_generation_server.pb import generate_pb2 from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.dist import RANK, MEMORY_FRACTION from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import cache_manager as tgi_cache_manager from transformers import PreTrainedTokenizerBase from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch from mindie_llm.modeling.backend_type import BackendType from tgi_npu.tokens_mindie import MindIELLMHeterogeneousNextTokenChooser from tgi_npu.cache_manager import ( BLOCK_SIZE, get_cache_manager, set_cache_manager, ) tracer = trace.get_tracer(__name__) @dataclass class MindFlashCausalLMBatch(FlashCausalLMBatch): next_token_chooser: MindIELLMHeterogeneousNextTokenChooser all_input_ids_tensor: torch.Tensor def __repr__(self): return (f"MindFlashCausalLMBatch: batch_id={self.batch_id}," f"requests_idx_mapping={self.requests_idx_mapping}," f"input_ids={self.input_ids}," f"position_ids={self.position_ids}," f"cu_seqlen_prefill={self.cu_seqlen_prefill}," f"start_slots={self.start_slots}," f"slot_indices={self.slot_indices}," f"needed_blocks_slots={self.needed_blocks_slots}," f"block_tables={self.block_tables}," f"block_tables_tensor={self.block_tables_tensor}," f"slots={self.slots}," f"max_seqlen={self.max_seqlen}," f"prefill_head_indices={self.prefill_head_indices}," f"prefill_next_token_indices={self.prefill_next_token_indices}," f"prefill_cu_outlens={self.prefill_cu_outlens}," f"input_lengths={self.input_lengths}," f"input_lengths_tensor={self.input_lengths_tensor}," f"prefix_offsets={self.prefix_offsets}," f"read_offsets={self.read_offsets}," f"all_input_ids_tensor={self.all_input_ids_tensor}," f"next_token_chooser={self.next_token_chooser}," f"stopping_criterias={self.stopping_criterias}," f"blocks={self.blocks}," f"max_blocks={self.max_blocks}") @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device ) -> "MindFlashCausalLMBatch": batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) @classmethod def from_tokenized( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, batch_tokenized_inputs, dtype: torch.dtype, device: torch.device, ) -> "MindFlashCausalLMBatch": position_ids = [] cu_seqlen_prefill = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] input_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] # Cumulative length cumulative_length = 0 cumulative_max_length = 0 prefill_out_cumulative_length = 0 blocks = 0 max_seqlen = 0 max_length = 0 max_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i tokenized_input = tokenized_input[-r.truncate:] if ( tokenized_input[0] == tokenizer.bos_token_id and tokenized_input[1] == tokenizer.bos_token_id ): tokenized_input = tokenized_input[1:] input_length = len(tokenized_input) input_lengths.append(input_length) prefix_offsets.append(input_length - 5) read_offsets.append(input_length) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange(0, input_length, dtype=torch.int32) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() speculative_length = 0 if speculative_length is None else speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: prefill_head_indices.append( torch.tensor( [cumulative_length + input_length - 1], dtype=torch.int64 ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 # Update cumulative_length += input_length cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) max_length = max( max_length, input_length + max_new_tokens + speculative_length ) next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb( pb=next_token_chooser_parameters, dtype=dtype, device=device ) start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( (len(all_input_ids), max_length), dtype=np.int64 ) for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int64 ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int64, device=device ) if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( torch.cat(prefill_head_indices), dtype=torch.int64, device=device ) prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, block_tables=None, block_tables_tensor=None, slots=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, speculative_ids=None, ) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["MindFlashCausalLMBatch"]) -> "MindFlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 max_seqlen = 0 for b in batches: total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( max_length, max( input_length + stopping_criteria.max_new_tokens + speculative_length - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( b.input_lengths, b.stopping_criterias ) ), ) input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) start_slots = [] block_tables = [] all_input_ids = [] input_lengths = [] prefix_offsets = [] read_offsets = [] next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] # Cumulative length cumulative_batch_size = 0 cumulative_slots = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) if i == 0: requests_idx_mapping = batch.requests_idx_mapping else: # We need to offset the mapping for each batch by the cumulative batch size for k, v in batch.requests_idx_mapping.items(): requests_idx_mapping[k] = v + cumulative_batch_size start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) top_n_tokens.extend(batch.top_n_tokens) # Update cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) start_slots = torch.concat(start_slots) next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb( pb=next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device ) speculative_ids = ( torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None ) # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None del b return cls( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, ) def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.blocks * BLOCK_SIZE, ) class MindModel(FlashCausalLM): def __init__( self, model_id: str ): logger.warning("Initialize mindie-llm model.") rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) model_config = { 'backend_type': BackendType.ATB, 'rank': rank, 'world_size': world_size, 'model_id': model_id, 'num_threads': 8, 'local_rank': rank, 'npu_device_id': rank } self.model_runner = GeneratorTorch(model_config) super(MindModel, self).__init__( model=self.model_runner.model_wrapper.model_runner.model, tokenizer=self.model_runner.tokenizer, num_layers=self.model_runner.model_info.num_layers, num_kv_heads=self.model_runner.model_info.num_kv_heads, head_size=self.model_runner.model_info.head_size, dtype=self.model_runner.model_info.dtype, device=self.model_runner.model_info.device, rank=self.model_runner.rank, world_size=self.model_runner.world_size, ) logger.warning("MindModel from tgi_npu initialized.") def __del__(self): del self.model_runner.model_wrapper @property def batch_type(self) -> Type[MindFlashCausalLMBatch]: return MindFlashCausalLMBatch def forward( self, batch: MindFlashCausalLMBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Assume return logits, speculative_logits""" input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = get_cache_manager().kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices speculative_logits = None return self.model_runner.forward_tensor( input_ids=input_ids, position_ids=position_ids, is_prefill=cu_seqlen_prefill is not None, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_seq_len=max_s, lm_head_indices=lm_head_indices, ), speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( self, batch: MindFlashCausalLMBatch ) -> Tuple[List[Generation], Optional[MindFlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None # check if need slots if batch.needed_blocks_slots: # Allocate blocks to this batch block_tables, block_tables_tensor, slots = get_cache_manager().allocate( batch.needed_blocks_slots, batch.blocks, batch.max_blocks, batch.input_ids.device, ) batch.needed_blocks_slots = None batch.block_tables = block_tables batch.block_tables_tensor = block_tables_tensor batch.slots = slots try: out, speculative_logits = self.forward(batch) except Exception as e: del batch raise e if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) if speculative_logits is not None: speculative_logits = ( speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits ) else: logger.debug(f"Decode batch size {batch.input_ids.shape[0]}") next_token_logits = out speculate = get_speculate() request_ids = [req.id for req in batch.requests] ( next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids, ) = batch.next_token_chooser( request_ids, prefill, batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculate ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids # Cumulative length cumulative_length = 0 # Results generations: List[Generation] = [] stopped = True # Zipped iterator iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time # For each member of the batch index = 0 for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length if prefill: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: if len(batch) > 1: prefill_tokens_indices[out_start_index: out_end_index - 1] = ( batch.input_ids[start_index + 1: start_index + out_length] ) else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ start_index + 1: start_index + out_length ] for _ in range(n_accepted_ids): index += 1 cumulative_length += input_length batch.all_input_ids_tensor.scatter_(1, batch.input_lengths_tensor.view(batch.input_lengths_tensor.shape[0], 1), next_input_ids.view(next_input_ids.shape[0], 1)) # Update values batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids batch.input_lengths_tensor += accepted_ids batch.slot_indices += accepted_ids if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) ) # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() start_decode = time.time_ns() # Zipped iterator iterator = zip( batch.requests, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) # For each member of the batch index = 0 for i, ( request, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, do_sample, seed, top_n_tokens, n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens next_token_texts = [] left = 0 if n_accepted_ids > 1: if RANK == 0: logger.debug(f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): # Generated token next_token_id = next_token_ids[j] all_input_ids.append(next_token_id) # Generated token next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, prefix_offset, read_offset, ) next_token_texts.append(next_token_text) stop, reason = stopping_criteria( next_token_id, next_token_text, ) if stop: left = index + n_accepted_ids - j - 1 current_stopped = True break else: current_stopped = False stopped = stopped and current_stopped _next_token_ids = next_token_ids[index: index + n_accepted_ids - left] _next_token_logprobs = next_token_logprobs[ index: index + n_accepted_ids - left ] index += n_accepted_ids # Shard generations # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens output_text, _, _ = self.decode_token( all_input_ids, prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, read_offset=len(all_input_ids) - stopping_criteria.current_tokens, skip_special_tokens=True, ) generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, reason, seed if do_sample else None, ) else: generated_text = None # Prefill if prefill and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ out_start_index: out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) prefill_tokens = Tokens( prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[], ) else: prefill_tokens = None if top_n_tokens > 0: all_top_tokens = [] for top_token_ids, top_token_logprobs in zip( top_token_ids, top_token_logprobs ): toptoken_texts = self.tokenizer.batch_decode( top_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) special_toptokens = [ token_id in self.all_special_ids for token_id in top_token_ids ] top_tokens = Tokens( top_token_ids, top_token_logprobs, toptoken_texts, special_toptokens, ) all_top_tokens.append(top_tokens) top_tokens = all_top_tokens else: top_tokens = None generation = Generation( request.id, prefill_tokens, Tokens( _next_token_ids, _next_token_logprobs, next_token_texts, [nid in self.all_special_ids for nid in _next_token_ids], ), generated_text, top_tokens, ) generations.append(generation) # Update values batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: batch.max_seqlen = batch.input_lengths[i] batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids if stopped: del batch # No need to return a batch if we know that all requests stopped forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode none_batch = None return generations, none_batch, (forward_ns, decode_ns) batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) def warmup(self, batch: MindFlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive torch.npu.empty_cache() peak_memory = torch_npu.npu.max_memory_allocated() logger.info(f">>>>before warmup peak_memory {peak_memory}") try: cache_manager = set_cache_manager( batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, False, self.dtype, self.device, ) _, batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) from e # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size torch_npu.npu.synchronize() total_gpu_memory = torch_npu.npu.get_device_properties(self.device).total_memory peak_memory = torch_npu.npu.max_memory_allocated() logger.info( f">>>>dtype_size {dtype_size}, cache_block_size {cache_block_size}, num_kv_heads {self.num_kv_heads}, " f"total_cache_size {total_cache_size}, peak_memory {peak_memory}") total_free_memory = total_gpu_memory - peak_memory logger.info(f">>>>total_free_memory {total_free_memory}, total_gpu_memory {total_gpu_memory}, " f"MEMORY_FRACTION {MEMORY_FRACTION}") free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory) num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + cache_manager.num_blocks ) del batch del cache_manager real_manager = set_cache_manager( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, self.sliding_window is not None, self.dtype, self.device, ) tgi_cache_manager.CACHE_MANAGER = real_manager logger.warning(f">>>>real CacheManger {get_cache_manager()}") peak_memory = torch_npu.npu.max_memory_allocated() logger.warning(f">>>>end warmup peak_memory {peak_memory}") logger.warning(f"Warmup return {int(num_blocks * BLOCK_SIZE)}") return int(num_blocks * BLOCK_SIZE)
# Part of codes in this file was copied from project[huggingface][text-generation-inference] import os from typing import List, Optional from loguru import logger import numpy as np import torch from text_generation_server.pb import generate_pb2 from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam from mindie_llm.text_generator.utils.config import SamplerConfig from mindie_llm.text_generator.samplers.sampler import Sampler from mindie_llm.modeling.backend_type import BackendType WRAPPER_KEY = "tensor_wrapper" try: from mindie_llm.text_generator.utils.sampling_metadata import TensorWrapper except ImportError: class TensorWrapper: def __init__(self, backend, device): self.device = device self.backend = backend def __call__(self, data): if data.dtype == np.int32: dtype = torch.int32 elif data.dtype == np.bool_: dtype = torch.bool else: dtype = None return torch.tensor(data, dtype=dtype, device=self.device) WRAPPER_KEY = "to_tensor" def do_filter(sample_param: List, indices): if any(sample_param): return [sample_param[i] for i in indices] return sample_param class MindIELLMHeterogeneousNextTokenChooser: def __init__( self, dtype: torch.dtype, device: torch.device, watermark: List[bool], temperature: List[float], repetition_penalty: List[float], frequency_penalty: List[float], top_k: List[int], top_p: List[float], typical_p: List[float], do_sample: List[bool], seeds: List[int], grammars: List[str], grammar_types: List[int], fsm_grammar_states: List[int], sample_method, # mindie-llm sampling method ): if any(watermark): logger.warning(f"Watermark not supported now in mindie-llm") if any([x < 1.0 for x in typical_p]): logger.warning(f"Typical_p not supported now in mindie-llm") if any(grammar_types) or any(grammars) or any(fsm_grammar_states): logger.warning(f"Grammar not supported now in mindie-llm") self.tensor_wrapper = TensorWrapper(BackendType.ATB, device) self.wrapper_dict = {WRAPPER_KEY: TensorWrapper(BackendType.ATB, device)} self.sample_params = SamplingParam.from_numpy( repetition_penalty=np.array(repetition_penalty, dtype=np.float16), presence_penalty=None, frequency_penalty=np.array(frequency_penalty, dtype=np.float16), temperature=np.array(temperature, dtype=np.float16), top_k=np.array(top_k), top_p=np.array(top_p), seed=np.array(seeds).astype(np.int32), do_sample=np.array(do_sample), **self.wrapper_dict ) # Temp store for filter self.temperature = temperature self.repetition_penalty = repetition_penalty self.frequency_penalty = frequency_penalty self.top_k = top_k self.top_p = top_p self.seeds = seeds self.do_sample = do_sample self.choice = sample_method self.dtype = dtype self.device = device self.seeds = seeds self.do_sample = self.sample_params.do_sample_meta.do_sample_array self.fsm_grammar_states = fsm_grammar_states def __call__(self, request_ids: List, is_prefill: bool, input_ids: torch.Tensor, scores: torch.Tensor, speculate: int ): batch_size = scores.shape[0] speculate_size = 1 scores = scores.view(batch_size, speculate_size, -1) # Don't use SamplingData。from_numpy to avoid tensor.cpu() to transfer large data input_ids_int32 = input_ids.to(torch.int32) sample_data = SamplingData(all_input_ids=input_ids_int32, output_ids=input_ids_int32, is_prefill=is_prefill, request_ids=np.array(request_ids)) next_ids = torch.zeros((batch_size, speculate_size), device=scores.device, dtype=torch.long) for j in range(speculate_size): _scores = scores[:, j] batch_logits, _next_ids = self.choice(batch_logits=_scores, batch_sampling_data=sample_data, batch_sampling_params=self.sample_params) scores[:, j] = _scores next_ids[:, j] = torch.from_numpy(_next_ids) next_ids = next_ids.view(batch_size * speculate_size) allscores = scores.view(batch_size * speculate_size, -1) alllogprobs = torch.log_softmax(allscores, -1) accepted_ids = torch.ones_like(next_ids) logprobs = alllogprobs next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) speculative_ids = None return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids @classmethod def from_pb( cls, pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, fsm_grammar_states: Optional[List[int]] = None, ) -> "MindIELLMHeterogeneousNextTokenChooser": curr_rank = int(os.getenv("RANK", "0")) sample_method = Sampler(SamplerConfig(rank=curr_rank, backend_type=BackendType.ATB, npu_id=curr_rank)) return MindIELLMHeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], frequency_penalty=[pb_.frequency_penalty for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, grammars=[pb_.grammar for pb_ in pb], grammar_types=[pb_.grammar_type for pb_ in pb], fsm_grammar_states=( fsm_grammar_states if fsm_grammar_states else [0] * len(pb) ), sample_method=sample_method ) def filter(self, indices): self.repetition_penalty = do_filter(self.repetition_penalty, indices) self.frequency_penalty = do_filter(self.frequency_penalty, indices) self.temperature = do_filter(self.temperature, indices) self.top_k = do_filter(self.top_k, indices) self.top_p = do_filter(self.top_p, indices) self.seeds = do_filter(self.seeds, indices) self.do_sample = do_filter(self.do_sample, indices) self.sample_params = SamplingParam.from_numpy( repetition_penalty=np.array(self.repetition_penalty, dtype=np.float16), presence_penalty=None, frequency_penalty=np.array(self.frequency_penalty, dtype=np.float16), temperature=np.array(self.temperature, dtype=np.float16), top_k=np.array(self.top_k), top_p=np.array(self.top_p), seed=np.array(self.seeds).astype(np.int32), do_sample=np.array(self.do_sample), **self.wrapper_dict ) return self
import re from typing import List, Type, Dict from dataclasses import dataclass import torch from loguru import logger from atb_llm.runner.tokenizer_wrapper import TokenizerWrapper from text_generation_server.pb import generate_pb2 from transformers import PreTrainedTokenizerBase from tgi_npu.mind_models import MindFlashCausalLMBatch, MindModel IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") def split(string) -> List[Dict[str, str]]: parts = [] cursor = 0 for pattern in IMAGES.finditer(string): start = pattern.start() if start != cursor: parts.append({"text": string[cursor:start]}) parts.append({"image": pattern.group(1)}) cursor = pattern.end() if cursor != len(string): parts.append({"text": string[cursor:]}) return parts @dataclass class VlmMindFlashCausalLMBatch(MindFlashCausalLMBatch): @classmethod def batch_tokenized_inputs(cls, requests, tokenize): inputs = [] for r in requests: splits = split(r.inputs) single_input = tokenize(splits).tolist() inputs.append(single_input) return inputs @classmethod def from_pb_processor( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, tokenize, dtype: torch.dtype, device: torch.device, ) -> "VlmMindFlashCausalLMBatch": batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenize) batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device) return batch class VlmMindModel(MindModel): def __init__( self, model_id: str ): logger.warning("Initialize mindie-llm model for vlm.") super(VlmMindModel, self).__init__(model_id) self.tokenize = TokenizerWrapper(model_id).tokenize logger.warning("VlmMindModel from tgi_npu initialized.") @property def batch_type(self) -> Type[VlmMindFlashCausalLMBatch]: return VlmMindFlashCausalLMBatch
[tool.poetry] name = "tgi-npu" version = "0.1.0" description = "NPU MindIE Adapter for TGI v2.0.4" authors = ["Your Name <you@example.com>"] readme = "README.md" exclude = ["router", "cover"] [tool.poetry.dependencies] python = "^3.10" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"