forward_tensor接口
接口功能
模型推理接口,采用PageAttention算子实现高性能推理,依赖上层调度模块实现对模型KVCache管理能力,支持continuous batching。
接口实现
def forward_tensor(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
is_prefill: bool,
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_seq_len: int,
lm_head_indices: Optional[torch.Tensor] = None,
**kwargs):
logits = self.model_wrapper.forward_tensor(
input_ids=input_ids,
position_ids=position_ids,
is_prefill=is_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_seq_len=max_seq_len,
lm_head_indices=lm_head_indices,
**kwargs,
)
return logits
参数说明
参数名称 |
是否必选 |
类型 |
默认值 |
描述 |
安全声明 |
|---|---|---|---|---|---|
input_ids |
必选 |
torch.Tensor |
- |
输入经过tokenizer后,每个token在词表中的索引。 |
推理强依赖数据的合法性,需由用户保证。 |
position_ids |
必选 |
torch.Tensor |
- |
Token的位置索引。 |
|
is_prefill |
必选 |
bool |
- |
是否为推理首token阶段。 |
|
kv_cache |
必选 |
List[Tuple[torch.Tensor, torch.Tensor]] |
- |
KV 缓存。 |
|
block_tables |
必选 |
torch.Tensor |
- |
存储每个token和它所使用的kv cache块之间的映射。 |
|
slots |
必选 |
torch.Tensor |
- |
存储每个kv cache块中实际使用的slots的序号。 |
|
input_lengths |
必选 |
torch.Tensor |
- |
batch中每个query的长度。(输入+当前输出的长度) |
|
max_seq_len |
必选 |
int |
- |
模型支持的最大上下文长度。(即最大可支持的输入+输出的长度) |
|
lm_head_indices |
可选 |
Optional[torch.Tensor] |
None |
设置此值可根据索引选择性地输出logits。 |
|
**kwargs中的q_lens |
可选 |
List[int] |
- |
在并行解码场景下,单次decode增量输入的token长度。 |
|
**kwargs中的spec_mask |
可选 |
torch.Tensor |
- |
并行解码场景下生成的mask。 |