昇腾社区首页
中文
注册

TGI v2.0.4 版本参考适配代码

文件目录结构如下。
Tgi-MindIE
 |______cover
        |______models
               |________init__.py
        |______cli.py
        |______server.py
 |______tgi_npu
        |______init__.py
        |____cache_manager.py
        |____info.py
        |____mind_models.py
        |____tokens_mindie.py
        |____vlm_mind_models.py
 |______pyproject.toml
 |______install.sh
 |______requirements_qwen_vl.txt

各源文件的含义和作用如表1所示。

表1 各源文件的含义和作用

源文件

含义及作用

cover/models/__init__.py

替换原仓中server/text_generation_server/models/__init__.py文件,将推理模型引导至MindIE LLM

cover/cli.py

替换原仓中server/text_generation_server/cli.py文件,添加tgi_npu模块的日志打印过滤。

cover/server.py

(Qwen-VL模型适配)替换原仓中server/text_generation_server/server.py文件,添加tgi_npu支持多模态模型入口功能。

tgi_npu/__init__.py

针对NPU硬件环境进行必要的初始化。

tgi_npu/cache_manager.py

KV Cache管理器,主要针对NPU进行KV Cache初始化。

tgi_npu/info.py

NPU信息。

tgi_npu/mind_models.py

定义了推理模型入口类MindModel以及对应的数据通信格式MindFlashCausalLMBatch,分别继承自原仓的FlashCausalLM以及FlashCausalLMBatch。在MindModel中,generate_token方法沿用了原版大部分代码,并结合MindIE LLM调用过程进行了修改。其中,Forward方法改为调用MindIE LLM提供的forward_tensor方法。warmup 结合NPU访存特点进行修改。

tgi_npu/tokens_mindie.py

后采样代码。

tgi_npu/vlm_mind_models.py

(Qwen-VL模型适配)定义了多模态模型入口类VlmMindModel以及对应的数据通信格式。VlmMindFlashCausalLMBatch,分别继承自MindModel以及MindFlashCausalLMBatch。在VlmMindFlashCausalLMBatch中,batch_tokenized_inputs方法中使用了MindIE LLM模块提供的Tokenize方法,将输入编码为符合多模态模型输入要求的格式。

pyproject.toml

适配安装包配置文件。

install.sh

一键安装脚本。

requirements_qwen_vl.txt

(Qwen-VL模型适配)Qwen-VL模型依赖的transformers库。

