昇腾社区首页
中文
注册

类功能

功能描述

辅助用户生成模型评估数据集的类,给用户提供原始文档数据集,基于文本生成评估数据集。用户需对生成的评估集进行人工筛选,挑选出符合该领域特征的问答对,才能较好的评估模型在该领域的精度。

函数原型

from mx_rag.tools.finetune.generator.eval_data_generator import EvalDataGenerator
EvalDataGenerator(llm: Text2TextLLM, dataset_path: str, encrypt_fn, decrypt_fn)

输入参数说明

参数名

数据类型

可选/必选

说明

llm

Text2TextLLM

必选

用于生成评估数据集的大模型,详情请参考Text2TextLLM类

dataset_path

str

必选

评估数据集文件存储目录, 路径长度取值范围为[1,1024]。路径不能包含软链接且不允许存在".."。

encrypt_fn

Callable[[str], str]

可选

对生成的Q-D对进行加密存储,默认为None,即不加密处理

须知:

如果上传的文档涉及银行卡号、身份证号、护照号、口令等个人数据,请配置该参数保证个人数据安全。

decrypt_fn

Callable[[str], str]

可选

对已存储的Q-D对进行解密处理,默认为None。

调用示例

from paddle.base import libpaddle
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from mx_rag.document import LoaderMng
from mx_rag.document.loader import DocxLoader
from mx_rag.llm import Text2TextLLM
from mx_rag.tools.finetune.generator.eval_data_generator import EvalDataGenerator
from mx_rag.utils import ClientParam

llm = Text2TextLLM(model_name="Llama3-8B-Chinese-Chat", base_url="https://{ip}:{port}/v1/chat/completions", 
client_param=ClientParam(ca_file="/path/to/ca.crt")
)

dataset_path = "path to data_output"  # 微调合成数据集的输出地址

document_path = "path to document dir"  # 用户提供的原始文档集所在地址

eval_data_generator = EvalDataGenerator(llm, dataset_path)

loader_mng = LoaderMng()

loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])

# 加载文档切分器,使用langchain的
loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter,
                             file_types=[".docx", ".txt", ".md"],
                             splitter_params={"chunk_size": 750,
                                              "chunk_overlap": 150,
                                              "keep_separator": False
                                              }
                             )

split_doc_list = eval_data_generator.generate_origin_document(document_path=document_path, loader_mng=loader_mng)

eval_data_generator.generate_evaluate_data(split_doc_list)