Class Introduction
Function
Clusters the embedded text into multiple semantic clusters using a clustering model, then calculates the cosine similarity between context and question, and finally deletes chunks with low similarity in each cluster based on the specified compression rate to retain the information most relevant to the prompt, achieving compressed summarization of long texts.
Prototype
from mx_rag.compress.cluster_compressor import ClusterCompressor class ClusterCompressor(cluster_func, embed, splitter, dev_id):
Parameters
Parameter |
Data Type |
Required/Optional |
Description |
|---|---|---|---|
cluster_func |
Callable[[List[List[float]]], Union[List[int], np.ndarray]] |
Required |
Clustering function, which clusters the embedded document chunks into multiple semantic clusters. The returned result must be in the format of "List[int]" or "ndarray", and the length cannot exceed 1000 × 1000. The length must be the same as the number of document chunks. |
embed |
Embeddings |
Required |
Embedding object, which converts document chunks into vectors. The value must be a subclass of Embeddings inherited from langchain_core.embeddings. |
splitter |
TextSplitter |
Optional |
Document splitting function, which must be a subclass of TextSplitter inherited from langchain_text_splitters.base. The default value is RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0, separators=[".", "!", "?", "\n", ",", ";", " ", ""]) of langchain.text_splitter. |
dev_id |
Integer |
Optional |
NPU ID. You can use npu-smi info to query the available ID. The value range is [0, 63]. The default value is 0. |
Example
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sklearn.cluster import HDBSCAN
from mx_rag.compress.cluster_compressor import ClusterCompressor
from mx_rag.embedding.local import TextEmbedding
from mx_rag.embedding.service import TEIEmbedding
from mx_rag.utils import ClientParam
context="""Prompt text to be compressed"""
question="Provide a title for the above content."
tei_emb=False
if tei_emb:
emb = TEIEmbedding.create(url="https://ip:port/embed", client_param=ClientParam(ca_file="/path/to/ca.crt"))
else:
emb = TextEmbedding(model_path="embedding_path", dev_id=0)
def _get_community(sentences_embedding):
# Community division
node_num=len(sentences_embedding)
min_cluster_size=2
hdbscan = HDBSCAN(min_cluster_size=min(min_cluster_size, node_num))
labels = hdbscan.fit_predict(sentences_embedding)
return labels
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0, separators=[".", "!", "?", "\n", ",", ";", " ", ""], )
compressor=ClusterCompressor(cluster_func=_get_community, embed=emb, splitter=splitter, dev_id=0)
res=compressor.compress_texts(context, question, 0.6)
print(res)