昇腾社区首页
中文
注册
开发者
下载

类功能

功能描述

提供知识图谱创建、检索的统一入口。

函数原型

from mx_rag.graphrag import GraphRAGPipeline

GraphRAGPipeline(work_dir, llm, embedding_model, dim, rerank_model, graph_type,graph_name, kwargs)

输入参数说明

参数名

数据类型

可选/必选

说明

work_dir

str

必选

知识图谱工作目录,其剩余空间至少为5GB,保存生成的图json中间文件,如果使用的MindFAISS,对应的向量数据也在该路径下。

不能为相对路径,路径长度不能超1024,不能为软链接且不允许存在".."。

路径不能在路径列表中:["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin"]。

llm

Text2TextLLM

必选

大模型接口实例对象。

embedding_model

Embeddings

必选

langchain_core.embeddings.Embeddings的子类,取值包含:

  • mx_rag.embedding.local.TextEmbedding
  • mx_rag.embedding.service.TEIEmbedding

rerank_model

Reranker

可选

mx_rag_reranker.Reranker的子类,默认为“None”,取值包含:

  • mx_rag.reranker.local.LocalReranker
  • mx_rag.reranker.service.TEIReranker

graph_type

str

可选

图数据库类型,默认为“networkx”,其取值仅支持["networkx", "opengauss"]。

graph_name

str

可选

知识图谱名称,默认为“graph”,其取值范围为[1, 255]。

dim

int

必须

嵌入模型生成的向量维度,其取值范围为[1, 1024 * 1024]。

kwargs

Dict

可选

扩展参数列表:

  • graph_conf:当图数据库类型为openGauss时,需要指定该参数,其类型为OpenGaussSettings,指定了连接图数据库的相关参数。
  • devs:指定NPU设备,为一个只包含一个元素的list。
  • node_vector_store: 用于存储向量化节点以实现相似节点搜索的向量数据库。默认为None,此时将使用MindFAISS作为向量数据库。
  • concept_vector_store: 在对概念进行聚类时,用于存储向量化概念以实现相似概念搜索的向量数据库。默认为None,此时将使用MindFAISS作为向量数据库。

返回值说明

GraphRAGPipeline对象。

调用示例

from paddle.base import libpaddle  # fix std::bad_alloc
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from mx_rag.chain.single_text_to_text import GraphRagText2TextChain
from mx_rag.document import LoaderMng
from mx_rag.embedding.local import TextEmbedding
from mx_rag.graphrag import GraphRAGPipeline
from mx_rag.llm import LLMParameterConfig, Text2TextLLM
from mx_rag.reranker.local import LocalReranker
from mx_rag.utils import ClientParam
work_dir = "test_pipeline"
llm = Text2TextLLM(
    base_url="https://x.x.x.x:port/v1/chat/completions",
    model_name="model_name",
    llm_config=LLMParameterConfig(max_tokens=64 * 1024, temperature=0.6, top_p=0.9),
    client_param=ClientParam(timeout=180, ca_file="/path/to/ca.crt"),
)
rerank_model = LocalReranker("/data/models/bge-reranker-v2-m3/", 0, 20, False)
embedding_model = TextEmbedding.create(model_path="/data/models/bge-large-en-v1.5")
data_load_mng = LoaderMng()
data_load_mng.register_loader(TextLoader, [".txt"])
data_load_mng.register_splitter(
    RecursiveCharacterTextSplitter,
    [".txt"],
    dict(chunk_size=512, chunk_overlap=20)
)
graph_name = "hotpotqa"
graph_type = "networkx"
pipeline = GraphRAGPipeline(work_dir, llm, embedding_model,rerank_model, 1024, graph_name=graph_name, graph_type=graph_type)
pipeline.upload_files(["./test_graph/hotpotqa.500.txt"], data_load_mng)
pipeline.build_graph()

question = "Which case was brought to court first Miller v. California or Gates v. Collier ?"
contexts = pipeline.retrieve_graph(graph_name, graph_type, question)
text2text_chain = GraphRagText2TextChain(
    llm=llm,
    retriever=pipeline.as_retriever(graph_name, graph_type),
    reranker=rerank_model)
result = text2text_chain.query(question)
print(f"#contexts: {len(contexts)}")
print(contexts)
print(result)