sample接口

接口功能

此接口对推理输出的logitis进行后处理。返回的第一个参数“next_tokens”表示被选中的token id。第二个参数“logits_or_logprobs”表示该批次被选中token的logits值或概率对数logprobs。其中,logits是在贪心搜索场景下的返回值,而logprobs是在随机采样场景下的返回值。若输入批次的所有请求均为贪心搜索,则该数值为None。

接口实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def sample(
        self,
        logits: Any,
        sampling_data: SamplingData,
        sampling_param: SamplingParam
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Call the sampler of mindie-llm.
        This method samples from the input logits based on the post-processing parameters and selects the token ids.
        Args:
            logits: A Tensor of corresponding backend. It is the output obtained from the model's forward propagation.
            sampling_data: Some sampling metadata like input and out token ids.
            sampling_param: Some sampling parameters like penalty, temperature, top_k and top_p, etc.
        Returns:
            next_tokens: A numpy array of the next tokens' ids.
            logits_or_logprobs: A numpy array of the logits (for greedy search) or the log-probability (for multinomial
                sampling) of the chosen tokens. Note that it would be a `None` if all sequences in the batch do greedy
                search.
        """
        logits_or_logprobs, next_tokens = self.sampler(logits, sampling_data, sampling_param)
        return next_tokens, logits_or_logprobs

参数说明

参数名称

是否必选

类型

默认值

描述

logits

Any

-

对应后端类型的Tensor,当前支持torch.Tensor或mindspore.Tensor。未经过softmax函数处理的网络输出结果,通常表示每个类别的得分或概率。

sampling_data

SamplingData

None

SamplingData类实例,用于传入batch后处理元数据。通过SamplingData类方法。SamplingData.from_numpy可以创建类实例sampling_data。

sampling_param

SamplingParam

None

SamplingParam类实例,用于传入batch后处理参数。通过SamplingParam类方法。SamplingParam.from_numpy可以创建类实例sampling_param。

补充说明

sampling_data通过SamplingData.from_numpy进行构造。

文件路径:“mindie_llm/text_generator/utils/sampling_metadata.py”

表1 SamplingData.from_numpy接口参数说明

参数名称

是否必选

类型

默认值

描述

all_input_ids

np.ndarray

None

二维int数组,包含每个请求所有输入加输出的token id。用于repetition_penalty计算。

output_ids

np.ndarray

None

二维int数组,包含每个请求所有输出的token id。用于frequency_penalty和presence_penalty计算。

to_tensor

callable

None

to_tensor方法,用于将np.ndarray生成不同后端的张量。若传入的all_input_ids或output_ids不为None,则本项必选。

is_prefill

bool

True

用于判断是否在生成第一个token。本项与request_ids成对传入,可触发缓存机制,提升性能。

request_ids

np.ndarray

None

一维int数组,批处理中所有请求的唯一标识。本项与is_prefill成对传入,可触发缓存机制,提升性能。

sampling_param通过SamplingParam.from_numpy进行构造。

文件路径:“mindie_llm/text_generator/utils/sampling_metadata.py”

表2 SamplingParam.from_numpy接口参数说明

参数名称

是否必选

类型

默认值

描述

repetition_penalty

np.ndarray

None

一维float数组,对应批处理中每个请求的重复惩罚。

frequency_penalty

np.ndarray

None

一维float数组,对应批处理中每个请求的频率惩罚。

presence_penalty

np.ndarray

None

一维float数组,对应批处理中每个请求的存在惩罚。

temperature

np.ndarray

None

一维float数组,对应批处理中每个请求的温度。

top_k

np.ndarray

None

一维int数组,对应批处理中每个请求在采样时依次从最高概率选项中选择的数量。

top_p

np.ndarray

None

一维float数组,对应批处理中每个请求在采样时的累加概率阈值。

seed

np.ndarray

None

一维int数组,对应批处理中每个请求在采样时的随机种子。

do_sample

np.ndarray

None

一维bool数组,对应批处理中每个请求是否做采样。

to_tensor

callable

None

to_tensor方法,用于将np.ndarray生成不同后端的张量。若传入的任意一个参数不为None,则本项必选。

后处理参数数组元素数值说明:

针对单个request而言,目前支持惩罚参数和采样参数两类后处理参数,如表3表4所示。

表3 惩罚参数

参数名称

类型

取值要求

描述

repetition_penalty

float

> 0,建议不超过2

重复惩罚的参数,对输入输出中已存在的token施加除法级惩罚,1.0表示没有惩罚。

frequency_penalty

float

负数表示奖励,建议 > 0

频率惩罚,根据输出中已存在的token的出现频率施加减法级惩罚,0.0表示没有惩罚。

presence_penalty

float

负数表示奖励,建议 > 0

存在惩罚,对输出中已存在的token施加减法级惩罚,0.0表示没有惩罚。

若模型路径下的config.json和generate_config.json中都没有配置pad_token_id时,需要手动添加pad_token_id。

可配置范围:[-1, vocab_size],建议值:vocab_size。

表4 采样参数

参数名称

类型

取值要求

描述

temperature

float

> 0,建议不超过2

控制生成文本的随机性,值越高文本越随机。

top_k

int

0 < top_k < 词表长度

限制每次生成时候概率最高的k个数量。

top_p

float

0.0 < top_p <= 1.0

选择概率总和达到p的所有选项,用于控制生成的多样性。

seed

int

>= 0

设置随机数种子,以确保结果的可重复性。

do_sample

bool

布尔值

决定是否使用抽样策略生成文本,而非选择概率最高的选项。

设置为“False”时,若存在采样参数,则自动变为“True”。