forward_tensor接口
接口功能
模型推理接口,采用Page Attention算子实现高性能推理,依赖上层调度模块实现对模型KV Cache管理能力,支持Continuous Batching。
接口实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33  | def forward_tensor( self, input_ids: torch.Tensor, position_ids: torch.Tensor, is_prefill: bool, kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, input_lengths: torch.Tensor, max_seq_len: int, lm_head_indices: Optional[torch.Tensor] = None, **kwargs ): """Call the `forward_tensor` of `model_wrapper`.""" adapter_ids = kwargs.get("adapter_ids") batch_size = input_lengths.shape[0] if adapter_ids is not None and len(adapter_ids) > batch_size: message = "The length of `adapter_ids` should not be larger than batch size." logger.error(message, ErrorCode.TEXT_GENERATOR_INTERNAL_ERROR) raise ValueError(message) logits = self.model_wrapper.forward_tensor( input_ids=input_ids, position_ids=position_ids, is_prefill=is_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, input_lengths=input_lengths, max_seq_len=max_seq_len, lm_head_indices=lm_head_indices, **kwargs, ) return logits  | 
参数说明
参数名称  | 
是否必选  | 
类型  | 
默认值  | 
描述  | 
安全声明  | 
|---|---|---|---|---|---|
input_ids  | 
必选  | 
torch.Tensor  | 
-  | 
输入经过tokenizer后,每个token在词表中的索引。  | 
推理强依赖数据的合法性,需由用户保证。  | 
position_ids  | 
必选  | 
torch.Tensor  | 
-  | 
Token的位置索引。  | 
|
is_prefill  | 
必选  | 
bool  | 
-  | 
是否为推理首token阶段。  | 
|
kv_cache  | 
必选  | 
List[Tuple[torch.Tensor, torch.Tensor]]  | 
-  | 
KV缓存。  | 
|
block_tables  | 
必选  | 
torch.Tensor  | 
-  | 
存储每个token和它所使用的KV Cache块之间的映射。  | 
|
slots  | 
必选  | 
torch.Tensor  | 
-  | 
存储每个KV Cache块中实际使用的slots的序号。  | 
|
input_lengths  | 
必选  | 
torch.Tensor  | 
-  | 
batch中每个query的长度。(输入+当前输出的长度。)  | 
|
max_seq_len  | 
必选  | 
int  | 
-  | 
模型支持的最大上下文长度。(即最大可支持的输入+输出的长度。)  | 
|
lm_head_indices  | 
可选  | 
Optional[torch.Tensor]  | 
None  | 
设置此值可根据索引选择性地输出logits。  | 
|
**kwargs中的q_lens  | 
可选  | 
List[int]  | 
[]  | 
在并行解码场景下,单次decode增量输入的token长度。  | 
|
**kwargs中的spec_mask  | 
可选  | 
torch.Tensor  | 
None  | 
并行解码场景下生成的mask。  | 
|
**kwargs中的atten_mask  | 
可选  | 
torch.Tensor  | 
None  | 
attention计算时需要的mask。  | 
|
**kwargs中的adapter_ids  | 
可选  | 
None | List[str | None]  | 
-  | 
lora权重名称的列表。  | 
|
**kwargs中的max_out_len  | 
可选  | 
int  | 
256  | 
最大输出序列长度。  |