昇腾社区首页
中文
注册

forward接口

接口功能

第三方服务调用调用此接口执行模型推理。

接口实现

def forward(self, model_inputs, **kwargs):
        logits = self.model_wrapper.forward(model_inputs, self.cache_pool.npu_cache, **kwargs)
        return logits

参数说明

参数名称

是否必选

类型

默认值

描述

安全声明

model_inputs

必选

ModelInput

-

模型的输入参数。

推理强依赖数据的合法性,需由用户保证。

**kwargs中的q_lens

可选

List[int]

-

在并行解码场景下,单次decode增量输入的token长度。

**kwargs中的spec_mask

可选

torch.Tensor

-

并行解码场景下生成的mask。