类功能
功能描述
将输入的query通过大模型提取关键词,再使用BM25进行topk检索,该类继承langchain_core.retrievers.BaseRetriever,通过调用基类的invoke方法使用检索功能,输入的query长度不超过100万。
函数原型
from mx_rag.retrievers.bm_retriever import BMRetriever # 所有参数需通过关键字参数传递 BMRetriever(docs, llm, k, llm_config, prompt, preprocess_func)
输入参数说明
参数名 |
数据类型 |
可选/必选 |
说明 |
---|---|---|---|
docs |
List[Document] |
必选 |
待检索的文档列表。 |
llm |
Text2TextLLM |
必选 |
大模型对象实例,具体类型请参见Text2TextLLM类。 |
k |
int |
可选 |
检索返回的top k,取值范围:[1,10000],默认值为“1”。 |
llm_config |
LLMParameterConfig |
可选 |
调用大模型参数,此处默认值temperature为“0.5”,top_p为“0.95”,其余参数说明请参见LLMParameterConfig类。 |
prompt |
langchain_core.prompts.PromptTemplate |
可选 |
默认值如下,其中question字符串是固定的,不能更改,表示输入的问题;prompt.input_variables必须包含question,prompt.template的长度取值范围为(0,1 * 1024 * 1024],表示提示词。实际请求大模型的query为prompt拼接question,其有效取值依赖MindIE的配置,请参见《MindIE Motor开发指南》中的“MindIE Motor组件 > MindIE Server > 配置参数说明”章节中关于“maxSeqLen”的说明。注意:prompt和question的语言类型最好保持一致,或者指明大模型回答语言类型,否则会影响大模型回答效果。 PromptTemplate( input_variables=["question"], template="""根据问题提取关键词,不超过10个。关键词尽量切分为动词、名词、或形容词等单独的词, 不要长词组(目的是更好的匹配检索到语义相关但表述不同的相关资料)。请根据给定参考资料提取关键词,关键词之间使用逗号分隔,比如{{关键词1, 关键词2}} Question: CANN如何安装? Keywords: CANN, 安装, install Question: MindStudio 容器镜像怎么制作 Keywords: MindStudio, 容器镜像, Docker build Question: {question} Keywords: """) |
preprocess_func |
Callable[[str], List[str]] |
可选 |
BM25检索前预处理,对大模型返回的文本串数据进行切分获取关键词列表。默认对字符串使用逗号进行切分。 |
调用示例
from mx_rag.document.loader import DocxLoader from mx_rag.chain import SingleText2TextChain from mx_rag.llm import Text2TextLLM from mx_rag.retrievers.bm_retriever import BMRetriever from langchain_text_splitters import RecursiveCharacterTextSplitter from mx_rag.utils import ClientParam loader = DocxLoader("/path/to/MindIE.docx") docs = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=150)) client_param = ClientParam(ca_file="/path/to/ca.crt") llm = Text2TextLLM(base_url="https://ip:port/v1/chat/completions", model_name="qianwen-7b", client_param = client_param) bm_retriever = BMRetriever(docs=docs, llm=llm, k=10) text2text_chain = SingleText2TextChain(llm=llm, retriever=bm_retriever) res = text2text_chain.query("怎么安装MindIE?") print(res)