Class Introduction

Function

Uses Transformers to start a model locally and provides the text-to-vector embedding function. Weights of models in the BertModel class supported by Transformers are required. The class inherits the langchain_core.embeddings.Embeddings API. The supported models are BAAI/bge-large-zh-v1.5 and aspire/acge_text_embedding.

If the configured model weight is not in the safetensors format, convert the model weight to the safetensors format before using it. This prevents security problems caused by insecure model weight formats such as CKPT and BIN.

Prototype

from mx_rag.embedding.local import TextEmbedding
TextEmbedding(model_path, dev_id, use_fp16, pooling_method, lock)

Parameters

Parameter

Data Type

Required/Optional

Description

model_path

String

Required

Model weight file directory. The path length cannot exceed 1024 characters. The path cannot be a soft link or a relative path.

  • The size of each file in the directory cannot exceed 10 GB, the level cannot exceed 64, and the total number of files cannot exceed 512.
  • The running user group and non-running users cannot have the write permission on the files in the directory.
  • All files within the directory, as well as the parent directory itself, must have their group ownership set to the running user.

    The storage path cannot be in the path list: ["/etc", "/usr/bin", "/usr/lib", "/usr/lib64", "/sys/", "/dev/", "/sbin", "/tmp"].

dev_id

Integer

Optional

ID of the model's running NPU. You can use npu-smi info to query the available ID. The value range is [0, 63]. The default value is 0.

use_fp16

Bool

Optional

Whether to use FP16. The default value is True.

pooling_method

String

Optional

Mode for processing last_hidden_state. The value range is ['cls', 'mean', 'max', 'lasttoken']. The default value is 'cls'.

lock

multiprocessing.synchronize.Lock or _thread.LockType

Optional

The local model does not support multi-thread or multi-process mode. If you need to call this API in multi-thread or multi-process mode, allocate a lock. The default value is None.

Value options:

  • None: No lock is used. In this case, this API does not support concurrency.
  • multiprocessing.Lock (): process lock. In this case, this API supports multi-process calling.
  • threading.Lock (): thread lock. In this case, this API supports multi-thread calling.

Example (Without Inference Acceleration Enabled)

from paddle.base import libpaddle
from mx_rag.embedding.local import TextEmbedding
# Same as embed = TextEmbedding("/path/to/model", 1).
embed = TextEmbedding.create(model_path="/path/to/model", dev_id=1)
print(embed.embed_documents(['abc', 'bcd']))
print(embed.embed_query('abc'))

Example (with Inference Acceleration Enabled)

import os
from paddle.base import libpaddle
import torch_npu
# Adapt to vectorized inference acceleration.
from modeling_bert_adapter import enable_bert_speed
# Enable vectorized inference acceleration (True: enabled; False: disabled).
os.environ["ENABLE_BOOST"] = "True"
from mx_rag.embedding.local import TextEmbedding
device_id = 1
torch_npu.npu.set_device(f"npu:{device_id}")
# Same as embed = TextEmbedding("/path/to/model", 1).
embed = TextEmbedding.create(model_path="/path/to/model", dev_id=device_id )
print(embed.embed_documents(['abc', 'bcd']))
print(embed.embed_query('abc'))

Multi-Thread Calling Example (for Other Embedding Models)

from paddle.base import libpaddle
import threading
from mx_rag.embedding.local import TextEmbedding
def infer(k, embed):
    print(f"thread_{k}")
    print(embed.embed_query('abc'))
    print(embed.embed_documents(['abc', 'bcd']))

if __name__ == '__main__':
    worker_nums=2
    threads = []
    embed = TextEmbedding.create(model_path='/path/to/model', dev_id=1, pooling_method='cls', lock=threading.Lock())
    for i in range(worker_nums):
        thread = threading.Thread(target=infer, args=(i,embed,))
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()