类功能
功能描述
进行微调合成数据自动生成的类,给用户提供原始文档集切分处理,基于切分后的文本生成及筛选微调合成数据。
函数原型
- 微调合成数据配置类:
from mx_rag.tools.finetune.generator import DataProcessConfig @dataclass class DataProcessConfig(): generate_qd_prompt: str = GENERATE_QD_PROMPT llm_preferred_prompt: str = SCORING_QD_PROMPT question_number: int = 3 featured: bool = True featured_percentage: float = 0.8 preferred: bool = True llm_threshold_score: float = 0.8 rewrite: bool = True query_rewrite_number: int = 2
- 微调合成数据方法类:
from mx_rag.tools.finetune.generator import TrainDataGenerator TrainDataGenerator(llm: Text2TextLLM, dataset_path: str, reranker: Reranker, encrypt_fn, decrypt_fn)
输入参数说明
微调合成数据配置类DataProcessConfig参数详情:
参数名 |
数据类型 |
可选/必选 |
说明 |
---|---|---|---|
generate_qd_prompt |
str |
可选 |
微调合成数据自动生成所用prompt,用户可根据相应领域进行修改,微调效果更佳。默认值为GENERATE_QD_PROMPT,字符串长度范围(0, 1*1024*1024] |
llm_preferred_prompt |
str |
可选 |
微调合成数据筛选过程所用prompt,用户可根据相应领域进行修改,微调效果更佳。默认值为SCORING_QD_PROMPT,字符串长度范围(0, 1*1024*1024] |
question_number |
int |
可选 |
每个原始文本切片对应生成的问题数,该数量越大,生成的问题角度越全面,有利于微调效果,但是耗时较长,默认值为3,取值范围(0, 20] |
featured |
bool |
可选 |
基于BM25+Reranker数据相关性评分融合筛选,默认值为True |
featured_percentage |
float |
可选 |
基于BM25+Reranker融合筛选后比例, 取值范围(0.0, 1.0),默认值为0.8。 |
preferred |
bool |
可选 |
基于LLM数据相关性评分筛选,默认值为True |
llm_threshold_score |
float |
可选 |
基于LLM数据相关性评分筛选后比例, 取值范围(0.0, 1.0),默认值为0.8。 |
rewrite |
bool |
可选 |
基于LLM对生成的数据进行语义多角度重写扩充,默认值为True |
query_rewrite_number |
int |
可选 |
针对每个问答对重写扩充的数量,默认值为2,取值范围(0, 20] |
GENERATE_QD_PROMPT和SCORING_QD_PROMPT定义如下:
GENERATE_QD_PROMPT = """阅读文章,生成一个相关的问题,例如: 文章:气候变化对海洋生态系统造成了严重的影响,其中包括海洋温度上升、海平面上升、酸化等问题。这些变化对海洋生物种群分布、生态圈的稳定性以及渔业等方面都产生了深远影响。在全球变暖的背景下,保护海洋生态系统已经成为当务之急。 问题:气候变化对海洋生态系统的影响主要体现在哪些方面? 文章:零售业是人工智能应用的另一个重要领域。通过数据分析和机器学习算法,零售商可以更好地了解消费者的购买行为、趋势和偏好。人工智能技术可以帮助零售商优化库存管理、推荐系统、市场营销等方面的工作,提高销售额和客户满意度。 问题:人工智能是如何帮助零售商改善客户体验和销售业绩的? 请仿照样例对以下文章提{question_number}个相关问题: 文章:{doc} 输出格式为以下,按照问题1,问题2...进行编号,冒号后面不要再出现数字编号: 问题1:... ... """ SCORING_QD_PROMPT = """您的任务是评估给定问题与文档之间的相关性。相关性评分应该在0到1之间,其中1表示非常相关,0表示不相关。评分应该基于文档内容回答问题的直接程度。 请仔细阅读问题和文档,然后基于以下标准给出一个相关性评分: - 如果文档直接回答了问题,给出接近1的分数。 - 如果文档与问题相关,但不是直接回答,给出一个介于0和1之间的分数,根据相关程度递减。 - 如果文档与问题不相关,给出0。 例如: 问题:小明昨天吃了什么饭? 文档:小明昨天和朋友出去玩,还聚了餐,吃的海底捞,真是快乐的一天。 因为文档直接回答了问题的内容,因此给出0.99的分数 问题:小红学习成绩怎么样? 文档:小红在班上上课积极,按时完成作业,帮助同学,被老师评为了班级积极分子。 文档中并没有提到小红的学习成绩,只是提到了上课积极,按时完成作业,因此给出0.10的分数 请基于上述标准,为以下问题与文档对给出一个相关性评分,评分分数保留小数点后2位数: 问题: {query} 文档: {doc} """
微调合成数据方法类TrainDataGenerator参数介绍:
参数名 |
数据类型 |
可选/必选 |
说明 |
---|---|---|---|
llm |
Text2TextLLM |
必选 |
用于微调合成数据生成及筛选的大模型,详情请参考Text2TextLLM类 |
dataset_path |
str |
必选 |
自动生成和筛选后的微调合成数据集文件存储目录,路径长度取值范围为[1,1024]。路径不能包含软链接且不允许存在".."。 |
reranker |
Reranker |
必选 |
用于微调合成数据筛选过程中的reranker,详情请参考Reranker |
encrypt_fn |
Callable[[str], str] |
可选 |
对生成的Q-D对进行加密存储,默认为None,即不加密处理 须知:
如果上传的文档涉及银行卡号、身份证号、护照号、口令等个人数据,请配置该参数保证个人数据安全。 |
decrypt_fn |
Callable[[str], str] |
可选 |
对已存储的Q-D对进行解密处理,默认为None。 |
调用示例
from paddle.base import libpaddle from langchain_community.document_loaders import TextLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from mx_rag.document import LoaderMng from mx_rag.document.loader import DocxLoader from mx_rag.llm import Text2TextLLM from mx_rag.reranker.local import LocalReranker from mx_rag.tools.finetune.generator import TrainDataGenerator, DataProcessConfig from mx_rag.utils import ClientParam llm = Text2TextLLM(model_name="Llama3-8B-Chinese-Chat", base_url="https://{ip}:{port}/v1/chat/completions", client_param=ClientParam(ca_file="/path/to/ca.crt") ) reranker = LocalReranker("/home/data/bge-reranker-large", dev_id=0) dataset_path = "path to data_output" # 微调合成数据集的输出地址 document_path = "path to document dir" # 用户提供的原始文档集所在地址 loader_mng = LoaderMng() loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) # 加载文档切分器,使用langchain的 loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, file_types=[".docx", ".txt", ".md"], splitter_params={"chunk_size": 750, "chunk_overlap": 150, "keep_separator": False } ) train_data_generator = TrainDataGenerator(llm, dataset_path, reranker) split_doc_list = train_data_generator.generate_origin_document(document_path=document_path, loader_mng=loader_mng) config = DataProcessConfig() train_data_generator.generate_train_data(split_doc_list, config)