sample接口
接口功能
此接口对推理输出的logitis进行后处理。
接口实现
def sample(self, logits, sampling_data, sampling_param): _, next_tokens = self.sampler(logits, sampling_data, sampling_param) return next_tokens
参数说明
参数名称 |
是否必选 |
类型 |
默认值 |
描述 |
---|---|---|---|---|
logits |
是 |
Any |
- |
对应后端类型的Tensor,当前支持torch.Tensor或mindspore.Tensor。未经过softmax函数处理的网络输出结果,通常表示每个类别的得分或概率。 |
sampling_data |
否 |
SamplingData |
None |
SamplingData类实例,用于传入batch后处理元数据。通过SamplingData类方法SamplingData.from_numpy可以创建类实例sampling_data。 |
sampling_param |
否 |
SamplingParam |
None |
SamplingParam类实例,用于传入batch后处理参数。通过SamplingParam类方法SamplingParam.from_numpy可以创建类实例sampling_param。 |
补充说明
sampling_data通过SamplingData.from_numpy进行构造。
文件路径:mindie_llm/text_generator/utils/sampling_metadata.py。
参数名称 |
是否必选 |
类型 |
默认值 |
描述 |
---|---|---|---|---|
all_input_ids |
否 |
np.ndarray |
None |
二维int数组,包含每个请求所有输入加输出的token id。用于repetition_penalty计算。 |
output_ids |
否 |
np.ndarray |
None |
二维int数组,包含每个请求所有输出的token id。用于frequency_penalty和presence_penalty计算。 |
to_tensor |
否 |
callable |
None |
to_tensor方法,用于将np.ndarray生成不同后端的张量。若传入的all_input_ids或output_ids不为None,则本项必选。 |
is_prefill |
否 |
bool |
True |
用于判断是否在生成第一个token。本项与request_ids成对传入,可触发缓存机制,提升性能。 |
request_ids |
否 |
np.ndarray |
None |
一维int数组,批处理中所有请求的唯一标识。本项与is_prefill成对传入,可触发缓存机制,提升性能。 |
sampling_param通过SamplingParam.from_numpy进行构造。
文件路径:mindie_llm/text_generator/utils/sampling_metadata.py。
参数名称 |
是否必选 |
类型 |
默认值 |
描述 |
---|---|---|---|---|
repetition_penalty |
否 |
np.ndarray |
None |
一维float数组,对应批处理中每个请求的重复惩罚。 |
frequency_penalty |
否 |
np.ndarray |
None |
一维float数组,对应批处理中每个请求的频率惩罚。 |
presence_penalty |
否 |
np.ndarray |
None |
一维float数组,对应批处理中每个请求的存在惩罚。 |
temperature |
否 |
np.ndarray |
None |
一维float数组,对应批处理中每个请求的温度。 |
top_k |
否 |
np.ndarray |
None |
一维int数组,对应批处理中每个请求在采样时依次从最高概率选项中选择的数量。 |
top_p |
否 |
np.ndarray |
None |
一维float数组,对应批处理中每个请求在采样时的累加概率阈值。 |
seed |
否 |
np.ndarray |
None |
一维int数组,对应批处理中每个请求在采样时的随机种子。 |
do_sample |
否 |
np.ndarray |
None |
一维bool数组,对应批处理中每个请求是否做采样。 |
to_tensor |
否 |
callable |
None |
to_tensor方法,用于将np.ndarray生成不同后端的张量。若传入的任意一个参数不为None,则本项必选。 |
后处理参数数组元素数值说明:
针对单个request而言,目前支持惩罚参数和采样参数两类后处理参数。
参数名称 |
类型 |
取值要求 |
描述 |
---|---|---|---|
repetition_penalty |
float |
> 0,建议不超过2 |
重复惩罚的参数,对输入输出中已存在的token施加除法级惩罚,1.0表示没有惩罚。 |
frequency_penalty |
float |
负数表示奖励,建议 > 0 |
频率惩罚,根据输出中已存在的token的出现频率施加减法级惩罚,0.0表示没有惩罚。 |
presence_penalty |
float |
负数表示奖励,建议 > 0 |
存在惩罚,对输出中已存在的token施加减法级惩罚,0.0表示没有惩罚。 |
参数名称 |
类型 |
取值要求 |
描述 |
---|---|---|---|
temperature |
float |
> 0,建议不超过2 |
控制生成文本的随机性,值越高文本越随机。 |
top_k |
int |
0 < top_k < 词表长度 |
限制每次生成时候只考虑概率最高的k个选项。 |
top_p |
float |
0.0 < top_p <= 1.0 |
选择概率总和达到p的所有选项,用于控制生成的多样性。 |
seed |
int |
>= 0 |
设置随机数种子,以确保结果的可重复性。 |
do_sample |
bool |
布尔值 |
输入False时若存在采样参数,则自动变为True。决定是否使用抽样策略生成文本,而非选择概率最高的选项。 |