昇腾社区首页
中文
注册

类功能

功能描述

进行微调合成数据自动生成的类,给用户提供原始文档集切分处理,基于切分后的文本生成及筛选微调合成数据。

函数原型

  • 微调合成数据配置类:
    from mx_rag.tools.finetune.generator import DataProcessConfig
    @dataclass
    class DataProcessConfig():
        generate_qd_prompt: str = GENERATE_QD_PROMPT
        llm_preferred_prompt: str = SCORING_QD_PROMPT
        question_number: int = 3
        featured: bool = True
        featured_percentage: float = 0.8
        preferred: bool = True
        llm_threshold_score: float = 0.8
        rewrite: bool = True
        query_rewrite_number: int = 2
  • 微调合成数据方法类:
    from mx_rag.tools.finetune.generator import TrainDataGenerator
    TrainDataGenerator(llm: Text2TextLLM, dataset_path: str, reranker: Reranker, encrypt_fn, decrypt_fn)

输入参数说明

微调合成数据配置类DataProcessConfig参数详情:

参数名

数据类型

可选/必选

说明

generate_qd_prompt

str

可选

微调合成数据自动生成所用prompt,用户可根据相应领域进行修改,微调效果更佳。默认值为GENERATE_QD_PROMPT,字符串长度范围(0, 1*1024*1024]

llm_preferred_prompt

str

可选

微调合成数据筛选过程所用prompt,用户可根据相应领域进行修改,微调效果更佳。默认值为SCORING_QD_PROMPT,字符串长度范围(0, 1*1024*1024]

question_number

int

可选

每个原始文本切片对应生成的问题数,该数量越大,生成的问题角度越全面,有利于微调效果,但是耗时较长,默认值为3,取值范围(0, 20]

featured

bool

可选

基于BM25+Reranker数据相关性评分融合筛选,默认值为True

featured_percentage

float

可选

基于BM25+Reranker融合筛选后比例, 取值范围(0.0, 1.0),默认值为0.8。

preferred

bool

可选

基于LLM数据相关性评分筛选,默认值为True

llm_threshold_score

float

可选

基于LLM数据相关性评分筛选后比例, 取值范围(0.0, 1.0),默认值为0.8。

rewrite

bool

可选

基于LLM对生成的数据进行语义多角度重写扩充,默认值为True

query_rewrite_number

int

可选

针对每个问答对重写扩充的数量,默认值为2,取值范围(0, 20]

GENERATE_QD_PROMPT和SCORING_QD_PROMPT定义如下:

GENERATE_QD_PROMPT = """阅读文章,生成一个相关的问题,例如:
文章:气候变化对海洋生态系统造成了严重的影响,其中包括海洋温度上升、海平面上升、酸化等问题。这些变化对海洋生物种群分布、生态圈的稳定性以及渔业等方面都产生了深远影响。在全球变暖的背景下,保护海洋生态系统已经成为当务之急。 
问题:气候变化对海洋生态系统的影响主要体现在哪些方面?
文章:零售业是人工智能应用的另一个重要领域。通过数据分析和机器学习算法,零售商可以更好地了解消费者的购买行为、趋势和偏好。人工智能技术可以帮助零售商优化库存管理、推荐系统、市场营销等方面的工作,提高销售额和客户满意度。
问题:人工智能是如何帮助零售商改善客户体验和销售业绩的?
请仿照样例对以下文章提{question_number}个相关问题:
文章:{doc}
输出格式为以下,按照问题1,问题2...进行编号,冒号后面不要再出现数字编号:
问题1:...
...
"""
SCORING_QD_PROMPT = """您的任务是评估给定问题与文档之间的相关性。相关性评分应该在0到1之间,其中1表示非常相关,0表示不相关。评分应该基于文档内容回答问题的直接程度。
请仔细阅读问题和文档,然后基于以下标准给出一个相关性评分:
- 如果文档直接回答了问题,给出接近1的分数。
- 如果文档与问题相关,但不是直接回答,给出一个介于0和1之间的分数,根据相关程度递减。
- 如果文档与问题不相关,给出0。
例如:
问题:小明昨天吃了什么饭?
文档:小明昨天和朋友出去玩,还聚了餐,吃的海底捞,真是快乐的一天。
因为文档直接回答了问题的内容,因此给出0.99的分数
问题:小红学习成绩怎么样?
文档:小红在班上上课积极,按时完成作业,帮助同学,被老师评为了班级积极分子。
文档中并没有提到小红的学习成绩,只是提到了上课积极,按时完成作业,因此给出0.10的分数
请基于上述标准,为以下问题与文档对给出一个相关性评分,评分分数保留小数点后2位数:
问题: {query}
文档: {doc}
"""

微调合成数据方法类TrainDataGenerator参数介绍:

参数名

数据类型

可选/必选

说明

llm

Text2TextLLM

必选

用于微调合成数据生成及筛选的大模型,详情请参考Text2TextLLM类

dataset_path

str

必选

自动生成和筛选后的微调合成数据集文件存储目录,路径长度取值范围为[1,1024]。路径不能包含软链接且不允许存在".."。

reranker

Reranker

必选

用于微调合成数据筛选过程中的reranker,详情请参考Reranker

encrypt_fn

Callable[[str], str]

可选

对生成的Q-D对进行加密存储,默认为None,即不加密处理

须知:

如果上传的文档涉及银行卡号、身份证号、护照号、口令等个人数据,请配置该参数保证个人数据安全。

decrypt_fn

Callable[[str], str]

可选

对已存储的Q-D对进行解密处理,默认为None。

调用示例

from paddle.base import libpaddle
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from mx_rag.document import LoaderMng
from mx_rag.document.loader import DocxLoader
from mx_rag.llm import Text2TextLLM
from mx_rag.reranker.local import LocalReranker
from mx_rag.tools.finetune.generator import TrainDataGenerator, DataProcessConfig

from mx_rag.utils import ClientParam


llm = Text2TextLLM(model_name="Llama3-8B-Chinese-Chat", base_url="https://{ip}:{port}/v1/chat/completions", 
client_param=ClientParam(ca_file="/path/to/ca.crt")
)
reranker = LocalReranker("/home/data/bge-reranker-large", dev_id=0)
dataset_path = "path to data_output"  # 微调合成数据集的输出地址

document_path = "path to document dir"  # 用户提供的原始文档集所在地址

loader_mng = LoaderMng()

loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])

# 加载文档切分器,使用langchain的
loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter,
                             file_types=[".docx", ".txt", ".md"],
                             splitter_params={"chunk_size": 750,
                                              "chunk_overlap": 150,
                                              "keep_separator": False
                                              }
                             )

train_data_generator = TrainDataGenerator(llm, dataset_path, reranker)

split_doc_list = train_data_generator.generate_origin_document(document_path=document_path, loader_mng=loader_mng)
config = DataProcessConfig()
train_data_generator.generate_train_data(split_doc_list, config)