Class Introduction
Function
Provides a unified entry for creating and retrieving knowledge graphs.
Prototype
from mx_rag.graphrag import GraphRAGPipeline GraphRAGPipeline(work_dir, llm, embedding_model, dim, rerank_model, graph_type,graph_name, encrypt_fn,decrypt_fn,kwargs)
Parameters
Parameter |
Data Type |
Required/Optional |
Description |
|---|---|---|---|
work_dir |
String |
Required |
knowledge graph working directory that stores generated graph JSON intermediate files. The remaining space of the directory must be at least 5 GB. If MindFAISS is used, the corresponding vector data is also stored in this directory. The directory cannot be a relative path. The length cannot exceed 1024 characters. The path cannot be a soft link and cannot contain two consecutive dotes (..). The path cannot be in the path list: ["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin", "/tmp"]. |
llm |
Text2TextLLM |
Required |
LLM instance object. |
embedding_model |
Embeddings |
Required |
Subclass of langchain_core.embeddings.Embeddings, which can be:
|
dim |
Integer |
Required |
Vector dimension generated by the embedding model. The value range is [1, 1024 × 1024]. |
rerank_model |
Reranker |
Optional |
Subclass of mx_rag_reranker.Reranker. The default value is None. The options are as follows:
|
graph_type |
String |
Optional |
Graph database type. The value can only be networkx (default) or opengauss. |
graph_name |
String |
Optional |
knowledge graph name, which can only be an identifier. The value range is [1, 255]. The default value is graph. |
encrypt_fn |
Callable |
Optional |
Callback method, which is used to encrypt the content of the JSON file generated by calling build_graph. The return value is the encrypted character string. Ensure you provide a correct and secure encryption method. NOTICE:
If the file to be uploaded contains personal data such as bank account numbers, ID card numbers, passport numbers, and passwords, set this parameter to ensure personal data security. |
decrypt_fn |
Callable |
Optional |
Callback method, which decrypts and reads the {graph_name}.json file during retrieval when graph_type is networkx. The return value is the decrypted character string. Ensure you provide a correct and secure decryption method. |
kwargs |
Dict |
Optional |
Extended parameters:
NOTE:
age_graph is controlled by users. Use a secure connection mode. |
Return Value
GraphRAGPipeline object.
Example
import getpass
from paddle.base import libpaddle # fix std::bad_alloc
from langchain_opengauss import OpenGaussSettings, openGaussAGEGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from mx_rag.chain.single_text_to_text import GraphRagText2TextChain
from mx_rag.document import LoaderMng
from mx_rag.embedding.local import TextEmbedding
from mx_rag.graphrag import GraphRAGPipeline
from mx_rag.llm import LLMParameterConfig, Text2TextLLM
from mx_rag.reranker.local import LocalReranker
from mx_rag.utils import ClientParam
work_dir = "test_pipeline"
llm = Text2TextLLM(
base_url="https://x.x.x.x:port/v1/chat/completions",
model_name="model_name",
llm_config=LLMParameterConfig(max_tokens=64 * 1024, temperature=0.6, top_p=0.9),
client_param=ClientParam(timeout=180, ca_file="/path/to/ca.crt"),
)
rerank_model = LocalReranker("/data/models/bge-reranker-v2-m3/", 0, 20, False)
embedding_model = TextEmbedding.create(model_path="/data/models/bge-large-en-v1.5")
data_load_mng = LoaderMng()
data_load_mng.register_loader(TextLoader, [".txt"])
data_load_mng.register_splitter(
RecursiveCharacterTextSplitter,
[".txt"],
dict(chunk_size=512, chunk_overlap=20)
)
graph_name = "hotpotqa"
graph_type = "opengauss"
conf = OpenGaussSettings(user="gaussdb",
password=getpass.getpass(),
host="x.x.x.x",
port="x",
database="postgres")
age_graph = openGaussAGEGraph(graph_name, conf,
sslmode="verify-ca",
sslcert="client.crt",
sslkey="client.key",
sslrootcert="cacert.pem")
pipeline = GraphRAGPipeline(work_dir, llm, embedding_model, 1024, rerank_model, graph_name=graph_name,
age_graph=age_graph)
pipeline.upload_files(["./test_graph/hotpotqa.500.txt"], data_load_mng)
pipeline.build_graph()
question = "Which case was brought to court first Miller v. California or Gates v. Collier ?"
contexts = pipeline.retrieve_graph(question)
text2text_chain = GraphRagText2TextChain(
llm=llm,
retriever=pipeline.as_retriever(),
reranker=rerank_model)
result = text2text_chain.query(question)
print(f"#contexts: {len(contexts)}")
print(contexts)
print(result)