昇腾社区首页
中文
注册

类功能

功能描述

本地使用transformers启动模型,提供文本相关性计算功能。继承抽象类Reranker。当前支持的模型:bge-reranker-large、bge-reranker-base。

配置的模型如果不是safetensors权重格式,请先将模型权重转换为safetensors格式后再使用,防止使用ckpt、bin等不安全的模型权重格式引入安全问题。

函数原型

from mx_rag.reranker.local import LocalReranker
LocalReranker(model_path, dev_id, k, use_fp16)

输入参数说明

参数名

数据类型

可选/必选

说明

model_path

str

必选

模型权重文件目录,路径长度不能超过1024,不能为软链接和相对路径。

  • 目录下的各文件大小不能超过10GB、深度不超过64,且文件总个数不超过512。
  • 运行用户的属组,以及非运行用户不能有该目录下文件的写权限。
  • 目录下的文件以及文件的上一级目录的属组必须是运行用户。

dev_id

int

可选

模型运行在哪张卡上,取值范围:[0, 63],默认值为“0”

k

int

可选

精排后返回最相关的k个结果,取值范围:[1, 10000],默认值为“1”

use_fp16

bool

可选

是否将模型转换为半精度,默认值为“True”

返回值说明

LocalReranker对象。

调用示例

from paddle.base import libpaddle
from langchain_core.documents import Document
from mx_rag.reranker.local import LocalReranker
# 同LocalReranker(model_path="path to model", dev_id=0)
doc_1 = Document(
                page_content="我是小红",
                metadata={"source": ""}
            )
doc_2 = Document(
                page_content="我是小明",
                metadata={"source": ""}
            )
docs = [doc_1, doc_2]
rerank = LocalReranker.create(model_path="path to model", dev_id=0)
scores = rerank.rerank('你好', [doc.page_content for doc in docs])
res = rerank.rerank_top_k(docs, scores)
print(res)