样例代码:

  • Tgi-MindIE/install.sh
    #!/usr/bin/env bash
    # install-origin
    if [ -d "./tgi_origin" ]; then 
        echo "./tgi_origin directory has already exist!"
        exit 1
    fi
    
    git clone -b v2.0.4 https://github.com/huggingface/text-generation-inference.git tgi_origin
    
    cp cover/cli.py tgi_origin/server/text_generation_server/
    # 若运行Qwen-VL模型,需要打开该语句
    # cp cover/server.py tgi_origin/server/text_generation_server/
    cp cover/models/__init__.py tgi_origin/server/text_generation_server/models
    sed -i "s/requires_padding, 16, window_size/requires_padding, 128, window_size/g" tgi_origin/router/src/infer.rs
    sed -i "s/prefill_logprobs: true/prefill_logprobs: false/g" tgi_origin/router/client/src/client.rs
    sed -i "s/bnb, accelerate, quantize, peft, outlines/accelerate, quantize, peft, outlines/g" tgi_origin/server/Makefile
    
    cd tgi_origin && make install-server && make install-router && make install-launcher
    
    cd .. && pip install -e .
    # 若运行Qwen-VL模型,需要打开该语句
    # pip install -r requirements_qwen_vl.txt
  • Tgi-MindIE/requirements_qwen_vl.txt
    transformers==4.30.2
  • MindIE LLM切入:将原TGI框架从调用GPU模型切到MindIE LLM
    Tgi-MindIE/cover/models/__init__.py
    # This file was copied from project[huggingface][text-generation-inference]
    
    from typing import Optional
    
    import torch
    from loguru import logger
    
    from transformers.configuration_utils import PretrainedConfig
    from text_generation_server.utils.speculate import get_speculate, set_speculate
    from text_generation_server.models.model import Model
    
    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
    # in PyTorch 1.12 and later.
    torch.backends.cuda.matmul.allow_tf32 = True
    
    # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
    torch.backends.cudnn.allow_tf32 = True
    
    # Disable gradients
    torch.set_grad_enabled(False)
    
    __all__ = [
        "Model",
        "get_model",
    ]
    
    
    def get_model(
        model_id: str,
        revision: Optional[str],
        sharded: bool,
        quantize: Optional[str],
        speculate: Optional[int],
        dtype: Optional[str],
        trust_remote_code: bool,
    ) -> Model:
        if speculate is not None:
            logger.warning("Speculate Decoding is not supported now!")
        set_speculate(0)
        
        # 1.引入tgi在npu上的适配包,tgi_npu 
        try:
            import torch_npu
            from tgi_npu import MindModel
            # for Qwen-VL
            # from tgi_npu import VlmMindModel
            npu_module_imported = True
        except (ImportError, NotImplementedError) as excp:
            npu_module_imported = False
            logger.error(f"Error catched: {str(excp)}")
    
        if npu_module_imported and torch.npu.is_available():
            config_dict, _ = PretrainedConfig.get_config_dict(
                model_id, revision=revision, trust_remote_code=trust_remote_code)
            model_type = config_dict.get("model_type", None)
            # 2.多模态模型支持Qwen-VL模型;文本模型支持范围同MindIE LLM
            # for Qwen-VL
            # if model_type == 'qwen' and 'visual' in config_dict:
            #     return VlmMindModel(model_id=model_id, trust_remote_code=trust_remote_code)
            return MindModel(model_id=model_id, trust_remote_code=trust_remote_code)
        else:
            logger.error("NPU enviroment error!!!!!!!!!!!!")
            raise ValueError("NPU enviroment error!!!!!!!!!!!!")
  • Tgi-MindIE/cover/cli.py
    import os
    import sys
    from pathlib import Path
    from typing import Optional
    from enum import Enum
    import typer
    
    from loguru import logger
    
    from huggingface_hub import hf_hub_download
    
    app = typer.Typer()
    
    MODEL_SUFFIX = ".safetensors"
    CONFIG_FILENAME = "config.json"
    
    
    class Quantization(str, Enum):
        bitsandbytes = "bitsandbytes"
        bitsandbytes_nf4 = "bitsandbytes-nf4"
        bitsandbytes_fp4 = "bitsandbytes-fp4"
        gptq = "gptq"
        awq = "awq"
        eetq = "eetq"
        fp8 = "fp8"
    
    
    class Dtype(str, Enum):
        float16 = "float16"
        bloat16 = "bfloat16"
    
    
    @app.command()
    def serve(
            model_id: str,
            revision: Optional[str] = None,
            sharded: bool = False,
            quantize: Optional[Quantization] = None,
            speculate: Optional[int] = None,
            dtype: Optional[Dtype] = None,
            trust_remote_code: bool = False,
            uds_path: Path = "/tmp/text-generation-server",
            logger_level: str = "INFO",
            json_output: bool = False,
            otlp_endpoint: Optional[str] = None,
    ):
        if sharded:
            assert (
                    os.getenv("RANK", None) is not None
            ), "RANK must be set when sharded is True"
            assert (
                    os.getenv("WORLD_SIZE", None) is not None
            ), "WORLD_SIZE must be set when sharded is True"
            assert (
                    os.getenv("MASTER_ADDR", None) is not None
            ), "MASTER_ADDR must be set when sharded is True"
            assert (
                    os.getenv("MASTER_PORT", None) is not None
            ), "MASTER_PORT must be set when sharded is True"
    
        # Remove default handler
        logger.remove()
        logger.add(
            sys.stdout,
            format="{message}",
            filter="text_generation_server",
            level=logger_level,
            serialize=json_output,
            backtrace=True,
            diagnose=False,
        )
        logger.add(
            sys.stdout,
            format="{message}",
            filter="tgi_npu",
            level=logger_level,
            serialize=json_output,
            backtrace=True,
            diagnose=False,
        )
    
        # Import here after the logger is added to log potential import exceptions
        from text_generation_server import server
        from text_generation_server.tracing import setup_tracing
    
        # Setup OpenTelemetry distributed tracing
        if otlp_endpoint is not None:
            setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
    
        # Downgrade enum into str for easier management later on
        quantize = None if quantize is None else quantize.value
        dtype = None if dtype is None else dtype.value
        if dtype is not None and quantize not in {
            None,
            "bitsandbytes",
            "bitsandbytes-nf4",
            "bitsandbytes-fp4",
        }:
            raise RuntimeError(
                "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
            )
        server.serve(
            model_id,
            revision,
            sharded,
            quantize,
            speculate,
            dtype,
            trust_remote_code,
            uds_path,
        )
    
    
    @app.command()
    def download_weights(
            model_id: str,
            revision: Optional[str] = None,
            extension: str = MODEL_SUFFIX,
            auto_convert: bool = True,
            logger_level: str = "INFO",
            json_output: bool = False,
            trust_remote_code: bool = False,
    ):
        # Remove default handler
        logger.remove()
        logger.add(
            sys.stdout,
            format="{message}",
            filter="text_generation_server",
            level=logger_level,
            serialize=json_output,
            backtrace=True,
            diagnose=False,
        )
    
        # Import here after the logger is added to log potential import exceptions
        from text_generation_server import utils
    
        # Test if files were already download
        try:
            utils.weight_files(model_id, revision, extension)
            logger.info("Files are already present on the host. " "Skipping download.")
            return
        # Local files not found
        except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
            pass
    
        is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
            "WEIGHTS_CACHE_OVERRIDE", None
        ) is not None
    
        if not is_local_model:
            try:
                adapter_config_filename = hf_hub_download(
                    model_id, revision=revision, filename="adapter_config.json"
                )
                utils.download_and_unload_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
                is_local_model = True
                utils.weight_files(model_id, revision, extension)
                return
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
    
            try:
                import json
    
                config = hf_hub_download(
                    model_id, revision=revision, filename=CONFIG_FILENAME
                )
                with open(config, "r") as f:
                    config = json.load(f)
    
                base_model_id = config.get("base_model_name_or_path", None)
                if base_model_id and base_model_id != model_id:
                    try:
                        logger.info(f"Downloading parent model {base_model_id}")
                        download_weights(
                            model_id=base_model_id,
                            revision="main",
                            extension=extension,
                            auto_convert=auto_convert,
                            logger_level=logger_level,
                            json_output=json_output,
                            trust_remote_code=trust_remote_code,
                        )
                    except Exception:
                        pass
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
    
            # Try to download weights from the hub
            try:
                filenames = utils.weight_hub_files(model_id, revision, extension)
                utils.download_weights(filenames, model_id, revision)
                # Successfully downloaded weights
                return
    
            # No weights found on the hub with this extension
            except utils.EntryNotFoundError as e:
                # Check if we want to automatically convert to safetensors or if we can use .bin weights instead
                if not extension == MODEL_SUFFIX or not auto_convert:
                    raise e
    
        elif (Path(model_id) / "adapter_config.json").exists():
            # Try to load as a local PEFT model
            try:
                utils.download_and_unload_peft(
                    model_id, revision, trust_remote_code=trust_remote_code
                )
                utils.weight_files(model_id, revision, extension)
                return
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
        elif (Path(model_id) / CONFIG_FILENAME).exists():
            # Try to load as a local Medusa model
            try:
                import json
    
                config = Path(model_id) / CONFIG_FILENAME
                with open(config, "r") as f:
                    config = json.load(f)
    
                base_model_id = config.get("base_model_name_or_path", None)
                if base_model_id:
                    try:
                        logger.info(f"Downloading parent model {base_model_id}")
                        download_weights(
                            model_id=base_model_id,
                            revision="main",
                            extension=extension,
                            auto_convert=auto_convert,
                            logger_level=logger_level,
                            json_output=json_output,
                            trust_remote_code=trust_remote_code,
                        )
                    except Exception:
                        pass
            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
                pass
    
        # Try to see if there are local pytorch weights
        try:
            # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
            try:
                local_pt_files = utils.weight_files(model_id, revision, ".bin")
            except Exception:
                local_pt_files = utils.weight_files(model_id, revision, ".pt")
    
        # No local pytorch weights
        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
            if extension == MODEL_SUFFIX:
                logger.warning(
                    f"No safetensors weights found for model {model_id} at revision {revision}. "
                    f"Downloading PyTorch weights."
                )
    
            # Try to see if there are pytorch weights on the hub
            pt_filenames = utils.weight_hub_files(model_id, revision, ".bin")
            # Download pytorch weights
            local_pt_files = utils.download_weights(pt_filenames, model_id, revision)
    
        if auto_convert:
            if not trust_remote_code:
                logger.warning(
                    f"BREAKING CHANGE in 2.0: Safetensors conversion is disabled without `--trust-remote-code` "
                    f"because Pickle files are unsafe and can essentially contain remote code execution!"
                    f"Please check for more information here:"
                    f" https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
                )
    
            logger.warning(
                f"No safetensors weights found for model {model_id} at revision {revision}. "
                f"Converting PyTorch weights to safetensors."
            )
    
            # Safetensors final filenames
            local_st_files = [
                p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
                for p in local_pt_files
            ]
            try:
                import transformers
                import json
    
                if is_local_model:
                    config_filename = os.path.join(model_id, CONFIG_FILENAME)
                else:
                    config_filename = hf_hub_download(
                        model_id, revision=revision, filename=CONFIG_FILENAME
                    )
                with open(config_filename, "r") as f:
                    config = json.load(f)
                architecture = config["architectures"][0]
    
                class_ = getattr(transformers, architecture)
    
                # Name for this varible depends on transformers version.
                discard_names = getattr(class_, "_tied_weights_keys", [])
    
            except Exception as e:
                discard_names = []
            # Convert pytorch weights to safetensors
            utils.convert_files(local_pt_files, local_st_files, discard_names)
    
    
    @app.command()
    def quantize(
            model_id: str,
            output_dir: str,
            revision: Optional[str] = None,
            logger_level: str = "INFO",
            json_output: bool = False,
            trust_remote_code: bool = False,
            upload_to_model_id: Optional[str] = None,
            percdamp: float = 0.01,
            act_order: bool = False,
    ):
        if revision is None:
            revision = "main"
        download_weights(
            model_id=model_id,
            revision=revision,
            logger_level=logger_level,
            json_output=json_output,
        )
        from text_generation_server.utils.gptq.quantize import quantize
    
        quantize(
            model_id=model_id,
            bits=4,
            groupsize=128,
            output_dir=output_dir,
            revision=revision,
            trust_remote_code=trust_remote_code,
            upload_to_model_id=upload_to_model_id,
            percdamp=percdamp,
            act_order=act_order,
        )
    
    
    if __name__ == "__main__":
        app()
  • (Qwen-VL模型适配)Tgi-MindIE/cover/server.py
    import asyncio
    import os
    import torch
    import torch_npu
    import time
    import signal
    
    from grpc import aio
    from loguru import logger
    
    from grpc_reflection.v1alpha import reflection
    from pathlib import Path
    from typing import List, Optional
    
    from text_generation_server.cache import Cache
    from text_generation_server.interceptor import ExceptionInterceptor
    from text_generation_server.models import Model, get_model
    
    # 增加新引入的VlmMindFlashCausalLMBatch
    try:
        from tgi_npu.vlm_mind_models import VlmMindFlashCausalLMBatch
    
        VLM_BATCH_TYPES = {VlmMindFlashCausalLMBatch}
    except (ImportError, NotImplementedError):
        # These imports can fail on CPU/Non flash.
        VLM_BATCH_TYPES = set()
    
    from text_generation_server.pb import generate_pb2_grpc, generate_pb2
    from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
    from text_generation_server.models.globals import set_model_id
    
    soc_version = torch_npu._C._npu_get_soc_version()
    if soc_version not in [104, 220, 221, 222, 223, 224]:    
        option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"}
        torch.npu.set_option(option)
    else:
        option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"}
    torch.npu.set_option(option)
    
    
    class SignalHandler:
        KEEP_PROCESSING = True
    
        def __init__(self):
            signal.signal(signal.SIGINT, self.exit_gracefully)
            signal.signal(signal.SIGTERM, self.exit_gracefully)
    
        def exit_gracefully(self, signum, frame):
            print(f"Exiting gracefully: Signal {signum}")
            self.KEEP_PROCESSING = False
    
    
    class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
        def __init__(
            self,
            model: Model,
            cache: Cache,
            quantize: Optional[str],
            server_urls: List[str],
        ):
            self.cache = cache
            self.model = model
            self.quantize = quantize
            self.server_urls = server_urls
            # For some reason, inference_mode does not work well with GLOO which we use on CPU
            if model.device.type == "cuda" or model.device.type == "npu":
                # Force inference mode for the lifetime of TextGenerationService
                self._inference_mode_raii_guard = torch._C._InferenceMode(True)
            self.step = 0
    
        async def Info(self, request, context):
            return self.model.info
    
        async def Health(self, request, context):
            if self.model.device.type == "cuda":
                torch.zeros((2, 2)).cuda()
            return generate_pb2.HealthResponse()
    
        async def ServiceDiscovery(self, request, context):
            return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
    
        async def ClearCache(self, request, context):
            if request.HasField("id"):
                self.cache.delete(request.id)
            else:
                self.cache.clear()
            return generate_pb2.ClearCacheResponse()
    
        async def FilterBatch(self, request, context):
            batch = self.cache.pop(request.batch_id)
            if batch is None:
                raise ValueError(f"Batch ID {request.batch_id} not found in cache.")
            filtered_batch = batch.filter(request.request_ids)
            self.cache.set(filtered_batch)
    
            return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
    
        async def Warmup(self, request, context):
            for i, r in enumerate(request.batch.requests):
                r.parameters.typical_p = 1.0
                r.prefill_logprobs = False
            if self.quantize == "gptq":
                try:
                    # When using GPTQ, Exllama kernels need some global kernels
                    # For which we have the finale shapes only after the model has loaded
                    # This will allocate those buffers.
                    from text_generation_server.layers.gptq import (
                        create_exllama_buffers,
                        set_device,
                    )
    
                    set_device(self.model.device)
                    create_exllama_buffers(request.max_prefill_tokens)
                except ImportError:
                    pass
    
            if (
                self.model.batch_type in VLM_BATCH_TYPES
            ):  # Hack, i would rather use kwargs in the `from_pb` call
                for i, r in enumerate(request.batch.requests):
                    r.inputs = r.inputs.split('!')[0]
                batch = self.model.batch_type.from_pb_processor(
                    request.batch,
                    self.model.tokenizer,
                    self.model.tokenize,
                    self.model.dtype,
                    self.model.device,
                )
            else:
                batch = self.model.batch_type.from_pb(
                    request.batch, self.model.tokenizer, self.model.dtype, self.model.device
                )
            max_supported_total_tokens = self.model.warmup(batch)
    
            return generate_pb2.WarmupResponse(
                max_supported_total_tokens=max_supported_total_tokens
            )
    
        async def Prefill(self, request, context):
            start = time.time_ns()
            if (
                self.model.batch_type in VLM_BATCH_TYPES
            ):  # Hack, i would rather use kwargs in the `from_pb` call
                batch = self.model.batch_type.from_pb_processor(
                    request.batch,
                    self.model.tokenizer,
                    self.model.tokenize,
                    self.model.dtype,
                    self.model.device,
                )
            else:
                batch = self.model.batch_type.from_pb(
                    request.batch, self.model.tokenizer, self.model.dtype, self.model.device
                )
    
            generations, next_batch, timings = self.model.generate_token(batch)
            self.cache.set(next_batch)
    
            return generate_pb2.PrefillResponse(
                generations=[generation.to_pb() for generation in generations],
                batch=next_batch.to_pb() if next_batch else None,
                forward_ns=timings[0],
                decode_ns=timings[1],
                total_ns=time.time_ns() - start,
            )
    
        async def Decode(self, request, context):
            start = time.time_ns()
            if len(request.batches) == 0:
                raise ValueError("Must provide at least one batch")
    
            batches = []
            for batch_pb in request.batches:
                batch = self.cache.pop(batch_pb.id)
                if batch is None:
                    raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
                batches.append(batch)
    
            if len(batches) == 0:
                raise ValueError("All batches are empty")
    
            if len(batches) > 1:
                start_concat = time.time_ns()
                batch = self.model.batch_type.concatenate(batches)
                concat_ns = time.time_ns() - start_concat
            else:
                batch = batches[0]
                concat_ns = None
    
            generations, next_batch, timings = self.model.generate_token(batch)
            self.step += 1
            self.cache.set(next_batch)
    
            return generate_pb2.DecodeResponse(
                generations=[generation.to_pb() for generation in generations],
                batch=next_batch.to_pb() if next_batch else None,
                concat_ns=concat_ns,
                forward_ns=timings[0],
                decode_ns=timings[1],
                total_ns=time.time_ns() - start,
            )
    
    
    def serve(
            model_id: str,
            revision: Optional[str],
            sharded: bool,
            quantize: Optional[str],
            speculate: Optional[int],
            dtype: Optional[str],
            trust_remote_code: bool,
            uds_path: Path,
    ):
        async def serve_inner(
            model_id: str,
            revision: Optional[str],
            sharded: bool = False,
            quantize: Optional[str] = None,
            speculate: Optional[int] = None,
            dtype: Optional[str] = None,
            trust_remote_code: bool = False,
        ):
            unix_socket_template = "unix://{}-{}"
            if sharded:
                server_urls = [
                    unix_socket_template.format(uds_path, rank)
                    for rank in range(int(os.environ["WORLD_SIZE"]))
                ]
                local_url = server_urls[int(os.environ["RANK"])]
            else:
                local_url = unix_socket_template.format(uds_path, 0)
                server_urls = [local_url]
    
            try:
                model = get_model(
                    model_id,
                    revision,
                    sharded,
                    quantize,
                    speculate,
                    dtype,
                    trust_remote_code,
                )
            except Exception as e:
                logger.exception("Error when initializing model", e)
                raise e
    
            server = aio.server(
                interceptors=[
                    ExceptionInterceptor(),
                    UDSOpenTelemetryAioServerInterceptor(),
                ]
            )
            generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
                TextGenerationService(model, Cache(), quantize, server_urls), server
            )
            SERVICE_NAMES = (
                generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
                reflection.SERVICE_NAME,
            )
            reflection.enable_server_reflection(SERVICE_NAMES, server)
            server.add_insecure_port(local_url)
    
            await server.start()
    
            logger.info(f"Server started at {local_url}, pid s {os.getpid()}")
            signal_handler = SignalHandler()
            while signal_handler.KEEP_PROCESSING:
                await asyncio.sleep(0.5)
    
        set_model_id(model_id)
        asyncio.run(
            serve_inner(
                model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
            )
        )
  • Tgi-MindIE/tgi_npu/__init__.py
    #!/usr/bin/env python3
    # coding=utf-8
    # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    
    
    import torch
    import torch_npu
    from loguru import logger
    
    from tgi_npu.mind_models import MindModel
    
    
    def init():
        torch._C._InferenceMode(True)
        soc_version = torch_npu._C._npu_get_soc_version()
        if soc_version not in [104, 220, 221, 222, 223, 224]:
            logger.info("Some op does not support for this soc!")
            option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"}
        else:
            option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"}
        try:
            torch.npu.set_option(option)
            logger.warning("Finish init for NPU device!")
        except Exception as e:
            logger.error(f"Failed to init for NPU device: {e}!")
    
    
    init()
  • 为Batch中每个request分配KV Cache由CacheManager类负责,与GPU上的CacheManager类相比,进行如下适配:
    • 每个BLOCK_SIZE(即一个BLOCK中SLOT数量)建议设置为128。
    • KV Cache分配需要考虑数据是ND排布还是NZ排布,根据卡型号来判断。

      Tgi-MindIE/tgi_npu/cache_manager.py

      # Part of codes in this file was copied from project[huggingface][text-generation-inference]
      
      import math
      from typing import Optional, List, Tuple
      import gc
      import torch
      from tgi_npu.info import NPUSocInfo
      
      # 1. 建议修改为128
      BLOCK_SIZE: int = 128
      # Will be set in warmup
      CACHE_MANAGER: Optional["CacheManager"] = None
      
      
      class CacheManager:
          def __init__(
                  self,
                  num_blocks: int,
                  num_layers: int,
                  num_heads: int,
                  head_size: int,
                  repeat_slots: bool,
                  dtype: torch.dtype,
                  device: torch.device,
          ):
              self.block_size = BLOCK_SIZE
              self.num_blocks = num_blocks
              self.repeat_slots = repeat_slots
              # 2.根据npu卡设置kvcache中tensor的数据排布格式
              self.need_nz = NPUSocInfo().need_nz
              if self.need_nz:
                  self.kv_cache = [
                      (
                          torch.empty(
                              (num_blocks, num_heads * head_size // 16, self.block_size, 16),
                              dtype=dtype,
                              device=device,
                          ),
                          torch.empty(
                              (num_blocks, num_heads * head_size // 16, self.block_size, 16),
                              dtype=dtype,
                              device=device,
                          ),
                      )
                      for _ in range(num_layers)
                  ]
              else:
                  self.kv_cache = [
                      (
                          torch.empty(
                              (num_blocks, self.block_size, num_heads, head_size),
                              dtype=dtype,
                              device=device,
                          ),
                          torch.empty(
                              (num_blocks, self.block_size, num_heads, head_size),
                              dtype=dtype,
                              device=device,
                          ),
                      )
                      for _ in range(num_layers)
                  ]
              self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
              self.slots = torch.arange(
                  0, num_blocks * self.block_size, dtype=torch.int64
              ).view(num_blocks, self.block_size)
      
          def __repr__(self):
              return (f"CacheManager: "
                      f"num_blocks={self.num_blocks},"
                      f"block_size={self.block_size},"
                      f"free_block_mask={self.free_block_mask},"
                      f"slots={self.slots},"
                      f"k_cache shape={self.kv_cache[0][0].shape},"
                      f"v_cache shape={self.kv_cache[0][1].shape}")
      
          def allocate(
                  self,
                  needed_blocks_slots: List[Tuple[int, int]],
                  blocks: int,
                  max_blocks: int,
                  device: torch.device,
          ):
              # Get free blocks indices by finding values in mask that are not set to 0
              free_block_indices = self.free_block_mask.nonzero()
              if blocks > len(free_block_indices):
                  raise RuntimeError(
                      f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
                  )
      
              # Slice by the number of required blocks
              block_indices = free_block_indices[:blocks]
              block_indices = block_indices.flatten()
      
              # Padded block tables
              block_tables_tensor = torch.zeros(
                  (len(needed_blocks_slots), max_blocks), dtype=torch.int32
              )
      
              # Allocate paged attention blocks
              cumulative_blocks = 0
              slots = []
              block_tables = []
              for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
                  # Get allocated blocks for this sequence
                  allocated_blocks = block_indices[
                                     cumulative_blocks: cumulative_blocks + needed_blocks
                                     ]
                  # Get slots for the allocated blocks
                  all_slots = self.slots[allocated_blocks].flatten()
      
                  # Repeat slots in the case of context sliding window
                  if needed_slots > len(all_slots) and self.repeat_slots:
                      repeats = math.ceil(needed_slots / len(all_slots))
                      all_slots = all_slots.repeat(repeats)
      
                  allocated_slots = all_slots[:needed_slots]
      
                  slots.append(allocated_slots)
                  block_tables.append(allocated_blocks.tolist())
                  block_tables_tensor[i, :needed_blocks] = allocated_blocks
                  cumulative_blocks += needed_blocks
      
              block_tables = block_tables
              block_tables_tensor = block_tables_tensor.to(device)
              slots = torch.concat(slots).to(device)
      
              # Allocate the required number of blocks by setting the mask to 0
              self.free_block_mask[block_indices] = 0
      
              return block_tables, block_tables_tensor, slots
      
          def free(self, block_indices: Optional[List[int]]):
              if block_indices is not None and block_indices:
                  # Reset mask
                  self.free_block_mask[block_indices] = 1
      
      
      def set_cache_manager(
              num_blocks: int,
              num_layers: int,
              num_heads: int,
              head_size: int,
              repeat_slots: bool,
              dtype: torch.dtype,
              device: torch.device,
      ) -> CacheManager:
          global CACHE_MANAGER
          if CACHE_MANAGER is not None:
              del CACHE_MANAGER
              torch.npu.empty_cache()
              gc.collect()
      
          CACHE_MANAGER = CacheManager(
              num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
          )
          return CACHE_MANAGER
      
      
      def get_cache_manager() -> CacheManager:
          global CACHE_MANAGER
          if CACHE_MANAGER is None:
              raise RuntimeError("cache manager was not initialized")
      
          return CACHE_MANAGER
  • NPU卡数据ND/NZ排布判断

    Tgi-MindIE/tgi_npu/info.py

    from dataclasses import dataclass
    import torch_npu
    
    @dataclass
    class NPUSocInfo:
         soc_name: str = ""
         soc_version: int = -1
         need_nz: bool = False
         def __post_init__(self):
             self.soc_version = torch_npu._C._npu_get_soc_version()
             if self.soc_version in (100, 101, 102, 103, 104, 200, 201, 202, 203):
                 self.need_nz = True
  • MindIE模型与请求Batch类。
    • 添加文本模型类MindModel:MindModel类继承自TGI框架的FlashCausalLM类,主要适配部分为初始化MindIE LLM模型,并获得模型、tokenizer及模型相关的信息。
    • 添加文本请求的Batch类MindFlashCausalLMBatch:创建MindFlashCausalLMBatch类作为MindIE LLM的文本请求Batch,继承自原TGI框架的FlashCausalLMBatch类。

      Tgi-MindIE/tgi_npu/mind_models.py

      # Part of codes in this file was copied from project[huggingface][text-generation-inference]
      
      
      import math
      import time
      import os
      from typing import Optional, Tuple, List, Type
      from dataclasses import dataclass
      
      import torch_npu
      import torch
      from loguru import logger
      from opentelemetry import trace
      import numpy as np
      
      from text_generation_server.models.flash_causal_lm import FlashCausalLM, FlashCausalLMBatch
      from text_generation_server.models.types import (
          Batch,
          Tokens,
          Generation,
          GeneratedText
      )
      
      from text_generation_server.utils import StoppingCriteria
      from text_generation_server.pb import generate_pb2
      from text_generation_server.utils.speculate import get_speculate
      from text_generation_server.utils.dist import RANK, MEMORY_FRACTION
      from text_generation_server.utils.tokens import batch_top_tokens
      from text_generation_server.models import cache_manager as tgi_cache_manager
      
      from transformers import PreTrainedTokenizerBase
      from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch
      from mindie_llm.modeling.backend_type import BackendType
      
      from tgi_npu.tokens_mindie import MindIELLMHeterogeneousNextTokenChooser
      
      from tgi_npu.cache_manager import (
          BLOCK_SIZE,
          get_cache_manager,
          set_cache_manager,
      )
      
      tracer = trace.get_tracer(__name__)
      
      
      @dataclass
      class MindFlashCausalLMBatch(FlashCausalLMBatch):
          # 1. 在基类FlashCausalLMBatch上增加两个字段,next_token_chooser(MindIE LLM的后处理类)、all_input_ids_tensor为(
          # 给next_token_chooser进行sampling)
          next_token_chooser: MindIELLMHeterogeneousNextTokenChooser
          all_input_ids_tensor: torch.Tensor
      
          def __repr__(self):
              return (f"MindFlashCausalLMBatch: batch_id={self.batch_id},"
                      f"requests_idx_mapping={self.requests_idx_mapping},"
                      f"input_ids={self.input_ids},"
                      f"position_ids={self.position_ids},"
                      f"cu_seqlen_prefill={self.cu_seqlen_prefill},"
                      f"start_slots={self.start_slots},"
                      f"slot_indices={self.slot_indices},"
                      f"needed_blocks_slots={self.needed_blocks_slots},"
                      f"block_tables={self.block_tables},"
                      f"block_tables_tensor={self.block_tables_tensor},"
                      f"slots={self.slots},"
                      f"max_seqlen={self.max_seqlen},"
                      f"prefill_head_indices={self.prefill_head_indices},"
                      f"prefill_next_token_indices={self.prefill_next_token_indices},"
                      f"prefill_cu_outlens={self.prefill_cu_outlens},"
                      f"input_lengths={self.input_lengths},"
                      f"input_lengths_tensor={self.input_lengths_tensor},"
                      f"prefix_offsets={self.prefix_offsets},"
                      f"read_offsets={self.read_offsets},"
                      f"all_input_ids_tensor={self.all_input_ids_tensor},"
                      f"next_token_chooser={self.next_token_chooser},"
                      f"stopping_criterias={self.stopping_criterias},"
                      f"blocks={self.blocks},"
                      f"max_blocks={self.max_blocks}")
      
          @classmethod
          def from_pb(
                  cls,
                  pb: generate_pb2.Batch,
                  tokenizer: PreTrainedTokenizerBase,
                  dtype: torch.dtype,
                  device: torch.device
          ) -> "MindFlashCausalLMBatch":
              batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
              return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
      
          @classmethod
          def from_tokenized(
                  cls,
                  pb: generate_pb2.Batch,
                  tokenizer: PreTrainedTokenizerBase,
                  batch_tokenized_inputs,
                  dtype: torch.dtype,
                  device: torch.device,
          ) -> "MindFlashCausalLMBatch":
              position_ids = []
              cu_seqlen_prefill = [0]
              needed_blocks_slots = []
              start_slots = []
              slot_indices = []
      
              input_lengths = []
              prefix_offsets = []
              read_offsets = []
              all_input_ids = []
              requests_idx_mapping = {}
      
              all_prefill_logprobs = True
              no_prefill_logprobs = True
              prefill_head_indices = []
              prefill_next_token_indices = []
              prefill_cu_outlens = [0]
      
              next_token_chooser_parameters = []
              stopping_criterias = []
              top_n_tokens = []
      
              # Cumulative length
              cumulative_length = 0
              cumulative_max_length = 0
              prefill_out_cumulative_length = 0
      
              blocks = 0
              max_seqlen = 0
              max_length = 0
              max_blocks = 0
      
              # Parse batch
              for i, (r, tokenized_input) in enumerate(
                      zip(pb.requests, batch_tokenized_inputs)
              ):
                  # request id -> idx in list mapping
                  requests_idx_mapping[r.id] = i
      
                  tokenized_input = tokenized_input[-r.truncate:]
                  if (
                          tokenized_input[0] == tokenizer.bos_token_id
                          and tokenized_input[1] == tokenizer.bos_token_id
                  ):
                      tokenized_input = tokenized_input[1:]
      
                  input_length = len(tokenized_input)
                  input_lengths.append(input_length)
      
                  prefix_offsets.append(input_length - 5)
                  read_offsets.append(input_length)
      
                  all_input_ids.append(tokenized_input)
      
                  # Position ids
                  request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
                  position_ids.append(request_position_ids)
      
                  # Add cumulative lengths of all previous inputs
                  cu_seqlen_prefill.append(cumulative_length + input_length)
      
                  next_token_chooser_parameters.append(r.parameters)
      
                  stopping_criteria = StoppingCriteria.from_pb(
                      r.stopping_parameters, tokenizer
                  )
                  max_new_tokens = stopping_criteria.max_new_tokens
                  stopping_criterias.append(stopping_criteria)
                  top_n_tokens.append(r.top_n_tokens)
      
                  # Paged attention
                  # Remove one as the first token des not have a past
                  speculative_length = get_speculate()
                  speculative_length = 0 if speculative_length is None else speculative_length
                  total_tokens = input_length + max_new_tokens - 1 + speculative_length
                  needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
                  blocks += needed_blocks
                  needed_blocks_slots.append((needed_blocks, total_tokens))
                  start_slots.append(cumulative_max_length)
      
                  request_slot_indices = torch.arange(
                      cumulative_max_length,
                      cumulative_max_length + input_length,
                      dtype=torch.int64,
                  )
                  slot_indices.append(request_slot_indices)
      
                  all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
                  no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
      
                  if r.prefill_logprobs:
                      prefill_head_indices.append(request_position_ids + cumulative_length)
                      prefill_next_token_indices.append(
                          prefill_out_cumulative_length + input_length - 1
                      )
                      prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                      prefill_out_cumulative_length += input_length
                  else:
                      prefill_head_indices.append(
                          torch.tensor(
                              [cumulative_length + input_length - 1], dtype=torch.int64
                          )
                      )
                      prefill_next_token_indices.append(prefill_out_cumulative_length)
                      prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                      prefill_out_cumulative_length += 1
      
                  # Update
                  cumulative_length += input_length
                  cumulative_max_length += total_tokens
                  max_seqlen = max(max_seqlen, input_length)
                  max_blocks = max(max_blocks, needed_blocks)
                  max_length = max(
                      max_length, input_length + max_new_tokens + speculative_length
                  )
              # 2.构建后处理类
              next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb(
                  pb=next_token_chooser_parameters, dtype=dtype, device=device
              )
      
              start_slots = torch.tensor(start_slots, dtype=torch.int64)
              
              # 3. 构建传给后处理类的input id tensor
              all_input_ids_tensor = np.zeros(
                  (len(all_input_ids), max_length), dtype=np.int64
              )
              for i, input_ids in enumerate(all_input_ids):
                  all_input_ids_tensor[i, : len(input_ids)] = input_ids
      
              all_input_ids_tensor = torch.tensor(
                  all_input_ids_tensor, dtype=torch.int64, device=device
              )
      
              if len(pb.requests) > 1:
                  input_ids = np.concatenate(all_input_ids, dtype=np.int64)
                  position_ids = torch.cat(position_ids)
                  slot_indices = torch.cat(slot_indices)
              else:
                  input_ids = all_input_ids[0]
                  position_ids = position_ids[0]
                  slot_indices = slot_indices[0]
      
              cu_seqlen_prefill = torch.tensor(
                  cu_seqlen_prefill, device=device, dtype=torch.int64
              )
              position_ids = position_ids.to(device)
              slot_indices = slot_indices.to(device)
              input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
              input_lengths_tensor = torch.tensor(
                  input_lengths, dtype=torch.int64, device=device
              )
      
              if all_prefill_logprobs:
                  prefill_head_indices = None
                  prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
              elif no_prefill_logprobs:
                  prefill_head_indices = cu_seqlen_prefill[1:] - 1
                  prefill_next_token_indices = None
              else:
                  prefill_head_indices = torch.tensor(
                      torch.cat(prefill_head_indices), dtype=torch.int64, device=device
                  )
                  prefill_next_token_indices = torch.tensor(
                      prefill_next_token_indices, dtype=torch.int64, device=device
                  )
              top_n_tokens_tensor = torch.tensor(
                  top_n_tokens, device=device, dtype=torch.int64
              )
      
              return cls(
                  batch_id=pb.id,
                  requests=pb.requests,
                  requests_idx_mapping=requests_idx_mapping,
                  input_ids=input_ids,
                  position_ids=position_ids,
                  cu_seqlen_prefill=cu_seqlen_prefill,
                  start_slots=start_slots,
                  slot_indices=slot_indices,
                  needed_blocks_slots=needed_blocks_slots,
                  block_tables=None,
                  block_tables_tensor=None,
                  slots=None,
                  max_seqlen=max_seqlen,
                  prefill_head_indices=prefill_head_indices,
                  prefill_next_token_indices=prefill_next_token_indices,
                  prefill_cu_outlens=prefill_cu_outlens,
                  input_lengths=input_lengths,
                  input_lengths_tensor=input_lengths_tensor,
                  prefix_offsets=prefix_offsets,
                  read_offsets=read_offsets,
                  all_input_ids=all_input_ids,
                  all_input_ids_tensor=all_input_ids_tensor,
                  next_token_chooser=next_token_chooser,
                  stopping_criterias=stopping_criterias,
                  top_n_tokens=top_n_tokens,
                  top_n_tokens_tensor=top_n_tokens_tensor,
                  blocks=blocks,
                  max_blocks=max_blocks,
                  speculative_ids=None,
              )
      
          @classmethod
          @tracer.start_as_current_span("concatenate")
          def concatenate(cls, batches: List["MindFlashCausalLMBatch"]) -> "MindFlashCausalLMBatch":
              # Batch attributes
              requests = []
              requests_idx_mapping = {}
      
              blocks = 0
              total_batch_size = 0
              total_slots = 0
              max_blocks = 0
              max_length = 0
              max_seqlen = 0
              for b in batches:
                  total_batch_size += len(b)
                  total_slots += len(b.slots)
                  blocks += b.blocks
                  speculative_length = (
                      b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
                  )
                  max_blocks = max(max_blocks, b.max_blocks)
                  max_seqlen = max(max_seqlen, b.max_seqlen)
                  max_length = max(
                      max_length,
                      max(
                          input_length
                          + stopping_criteria.max_new_tokens
                          + speculative_length
                          - stopping_criteria.current_tokens
                          for input_length, stopping_criteria in zip(
                              b.input_lengths, b.stopping_criterias
                          )
                      ),
                  )
      
              input_ids = batches[0].input_ids.new_empty(total_batch_size)
              position_ids = batches[0].position_ids.new_empty(total_batch_size)
              slots = batches[0].slots.new_empty(total_slots)
              slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
              input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
                  total_batch_size
              )
              block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
                  (total_batch_size, max_blocks)
              )
      
              all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
                  (total_batch_size, max_length)
              )
              top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
                  total_batch_size,
              )
      
              start_slots = []
              block_tables = []
              all_input_ids = []
      
              input_lengths = []
              prefix_offsets = []
              read_offsets = []
      
              next_token_chooser_parameters = []
              fsm_grammar_states = []
              stopping_criterias = []
              top_n_tokens = []
      
              # Cumulative length
              cumulative_batch_size = 0
              cumulative_slots = 0
      
              for i, batch in enumerate(batches):
                  requests.extend(batch.requests)
      
                  if i == 0:
                      requests_idx_mapping = batch.requests_idx_mapping
                  else:
                      # We need to offset the mapping for each batch by the cumulative batch size
                      for k, v in batch.requests_idx_mapping.items():
                          requests_idx_mapping[k] = v + cumulative_batch_size
      
                  start_index = cumulative_batch_size
                  end_index = cumulative_batch_size + len(batch)
                  slots_start_index = cumulative_slots
                  slots_end_index = cumulative_slots + len(batch.slots)
      
                  # Copy tensors (GPU)
                  input_ids[start_index:end_index] = batch.input_ids
                  position_ids[start_index:end_index] = batch.position_ids
                  slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
                  input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
                  top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
                  slots[slots_start_index:slots_end_index] = batch.slots
      
                  all_input_ids_tensor[
                  start_index:end_index, : batch.all_input_ids_tensor.shape[1]
                  ] = batch.all_input_ids_tensor[:, :max_length]
      
                  block_tables_tensor[
                  start_index:end_index, : batch.block_tables_tensor.shape[1]
                  ] = batch.block_tables_tensor[:, :max_blocks]
      
                  start_slots.append(batch.start_slots + cumulative_slots)
      
                  block_tables.extend(batch.block_tables)
                  all_input_ids.extend(batch.all_input_ids)
      
                  input_lengths.extend(batch.input_lengths)
                  prefix_offsets.extend(batch.prefix_offsets)
                  read_offsets.extend(batch.read_offsets)
      
                  next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
                  fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
                  stopping_criterias.extend(batch.stopping_criterias)
      
                  top_n_tokens.extend(batch.top_n_tokens)
      
                  # Update
                  cumulative_batch_size += len(batch)
                  cumulative_slots += len(batch.slots)
      
              start_slots = torch.concat(start_slots)
      
              next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb(
                  pb=next_token_chooser_parameters,
                  dtype=batches[0].next_token_chooser.dtype,
                  device=batches[0].next_token_chooser.device
              )
      
              speculative_ids = (
                  torch.cat([b.speculative_ids for b in batches], dim=0)
                  if batches[0].speculative_ids is not None
                  else None
              )
      
              # Needed to avoid dropping blocks when the batches will go out of scope
              for b in batches:
                  b.block_tables = None
                  del b
      
              return cls(
                  batch_id=batches[0].batch_id,
                  requests=requests,
                  requests_idx_mapping=requests_idx_mapping,
                  input_ids=input_ids,
                  position_ids=position_ids,
                  cu_seqlen_prefill=None,
                  start_slots=start_slots,
                  slot_indices=slot_indices,
                  needed_blocks_slots=None,
                  block_tables=block_tables,
                  block_tables_tensor=block_tables_tensor,
                  slots=slots,
                  max_seqlen=max_seqlen,
                  prefill_head_indices=None,
                  prefill_next_token_indices=None,
                  prefill_cu_outlens=None,
                  input_lengths=input_lengths,
                  input_lengths_tensor=input_lengths_tensor,
                  prefix_offsets=prefix_offsets,
                  read_offsets=read_offsets,
                  all_input_ids=all_input_ids,
                  all_input_ids_tensor=all_input_ids_tensor,
                  next_token_chooser=next_token_chooser,
                  stopping_criterias=stopping_criterias,
                  top_n_tokens=top_n_tokens,
                  top_n_tokens_tensor=top_n_tokens_tensor,
                  blocks=blocks,
                  max_blocks=max_blocks,
                  speculative_ids=speculative_ids,
              )
      
          def to_pb(self) -> generate_pb2.CachedBatch:
              return generate_pb2.CachedBatch(
                  id=self.batch_id,
                  request_ids=[r.id for r in self.requests],
                  size=len(self),
                  max_tokens=self.blocks * BLOCK_SIZE,
              )
      
      class MindModel(FlashCausalLM):
          def __init__(
                  self,
                  model_id: str,
                  trust_remote_code: bool
          ):
              logger.warning("Initialize mindie-llm model.")
              rank = int(os.getenv("RANK", "0"))
              world_size = int(os.getenv("WORLD_SIZE", "1"))
              model_config = {
                  'backend_type': BackendType.ATB,
                  'rank': rank,
                  'world_size': world_size,
                  'model_id': model_id,
                  'num_threads': 8,
                  'local_rank': rank,
                  'npu_device_id': rank,
                  'trust_remote_code': trust_remote_code
              }
              # 1. 初始化mindie llm模型,获得的self.model_runner包含模型信息
              self.model_runner = GeneratorTorch(model_config)
              super(MindModel, self).__init__(
                  model=self.model_runner.model_wrapper.model_runner.model,
                  tokenizer=self.model_runner.tokenizer,
                  num_layers=self.model_runner.model_info.num_layers,
                  num_kv_heads=self.model_runner.model_info.num_kv_heads,
                  head_size=self.model_runner.model_info.head_size,
                  dtype=self.model_runner.model_info.dtype,
                  device=self.model_runner.model_info.device,
                  rank=self.model_runner.rank,
                  world_size=self.model_runner.world_size,
              )
              logger.warning("MindModel from tgi_npu initialized.")
      
          def __del__(self):
              del self.model_runner.model_wrapper
      
          @property
          def batch_type(self) -> Type[MindFlashCausalLMBatch]:
              return MindFlashCausalLMBatch
      
          def forward(
                  self, batch: MindFlashCausalLMBatch
          ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
              """Assume return logits, speculative_logits"""
              input_ids = batch.input_ids
              position_ids = batch.position_ids
              cu_seqlen_prefill = batch.cu_seqlen_prefill
              kv_cache = get_cache_manager().kv_cache
              block_tables = batch.block_tables_tensor
              slots = batch.slots[batch.slot_indices]
              input_lengths = batch.input_lengths_tensor
              max_s = batch.max_seqlen
              lm_head_indices = batch.prefill_head_indices
              speculative_logits = None
      
              # 2. 为每个batch请求生成token
              # 改用MindModel初始化生成的self.model_runner的forward_tensor推理接口(链接)进行推理,输入为MindFlashCausalLMBatch类(链接)中相应字段
              return self.model_runner.forward_tensor(
                  input_ids=input_ids,
                  position_ids=position_ids,
                  is_prefill=cu_seqlen_prefill is not None,
                  kv_cache=kv_cache,
                  block_tables=block_tables,
                  slots=slots,
                  input_lengths=input_lengths,
                  max_seq_len=max_s,
                  lm_head_indices=lm_head_indices,
              ), speculative_logits
      
          @tracer.start_as_current_span("generate_token")
          def generate_token(
                  self, batch: MindFlashCausalLMBatch
          ) -> Tuple[List[Generation], Optional[MindFlashCausalLMBatch], Tuple[int, int]]:
              start = time.time_ns()
              prefill = batch.cu_seqlen_prefill is not None
              prefill_logprobs = batch.prefill_next_token_indices is not None
              # check if need slots
              if batch.needed_blocks_slots:
                  # Allocate blocks to this batch
                  block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
                      batch.needed_blocks_slots,
                      batch.blocks,
                      batch.max_blocks,
                      batch.input_ids.device,
                  )
                  batch.needed_blocks_slots = None
                  batch.block_tables = block_tables
                  batch.block_tables_tensor = block_tables_tensor
                  batch.slots = slots
      
              try:
                  out, speculative_logits = self.forward(batch)
              except Exception as e:
                  del batch
                  raise e
      
              if prefill:
                  next_token_logits = (
                      out[batch.prefill_next_token_indices] if prefill_logprobs else out
                  )
                  if speculative_logits is not None:
                      speculative_logits = (
                          speculative_logits[batch.prefill_next_token_indices]
                          if prefill_logprobs
                          else speculative_logits
                      )
              else:
                  logger.debug(f"Decode batch size {batch.input_ids.shape[0]}")
                  next_token_logits = out
      
              speculate = get_speculate()
              
              # 3. 后处理部分,输入参数适配MindIELLMHeterogeneousNextTokenChooser类中的采样方法
              request_ids = [req.id for req in batch.requests]
              (
                  next_input_ids,
                  next_token_logprobs,
                  logprobs,
                  accepted_ids,
                  speculative_ids,
              ) = batch.next_token_chooser(
                  request_ids,
                  prefill,
                  batch.all_input_ids_tensor[:, : batch.max_seqlen],
                  next_token_logits,
                  speculate
              )
      
              batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
                  batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
              )
      
              if prefill:
                  if len(batch) > 1 and prefill_logprobs:
                      # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                      # When batch == 1, we will just use the batch.input_ids values directly
                      prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
      
                  next_position_ids = batch.position_ids.new_empty(len(batch))
                  batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
                  # We do not need cu_seqlen_prefill anymore
                  batch.cu_seqlen_prefill = None
              else:
                  prefill_logprobs = None
                  next_position_ids = batch.position_ids
      
              # Cumulative length
              cumulative_length = 0
      
              # Results
              generations: List[Generation] = []
              stopped = True
      
              # Zipped iterator
              iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
      
              # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
              # one, we need to first do a GPU <-> CPU sync
              # It is faster if we delay this sync for the maximum amount of time
      
              # For each member of the batch
              index = 0
              for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
                  # Indexing metadata
                  start_index = cumulative_length
                  end_index = cumulative_length + input_length
      
                  if prefill:
                      # Indexing metadata
                      out_start_index = batch.prefill_cu_outlens[i]
                      out_end_index = batch.prefill_cu_outlens[i + 1]
                      out_length = out_end_index - out_start_index
      
                      # Initialize position_ids
                      # In decode, we do not need this as we can just increment position ids
                      next_position_ids[i] = batch.position_ids[end_index - 1]
      
                      # Used to gather prefill logprobs
                      # Copy batch.input_ids to prefill_token_indices
                      if prefill_logprobs:
                          if len(batch) > 1:
                              prefill_tokens_indices[out_start_index: out_end_index - 1] = (
                                  batch.input_ids[start_index + 1: start_index + out_length]
                              )
                          else:
                              # Set prefill_tokens_indices to the correct slice
                              prefill_tokens_indices = batch.input_ids[
                                                       start_index + 1: start_index + out_length
                                                       ]
      
                  for _ in range(n_accepted_ids):
                      index += 1
      
                  cumulative_length += input_length
      
              batch.all_input_ids_tensor.scatter_(1, batch.input_lengths_tensor.view(batch.input_lengths_tensor.shape[0], 1),
                                                  next_input_ids.view(next_input_ids.shape[0], 1))
              # Update values
              batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
              batch.speculative_ids = speculative_ids
              batch.position_ids = next_position_ids + accepted_ids
              batch.input_lengths_tensor += accepted_ids
              batch.slot_indices += accepted_ids
      
              if prefill and prefill_logprobs:
                  # Get prefill logprobs
                  prefill_logprobs_tensor = torch.log_softmax(out, -1)
                  prefill_logprobs = torch.gather(
                      prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
                  )
                  # GPU <-> CPU sync
                  prefill_logprobs = prefill_logprobs.view(-1).tolist()
      
              # GPU <-> CPU sync
              next_token_logprobs = next_token_logprobs.tolist()
              next_token_ids = next_input_ids.tolist()
              accepted_ids = accepted_ids.tolist()
              start_decode = time.time_ns()
      
              # Zipped iterator
              iterator = zip(
                  batch.requests,
                  batch.input_lengths,
                  batch.prefix_offsets,
                  batch.read_offsets,
                  batch.stopping_criterias,
                  batch.all_input_ids,
                  batch.next_token_chooser.do_sample,
                  batch.next_token_chooser.seeds,
                  batch.top_n_tokens,
                  accepted_ids,
                  batch_top_token_ids,
                  batch_top_token_logprobs,
              )
      
              # For each member of the batch
              index = 0
              for i, (
                      request,
                      input_length,
                      prefix_offset,
                      read_offset,
                      stopping_criteria,
                      all_input_ids,
                      do_sample,
                      seed,
                      top_n_tokens,
                      n_accepted_ids,
                      top_token_ids,
                      top_token_logprobs,
              ) in enumerate(iterator):
                  # Append next token to all tokens
                  next_token_texts = []
                  left = 0
      
                  if n_accepted_ids > 1:
                      if RANK == 0:
                          logger.debug(f"Speculated ids {n_accepted_ids - 1}")
      
                  current_stopped = False
                  for j in range(index, index + n_accepted_ids):
                      # Generated token
                      next_token_id = next_token_ids[j]
                      all_input_ids.append(next_token_id)
                      # Generated token
                      next_token_text, prefix_offset, read_offset = self.decode_token(
                          all_input_ids,
                          prefix_offset,
                          read_offset,
                      )
      
                      next_token_texts.append(next_token_text)
      
                      stop, reason = stopping_criteria(
                          next_token_id,
                          next_token_text,
                      )
      
                      if stop:
                          left = index + n_accepted_ids - j - 1
                          current_stopped = True
                          break
                      else:
                          current_stopped = False
                  stopped = stopped and current_stopped
      
                  _next_token_ids = next_token_ids[index: index + n_accepted_ids - left]
                  _next_token_logprobs = next_token_logprobs[
                                         index: index + n_accepted_ids - left
                                         ]
                  index += n_accepted_ids
      
                  # Shard generations
                  # All generations will be appended in the rust sharded client
                  if i % self.world_size == self.rank:
                      if stop:
                          # Decode generated tokens
                          output_text, _, _ = self.decode_token(
                              all_input_ids,
                              prefix_offset=len(all_input_ids)
                                            - stopping_criteria.current_tokens
                                            - 1,
                              read_offset=len(all_input_ids)
                                          - stopping_criteria.current_tokens,
                              skip_special_tokens=True,
                          )
                          generated_text = GeneratedText(
                              output_text,
                              stopping_criteria.current_tokens,
                              reason,
                              seed if do_sample else None,
                          )
                      else:
                          generated_text = None
      
                      # Prefill
                      if prefill and request.prefill_logprobs:
                          out_start_index = batch.prefill_cu_outlens[i]
                          out_end_index = batch.prefill_cu_outlens[i + 1]
      
                          # Remove generated token to only have prefill and add nan for first prompt token
                          request_prefill_logprobs = [float("nan")] + prefill_logprobs[
                                                                      out_start_index: out_end_index - 1
                                                                      ]
                          prefill_token_ids = all_input_ids[:-1]
                          prefill_texts = self.tokenizer.batch_decode(
                              prefill_token_ids,
                              clean_up_tokenization_spaces=False,
                              skip_special_tokens=False,
                          )
      
                          prefill_tokens = Tokens(
                              prefill_token_ids,
                              request_prefill_logprobs,
                              prefill_texts,
                              is_special=[],
                          )
                      else:
                          prefill_tokens = None
      
                      if top_n_tokens > 0:
                          all_top_tokens = []
                          for top_token_ids, top_token_logprobs in zip(
                                  top_token_ids, top_token_logprobs
                          ):
                              toptoken_texts = self.tokenizer.batch_decode(
                                  top_token_ids,
                                  clean_up_tokenization_spaces=False,
                                  skip_special_tokens=False,
                              )
                              special_toptokens = [
                                  token_id in self.all_special_ids
                                  for token_id in top_token_ids
                              ]
                              top_tokens = Tokens(
                                  top_token_ids,
                                  top_token_logprobs,
                                  toptoken_texts,
                                  special_toptokens,
                              )
                              all_top_tokens.append(top_tokens)
                          top_tokens = all_top_tokens
                      else:
                          top_tokens = None
      
                      generation = Generation(
                          request.id,
                          prefill_tokens,
                          Tokens(
                              _next_token_ids,
                              _next_token_logprobs,
                              next_token_texts,
                              [nid in self.all_special_ids for nid in _next_token_ids],
                          ),
                          generated_text,
                          top_tokens,
                      )
      
                      generations.append(generation)
      
                  # Update values
                  batch.input_lengths[i] = input_length + n_accepted_ids
                  if batch.input_lengths[i] > batch.max_seqlen:
                      batch.max_seqlen = batch.input_lengths[i]
                  batch.prefix_offsets[i] = prefix_offset
                  batch.read_offsets[i] = read_offset
                  batch.all_input_ids[i] = all_input_ids
      
              if stopped:
                  del batch
                  # No need to return a batch if we know that all requests stopped
                  forward_ns = start_decode - start
                  decode_ns = time.time_ns() - start_decode
                  none_batch = None
                  return generations, none_batch, (forward_ns, decode_ns)
      
              batch.prefill_cu_outlens = None
              batch.prefill_head_indices = None
              batch.prefill_next_token_indices = None
      
              forward_ns = start_decode - start
              decode_ns = time.time_ns() - start_decode
              return generations, batch, (forward_ns, decode_ns)
      
          # 需要适配的部分是获取Warmup Batch计算时峰值NPU显存占用、NPU卡最大显存、以及计算出装满剩余显存的最大Token数量,峰值时Token数量的计算。
          def warmup(self, batch: MindFlashCausalLMBatch):
              # The warmup batch is the biggest batch we could ever receive
              torch.npu.empty_cache()
      
              peak_memory = torch_npu.npu.max_memory_allocated()
              logger.info(f">>>>before warmup peak_memory {peak_memory}")
              try:
                  cache_manager = set_cache_manager(
                      batch.blocks,
                      self.num_layers,
                      self.num_kv_heads,
                      self.head_size,
                      False,
                      self.dtype,
                      self.device,
                  )
                  _, batch, _ = self.generate_token(batch)
              except torch.cuda.OutOfMemoryError as e:
                  raise RuntimeError(
                      f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                      f"You need to decrease `--max-batch-prefill-tokens`"
                  ) from e
      
              # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
              # Calculate the number of blocks that can be allocated with the free memory
              dtype_size = torch.tensor([], dtype=self.dtype).element_size()
              cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
              total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
              torch_npu.npu.synchronize()
      
              # NPU卡总显存
              total_gpu_memory = torch_npu.npu.get_device_properties(self.device).total_memory
      
              # 峰值显存
              peak_memory = torch_npu.npu.max_memory_allocated()
              logger.info(
                  f">>>>dtype_size {dtype_size}, cache_block_size {cache_block_size}, num_kv_heads {self.num_kv_heads}, "
                  f"total_cache_size {total_cache_size}, peak_memory {peak_memory}")
      
              # 剩余可用显存
              total_free_memory = total_gpu_memory - peak_memory
              logger.info(f">>>>total_free_memory {total_free_memory}, total_gpu_memory {total_gpu_memory}, "
                             f"MEMORY_FRACTION {MEMORY_FRACTION}")
      
              # 剩余可用显存
              free_memory = max(0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory)
              # 总共可支持的KV Block
              num_blocks = (
                  # Leave 5% for some wiggle room
                      int((free_memory * 0.95) // total_cache_size)
                      # Add batch.blocks as we allocated it above, so it is included in the peak memory.
                      + cache_manager.num_blocks
              )
      
              del batch
              del cache_manager
      
              real_manager = set_cache_manager(
                  num_blocks,
                  self.num_layers,
                  self.num_kv_heads,
                  self.head_size,
                  self.sliding_window is not None,
                  self.dtype,
                  self.device,
              )
              tgi_cache_manager.CACHE_MANAGER = real_manager
              logger.warning(f">>>>real CacheManger {get_cache_manager()}")
              peak_memory = torch_npu.npu.max_memory_allocated()
              logger.warning(f">>>>end warmup peak_memory {peak_memory}")
              logger.warning(f"Warmup return {int(num_blocks * BLOCK_SIZE)}")
      
              #总共可支持的KV slot(total batch tokens),BLOCK_SIZE建议128
              return int(num_blocks * BLOCK_SIZE)
  • 增加后处理类MindIELLMHeterogeneousNextTokenChooser:后处理类根据Batch中每个Request的采样参数,为每个Batch构造一个后处理对象。Sampling为从推理生成的logits中采样出token(s)。
    Tgi-MindIE/tgi_npu/tokens_mindie.py
    # Part of codes in this file was copied from project[huggingface][text-generation-inference]
    
    import os
    from typing import List, Optional
    from loguru import logger
    import numpy as np
    import torch
    from text_generation_server.pb import generate_pb2
    
    from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam
    from mindie_llm.text_generator.utils.config import SamplerConfig
    from mindie_llm.text_generator.samplers.sampler import Sampler
    
    from mindie_llm.modeling.backend_type import BackendType
    
    WRAPPER_KEY = "tensor_wrapper"
    try:
        from mindie_llm.text_generator.utils.sampling_metadata import TensorWrapper
    except ImportError:
        class TensorWrapper:
            def __init__(self, backend, device):
                self.device = device
                self.backend = backend
    
            def __call__(self, data):
                if data.dtype == np.int32:
                    dtype = torch.int32
                elif data.dtype == np.bool_:
                    dtype = torch.bool
                else:
                    dtype = None
                return torch.tensor(data, dtype=dtype, device=self.device)
    
    
        WRAPPER_KEY = "to_tensor"
    
    
    def do_filter(sample_param: List, indices):
        if any(sample_param):
            return [sample_param[i] for i in indices]
        return sample_param
    
    
    class MindIELLMHeterogeneousNextTokenChooser:
        def __init__(
                self,
                dtype: torch.dtype,
                device: torch.device,
                watermark: List[bool],
                temperature: List[float],
                repetition_penalty: List[float],
                frequency_penalty: List[float],
                top_k: List[int],
                top_p: List[float],
                typical_p: List[float],
                do_sample: List[bool],
                seeds: List[int],
                grammars: List[str],
                grammar_types: List[int],
                fsm_grammar_states: List[int],
                sample_method,  # mindie-llm sampling method
        ):
            if any(watermark):
                logger.warning(f"Watermark not supported now in mindie-llm")
            if any([x < 1.0 for x in typical_p]):
                logger.warning(f"Typical_p not supported now in mindie-llm")
            if any(grammar_types) or any(grammars) or any(fsm_grammar_states):
                logger.warning(f"Grammar not supported now in mindie-llm")
            
            # 1. 初始化函数,从grpc接收的后采样参数构造mindie sampler
            self.tensor_wrapper = TensorWrapper(BackendType.ATB, device)
            self.wrapper_dict = {WRAPPER_KEY: TensorWrapper(BackendType.ATB, device)}
            self.sample_params = SamplingParam.from_numpy(
                repetition_penalty=np.array(repetition_penalty, dtype=np.float16),
                presence_penalty=None,
                frequency_penalty=np.array(frequency_penalty, dtype=np.float16),
                temperature=np.array(temperature, dtype=np.float16),
                top_k=np.array(top_k),
                top_p=np.array(top_p),
                seed=np.array(seeds).astype(np.int32),
                do_sample=np.array(do_sample),
                **self.wrapper_dict
            )
    
            # Temp store for filter
            self.temperature = temperature
            self.repetition_penalty = repetition_penalty
            self.frequency_penalty = frequency_penalty
            self.top_k = top_k
            self.top_p = top_p
            self.seeds = seeds
            self.do_sample = do_sample
    
            self.choice = sample_method
            self.dtype = dtype
            self.device = device
            self.seeds = seeds
            self.do_sample = self.sample_params.do_sample_meta.do_sample_array
            self.fsm_grammar_states = fsm_grammar_states
    
        # 2. 根据输入token id(input_ids)及推理的logits(scores),采样出下一个tokenid(next_ids)
        def __call__(self,
                     request_ids: List,
                     is_prefill: bool,
                     input_ids: torch.Tensor,
                     scores: torch.Tensor,
                     speculate: int
                     ):
            batch_size = scores.shape[0]
            speculate_size = 1
            scores = scores.view(batch_size, speculate_size, -1)
    
            # Don't use SamplingData。from_numpy to avoid tensor.cpu() to transfer large data
            input_ids_int32 = input_ids.to(torch.int32)
            sample_data = SamplingData(all_input_ids=input_ids_int32, output_ids=input_ids_int32, is_prefill=is_prefill,
                                       request_ids=np.array(request_ids))
            next_ids = torch.zeros((batch_size, speculate_size), device=scores.device, dtype=torch.long)
            for j in range(speculate_size):
                _scores = scores[:, j]
    
                batch_logits, _next_ids = self.choice(batch_logits=_scores, batch_sampling_data=sample_data,
                                                      batch_sampling_params=self.sample_params)
                scores[:, j] = _scores
                next_ids[:, j] = torch.from_numpy(_next_ids)
            next_ids = next_ids.view(batch_size * speculate_size)
            allscores = scores.view(batch_size * speculate_size, -1)
            alllogprobs = torch.log_softmax(allscores, -1)
    
            accepted_ids = torch.ones_like(next_ids)
            logprobs = alllogprobs
    
            next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
    
            speculative_ids = None
            return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
    
        @classmethod
        def from_pb(
                cls,
                pb: List[generate_pb2.NextTokenChooserParameters],
                dtype: torch.dtype,
                device: torch.device,
                fsm_grammar_states: Optional[List[int]] = None,
        ) -> "MindIELLMHeterogeneousNextTokenChooser":
            curr_rank = int(os.getenv("RANK", "0"))
            sample_method = Sampler(SamplerConfig(rank=curr_rank, backend_type=BackendType.ATB, npu_id=curr_rank))
            return MindIELLMHeterogeneousNextTokenChooser(
                watermark=[pb_.watermark for pb_ in pb],
                temperature=[pb_.temperature for pb_ in pb],
                repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
                frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
                top_k=[pb_.top_k for pb_ in pb],
                top_p=[pb_.top_p for pb_ in pb],
                typical_p=[pb_.typical_p for pb_ in pb],
                do_sample=[pb_.do_sample for pb_ in pb],
                seeds=[pb_.seed for pb_ in pb],
                device=device,
                dtype=dtype,
                grammars=[pb_.grammar for pb_ in pb],
                grammar_types=[pb_.grammar_type for pb_ in pb],
                fsm_grammar_states=(
                    fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
                ),
                sample_method=sample_method
            )
    
        def filter(self, indices):
            self.repetition_penalty = do_filter(self.repetition_penalty, indices)
            self.frequency_penalty = do_filter(self.frequency_penalty, indices)
            self.temperature = do_filter(self.temperature, indices)
            self.top_k = do_filter(self.top_k, indices)
            self.top_p = do_filter(self.top_p, indices)
            self.seeds = do_filter(self.seeds, indices)
            self.do_sample = do_filter(self.do_sample, indices)
    
            self.sample_params = SamplingParam.from_numpy(
                repetition_penalty=np.array(self.repetition_penalty, dtype=np.float16),
                presence_penalty=None,
                frequency_penalty=np.array(self.frequency_penalty, dtype=np.float16),
                temperature=np.array(self.temperature, dtype=np.float16),
                top_k=np.array(self.top_k),
                top_p=np.array(self.top_p),
                seed=np.array(self.seeds).astype(np.int32),
                do_sample=np.array(self.do_sample),
                **self.wrapper_dict
            )
            return self
  • (Qwen-VL模型适配)MindIE多模态模型与多模态请求Batch类。
    • 添加多模态模型类VlmMindModel:多模态模型类继承自MindModel,在初始化部分额外设置了Tokenize方法,该方法由MindIE LLM侧提供,专用于多模态的编码。
    • 添加多模态请求的batch类VlmMindFlashCausalLMBatch:创建VlmMindFlashCausalLMBatch类作为MindIE LLM的多模态请求Batch,继承自NPU适配后的MindFlashCausalLMBatch类。主要是针对入参进行参数拆解组装,并使用Tokenzie对入参进行编码。
      Tgi-MindIE/tgi_npu/vlm_mind_models.py
      import re
      from typing import List, Type, Dict
      from dataclasses import dataclass
      import torch
      from loguru import logger
      from atb_llm.runner.tokenizer_wrapper import TokenizerWrapper
      from text_generation_server.pb import generate_pb2
      from transformers import PreTrainedTokenizerBase
      from tgi_npu.mind_models import MindFlashCausalLMBatch, MindModel
      IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
      def split(string) -> List[Dict[str, str]]:
          parts = []
          cursor = 0
          for pattern in IMAGES.finditer(string):
              start = pattern.start()
              if start != cursor:
                  parts.append({"text": string[cursor:start]})
              parts.append({"image": pattern.group(1)})
              cursor = pattern.end()
          if cursor != len(string):
              parts.append({"text": string[cursor:]})
          return parts
      @dataclass
      class VlmMindFlashCausalLMBatch(MindFlashCausalLMBatch):
          @classmethod
          def batch_tokenized_inputs(cls, requests, tokenize):
              inputs = []
              for r in requests:
                  splits = split(r.inputs)
                  single_input = tokenize(splits).tolist()
                  inputs.append(single_input)
              return inputs
          @classmethod
          def from_pb_processor(
                  cls,
                  pb: generate_pb2.Batch,
                  tokenizer: PreTrainedTokenizerBase,
                  tokenize,
                  dtype: torch.dtype,
                  device: torch.device,
          ) -> "VlmMindFlashCausalLMBatch":
              batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenize)
              batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
              return batch
      
      class VlmMindModel(MindModel):
          def __init__(
                  self,
                  model_id: str,
                  trust_remote_code: bool
          ):
              logger.warning("Initialize mindie-llm model for vlm.")
              # 使用父类初始化方法
              super(VlmMindModel, self).__init__(model_id, trust_remote_code)
              # 额外设置tokenize, 该方法是由MindIE-LLM提供的专用于多模态编码的方法
              self.tokenize = TokenizerWrapper(model_id, trust_remote_code=trust_remote_code).tokenize
              logger.warning("VlmMindModel from tgi_npu initialized.")
          @property
          def batch_type(self) -> Type[VlmMindFlashCausalLMBatch]:
              # 返回多模态Batch类型, 用于在server端接收gRPC请求后转换为多模态Batch
              return VlmMindFlashCausalLMBatch
  • Tgi-MindIE/pyproject.toml
    [tool.poetry]
    name = "tgi-npu"
    version = "0.1.0"
    description = "NPU MindIE Adapter for TGI v2.0.4"
    authors = ["Your Name <you@example.com>"]
    exclude = ["router", "cover"]
    
    
    
    [tool.poetry.dependencies]
    python = "^3.10"
    
    [build-system]
    requires = ["poetry-core"]
    build-backend = "poetry.core.masonry.api"