类功能
功能描述
提供知识图谱创建、检索的统一入口。
函数原型
from mx_rag.graphrag import GraphRAGPipeline GraphRAGPipeline(work_dir, llm, embedding_model, dim, rerank_model, graph_type,graph_name, encrypt_fn,decrypt_fn,kwargs)
输入参数说明
参数名 |
数据类型 |
可选/必选 |
说明 |
|---|---|---|---|
work_dir |
str |
必选 |
知识图谱工作目录,其剩余空间至少为5GB,保存生成的图json中间文件,如果使用的MindFAISS,对应的向量数据也在该路径下。 不能为相对路径,路径长度不能超过1024,不能为软链接且不允许存在".."。 路径不能在路径列表中:["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin", "/tmp"]。 |
llm |
Text2TextLLM |
必选 |
大模型接口实例对象。 |
embedding_model |
Embeddings |
必选 |
langchain_core.embeddings.Embeddings的子类,取值包含:
|
dim |
int |
必选 |
嵌入模型生成的向量维度,其取值范围为[1, 1024 * 1024]。 |
rerank_model |
Reranker |
可选 |
mx_rag_reranker.Reranker的子类,默认为“None”,取值包含:
|
graph_type |
str |
可选 |
图数据库类型,默认为“networkx”,其取值仅支持["networkx", "opengauss"]。 |
graph_name |
str |
可选 |
知识图谱名称,默认为“graph”,其取值范围为[1, 255],只能由标识符组成。 |
encrypt_fn |
Callable |
可选 |
回调方法,对调用build_graph产生的json文件内容加密。请注意提供正确加密方法并保证安全性,返回值是加密后的字符串。 须知:
如果上传的文档涉及银行卡号、身份证号、护照号、口令等个人数据,请配置该参数保证个人数据安全。 |
decrypt_fn |
Callable |
可选 |
回调方法,在graph_type为"networkx"时,在检索时会对"{graph_name}.json"解密读取。请注意提供正确解密方法并保证安全性,返回值是解密后的字符串。 |
kwargs |
Dict |
可选 |
扩展参数列表:
说明:
age_graph由用户控制传入,请使用安全的连接方式。 |
返回值说明
GraphRAGPipeline对象。
调用示例
import getpass
from paddle.base import libpaddle # fix std::bad_alloc
from langchain_opengauss import OpenGaussSettings, openGaussAGEGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from mx_rag.chain.single_text_to_text import GraphRagText2TextChain
from mx_rag.document import LoaderMng
from mx_rag.embedding.local import TextEmbedding
from mx_rag.graphrag import GraphRAGPipeline
from mx_rag.llm import LLMParameterConfig, Text2TextLLM
from mx_rag.reranker.local import LocalReranker
from mx_rag.utils import ClientParam
work_dir = "test_pipeline"
llm = Text2TextLLM(
base_url="https://x.x.x.x:port/v1/chat/completions",
model_name="model_name",
llm_config=LLMParameterConfig(max_tokens=64 * 1024, temperature=0.6, top_p=0.9),
client_param=ClientParam(timeout=180, ca_file="/path/to/ca.crt"),
)
rerank_model = LocalReranker("/data/models/bge-reranker-v2-m3/", 0, 20, False)
embedding_model = TextEmbedding.create(model_path="/data/models/bge-large-en-v1.5")
data_load_mng = LoaderMng()
data_load_mng.register_loader(TextLoader, [".txt"])
data_load_mng.register_splitter(
RecursiveCharacterTextSplitter,
[".txt"],
dict(chunk_size=512, chunk_overlap=20)
)
graph_name = "hotpotqa"
graph_type = "opengauss"
conf = OpenGaussSettings(user="gaussdb",
password=getpass.getpass(),
host="x.x.x.x",
port="x",
database="postgres")
age_graph = openGaussAGEGraph(graph_name, conf,
sslmode="verify-ca",
sslcert="client.crt",
sslkey="client.key",
sslrootcert="cacert.pem")
pipeline = GraphRAGPipeline(work_dir, llm, embedding_model, 1024, rerank_model, graph_name=graph_name,
age_graph=age_graph)
pipeline.upload_files(["./test_graph/hotpotqa.500.txt"], data_load_mng)
pipeline.build_graph()
question = "Which case was brought to court first Miller v. California or Gates v. Collier ?"
contexts = pipeline.retrieve_graph(question)
text2text_chain = GraphRagText2TextChain(
llm=llm,
retriever=pipeline.as_retriever(),
reranker=rerank_model)
result = text2text_chain.query(question)
print(f"#contexts: {len(contexts)}")
print(contexts)
print(result)