Class Introduction
Function
Uses Transformers to start a model locally and provides the text-to-vector embedding function. Weights of models in the BertModel class supported by Transformers are required. The class inherits the langchain_core.embeddings.Embeddings API. The supported models are BAAI/bge-large-zh-v1.5 and aspire/acge_text_embedding.
If the configured model weight is not in the safetensors format, convert the model weight to the safetensors format before using it. This prevents security problems caused by insecure model weight formats such as CKPT and BIN.
Prototype
from mx_rag.embedding.local import TextEmbedding TextEmbedding(model_path, dev_id, use_fp16, pooling_method, lock)
Parameters
Example (Without Inference Acceleration Enabled)
from paddle.base import libpaddle
from mx_rag.embedding.local import TextEmbedding
# Same as embed = TextEmbedding("/path/to/model", 1).
embed = TextEmbedding.create(model_path="/path/to/model", dev_id=1)
print(embed.embed_documents(['abc', 'bcd']))
print(embed.embed_query('abc'))
Example (with Inference Acceleration Enabled)
import os
from paddle.base import libpaddle
import torch_npu
# Adapt to vectorized inference acceleration.
from modeling_bert_adapter import enable_bert_speed
# Enable vectorized inference acceleration (True: enabled; False: disabled).
os.environ["ENABLE_BOOST"] = "True"
from mx_rag.embedding.local import TextEmbedding
device_id = 1
torch_npu.npu.set_device(f"npu:{device_id}")
# Same as embed = TextEmbedding("/path/to/model", 1).
embed = TextEmbedding.create(model_path="/path/to/model", dev_id=device_id )
print(embed.embed_documents(['abc', 'bcd']))
print(embed.embed_query('abc'))
Multi-Thread Calling Example (for Other Embedding Models)
from paddle.base import libpaddle
import threading
from mx_rag.embedding.local import TextEmbedding
def infer(k, embed):
print(f"thread_{k}")
print(embed.embed_query('abc'))
print(embed.embed_documents(['abc', 'bcd']))
if __name__ == '__main__':
worker_nums=2
threads = []
embed = TextEmbedding.create(model_path='/path/to/model', dev_id=1, pooling_method='cls', lock=threading.Lock())
for i in range(worker_nums):
thread = threading.Thread(target=infer, args=(i,embed,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
Parent topic: TextEmbedding