类功能
功能描述
本地使用transformers启动模型,提供文本至向量的embedding功能。需要使用transformers支持的BertModel类模型权重。类继承实现了langchain_core.embeddings.Embeddings接口。当前支持的模型:BAAI/bge-large-zh-v1.5,aspire/acge_text_embedding。
函数原型
from mx_rag.embedding.local import TextEmbedding TextEmbedding(model_path, dev_id, use_fp16, pooling_method, lock)
输入参数说明
参数名 |
数据类型 |
可选/必选 |
说明 |
|---|---|---|---|
model_path |
str |
必选 |
模型权重文件目录,路径长度不能超过1024,不能为软链接和相对路径。
|
dev_id |
int |
可选 |
模型运行NPU ID,通过npu-smi info查询可用ID,取值范围[0, 63],默认为卡0。 |
use_fp16 |
bool |
可选 |
是否将模型转换为半精度,默认为“True”。 |
pooling_method |
str |
可选 |
选择处理last_hidden_state的方式,取值范围['cls', 'mean', 'max', 'lasttoken'],默认'cls'。 |
lock |
multiprocessing.synchronize.Lock或_thread.LockType |
可选 |
local model不支持多线程或者多进程进行处理,如果用户需要多进程或者多线程调用此接口需要申请锁。默认值为None。 可选值:
|
不启用推理加速调用示例
from paddle.base import libpaddle
from mx_rag.embedding.local import TextEmbedding
# 同embed = TextEmbedding("/path/to/model", 1)
embed = TextEmbedding.create(model_path="/path/to/model", dev_id=1)
print(embed.embed_documents(['abc', 'bcd']))
print(embed.embed_query('abc'))
开启推理加速调用示例
import os from paddle.base import libpaddle import torch_npu # 适配向量化推理加速 from modeling_bert_adapter import enable_bert_speed # 使能向量化推理加速(设置为"True"时表示使能,"False"表示不使能) os.environ["ENABLE_BOOST"] = "True" from mx_rag.embedding.local import TextEmbedding device_id = 1 torch_npu.npu.set_device(f"npu:{device_id}") # 同embed = TextEmbedding("/path/to/model", 1) embed = TextEmbedding.create(model_path="/path/to/model", dev_id=device_id ) print(embed.embed_documents(['abc', 'bcd'])) print(embed.embed_query('abc'))
多线程调用示例(其余嵌入模型均可参考该示例)
from paddle.base import libpaddle
import threading
from mx_rag.embedding.local import TextEmbedding
def infer(k, embed):
print(f"thread_{k}")
print(embed.embed_query('abc'))
print(embed.embed_documents(['abc', 'bcd']))
if __name__ == '__main__':
worker_nums=2
threads = []
embed = TextEmbedding.create(model_path='/path/to/model', dev_id=1, pooling_method='cls', lock=threading.Lock())
for i in range(worker_nums):
thread = threading.Thread(target=infer, args=(i,embed,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
父主题: TextEmbedding