forward_tensor接口

接口功能

模型推理接口,采用Page Attention算子实现高性能推理,依赖上层调度模块实现对模型KV Cache管理能力,支持Continuous Batching。

接口实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def forward_tensor(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        is_prefill: bool,
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
        max_seq_len: int,
        lm_head_indices: Optional[torch.Tensor] = None,
        **kwargs
    ):
        """Call the `forward_tensor` of `model_wrapper`."""
        adapter_ids = kwargs.get("adapter_ids")
        batch_size = input_lengths.shape[0]
        if adapter_ids is not None and len(adapter_ids) > batch_size:
            message = "The length of `adapter_ids` should not be larger than batch size."
            logger.error(message, ErrorCode.TEXT_GENERATOR_INTERNAL_ERROR)
            raise ValueError(message)
        logits = self.model_wrapper.forward_tensor(
            input_ids=input_ids,
            position_ids=position_ids,
            is_prefill=is_prefill,
            kv_cache=kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
            max_seq_len=max_seq_len,
            lm_head_indices=lm_head_indices,
            **kwargs,
        )
        return logits

参数说明

参数名称

是否必选

类型

默认值

描述

安全声明

input_ids

必选

torch.Tensor

-

输入经过tokenizer后,每个token在词表中的索引。

推理强依赖数据的合法性,需由用户保证。

position_ids

必选

torch.Tensor

-

Token的位置索引。

is_prefill

必选

bool

-

是否为推理首token阶段。

kv_cache

必选

List[Tuple[torch.Tensor, torch.Tensor]]

-

KV缓存。

block_tables

必选

torch.Tensor

-

存储每个token和它所使用的KV Cache块之间的映射。

slots

必选

torch.Tensor

-

存储每个KV Cache块中实际使用的slots的序号。

input_lengths

必选

torch.Tensor

-

batch中每个query的长度。(输入+当前输出的长度。)

max_seq_len

必选

int

-

模型支持的最大上下文长度。(即最大可支持的输入+输出的长度。)

lm_head_indices

可选

Optional[torch.Tensor]

None

设置此值可根据索引选择性地输出logits。

**kwargs中的q_lens

可选

List[int]

[]

在并行解码场景下,单次decode增量输入的token长度。

**kwargs中的spec_mask

可选

torch.Tensor

None

并行解码场景下生成的mask。

**kwargs中的atten_mask

可选

torch.Tensor

None

attention计算时需要的mask。

**kwargs中的adapter_ids

可选

None | List[str | None]

-

lora权重名称的列表。

**kwargs中的max_out_len

可选

int

256

最大输出序列长度。