forward接口

接口功能

第三方服务调用此接口执行模型推理。

接口实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def forward(self, model_inputs: ModelInput, **kwargs):
        """
        Preprocess the inputs involving multi-lora, and pass the processed inputs to the model
        wrapper for forward inference.
        """
        # sort input by adapter_ids
        do_reorder_requests = False
        revert_adapter_idx = []
        adapter_ids = model_inputs.adapter_ids
        if adapter_ids is not None and self.adapter_manager is not None:
            do_reorder_requests, revert_adapter_idx = self._sort_model_inputs_by_adapter_ids(model_inputs)
        logits = self.model_wrapper.forward(model_inputs, self.cache_pool.npu_cache, **kwargs)
        # sort logits back to the original order (related to lm_head_indices)
        if do_reorder_requests:
            logits = reorder_tensor(logits, revert_adapter_idx)
        return logits

参数说明

参数名称

是否必选

类型

默认值

描述

安全声明

model_inputs

必选

ModelInput

-

模型的输入参数。

推理强依赖数据的合法性,需由用户保证。

**kwargs中的q_lens

可选

List[int]

[]

在并行解码场景下,单次decode增量输入的token长度。

**kwargs中的spec_mask

可选

torch.Tensor

None

并行解码场景下生成的mask。

**kwargs中的atten_mask

可选

torch.Tensor

None

attention计算时需要的mask。

**kwargs中的max_out_len

可选

int

256

最大输出序列长度。

补充说明

forward接口中的model_inputs参数通过ModelInput类进行构造。

文件路径:“mindie_llm/text_generator/utils/input.py”

ModelInput类的初始化参数说明如下:

参数名称

是否必选

类型

默认值

描述

安全声明

input_ids

必选

np.ndarray

-

输入经过tokenizer后,每个token在词表中的索引。

推理强依赖数据的合法性,需由用户保证。

position_ids

必选

np.ndarray或None

-

Token的位置索引。

is_prefill

必选

bool

-

是否为推理首token阶段。

block_tables

必选

np.ndarray

-

存储每个token和它所使用的KV Cache块之间的映射。

slots

必选

np.ndarray

-

存储每个KV Cache块中实际使用的slots的序号。

context_length

必选

np.ndarray或List[int]

-

batch中每个query的长度。(输入+当前输出的长度。)

max_seq_len

必选

int

-

模型支持的最大上下文长度。(即最大可支持的输入+输出的长度。)

prefill_head_indices

必选

np.ndarray或None

-

设置此值可根据索引选择性地输出logits。

query_length

可选

Optional[np.ndarray]

None

请求的长度。

adapter_ids

可选

List[str]

-

lora权重名称的列表。