TGI v0.9.4 版本参考适配代码
Tgi-MindIE |______cover |______cli.py |______models |________init__.py |______utils |________init__.py |______hub.py |______import_utils.py |______peft.py |______tgi_npu |______init__.py |____cache_manager.py |____info.py |____mind_models.py |____tokens_mindie.py |______pyproject.toml |______install.sh |______requirements.txt
各源文件的含义和作用如表1所示。
源文件 |
含义及作用 |
---|---|
cover/models/__init__.py |
替换原仓中server/text_generation_server/models/__init__.py文件,将推理模型引导至MindIE LLM。 |
cover/cli.py |
替换原仓中server/text_generation_server/cli.py文件,添加tgi_npu模块的日志打印过滤。 |
cover/utils/peft.py |
为支持peft推理(单lora), 引入加载peft模型的方法。 |
cover/utils/__init__.py |
替换原仓中server/text_generation_server/utils/__init__.py文件, 引入新增peft模块中的download_and_unload_peft方法。 |
cover/utils/hub.py |
替换原仓中server/text_generation_server/utils/hub.py文件, 兼容peft模型文件名adapter_model.safetensors。 |
cover/utils/import_utils.py |
封装torch内存相关的操作,比如清空cache、同步和获取内存信息。 |
tgi_npu/__init__.py |
针对NPU硬件环境进行必要的初始化。 |
tgi_npu/cache_manager.py |
KV Cache管理器,主要针对NPU进行KV Cache初始化。 |
tgi_npu/info.py |
NPU信息。 |
tgi_npu/mind_models.py |
定义了推理模型入口类MindModel以及对应的数据通信格式MindFlashCausalLMBatch,分别继承自原仓的FlashCausalLM以及FlashCausalLMBatch。在MindModel中,generate_token方法沿用了原版大部分代码,并结合MindIE LLM调用过程进行了修改。其中,Forward方法改为调用MindIE LLM提供的forward_tensor方法。warmup 结合NPU访存特点进行修改。 |
tgi_npu/tokens_mindie.py |
后采样代码。 |
pyproject.toml |
适配安装包配置文件。 |
requirements.txt |
python依赖库。 |
install.sh |
安装脚本。 |
样例代码:
- Tgi-MindIE/install.sh
#!/usr/bin/env bash # install-origin if [ -d "./tgi_origin" ]; then echo "./tgi_origin directory hads already exist!" exit 1 fi git clone -b v0.9.4 https://github.com/huggingface/text-generation-inference.git tgi_origin cp cover/cli.py tgi_origin/server/text_generation_server/ cp cover/models/__init__.py tgi_origin/server/text_generation_server/models cp cover/utils/__init__.py tgi_origin/server/text_generation_server/utils cp cover/utils/hub.py tgi_origin/server/text_generation_server/utils cp cover/utils/import_utils.py tgi_origin/server/text_generation_server/utils cp cover/utils/peft.py tgi_origin/server/text_generation_server/utils sed -i "s/requires_padding, 16/requires_padding, 128/g" tgi_origin/router/src/infer.rs sed -i "s/prefill_logprobs: true/prefill_logprobs: false/g" tgi_origin/router/client/src/client.rs sed -i "s/Vec::from(encoding.get_ids())/encoding.get_ids()/g" tgi_origin/router/src/validation.rs sed -i "s/bnb, accelerate/accelerate/g" tgi_origin/server/Makefile cd tgi_origin && make install-server && make install-router && make install-launcher cd .. && pip install -e . pip install -r requirements.txt
- Tgi-MindIE/requirements.txt
accelerate==0.29.1 safetensors==0.4.4 torch==2.1.0 transformers==4.40.1 peft==0.13.2
- Tgi-MindIE/cover/models/__init__.pyMindIE LLM切入:将原TGI框架从调用GPU模型切到MindIE LLM。
# This file was copied from project[huggingface][text-generation-inference] from typing import Optional import torch from loguru import logger from text_generation_server.models.model import Model # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True # Disable gradients torch.set_grad_enabled(False) __all__ = [ "Model", "get_model", ] def get_model( model_id: str, revision: Optional[str], sharded: bool, quantize: Optional[str], dtype: Optional[str], trust_remote_code: bool, ) -> Model: if dtype is None: dtype = torch.float16 elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") # 1.引入tgi在npu上的适配包,tgi_npu try: import torch_npu from tgi_npu import MindModel npu_module_imported = True except (ImportError, NotImplementedError) as excp: npu_module_imported = False logger.error(f"Error catched: {str(excp)}") # 2.导入到MindIE LLM模型,文本模型支持范围同MindIE LLM if npu_module_imported and torch.npu.is_available(): return MindModel(model_id=model_id, trust_remote_code=trust_remote_code) else: logger.error("NPU enviroment error!!!!!!!!!") raise ValueError(f"NPU enviroment error!!!!!!!!!")
- Tgi-MindIE/cover/cli.py
# This file was copied from project[huggingface][text-generation-inference] import os import sys from pathlib import Path from typing import Optional from enum import Enum import typer from loguru import logger app = typer.Typer() MODEL_SUFFIX = ".safetensors" class Quantization(str, Enum): bitsandbytes = "bitsandbytes" gptq = "gptq" smooth_quant = 'smooth_quant' class Dtype(str, Enum): float16 = "float16" bloat16 = "bfloat16" @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, ): if sharded: assert ( os.getenv("RANK", None) is not None ), "RANK must be set when sharded is True" assert ( os.getenv("WORLD_SIZE", None) is not None ), "WORLD_SIZE must be set when sharded is True" assert ( os.getenv("MASTER_ADDR", None) is not None ), "MASTER_ADDR must be set when sharded is True" assert ( os.getenv("MASTER_PORT", None) is not None ), "MASTER_PORT must be set when sharded is True" # Remove default handler logger.remove() logger.add( sys.stdout, format="{message}", filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, ) logger.add( sys.stdout, format="{message}", filter="tgi_npu", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, ) # Import here after the logger is added to log potential import exceptions from text_generation_server import server from text_generation_server.tracing import setup_tracing # Setup OpenTelemetry distributed tracing if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value if dtype is not None and quantize is not None: raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path ) @app.command() def download_weights( model_id: str, revision: Optional[str] = None, extension: str = MODEL_SUFFIX, auto_convert: bool = True, logger_level: str = "INFO", json_output: bool = False, ): # Remove default handler logger.remove() logger.add( sys.stdout, format="{message}", filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, diagnose=False, ) # Import here after the logger is added to log potential import exceptions from text_generation_server import utils # Test if files were already download try: utils.weight_files(model_id, revision, extension) logger.info("Files are already present on the host. " "Skipping download.") return # Local files not found except (utils.LocalEntryNotFoundError, FileNotFoundError): pass is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( "WEIGHTS_CACHE_OVERRIDE", None ) is not None if not is_local_model: # Try to download weights from the hub try: filenames = utils.weight_hub_files(model_id, revision, extension) utils.download_weights(filenames, model_id, revision) # Successfully downloaded weights return # No weights found on the hub with this extension except utils.EntryNotFoundError as e: # Check if we want to automatically convert to safetensors or if we can use .bin weights instead if not extension == MODEL_SUFFIX or not auto_convert: raise e elif (Path(model_id) / "adapter_config.json").exists(): # Try to load as a local PEFT model try: logger.warning(f'{revision}') utils.download_and_unload_peft( model_id, revision, trust_remote_code=True ) utils.weight_files(model_id, revision, extension) return except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass # Try to see if there are local pytorch weights try: # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE local_pt_files = utils.weight_files(model_id, revision, ".bin") # No local pytorch weights except utils.LocalEntryNotFoundError: if extension == MODEL_SUFFIX: logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. " f"Downloading PyTorch weights." ) # Try to see if there are pytorch weights on the hub pt_filenames = utils.weight_hub_files(model_id, revision, ".bin") # Download pytorch weights local_pt_files = utils.download_weights(pt_filenames, model_id, revision) if auto_convert: logger.warning( f"No safetensors weights found for model {model_id} at revision {revision}. " f"Converting PyTorch weights to safetensors." ) # Safetensors final filenames local_st_files = [ p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files ] try: from transformers import AutoConfig import transformers config = AutoConfig.from_pretrained( model_id, revision=revision, ) architecture = config.architectures[0] class_ = getattr(transformers, architecture) # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) except Exception as e: discard_names = [] # Convert pytorch weights to safetensors utils.convert_files(local_pt_files, local_st_files, discard_names) @app.command() def quantize( model_id: str, output_dir: str, revision: Optional[str] = None, logger_level: str = "INFO", json_output: bool = False, trust_remote_code: bool = False, upload_to_model_id: Optional[str] = None, percdamp: float = 0.01, act_order: bool = False, ): if revision is None: revision = "main" download_weights( model_id=model_id, revision=revision, logger_level=logger_level, json_output=json_output, ) from text_generation_server.utils.gptq.quantize import quantize quantize( model_id=model_id, bits=4, groupsize=128, output_dir=output_dir, revision=revision, trust_remote_code=trust_remote_code, upload_to_model_id=upload_to_model_id, percdamp=percdamp, act_order=act_order, ) if __name__ == "__main__": app()
- Tgi-MindIE/cover/utils/__init__.py
# This file was copied from project[huggingface][text-generation-inference] from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.weights import Weights from text_generation_server.utils.peft import download_and_unload_peft from text_generation_server.utils.hub import ( weight_files, weight_hub_files, download_weights, EntryNotFoundError, LocalEntryNotFoundError, RevisionNotFoundError, ) from text_generation_server.utils.tokens import ( NextTokenChooser, HeterogeneousNextTokenChooser, StoppingCriteria, StopSequenceCriteria, FinishReason, Sampling, Greedy, ) __all__ = [ "convert_file", "convert_files", "initialize_torch_distributed", "weight_files", "weight_hub_files", "download_weights", "download_and_unload_peft", "EntryNotFoundError", "HeterogeneousNextTokenChooser", "LocalEntryNotFoundError", "RevisionNotFoundError", "Greedy", "NextTokenChooser", "Sampling", "StoppingCriteria", "StopSequenceCriteria", "FinishReason", "Weights", ]
- Tgi-MindIE/cover/utils/hub.py
# This file was copied from project[huggingface][text-generation-inference] import time import os from datetime import timedelta from pathlib import Path from typing import Optional, List from loguru import logger from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, EntryNotFoundError, RevisionNotFoundError, # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) def weight_hub_files( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[str]: api = HfApi() info = api.model_info(model_id, revision=revision) def is_valid_filename(filename: str, extension: str) -> bool: return ( filename.endswith(extension) and len(filename.split("/")) == 1 and "arguments" not in filename and "args" not in filename and "training" not in filename ) filenames = [ s.rfilename for s in info.siblings if is_valid_filename(s.rfilename, extension) ] if not filenames: raise EntryNotFoundError( f"No {extension} weights found for model {model_id} and revision {revision}.", None, ) return filenames def try_to_load_from_cache( model_id: str, revision: Optional[str], filename: str ) -> Optional[Path]: """Try to load a file from the Hugging Face cache""" if revision is None: revision = "main" object_id = model_id.replace("/", "--") repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" if not repo_cache.is_dir(): # No cache for this model none_type = None logger.info("No cache for this model") return none_type refs_dir = repo_cache / "refs" snapshots_dir = repo_cache / "snapshots" # Resolve refs (for instance to convert main to the associated commit sha) if refs_dir.is_dir(): revision_file = refs_dir / revision if revision_file.exists(): with revision_file.open() as f: revision = f.read() # Check if revision folder exists if not snapshots_dir.exists(): none_type = None logger.info("Revision folder doesn't exsist") return none_type cached_shas = os.listdir(snapshots_dir) if revision not in cached_shas: none_type = None # No cache for this revision and we won't try to return a random revision logger.info("No cache for this revision and we won't try to return a random revision") return none_type # Check if file exists in cache cached_file = snapshots_dir / revision / filename return cached_file if cached_file.is_file() else None def weight_files( model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" ) -> List[Path]: """Get the local files""" # Local model if Path(model_id).exists() and Path(model_id).is_dir(): local_files = list(Path(model_id).glob(f"*{extension}")) filtered_files = [f for f in local_files if "adapter" not in str(f)] if not filtered_files: raise FileNotFoundError( f"No local weights found in {model_id} with extension {extension}" ) return filtered_files try: filenames = weight_hub_files(model_id, revision, extension) except EntryNotFoundError as e: if extension != ".safetensors": raise e # Try to see if there are pytorch weights pt_filenames = weight_hub_files(model_id, revision, extension=".bin") # Change pytorch extension to safetensors extension # It is possible that we have safetensors weights locally even though they are not on the # hub if we converted weights locally without pushing them filenames = [ f"{Path(pt_filename).stem.lstrip('pytorch_')}.safetensors" \ for pt_filename in pt_filenames ] if WEIGHTS_CACHE_OVERRIDE is not None: files = [] for filename in filenames: p = Path(WEIGHTS_CACHE_OVERRIDE) / filename if not p.exists(): raise FileNotFoundError( f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}." ) files.append(p) return files files = [] for filename in filenames: cache_file = try_to_load_from_cache( model_id, revision=revision, filename=filename ) if cache_file is None: raise LocalEntryNotFoundError( f"File {filename} of model {model_id} not found in " f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " f"Please run `text-generation-server download-weights {model_id}` first." ) files.append(cache_file) return files def download_weights( filenames: List[str], model_id: str, revision: Optional[str] = None ) -> List[Path]: """Download the safetensors files from the hub""" def download_file(filename="", tries=5, backoff: int = 5): local_file = try_to_load_from_cache(model_id, revision, filename) if local_file is not None: logger.info(f"File {filename} already present in cache.") return Path(local_file) else: invalid_value = None return invalid_value for i in range(tries): try: logger.info(f"Download file: {filename}") start_time = time.time() local_file = hf_hub_download( filename=filename, repo_id=model_id, revision=revision, local_files_only=False, ) logger.info( f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." ) return Path(local_file) except Exception as e: logger.error(e) if i + 1 == tries: raise Exception(f"Failed to download {filename} after {tries} attempts.") from e else: logger.info(f"Retrying in {backoff} seconds") time.sleep(backoff) logger.info(f"Retry {i + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher start_time = time.time() files = [] for i, filename in enumerate(filenames): file = download_file(filename) elapsed = timedelta(seconds=int(time.time() - start_time)) remaining = len(filenames) - (i + 1) eta = (elapsed / (i + 1)) * remaining if remaining > 0 else timedelta(0) logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}") files.append(file) return files
- Tgi-MindIE/cover/utils/import_utils.py
# This file was copied from project[huggingface][text-generation-inference] import torch def is_xpu_available(): try: import intel_extension_for_pytorch except ImportError: return False return hasattr(torch, "xpu") and torch.xpu.is_available() def get_cuda_free_memory(device, memory_fraction): total_free_memory, _ = torch.cuda.mem_get_info(device) total_gpu_memory = torch.cuda.get_device_properties(device).total_memory free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory) return free_memory def get_xpu_free_memory(device, memory_fraction): total_gpu_memory = torch.xpu.get_device_properties(device).total_memory free_memory = int(total_gpu_memory * 0.5) return free_memory SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory elif torch.version.cuda is not None and torch.cuda.is_available(): SYSTEM = "cuda" empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory elif is_xpu_available(): SYSTEM = "xpu" empty_cache = torch.xpu.empty_cache synchronize = torch.xpu.synchronize get_free_memory = get_xpu_free_memory else: SYSTEM = "cpu" def noop(*args, **kwargs): pass empty_cache = noop synchronize = noop get_free_memory = noop
- Tgi-MindIE/cover/utils/peft.py
# This file was copied from project[huggingface][text-generation-inference] import os from loguru import logger import torch from transformers import AutoTokenizer from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM def download_and_unload_peft(model_id, revision, trust_remote_code): torch_dtype = torch.float16 logger.info("Trying to load a Peft model. It might take a while without feedback") try: model = AutoPeftModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, low_cpu_mem_usage=True, ) except Exception: model = AutoPeftModelForSeq2SeqLM.from_pretrained( model_id, revision=revision, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, low_cpu_mem_usage=True, ) logger.info("Peft model detected.") logger.info(f"Merging the lora weights.") base_model_id = model.peft_config["default"].base_model_name_or_path model = model.merge_and_unload() os.makedirs(model_id, exist_ok=True) cache_dir = model_id logger.info(f"Saving the newly created merged model to {cache_dir}") tokenizer = AutoTokenizer.from_pretrained( base_model_id, trust_remote_code=trust_remote_code ) model.save_pretrained(cache_dir, safe_serialization=True) model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)
- Tgi-MindIE/tgi_npu/__init__.py
#!/usr/bin/env python3 # coding=utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. import torch import torch_npu from loguru import logger from tgi_npu.mind_models import MindModel def init(): torch._C._InferenceMode(True) soc_version = torch_npu._C._npu_get_soc_version() if soc_version not in [104, 220, 221, 222, 223, 224]: logger.info("Some op does not support for this soc!") option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceNansum"} else: option = {"NPU_FUZZY_COMPILE_BLACKLIST": "GatherElements"} try: torch.npu.set_option(option) logger.warning("Finish init for NPU device!") except Exception as e: logger.error(f"Failed to init for NPU device: {e}!") init()
- Tgi-MindIE/tgi_npu/cache_manager.py
- 为Batch中每个request分配KV Cache由CacheManager类负责,与GPU上的CacheManager类相比,进行如下适配:
- 每个BLOCK_SIZE(即一个BLOCK中SLOT数量)建议设置为128。
- KV Cache分配需要考虑数据是ND排布还是NZ排布,根据卡型号来判断。
# This file was copied from project[huggingface][text-generation-inference] import math from typing import Optional, List, Tuple import gc import torch from tgi_npu.info import NPUSocInfo # 1. 建议修改为128 BLOCK_SIZE: int = 128 # Will be set in warmup CACHE_MANAGER: Optional["CacheManager"] = None class CacheManager: def __init__( self, num_blocks: int, num_layers: int, num_heads: int, head_size: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, ): self.block_size = BLOCK_SIZE self.num_blocks = num_blocks self.repeat_slots = repeat_slots # 2.根据npu卡设置kvcache中tensor的数据排布格式 self.need_nz = NPUSocInfo().need_nz if self.need_nz: self.kv_cache = [ ( torch.empty( (num_blocks, num_heads * head_size // 16, self.block_size, 16), dtype=dtype, device=device, ), torch.empty( (num_blocks, num_heads * head_size // 16, self.block_size, 16), dtype=dtype, device=device, ), ) for _ in range(num_layers) ] else: self.kv_cache = [ ( torch.empty( (num_blocks, self.block_size, num_heads, head_size), dtype=dtype, device=device, ), torch.empty( (num_blocks, self.block_size, num_heads, head_size), dtype=dtype, device=device, ), ) for _ in range(num_layers) ] self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu") self.slots = torch.arange( 0, num_blocks * self.block_size, dtype=torch.int64 ).view(num_blocks, self.block_size) def __repr__(self): return (f"CacheManager: " f"num_blocks={self.num_blocks}," f"block_size={self.block_size}," f"free_block_mask={self.free_block_mask}," f"slots={self.slots}," f"k_cache shape={self.kv_cache[0][0].shape}," f"v_cache shape={self.kv_cache[0][1].shape}") def allocate( self, needed_blocks_slots: List[Tuple[int, int]], blocks: int, max_blocks: int, device: torch.device, ): # Get free blocks indices by finding values in mask that are not set to 0 free_block_indices = self.free_block_mask.nonzero() if blocks > len(free_block_indices): raise RuntimeError( f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks" ) # Slice by the number of required blocks block_indices = free_block_indices[:blocks] block_indices = block_indices.flatten() # Padded block tables block_tables_tensor = torch.zeros( (len(needed_blocks_slots), max_blocks), dtype=torch.int32 ) # Allocate paged attention blocks cumulative_blocks = 0 slots = [] block_tables = [] for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots): # Get allocated blocks for this sequence allocated_blocks = block_indices[ cumulative_blocks: cumulative_blocks + needed_blocks ] # Get slots for the allocated blocks all_slots = self.slots[allocated_blocks].flatten() # Repeat slots in the case of context sliding window if needed_slots > len(all_slots) and self.repeat_slots: repeats = math.ceil(needed_slots / len(all_slots)) all_slots = all_slots.repeat(repeats) allocated_slots = all_slots[:needed_slots] slots.append(allocated_slots) block_tables.append(allocated_blocks.tolist()) block_tables_tensor[i, :needed_blocks] = allocated_blocks cumulative_blocks += needed_blocks block_tables = block_tables block_tables_tensor = block_tables_tensor.to(device) slots = torch.concat(slots).to(device) # Allocate the required number of blocks by setting the mask to 0 self.free_block_mask[block_indices] = 0 return block_tables, block_tables_tensor, slots def free(self, block_indices: Optional[List[int]]): if block_indices is not None and block_indices: # Reset mask self.free_block_mask[block_indices] = 1 def set_cache_manager( num_blocks: int, num_layers: int, num_heads: int, head_size: int, repeat_slots: bool, dtype: torch.dtype, device: torch.device, ) -> CacheManager: global CACHE_MANAGER if CACHE_MANAGER is not None: del CACHE_MANAGER torch.npu.empty_cache() gc.collect() CACHE_MANAGER = CacheManager( num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device ) return CACHE_MANAGER def get_cache_manager() -> CacheManager: global CACHE_MANAGER if CACHE_MANAGER is None: raise RuntimeError("cache manager was not initialized") return CACHE_MANAGER
- 为Batch中每个request分配KV Cache由CacheManager类负责,与GPU上的CacheManager类相比,进行如下适配:
- NPU卡数据ND/NZ排布判断。
#!/usr/bin/env python3 # coding=utf-8 # Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. from dataclasses import dataclass import torch_npu @dataclass class NPUSocInfo: soc_version: int = -1 need_nz: bool = False def __post_init__(self): self.soc_version = torch_npu._C._npu_get_soc_version() if self.soc_version in (100, 101, 102, 103, 104, 200, 201, 202, 203): self.need_nz = True
- Tgi-MindIE/tgi_npu/mind_models.py
- MindIE模型与请求Batch类。
- 添加文本模型类MindModel:MindModel类继承自TGI框架的FlashCausalLM类,主要适配部分为初始化MindIE LLM模型,并获得模型、tokenizer及模型相关的信息。
- 添加文本请求的Batch类MindFlashCausalLMBatch:创建MindFlashCausalLMBatch类作为MindIE LLM的文本请求Batch,继承自原TGI框架的FlashCausalLMBatch类。
# Part of codes in this file was copied from project[huggingface][text-generation-inference] #!/usr/bin/env python3 # coding=utf-8 import math import itertools import os from typing import Optional, Tuple, List, Type from dataclasses import dataclass import torch_npu import torch from loguru import logger from opentelemetry import trace import numpy as np from transformers import PreTrainedTokenizerBase from mindie_llm.text_generator.adapter.generator_torch import GeneratorTorch from mindie_llm.modeling.backend_type import BackendType from text_generation_server.models.flash_causal_lm import FlashCausalLM, FlashCausalLMBatch from text_generation_server.models.types import ( Batch, PrefillTokens, Generation, GeneratedText ) from text_generation_server.utils import StoppingCriteria from text_generation_server.pb import generate_pb2 from text_generation_server.utils.dist import RANK, MEMORY_FRACTION from text_generation_server.utils.import_utils import ( empty_cache, synchronize, get_free_memory, ) from tgi_npu.tokens_mindie import MindIELLMHeterogeneousNextTokenChooser from tgi_npu.cache_manager import ( BLOCK_SIZE, get_cache_manager, set_cache_manager, ) tracer = trace.get_tracer(__name__) @dataclass class MindFlashCausalLMBatch(FlashCausalLMBatch): # 1. 在基类FlashCausalLMBatch上增加字段,next_token_chooser(MindIE LLM的后处理类)、all_input_ids_tensor为( # 给next_token_chooser进行sampling) next_token_chooser: MindIELLMHeterogeneousNextTokenChooser all_input_ids_tensor: torch.Tensor block_tables: Optional[List[List[int]]] block_tables_tensor: Optional[torch.Tensor] def __repr__(self): return (f"MindFlashCausalLMBatch: batch_id={self.batch_id}," f"requests_idx_mapping={self.requests_idx_mapping}," f"input_ids={self.input_ids}," f"position_ids={self.position_ids}," f"cu_seqlen_prefill={self.cu_seqlen_prefill}," f"start_slots={self.start_slots}," f"slot_indices={self.slot_indices}," f"needed_blocks_slots={self.needed_blocks_slots}," f"block_tables={self.block_tables}," f"block_tables_tensor={self.block_tables_tensor}," f"slots={self.slots}," f"max_seqlen={self.max_seqlen}," f"prefill_head_indices={self.prefill_head_indices}," f"prefill_next_token_indices={self.prefill_next_token_indices}," f"prefill_cu_outlens={self.prefill_cu_outlens}," f"input_lengths={self.input_lengths}," f"input_lengths_tensor={self.input_lengths_tensor}," f"prefix_offsets={self.prefix_offsets}," f"read_offsets={self.read_offsets}," f"all_input_ids_tensor={self.all_input_ids_tensor}," f"next_token_chooser={self.next_token_chooser}," f"stopping_criterias={self.stopping_criterias}," f"blocks={self.blocks}," f"max_blocks={self.max_blocks}") def __len__(self): return len(self.requests) def __del__(self): if self.block_tables is not None and self.block_tables: global CACHE_MANAGER # Free blocks get_cache_manager().free( list(itertools.chain.from_iterable(self.block_tables)) ) @classmethod def batch_tokenized_inputs(cls, requests, tokenizer): batch_inputs = [] max_truncation = 0 for r in requests: batch_inputs.append(r.inputs) max_truncation = max(max_truncation, r.truncate) batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)["input_ids"] return batch_tokenized_inputs @classmethod def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, dtype: torch.dtype, device: torch.device, ) -> "MindFlashCausalLMBatch": batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer) position_ids = [] cu_seqlen_prefill = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] input_lengths = [] prefix_offsets = [] read_offsets = [] all_input_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_length = 0 cumulative_max_length = 0 prefill_out_cumulative_length = 0 blocks = 0 max_seqlen = 0 max_length = 0 max_blocks = 0 # Parse batch for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i tokenized_input = tokenized_input[-r.truncate:] if ( tokenized_input[0] == tokenizer.bos_token_id and tokenized_input[1] == tokenizer.bos_token_id ): tokenized_input = tokenized_input[1:] input_length = len(tokenized_input) input_lengths.append(input_length) prefix_offsets.append(input_length - 5) read_offsets.append(input_length) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange(0, input_length, dtype=torch.int32) position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) # Paged attention # Remove one as the first token des not have a past total_tokens = input_length + max_new_tokens - 1 needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) blocks += needed_blocks needed_blocks_slots.append((needed_blocks, total_tokens)) start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64, ) slot_indices.append(request_slot_indices) all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 ) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: prefill_head_indices.append( torch.tensor( [cumulative_length + input_length - 1], dtype=torch.int64 ) ) prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 # Update cumulative_length += input_length cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, needed_blocks) max_length = max(max_length, input_length + max_new_tokens) # 2.构建后处理类 next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb( pb=next_token_chooser_parameters, dtype=dtype, device=device ) start_slots = torch.tensor(start_slots, dtype=torch.int64) # 3.构建传给后处理类的input id tensor all_input_ids_tensor = np.zeros( (len(all_input_ids), max_length), dtype=np.int64 ) for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids # Create tensors on device all_input_ids_tensor = torch.tensor( all_input_ids_tensor, dtype=torch.int64, device=device ) if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) else: input_ids = all_input_ids[0] position_ids = position_ids[0] slot_indices = slot_indices[0] cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int64 ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_lengths_tensor = torch.tensor( input_lengths, dtype=torch.int64, device=device ) if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: prefill_head_indices = cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( torch.cat(prefill_head_indices), dtype=torch.int64, device=device ) prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, block_tables=None, block_tables_tensor=None, slots=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, blocks=blocks, max_blocks=max_blocks, ) @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["MindFlashCausalLMBatch"]) -> "MindFlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} blocks = 0 total_batch_size = 0 total_slots = 0 max_blocks = 0 max_length = 0 max_seqlen = 0 for b in batches: total_batch_size += len(b) total_slots += len(b.slots) blocks += b.blocks max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) max_length = max( max_length, max( input_length + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens for input_length, stopping_criteria in zip( b.input_lengths, b.stopping_criterias ) ), ) input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) start_slots = [] block_tables = [] all_input_ids = [] input_lengths = [] prefix_offsets = [] read_offsets = [] next_token_chooser_parameters = [] stopping_criterias = [] # Cumulative length cumulative_batch_size = 0 cumulative_slots = 0 for i, batch in enumerate(batches): requests.extend(batch.requests) if i == 0: requests_idx_mapping = batch.requests_idx_mapping else: # We need to offset the mapping for each batch by the cumulative batch size for k, v in batch.requests_idx_mapping.items(): requests_idx_mapping[k] = v + cumulative_batch_size start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] start_slots.append(batch.start_slots + cumulative_slots) block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) # Update cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) start_slots = torch.concat(start_slots) next_token_chooser = MindIELLMHeterogeneousNextTokenChooser.from_pb( pb=next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, device=batches[0].next_token_chooser.device, ) # Needed to avoid dropping blocks when the batches will go out of scope for b in batches: b.block_tables = None del b return cls( batch_id=batches[0].batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, blocks=blocks, max_blocks=max_blocks, ) @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "MindFlashCausalLMBatch": if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same if len(request_ids) == len(self): return self device = self.input_ids.device # New values after filtering requests_idx_mapping = {} # Used to index into tensors indices = [] # slots to keep after filtering slot_filtering_indices = torch.zeros( self.slots.shape[0], dtype=torch.bool, device=device ) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) max_seqlen = 0 requests = [] start_slots = [] block_tables = [] all_input_ids = [] input_lengths = [] prefix_offsets = [] read_offsets = [] stopping_criterias = [] blocks = 0 max_blocks = 0 # Cumulative length cumulative_max_length = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) requests_idx_mapping[request_id] = i requests.append(self.requests[idx]) # Get length request_input_length = self.input_lengths[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) input_lengths.append(request_input_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) request_block_table = self.block_tables[idx] blocks += len(request_block_table) block_tables.append(request_block_table) start_slots.append(cumulative_max_length) # Copy to tensor (CPU) slot_indices[i] = cumulative_max_length + request_input_length - 1 # Set slice slot_filtering_indices[ self.start_slots[idx] : self.start_slots[idx] + request_input_length + remaining_tokens - 1 ] = True cumulative_max_length += request_input_length + remaining_tokens - 1 max_blocks = max(max_blocks, len(request_block_table)) block_indices_to_free = [] # Iterate on all requests for i, r in enumerate(self.requests): # Filter requests that are not part of the new batch if r.id not in requests_idx_mapping.keys(): block_indices_to_free.extend(self.block_tables[i]) # Free blocks get_cache_manager().free(block_indices_to_free) # Needed to avoid dropping blocks when the batches will go out of scope self.block_tables = None # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) start_slots = torch.tensor(start_slots, dtype=torch.int64) # Move to GPU now that we have the whole tensor slot_indices = slot_indices.to(device) return MindFlashCausalLMBatch( batch_id=self.batch_id, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=None, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, blocks=blocks, max_blocks=max_blocks, ) def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.blocks * BLOCK_SIZE, ) class MindModel(FlashCausalLM): def __init__( self, model_id: str, trust_remote_code: bool ): logger.warning("Initialize mindie-llm model.") rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) model_config = { 'backend_type': BackendType.ATB, 'rank': rank, 'world_size': world_size, 'model_id': model_id, 'num_threads': 8, 'local_rank': rank, 'npu_device_id': rank, 'trust_remote_code': trust_remote_code } # 1. 初始化mindie llm模型,获得的self.model_runner包含模型信息 self.model_runner = GeneratorTorch(model_config) super(MindModel, self).__init__( model=self.model_runner.model_wrapper.model_runner.model, tokenizer=self.model_runner.tokenizer, num_layers=self.model_runner.model_info.num_layers, num_kv_heads=self.model_runner.model_info.num_kv_heads, head_size=self.model_runner.model_info.head_size, dtype=self.model_runner.model_info.dtype, device=self.model_runner.model_info.device, rank=self.model_runner.rank, world_size=self.model_runner.world_size, ) logger.warning("MindModel from tgi_npu initialized.") def __del__(self): del self.model_runner.model_wrapper @property def batch_type(self) -> Type[MindFlashCausalLMBatch]: return MindFlashCausalLMBatch def forward( self, batch: MindFlashCausalLMBatch ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_ids = batch.input_ids position_ids = batch.position_ids cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = get_cache_manager().kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices # 2. 为每个batch请求生成token # 改用MindModel初始化生成的self.model_runner的forward_tensor推理接口(链接)进行推理,输入为MindFlashCausalLMBatch类(链接)中相应字段 return self.model_runner.forward_tensor( input_ids=input_ids, position_ids=position_ids, is_prefill=cu_seqlen_prefill is not None, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_seq_len=max_s, lm_head_indices=lm_head_indices, ) @tracer.start_as_current_span("generate_token") def generate_token( self, batch: MindFlashCausalLMBatch ) -> Tuple[List[Generation], Optional[MindFlashCausalLMBatch]]: prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None if batch.needed_blocks_slots: # Allocate blocks to this batch block_tables, block_tables_tensor, slots = get_cache_manager().allocate( batch.needed_blocks_slots, batch.blocks, batch.max_blocks, batch.input_ids.device, ) batch.needed_blocks_slots = None batch.block_tables = block_tables batch.block_tables_tensor = block_tables_tensor batch.slots = slots try: out = self.forward(batch) except Exception as e: del batch raise e if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) else: logger.debug(f"Decode batch size {batch.input_ids.shape[0]}") next_token_logits = out # 3. 后处理部分,输入参数适配MindIELLMHeterogeneousNextTokenChooser类中的采样方法 request_ids = [req.id for req in batch.requests] next_input_ids, next_token_logprobs = batch.next_token_chooser( request_ids, prefill, batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits ) if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids # Cumulative length cumulative_length = 0 # Results generations: List[Generation] = [] stopped = True # Zipped iterator iterator = zip( batch.input_lengths, batch.all_input_ids, ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time # For each member of the batch for i, ( input_length, all_input_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length if prefill: # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index # Initialize position_ids # In decode, we do not need this as we can just increment position ids next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: if len(batch) > 1: prefill_tokens_indices[ out_start_index: out_end_index - 1 ] = batch.input_ids[start_index + 1: start_index + out_length] else: # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = batch.input_ids[ start_index + 1: start_index + out_length ] batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] cumulative_length += input_length # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 batch.input_lengths_tensor += 1 batch.slot_indices += 1 if prefill and prefill_logprobs: # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) ) # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() # GPU <-> CPU sync next_token_logprobs = next_token_logprobs next_token_ids = batch.input_ids.tolist() # Zipped iterator iterator = zip( batch.requests, batch.input_lengths, batch.prefix_offsets, batch.read_offsets, batch.stopping_criterias, batch.all_input_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, next_token_ids, next_token_logprobs, ) # For each member of the batch for i, ( request, input_length, prefix_offset, read_offset, stopping_criteria, all_input_ids, do_sample, seed, next_token_id, next_token_logprob, ) in enumerate(iterator): # Append next token to all tokens all_input_ids.append(next_token_id) # Generated token next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, prefix_offset, read_offset, ) # Evaluate stopping criteria stop, reason = stopping_criteria( next_token_id, next_token_text, ) if not stop: stopped = False # Shard generations # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens # decode & decode_token output_text = self.decode( all_input_ids[-stopping_criteria.current_tokens:] ) generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, reason, seed if do_sample else None, ) else: generated_text = None # Prefill if prefill and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] # Remove generated token to only have prefill and add nan for first prompt token request_prefill_logprobs = [float("nan")] + prefill_logprobs[ out_start_index: out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, clean_up_tokenization_spaces=False, skip_special_tokens=False, ) # PrefillTokens & Tokens prefill_tokens = PrefillTokens( prefill_token_ids, request_prefill_logprobs, prefill_texts ) else: prefill_tokens = None generation = Generation( request.id, prefill_tokens, next_token_id, next_token_logprob, next_token_text, next_token_id in self.all_special_ids, generated_text, ) generations.append(generation) # Update values batch.input_lengths[i] = input_length + 1 batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids if stopped: del batch # No need to return a batch if we know that all requests stopped return_value = None return generations, return_value batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 return generations, batch # 需要适配的部分是获取Warmup Batch计算时峰值NPU显存占用、NPU卡最大显存、以及计算出装满剩余显存的最大Token数量,峰值时Token数量的计算。 def warmup(self, batch: MindFlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive empty_cache() peak_memory = torch_npu.npu.max_memory_allocated() logger.warning(f">>>>before warmup peak_memory {peak_memory}") try: cache_manager = set_cache_manager( batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, False, self.dtype, self.device, ) _, batch = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"You need to decrease `--max-batch-prefill-tokens`" ) from e # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory dtype_size = torch.tensor([], dtype=self.dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size torch_npu.npu.synchronize() # NPU卡总显存 total_gpu_memory = torch_npu.npu.get_device_properties(self.device).total_memory # 峰值显存 peak_memory = torch_npu.npu.max_memory_allocated() logger.info( f">>>>dtype_size {dtype_size}, cache_block_size {cache_block_size}, num_kv_heads {self.num_kv_heads}, " f"total_cache_size {total_cache_size}, peak_memory {peak_memory}") # 剩余显存 total_free_memory = total_gpu_memory - peak_memory logger.info(f">>>>total_free_memory {total_free_memory}, total_gpu_memory {total_gpu_memory}, " f"MEMORY_FRACTION {MEMORY_FRACTION}") # 剩余可用显存 free_memory = max( 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory ) # 总共可支持的KV Block num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.blocks as we allocated it above, so it is included in the peak memory. + cache_manager.num_blocks ) del batch del cache_manager set_cache_manager( num_blocks, self.num_layers, self.num_kv_heads, self.head_size, # self.sliding_window is not None, False, self.dtype, self.device, ) logger.warning(f">>>>real CacheManger {get_cache_manager()}") peak_memory = torch_npu.npu.max_memory_allocated() logger.warning(f">>>>end warmup peak_memory {peak_memory}") logger.warning(f"Warmup return {int(num_blocks * BLOCK_SIZE)}") #总共可支持的KV slot(total batch tokens),BLOCK_SIZE建议128 return int(num_blocks * BLOCK_SIZE)
- MindIE模型与请求Batch类。
- Tgi-MindIE/tgi_npu/tokens_mindie.py:增加后处理类MindIELLMHeterogeneousNextTokenChooser:后处理类根据Batch中每个Request的采样参数,为每个Batch构造一个后处理对象。Sampling为从推理生成的logits中采样出token(s)。
# Part of codes in this file was copied from project[huggingface][text-generation-inference] #!/usr/bin/env python3 # coding=utf-8 import os from typing import List import torch from loguru import logger from text_generation_server.pb import generate_pb2 from mindie_llm.text_generator.utils.sampling_metadata import SamplingData, SamplingParam from mindie_llm.modeling.backend_type import BackendType from mindie_llm.text_generator.utils.config import SamplerConfig from mindie_llm.text_generator.samplers.sampler import Sampler import numpy as np WRAPPER_KEY = "tensor_wrapper" try: from mindie_llm.text_generator.utils.sampling_metadata import TensorWrapper except ImportError: class TensorWrapper: def __init__(self, backend, device): self.device = device self.backend = backend def __call__(self, data): if data.dtype == np.int32: dtype = torch.int32 elif data.dtype == np.bool_: dtype = torch.bool else: dtype = None return torch.tensor(data, dtype=dtype, device=self.device) WRAPPER_KEY = "to_tensor" def do_filter(sample_param: List, indices): if any(sample_param): return [sample_param[i] for i in indices] return sample_param class MindIELLMHeterogeneousNextTokenChooser: def __init__(self, dtype: torch.dtype, device: torch.device, watermark: List[bool], temperature: List[float], repetition_penalty: List[float], frequency_penalty: List[float], top_k: List[int], top_p: List[float], typical_p: List[float], do_sample: List[bool], seeds: List[int], grammars: List[str], grammar_types: List[int], fsm_grammar_states: List[int], sample_method, # mindie-llm sampling method ): if any([x != 0.0 for x in frequency_penalty]): logger.warning(f"Frequency_penalty is supported now in mindie-llm but not supported in TGI v0.9.4") if any(watermark): logger.warning(f"Watermark not supported now in mindie-llm") if any([x < 1.0 for x in typical_p]): logger.warning(f"Typical_p not supported now in mindie-llm") if any(grammar_types) or any(grammars) or any(fsm_grammar_states): logger.warning(f"Grammar not supported now in mindie-llm") # 1. 初始化函数,从grpc接收的后采样参数构造mindie sampler self.tensor_wrapper = TensorWrapper(BackendType.ATB, device) wrapper_dict = {WRAPPER_KEY: TensorWrapper(BackendType.ATB, device)} self.sample_params = SamplingParam.from_numpy( repetition_penalty=np.array(repetition_penalty, dtype=np.float16), presence_penalty=None, temperature=np.array(temperature, dtype=np.float16), top_k=np.array(top_k), top_p=np.array(top_p), seed=np.array(seeds).astype(np.int32), do_sample=np.array(do_sample), **wrapper_dict ) # Temp store for filter self.temperature = temperature self.repetition_penalty = repetition_penalty self.frequency_penalty = frequency_penalty self.top_k = top_k self.top_p = top_p self.seeds = seeds self.do_sample = do_sample self.choice = sample_method self.dtype = dtype self.device = device self.seeds = seeds self.do_sample = self.sample_params.do_sample_meta.do_sample_array self.fsm_grammar_states = fsm_grammar_states # 2. 根据输入token id(input_ids)及推理的logits(scores),采样出下一个tokenid(next_ids) def __call__(self, request_ids: List, is_prefill: bool, input_ids: torch.Tensor, scores: torch.Tensor, ): batch_size = scores.shape[0] speculate_size = 1 scores = scores.view(batch_size, speculate_size, -1) # Don't use SamplingData.from_numpy to avoid tensor.cpu() to transfer large data input_ids_int32 = input_ids.to(torch.int32) sample_data = SamplingData(all_input_ids=input_ids_int32, output_ids=input_ids_int32, is_prefill=is_prefill, request_ids=np.array(request_ids)) next_ids = torch.zeros((batch_size, speculate_size), device=scores.device, dtype=torch.long) for j in range(speculate_size): _scores = scores[:, j] batch_logits, _next_ids = self.choice(batch_logits=_scores, batch_sampling_data=sample_data, batch_sampling_params=self.sample_params) scores[:, j] = _scores next_ids[:, j] = torch.from_numpy(_next_ids) next_ids = next_ids.view(batch_size * speculate_size) allscores = scores.view(batch_size * speculate_size, -1) alllogprobs = torch.log_softmax(allscores, -1) logprobs = alllogprobs next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) return next_ids, next_logprobs @classmethod def from_pb( cls, pb: List[generate_pb2.NextTokenChooserParameters], dtype: torch.dtype, device: torch.device, ) -> "MindIELLMHeterogeneousNextTokenChooser": curr_rank = int(os.getenv("RANK", "0")) sample_method = Sampler(SamplerConfig( rank=curr_rank, backend_type=BackendType.ATB, npu_id=curr_rank )) return MindIELLMHeterogeneousNextTokenChooser( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], frequency_penalty=[], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb], do_sample=[pb_.do_sample for pb_ in pb], seeds=[pb_.seed for pb_ in pb], device=device, dtype=dtype, grammars=[], grammar_types=[], fsm_grammar_states=[], sample_method=sample_method ) def filter(self, indices): self.repetition_penalty = do_filter(self.repetition_penalty, indices) self.frequency_penalty = do_filter(self.frequency_penalty, indices) self.temperature = do_filter(self.temperature, indices) self.top_k = do_filter(self.top_k, indices) self.top_p = do_filter(self.top_p, indices) self.seeds = do_filter(self.seeds, indices) self.do_sample = do_filter(self.do_sample, indices) self.sample_params = SamplingParam.from_numpy( repetition_penalty=np.array(self.repetition_penalty, dtype=np.float16), presence_penalty=None, frequency_penalty=np.array(self.frequency_penalty, dtype=np.float16), temperature=np.array(self.temperature, dtype=np.float16), top_k=np.array(self.top_k), top_p=np.array(self.top_p), seed=np.array(self.seeds).astype(np.int32), do_sample=np.array(self.do_sample), **self.wrapper_dict ) return self
- Tgi-MindIE/pyproject.toml
[tool.poetry] name = "tgi-npu" version = "0.1.0" description = "NPU MindIE Adapter for TGI v0.9.4" authors = ["Your Name <you@example.com>"] exclude = ["router", "cover"] [tool.poetry.dependencies] python = "^3.10" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"