Class Introduction

Function

Provides a unified entry for creating and retrieving knowledge graphs.

Prototype

from mx_rag.graphrag import GraphRAGPipeline

GraphRAGPipeline(work_dir, llm, embedding_model, dim, rerank_model, graph_type,graph_name, encrypt_fn,decrypt_fn,kwargs)

Parameters

Parameter

Data Type

Required/Optional

Description

work_dir

String

Required

knowledge graph working directory that stores generated graph JSON intermediate files. The remaining space of the directory must be at least 5 GB. If MindFAISS is used, the corresponding vector data is also stored in this directory.

The directory cannot be a relative path. The length cannot exceed 1024 characters. The path cannot be a soft link and cannot contain two consecutive dotes (..).

The path cannot be in the path list: ["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin", "/tmp"].

llm

Text2TextLLM

Required

LLM instance object.

embedding_model

Embeddings

Required

Subclass of langchain_core.embeddings.Embeddings, which can be:

  • mx_rag.embedding.local.TextEmbedding
  • mx_rag.embedding.service.TEIEmbedding

dim

Integer

Required

Vector dimension generated by the embedding model. The value range is [1, 1024 × 1024].

rerank_model

Reranker

Optional

Subclass of mx_rag_reranker.Reranker. The default value is None. The options are as follows:

  • mx_rag.reranker.local.LocalReranker
  • mx_rag.reranker.service.TEIReranker

graph_type

String

Optional

Graph database type. The value can only be networkx (default) or opengauss.

graph_name

String

Optional

knowledge graph name, which can only be an identifier. The value range is [1, 255]. The default value is graph.

encrypt_fn

Callable

Optional

Callback method, which is used to encrypt the content of the JSON file generated by calling build_graph. The return value is the encrypted character string. Ensure you provide a correct and secure encryption method.

NOTICE:

If the file to be uploaded contains personal data such as bank account numbers, ID card numbers, passport numbers, and passwords, set this parameter to ensure personal data security.

decrypt_fn

Callable

Optional

Callback method, which decrypts and reads the {graph_name}.json file during retrieval when graph_type is networkx. The return value is the decrypted character string. Ensure you provide a correct and secure decryption method.

kwargs

Dict

Optional

Extended parameters:

  • age_graph: This parameter is required when the graph database type is openGauss. The type is openGauss AGEGraph, which is the openGuass graph database connection instance.
  • devs: NPU device. The value is a list that contains only one element. The type is list[int].
  • node_vector_store: vector database used to store vectorized nodes for similar node search. The default value is None, meaning that MindFAISS is used as the vector database.
  • concept_vector_store: vector database used to store vectorized concepts for similar concept search during concept clustering. The default value is None, meaning that MindFAISS is used as the vector database.
NOTE:

age_graph is controlled by users. Use a secure connection mode.

Return Value

GraphRAGPipeline object.

Example

import getpass
from paddle.base import libpaddle  # fix std::bad_alloc
from langchain_opengauss import OpenGaussSettings, openGaussAGEGraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from mx_rag.chain.single_text_to_text import GraphRagText2TextChain
from mx_rag.document import LoaderMng
from mx_rag.embedding.local import TextEmbedding
from mx_rag.graphrag import GraphRAGPipeline
from mx_rag.llm import LLMParameterConfig, Text2TextLLM
from mx_rag.reranker.local import LocalReranker
from mx_rag.utils import ClientParam
work_dir = "test_pipeline"
llm = Text2TextLLM(
    base_url="https://x.x.x.x:port/v1/chat/completions",
    model_name="model_name",
    llm_config=LLMParameterConfig(max_tokens=64 * 1024, temperature=0.6, top_p=0.9),
    client_param=ClientParam(timeout=180, ca_file="/path/to/ca.crt"),
)
rerank_model = LocalReranker("/data/models/bge-reranker-v2-m3/", 0, 20, False)
embedding_model = TextEmbedding.create(model_path="/data/models/bge-large-en-v1.5")
data_load_mng = LoaderMng()
data_load_mng.register_loader(TextLoader, [".txt"])
data_load_mng.register_splitter(
    RecursiveCharacterTextSplitter,
    [".txt"],
    dict(chunk_size=512, chunk_overlap=20)
)
graph_name = "hotpotqa"
graph_type = "opengauss"

conf = OpenGaussSettings(user="gaussdb",
                         password=getpass.getpass(),
                         host="x.x.x.x",
                         port="x",
                         database="postgres")
age_graph = openGaussAGEGraph(graph_name, conf,
                              sslmode="verify-ca",
                              sslcert="client.crt",
                              sslkey="client.key",
                              sslrootcert="cacert.pem")
pipeline = GraphRAGPipeline(work_dir, llm, embedding_model, 1024, rerank_model, graph_name=graph_name,
                            age_graph=age_graph)
pipeline.upload_files(["./test_graph/hotpotqa.500.txt"], data_load_mng)
pipeline.build_graph()
question = "Which case was brought to court first Miller v. California or Gates v. Collier ?"
contexts = pipeline.retrieve_graph(question)
text2text_chain = GraphRagText2TextChain(
    llm=llm,
    retriever=pipeline.as_retriever(),
    reranker=rerank_model)
result = text2text_chain.query(question)
print(f"#contexts: {len(contexts)}")
print(contexts)
print(result)