昇腾社区首页
中文
注册

类功能

功能描述

将输入的query通过大模型提取关键词,再使用BM25进行topk检索,该类继承langchain_core.retrievers.BaseRetriever,通过调用基类的invoke方法使用检索功能,输入的query长度不超过100万。

函数原型

from mx_rag.retrievers.bm_retriever import BMRetriever
# 所有参数需通过关键字参数传递
BMRetriever(docs, llm, k, llm_config, prompt, preprocess_func)

输入参数说明

参数名

数据类型

可选/必选

说明

docs

List[Document]

必选

待检索的文档列表。

llm

Text2TextLLM

必选

大模型对象实例,具体类型请参见Text2TextLLM类

k

int

可选

检索返回的top k,取值范围:[1,10000],默认值为“1”

llm_config

LLMParameterConfig

可选

调用大模型参数,此处默认值temperature为“0.5”,top_p为“0.95”,其余参数说明请参见LLMParameterConfig类

prompt

langchain_core.prompts.PromptTemplate

可选

默认值如下,其中question字符串是固定的,不能更改,表示输入的问题;prompt.input_variables必须包含question,prompt.template的长度取值范围为(0,1 * 1024 * 1024],表示提示词。实际请求大模型的query为prompt拼接question,其有效取值依赖MindIE的配置,请参见《MindIE Motor开发指南》中的“MindIE Motor组件 > MindIE Server > 配置参数说明”章节中关于“maxSeqLen”的说明。注意:prompt和question的语言类型最好保持一致,或者指明大模型回答语言类型,否则会影响大模型回答效果。

PromptTemplate(
input_variables=["question"],
template="""根据问题提取关键词,不超过10个。关键词尽量切分为动词、名词、或形容词等单独的词,
不要长词组(目的是更好的匹配检索到语义相关但表述不同的相关资料)。请根据给定参考资料提取关键词,关键词之间使用逗号分隔,比如{{关键词1, 关键词2}}
Question: CANN如何安装?
Keywords: CANN, 安装, install

Question: MindStudio 容器镜像怎么制作
Keywords: MindStudio, 容器镜像, Docker build

Question: {question}
Keywords:
""")

preprocess_func

Callable[[str], List[str]]

可选

BM25检索前预处理,对大模型返回的文本串数据进行切分获取关键词列表。默认对字符串使用逗号进行切分。

调用示例

from mx_rag.document.loader import DocxLoader
from mx_rag.chain import SingleText2TextChain
from mx_rag.llm import Text2TextLLM
from mx_rag.retrievers.bm_retriever import BMRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from mx_rag.utils import ClientParam
loader = DocxLoader("/path/to/MindIE.docx")
docs = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=150))
client_param = ClientParam(ca_file="/path/to/ca.crt")
llm = Text2TextLLM(base_url="https://ip:port/v1/chat/completions", model_name="qianwen-7b", client_param = client_param)
bm_retriever = BMRetriever(docs=docs, llm=llm, k=10)
text2text_chain = SingleText2TextChain(llm=llm, retriever=bm_retriever)
res = text2text_chain.query("怎么安装MindIE?")
print(res)