昇腾社区首页
中文
注册

类功能

功能描述

用于适配RAG SDK的文生文、图生图、文生图的各种chain,同时也提供访问MxRAGCache的能力,当缓存未命中时,将进行大模型推理,然后将结果再刷新至缓存。

函数原型

from mx_rag.cache import CacheChainChat
CacheChainChat(cache,chain,convert_data_to_cache,convert_data_to_user)

输入参数说明

参数名

数据类型

可选/必选

说明

cache

MxRAGCache

必选

RAG SDK缓存。

chain

Chain

必选

RAG SDK chain,用于访问大模型。

convert_data_to_cache

Callable[[Any], Dict]

可选

该回调函数主要用于当用户数据无法转换为字符串格式时,由用户提供转换函数。

默认为不做转换。

convert_data_to_user

Callable[[Dict], Any]

可选

该回调函数主要是配合“convert_data_to_cache”使用,当用户问题命中时,将cache存储的格式转换为用户格式。

默认为不做转换。

调用示例

import time
from paddle.base import libpaddle
from langchain.text_splitter import RecursiveCharacterTextSplitter
from mx_rag.chain import SingleText2TextChain
from mx_rag.document.loader import DocxLoader
from mx_rag.embedding.local import TextEmbedding
from mx_rag.knowledge import KnowledgeDB
from mx_rag.knowledge.knowledge import KnowledgeStore
from mx_rag.llm import Text2TextLLM
from mx_rag.storage.document_store import SQLiteDocstore
from mx_rag.knowledge.handler import upload_files
from mx_rag.document import LoaderMng
from mx_rag.storage.vectorstore import MindFAISS
from mx_rag.utils import ClientParam
from mx_rag.cache import CacheChainChat, MxRAGCache, SimilarityCacheConfig

#向量维度
dim = 1024
# NPU卡id
dev = 0

similarity_config = SimilarityCacheConfig(
    vector_config={
        "vector_type": "npu_faiss_db",
        "x_dim": dim,
        "devs": [dev],

    },
    cache_config="sqlite",
    emb_config={
        "embedding_type": "local_text_embedding",
        "x_dim": dim,
        "model_path": "/path to emb",  # emb 模型路径
        "dev_id": dev
    },
    similarity_config={
        "similarity_type": "local_reranker",
        "model_path": "/path to reranker",  # reranker 模型路径
        "dev_id": dev
    },

    retrieval_top_k=1,
    cache_size=1000,
    clean_size=20,
    similarity_threshold=0.86,
    data_save_folder="/save path",  # 落盘路径
    disable_report=True
)
similarity_cache = MxRAGCache("similarity_cache", similarity_config)

# cache 初始化
cache = MxRAGCache("similarity_cache", similarity_config)
# Step1离线构建知识库,首先注册文档处理器
loader_mng = LoaderMng()
# 加载文档加载器,可以使用RAG SDK自有的,也可以使用langchain的
loader_mng.register_loader(DocxLoader, [".docx"])
# 加载文档切分器,使用langchain的
loader_mng.register_splitter(RecursiveCharacterTextSplitter, [".xlsx", ".docx", ".pdf"],
                             {"chunk_size": 200, "chunk_overlap": 50, "keep_separator": False})

emb = TextEmbedding(model_path="/path to emb", dev_id=dev)

# 初始化文档chunk关系数据库
chunk_store = SQLiteDocstore(db_path="./sql.db")
# 初始化知识管理关系数据库
knowledge_store = KnowledgeStore(db_path="./sql.db")
# 初始化矢量检索

vector_store = MindFAISS(x_dim=dim,
                         devs=[dev],
                         load_local_index="./faiss.index"
                         )

#添加知识库及管理员
knowledge_store.add_knowledge(knowledge_name="test", user_id='Default', role='admin')
# 初始化知识库管理
knowledge_db = KnowledgeDB(knowledge_store=knowledge_store,
                           chunk_store=chunk_store,
                           vector_store=vector_store,
                           knowledge_name="test",
                           user_id='Default',
                           white_paths=["/home"])
# 完成离线知识库构建,上传领域知识test.docx文档。
upload_files(knowledge_db, ["/path to files"],
             loader_mng=loader_mng,
             embed_func=emb.embed_documents,
             force=True)
# Step2在线问题答复,初始化检索器
retriever = vector_store.as_retriever(document_store=chunk_store,
                                      embed_func=emb.embed_documents, k=3, score_threshold=0.3)
# 配置reranker

# 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改
llm = Text2TextLLM(base_url="https://<ip>:<port>",
                   model_name="Llama3-8B-Chinese-Chat",
                   client_param=ClientParam(ca_file="/path/to/ca.crt"))
text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever)
cache_chain = CacheChainChat(chain=text2text_chain, cache=cache)
start_time = time.time()
res = cache_chain.query("请描述2024年高考作文题目")
end_time = time.time()
print(f"no cache query time cost:{(end_time - start_time) * 1000}ms")
print(f"no cache answer {res}")
start_time = time.time()
res = cache_chain.query("2024年的高考题目是什么", )
end_time = time.time()
print(f"cache query time cost:{(end_time - start_time) * 1000}ms")
print(f"cache answer {res}")