昇腾社区首页
中文
注册

TEI框架接入MindIE Torch与ATB后端全量适配代码参考

适配昇腾环境及MindIE Torch与ATB后端的TEI代码目录结构与其Github代码仓完全一致(需要修改的代码文件加粗显示在下方代码结构树下,其它未显示文件或文件夹与TEI的Github代码仓保持一致),其目录结构如下所示。

text-embeddings-inference
|____core
|  |____src
|  |  |____ infer.rs
|____backends
|  |____grpc-client
|  |  |____src
|  |  |  |____ client.rs
|  |____proto
|  |  |____ embed.proto
|  |____src
|  |  |____ dtype.rs
|  |____python
|  |  |____src
|  |  |  |____ lib.rs
|  |  |____server
|  |  |  |____pyproject.toml
|  |  |  |____requirements.txt
|  |  |  |____text_embeddings_server
|  |  |  |  |_____utils
|  |  |  |  |  |____ env.py
|  |  |  |  |____ server.py
|  |  |  |  |_____models
|  |  |  |  |  |____ __init__.py
|  |  |  |  |  |____ default_model.py
|  |  |  |  |  |____ model.py
|  |  |  |  |  |____ rerank_model.py
|  |  |  |  |  |____ types.py
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/core/src/infer.rs文件中的Infer类型内将其batching_task数量从2改为1,即注释掉源码中对应位置的加粗代码,仅用一个batching_task性能更佳。

    不做修改也不会影响TEI的运行以及服务化功能,修改后性能会有明显提升。

    use crate::queue::{Entry, Metadata, NextBatch, Queue};
    use crate::tokenization::{EncodingInput, RawEncoding, Tokenization};
    use crate::TextEmbeddingsError;
    use std::sync::Arc;
    use std::time::{Duration, Instant};
    use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType};
    use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore};
    use tracing::instrument;
    
    /// Inference struct
    #[derive(Debug, Clone)]
    pub struct Infer {
        tokenization: Tokenization,
        queue: Queue,
        /// Shared notify
        notify_batching_task: Arc<Notify>,
        /// Inference limit
        limit_concurrent_requests: Arc<Semaphore>,
        backend: Backend,
    }
    
    impl Infer {
        #[allow(clippy::too_many_arguments)]
        pub fn new(
            tokenization: Tokenization,
            queue: Queue,
            max_concurrent_requests: usize,
            backend: Backend,
        ) -> Self {
            let notify_batching_task = Arc::new(Notify::new());
    
            let (embed_sender, embed_receiver) = mpsc::unbounded_channel();
    
            // Create only one batching task to prefetch batches
            tokio::spawn(batching_task(
                queue.clone(),
                notify_batching_task.clone(),
                embed_sender.clone(),
            ));
            /* 此处源码使用两个batching_task对请求进行队列管理,将此处注释后仅用一个batching_task性能更佳
            tokio::spawn(batching_task(
                queue.clone(),
                notify_batching_task.clone(),
                embed_sender,
            ));
            */
    
            // Create embed task to communicate with backend
            tokio::spawn(backend_task(backend.clone(), embed_receiver));
    
            // Inference limit with a semaphore
            let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
    
            Self {
                tokenization,
                queue,
                notify_batching_task,
                limit_concurrent_requests: semaphore,
                backend,
            }
        }
    
        #[instrument(skip(self))]
        pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            add_special_tokens: bool,
        ) -> Result<RawEncoding, TextEmbeddingsError> {
            self.tokenization
                .tokenize(inputs.into(), add_special_tokens)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })
        }
    
        #[instrument(skip(self))]
        pub async fn decode(
            &self,
            ids: Vec<u32>,
            skip_special_tokens: bool,
        ) -> Result<String, TextEmbeddingsError> {
            self.tokenization
                .decode(ids, skip_special_tokens)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })
        }
    
        #[instrument(skip(self))]
        pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> {
            // Limit concurrent requests by acquiring a permit from the semaphore
            self.clone()
                .limit_concurrent_requests
                .try_acquire_owned()
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "overloaded");
                    tracing::error!("{err}");
                    TextEmbeddingsError::from(err)
                })
        }
        #[instrument(skip(self))]
        pub async fn acquire_permit(&self) -> OwnedSemaphorePermit {
            // Limit concurrent requests by acquiring a permit from the semaphore
            self.clone()
                .limit_concurrent_requests
                .acquire_owned()
                .await
                .expect("Semaphore has been closed. This is a bug.")
        }
    
        #[instrument(skip(self, permit))]
        pub async fn embed_all<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            permit: OwnedSemaphorePermit,
        ) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
            let start_time = Instant::now();
    
            if self.is_splade() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "`embed_all` is not available for SPLADE models".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let results = self
                .embed(inputs, truncate, false, &start_time, permit)
                .await?;
    
            let InferResult::AllEmbedding(response) = results else {
                panic!("unexpected enum variant")
            };
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_embed_success");
            metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_embed_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        #[instrument(skip(self, permit))]
        pub async fn embed_sparse<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            permit: OwnedSemaphorePermit,
        ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
            let start_time = Instant::now();
    
            if !self.is_splade() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "Model is not an embedding model with SPLADE pooling".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let results = self
                .embed(inputs, truncate, true, &start_time, permit)
                .await?;
    
            let InferResult::PooledEmbedding(response) = results else {
                panic!("unexpected enum variant")
            };
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_embed_success");
            metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_embed_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        #[instrument(skip(self, permit))]
        pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            normalize: bool,
            permit: OwnedSemaphorePermit,
        ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
            let start_time = Instant::now();
    
            if self.is_splade() && normalize {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "`normalize` is not available for SPLADE models".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let results = self
                .embed(inputs, truncate, true, &start_time, permit)
                .await?;
    
            let InferResult::PooledEmbedding(mut response) = results else {
                panic!("unexpected enum variant")
            };
    
            if normalize {
                // Normalize embedding
                let scale = (1.0
                    / response
                        .results
                        .iter()
                        .map(|v| {
                            let v = *v as f64;
                            v * v
                        })
                        .sum::<f64>()
                        .sqrt()) as f32;
                for v in response.results.iter_mut() {
                    *v *= scale;
                }
            }
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_embed_success");
            metrics::histogram!("te_embed_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_embed_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_embed_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        async fn embed<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            pooling: bool,
            start_time: &Instant,
            _permit: OwnedSemaphorePermit,
        ) -> Result<InferResult, TextEmbeddingsError> {
            if self.is_classifier() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "Model is not an embedding model".to_string();
                tracing::error!("{message}");
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            metrics::increment_counter!("te_embed_count");
    
            // Tokenization
            let encoding = self
                .tokenization
                .encode(inputs.into(), truncate)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })?;
    
            // MPSC channel to communicate with the background batching task
            let (response_tx, response_rx) = oneshot::channel();
    
            // Append the request to the queue
            self.queue.append(Entry {
                metadata: Metadata {
                    response_tx,
                    tokenization: start_time.elapsed(),
                    queue_time: Instant::now(),
                    prompt_tokens: encoding.input_ids.len(),
                    pooling,
                },
                encoding,
            });
    
            self.notify_batching_task.notify_one();
    
            let response = response_rx
                .await
                .expect(
                    "Infer batching task dropped the sender without sending a response. This is a bug.",
                )
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "inference");
                    tracing::error!("{err}");
                    err
                })?;
    
            Ok(response)
        }
    
        #[instrument(skip(self, _permit))]
        pub async fn predict<I: Into<EncodingInput> + std::fmt::Debug>(
            &self,
            inputs: I,
            truncate: bool,
            raw_scores: bool,
            _permit: OwnedSemaphorePermit,
        ) -> Result<ClassificationInferResponse, TextEmbeddingsError> {
            if !self.is_classifier() {
                metrics::increment_counter!("te_request_failure", "err" => "model_type");
                let message = "Model is not a classifier model".to_string();
                return Err(TextEmbeddingsError::Backend(BackendError::Inference(
                    message,
                )));
            }
    
            let start_time = Instant::now();
            metrics::increment_counter!("te_predict_count");
    
            // Tokenization
            let encoding = self
                .tokenization
                .encode(inputs.into(), truncate)
                .await
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "tokenization");
                    tracing::error!("{err}");
                    err
                })?;
    
            // MPSC channel to communicate with the background batching task
            let (response_tx, response_rx) = oneshot::channel();
    
            // Append the request to the queue
            self.queue.append(Entry {
                metadata: Metadata {
                    response_tx,
                    tokenization: start_time.elapsed(),
                    queue_time: Instant::now(),
                    prompt_tokens: encoding.input_ids.len(),
                    pooling: true,
                },
                encoding,
            });
    
            self.notify_batching_task.notify_one();
    
            let response = response_rx
                .await
                .expect(
                    "Infer batching task dropped the sender without sending a response. This is a bug.",
                )
                .map_err(|err| {
                    metrics::increment_counter!("te_request_failure", "err" => "inference");
                    tracing::error!("{err}");
                    err
                })?;
    
            let InferResult::Classification(mut response) = response else {
                panic!("unexpected enum variant")
            };
    
            if !raw_scores {
                // Softmax
                if response.results.len() > 1 {
                    let max = *response
                        .results
                        .iter()
                        .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap())
                        .unwrap();
    
                    let mut den = 0.0;
                    for v in response.results.iter_mut() {
                        *v = (*v - max).exp();
                        den += *v;
                    }
                    for v in response.results.iter_mut() {
                        *v /= den;
                    }
                }
                // Sigmoid
                else {
                    response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp());
                }
            }
    
            // Timings
            let total_time = start_time.elapsed();
    
            // Metrics
            metrics::increment_counter!("te_predict_success");
            metrics::histogram!("te_predict_duration", total_time.as_secs_f64());
            metrics::histogram!(
                "te_predict_tokenization_duration",
                response.metadata.tokenization.as_secs_f64()
            );
            metrics::histogram!(
                "te_predict_queue_duration",
                response.metadata.queue.as_secs_f64()
            );
            metrics::histogram!(
                "te_predict_inference_duration",
                response.metadata.inference.as_secs_f64()
            );
    
            Ok(response)
        }
    
        #[instrument(skip(self))]
        pub fn is_classifier(&self) -> bool {
            matches!(self.backend.model_type, ModelType::Classifier)
        }
    
        #[instrument(skip(self))]
        pub fn is_splade(&self) -> bool {
            matches!(
                self.backend.model_type,
                ModelType::Embedding(text_embeddings_backend::Pool::Splade)
            )
        }
    
        #[instrument(skip(self))]
        pub async fn health(&self) -> bool {
            self.backend.health().await.is_ok()
        }
    
        #[instrument(skip(self))]
        pub fn health_watcher(&self) -> watch::Receiver<bool> {
            self.backend.health_watcher()
        }
    }
    
    #[instrument(skip_all)]
    async fn batching_task(
        queue: Queue,
        notify: Arc<Notify>,
        embed_sender: mpsc::UnboundedSender<(NextBatch, oneshot::Sender<()>)>,
    ) {
        loop {
            notify.notified().await;
    
            while let Some(next_batch) = queue.next_batch().await {
                let (callback_sender, callback_receiver) = oneshot::channel();
                embed_sender
                    .send((next_batch, callback_sender))
                    .expect("embed receiver was dropped. This is a bug.");
                let _ = callback_receiver.await;
            }
        }
    }
    
    #[instrument(skip_all)]
    async fn backend_task(
        backend: Backend,
        mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>,
    ) {
        while let Some((batch, _callback)) = embed_receiver.recv().await {
            match &backend.model_type {
                ModelType::Classifier => {
                    let results = backend.predict(batch.1).await;
    
                    // Handle sending responses in another thread to avoid starving the backend
                    std::thread::spawn(move || match results {
                        Ok((mut predictions, inference_duration)) => {
                            batch.0.into_iter().enumerate().for_each(|(i, m)| {
                                let infer_metadata = InferMetadata {
                                    prompt_tokens: m.prompt_tokens,
                                    tokenization: m.tokenization,
                                    queue: m.queue_time.elapsed() - inference_duration,
                                    inference: inference_duration,
                                };
    
                                let _ = m.response_tx.send(Ok(InferResult::Classification(
                                    ClassificationInferResponse {
                                        results: predictions.remove(&i).expect(
                                            "prediction not found in results. This is a backend bug.",
                                        ),
                                        metadata: infer_metadata,
                                    },
                                )));
                            });
                        }
                        Err(err) => {
                            batch.0.into_iter().for_each(|m| {
                                let _ = m.response_tx.send(Err(err.clone()));
                            });
                        }
                    });
                }
                ModelType::Embedding(_) => {
                    let results = backend.embed(batch.1).await;
    
                    // Handle sending responses in another thread to avoid starving the backend
                    std::thread::spawn(move || match results {
                        Ok((mut embeddings, inference_duration)) => {
                            batch.0.into_iter().enumerate().for_each(|(i, m)| {
                                let metadata = InferMetadata {
                                    prompt_tokens: m.prompt_tokens,
                                    tokenization: m.tokenization,
                                    queue: m.queue_time.elapsed() - inference_duration,
                                    inference: inference_duration,
                                };
    
                                let results = match embeddings
                                    .remove(&i)
                                    .expect("embedding not found in results. This is a backend bug.")
                                {
                                    Embedding::Pooled(e) => {
                                        InferResult::PooledEmbedding(PooledEmbeddingsInferResponse {
                                            results: e,
                                            metadata,
                                        })
                                    }
                                    Embedding::All(e) => {
                                        InferResult::AllEmbedding(AllEmbeddingsInferResponse {
                                            results: e,
                                            metadata,
                                        })
                                    }
                                };
    
                                let _ = m.response_tx.send(Ok(results));
                            })
                        }
                        Err(err) => {
                            batch.0.into_iter().for_each(|m| {
                                let _ = m.response_tx.send(Err(err.clone()));
                            });
                        }
                    });
                }
            };
        }
    }
    
    #[derive(Debug)]
    pub struct InferMetadata {
        pub prompt_tokens: usize,
        pub tokenization: Duration,
        pub queue: Duration,
        pub inference: Duration,
    }
    
    #[derive(Debug)]
    pub(crate) enum InferResult {
        Classification(ClassificationInferResponse),
        PooledEmbedding(PooledEmbeddingsInferResponse),
        AllEmbedding(AllEmbeddingsInferResponse),
    }
    
    #[derive(Debug)]
    pub struct ClassificationInferResponse {
        pub results: Vec<f32>,
        pub metadata: InferMetadata,
    }
    
    #[derive(Debug)]
    pub struct PooledEmbeddingsInferResponse {
        pub results: Vec<f32>,
        pub metadata: InferMetadata,
    }
    
    #[derive(Debug)]
    pub struct AllEmbeddingsInferResponse {
        pub results: Vec<Vec<f32>>,
        pub metadata: InferMetadata,
    }
    
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/grpc-client/src/client.rs文件中为后端的grpc_client添加embed_all和predict异步接口(以下代码加粗部分)。
    /// Single shard Client
    use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient;
    use crate::pb::embedding::v1::*;
    use crate::Result;
    use grpc_metadata::InjectTelemetryContext;
    use tonic::transport::{Channel, Uri};
    use tracing::instrument;
    
    /// Text Generation Inference gRPC client
    #[derive(Debug, Clone)]
    pub struct Client {
        stub: EmbeddingServiceClient<Channel>,
    }
    
    impl Client {
        /// Returns a client connected to the given url
        pub async fn connect(uri: Uri) -> Result<Self> {
            let channel = Channel::builder(uri).connect().await?;
    
            Ok(Self {
                stub: EmbeddingServiceClient::new(channel),
            })
        }
    
        /// Returns a client connected to the given unix socket
        pub async fn connect_uds(path: String) -> Result<Self> {
            let channel = Channel::from_shared("http://[::]:50051".to_string())
                .unwrap()
                .connect_with_connector(tower::service_fn(move |_: Uri| {
                    tokio::net::UnixStream::connect(path.clone())
                }))
                .await?;
    
            Ok(Self {
                stub: EmbeddingServiceClient::new(channel),
            })
        }
    
        /// Get backend health
        #[instrument(skip(self))]
        pub async fn health(&mut self) -> Result<HealthResponse> {
            let request = tonic::Request::new(HealthRequest {}).inject_context();
            let response = self.stub.health(request).await?.into_inner();
            Ok(response)
        }
    
        #[instrument(skip_all)]
        pub async fn embed(
            &mut self,
            input_ids: Vec<u32>,
            token_type_ids: Vec<u32>,
            position_ids: Vec<u32>,
            cu_seq_lengths: Vec<u32>,
            max_length: u32,
        ) -> Result<Vec<Embedding>> {
            let request = tonic::Request::new(EmbedRequest {
                input_ids,
                token_type_ids,
                position_ids,
                max_length,
                cu_seq_lengths,
            })
            .inject_context();
            let response = self.stub.embed(request).await?.into_inner();
            Ok(response.embeddings)
        }
    
        #[instrument(skip_all)]
        pub async fn embed_all(
            &mut self,
            input_ids: Vec<u32>,
            token_type_ids: Vec<u32>,
            position_ids: Vec<u32>,
            cu_seq_lengths: Vec<u32>,
            max_length: u32,
        ) -> Result<Vec<TokenEmbedding>> {
            let request = tonic::Request::new(EmbedRequest {
                input_ids,
                token_type_ids,
                position_ids,
                max_length,
                cu_seq_lengths,
            })
            .inject_context();
            let response = self.stub.embed_all(request).await?.into_inner();
            Ok(response.allembeddings)
        }
    
        #[instrument(skip_all)]
        pub async fn predict(
            &mut self,
            input_ids: Vec<u32>,
            token_type_ids: Vec<u32>,
            position_ids: Vec<u32>,
            cu_seq_lengths: Vec<u32>,
            max_length: u32,
        ) -> Result<Vec<Prediction>> {
            let request = tonic::Request::new(PredictRequest {
                input_ids,
                token_type_ids,
                position_ids,
                max_length,
                cu_seq_lengths,
            })
            .inject_context();
            let response = self.stub.predict(request).await?.into_inner();
            Ok(response.predictions)
        }
    }
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/proto/embed.proto文件中为EmbeddingService添加Embed_all和Predict服务(以下代码加粗部分),并设定相应的数据类、请求与响应格式。
    syntax = "proto3";
    
    package embedding.v1;
    
    service EmbeddingService {
        /// Decode token for a list of prefilled batches
        rpc Embed (EmbedRequest) returns (EmbedResponse);
        rpc Embed_all (EmbedRequest) returns (RawEmbedResponse);
        rpc Predict (PredictRequest) returns (PredictResponse);
        /// Health check
        rpc Health (HealthRequest) returns (HealthResponse);
    }
    
    message HealthRequest {}
    message HealthResponse {}
    
    message PredictRequest {
        repeated uint32 input_ids = 1;
        repeated uint32 token_type_ids = 2;
        repeated uint32 position_ids = 3;
        repeated uint32 cu_seq_lengths = 4;
        /// Length of the longest request
        uint32 max_length = 5;
    }
    
    message Prediction {
        repeated float values = 1;
    }
    
    message PredictResponse {
        repeated Prediction predictions = 1;
    }
    
    message EmbedRequest {
        repeated uint32 input_ids = 1;
        repeated uint32 token_type_ids = 2;
        repeated uint32 position_ids = 3;
        repeated uint32 cu_seq_lengths = 4;
        /// Length of the longest request
        uint32 max_length = 5;
    }
    
    
    message Embedding {
        repeated float values = 1;
    }
    
    message EmbedResponse {
        repeated Embedding embeddings = 1;
    }
    
    message TokenEmbedding {
        repeated Embedding embeddings = 1;
    }
    
    message RawEmbedResponse {
        repeated TokenEmbedding allembeddings = 1;
    }
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/src/dtype.rs文件中为DType枚举增加Bfloat16变体,以支持更广泛的模型数据类型。
    use std::fmt;
    
    #[cfg(feature = "clap")]
    use clap::ValueEnum;
    
    #[derive(Debug, PartialEq)]
    #[cfg_attr(feature = "clap", derive(Clone, ValueEnum))]
    pub enum DType {
        // Float16 is not available on accelerate
        #[cfg(any(
            feature = "python",
            all(feature = "candle", not(feature = "accelerate"))
        ))]
        Float16,
        // Bfloat16 is available only when the "python" feature is enabled
        #[cfg(feature = "python")]
        Bfloat16,
        // Float32 is not available on candle cuda
        #[cfg(any(feature = "python", feature = "candle"))]
        Float32,
        // #[cfg(feature = "candle")]
        // Q6K,
    }
    
    impl fmt::Display for DType {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
            match self {
                // Float16 is not available on accelerate
                #[cfg(any(
                    feature = "python",
                    all(feature = "candle", not(feature = "accelerate"))
                ))]
                DType::Float16 => write!(f, "float16"),
    
                // Bfloat16 is only available with the "python" feature
                #[cfg(feature = "python")]
                DType::Bfloat16 => write!(f, "bfloat16"),
    
                // Float32 is not available on candle cuda
                #[cfg(any(feature = "python", feature = "candle"))]
                DType::Float32 => write!(f, "float32"),
    
                // #[cfg(feature = "candle")]
                // DType::Q6K => write!(f, "q6k"),
            }
        }
    }
    
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/src/lib.rs文件中将模型类型校验处的分类模型不支持移除,改为正常拉起模型(以下代码加粗部分);为后端模型增加predict接口,并修改embed接口,添加判断分支,根据batch的成员内容内容决定返回所有token的hidden_states或是池化后的全局hidden_states。
    mod logging;
    mod management;
    
    use backend_grpc_client::Client;
    use nohash_hasher::BuildNoHashHasher;
    use std::collections::HashMap;
    use text_embeddings_backend_core::{
        Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
    };
    use tokio::runtime::Runtime;
    
    pub struct PythonBackend {
        _backend_process: management::BackendProcess,
        tokio_runtime: Runtime,
        backend_client: Client,
    }
    
    impl PythonBackend {
        pub fn new(
            model_path: String,
            dtype: String,
            model_type: ModelType,
            uds_path: String,
            otlp_endpoint: Option<String>,
        ) -> Result<Self, BackendError> {
            match model_type {
                ModelType::Classifier => {
                    let pool = Pool::Cls;
                    pool
                }
                ModelType::Embedding(pool) => {
                    if pool != Pool::Cls {
                        return Err(BackendError::Start(format!("{pool:?} is not supported")));
                    }
                    pool
                }
            };
    
            let backend_process =
                management::BackendProcess::new(model_path, dtype, &uds_path, otlp_endpoint)?;
            let tokio_runtime = tokio::runtime::Builder::new_current_thread()
                .enable_all()
                .build()
                .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?;
    
            let backend_client = tokio_runtime
                .block_on(Client::connect_uds(uds_path))
                .map_err(|err| {
                    BackendError::Start(format!("Could not connect to backend process: {err}"))
                })?;
    
            Ok(Self {
                _backend_process: backend_process,
                tokio_runtime,
                backend_client,
            })
        }
    }
    
    impl Backend for PythonBackend {
        fn health(&self) -> Result<(), BackendError> {
            if self
                .tokio_runtime
                .block_on(self.backend_client.clone().health())
                .is_err()
            {
                return Err(BackendError::Unhealthy);
            }
            Ok(())
        }
    
        fn is_padded(&self) -> bool {
            false
        }
    
        fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> {
            /*
            if !batch.raw_indices.is_empty() {
                return Err(BackendError::Inference(
                    "raw embeddings are not supported for the Python backend.".to_string(),
                ));
            }
            */
    
            let batch_size = batch.len();
    
            let mut embeddings =
                HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
    
            if !batch.pooled_indices.is_empty() {
                let results = self
                    .tokio_runtime
                    .block_on(self.backend_client.clone().embed(
                        batch.input_ids,
                        batch.token_type_ids,
                        batch.position_ids,
                        batch.cumulative_seq_lengths,
                        batch.max_length,
                    ))
                    .map_err(|err| BackendError::Inference(err.to_string()))?;
    
                let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
                for (i, e) in pooled_embeddings.into_iter().enumerate() {
                    embeddings.insert(i, Embedding::Pooled(e));
                }
            }
            else if !batch.raw_indices.is_empty() {
                let results = self
                    .tokio_runtime
                    .block_on(self.backend_client.clone().embed_all(
                        batch.input_ids,
                        batch.token_type_ids,
                        batch.position_ids,
                        batch.cumulative_seq_lengths,
                        batch.max_length,
                    ))
                    .map_err(|err| BackendError::Inference(err.to_string()))?;
    
                let mut raw_embeddings = Vec::new();
                for token_embedding in results {
                    let mut two_dim_list = Vec::new();
                    for embeddings in token_embedding.embeddings {
                        let values = embeddings.values.clone();
                        two_dim_list.push(values);
                    }
                    raw_embeddings.push(two_dim_list);
                }
                for (i, e) in raw_embeddings.into_iter().enumerate() {
                    embeddings.insert(i, Embedding::All(e));
                }
            }
            Ok(embeddings)
        }
    
        fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
            let batch_size = batch.len();
    
            let results = self
                .tokio_runtime
                .block_on(self.backend_client.clone().predict(
                    batch.input_ids,
                    batch.token_type_ids,
                    batch.position_ids,
                    batch.cumulative_seq_lengths,
                    batch.max_length,
                ))
                .map_err(|err| BackendError::Inference(err.to_string()))?;
    
            let predictions_result: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
    
            let mut predictions =
                HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
            for (i, r) in predictions_result.into_iter().enumerate() {
                predictions.insert(i, r);
            }
    
            Ok(predictions)
        }
    }
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/requirements.txt文件中修改“huggingface-hub”“safetensors”“torch”“poetry”的版本信息(以下代码加粗部分),用于适配昇腾推理环境。
    backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
    certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
    charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
    click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
    colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
    deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
    filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
    fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
    googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
    grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
    grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
    grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
    grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
    huggingface-hub==0.23.2 ; python_version >= "3.9" and python_version < "3.13"
    idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
    jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
    loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
    markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
    mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
    networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
    packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
    protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
    pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
    requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
    safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
    setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
    sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
    torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
    tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
    typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
    typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
    urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
    win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
    wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
    poetry==1.8.3 ; python_version >= "3.9" and python_version < "3.13"
  • 基于TEI v1.2.3版本,请根据上一步的变更内容,同步修改text-embeddings-inference/backends/python/server/pyproject.toml文件中“safetensors”“torch”版本信息并使用国内镜像源(以下代码加粗部分),用于修改"text-embeddings-server"项目依赖关系并提高Python三方库的下载速度。
    [tool.poetry]
    name = "text-embeddings-server"
    version = "0.1.0"
    description = "Text Embeddings Python gRPC Server"
    authors = ["Olivier Dehaene <olivier@huggingface.co>"]
    [tool.poetry.scripts]
    python-text-embeddings-server = 'text_embeddings_server.cli:app'
    [tool.poetry.dependencies]
    python = ">=3.9,<3.13"
    protobuf = "^4.21.7"
    grpcio = "^1.51.1"
    grpcio-status = "^1.51.1"
    grpcio-reflection = "^1.51.1"
    grpc-interceptor = "^0.15.0"
    typer = "^0.6.1"
    safetensors = "^0.4.3"
    loguru = "^0.6.0"
    opentelemetry-api = "^1.15.0"
    opentelemetry-exporter-otlp = "^1.15.0"
    opentelemetry-instrumentation-grpc = "^0.36b0"
    torch = { version = "^2.1.0" }
    [tool.poetry.extras]
    [tool.poetry.group.dev.dependencies]
    grpcio-tools = "^1.51.1"
    pytest = "^7.3.0"
    [[tool.poetry.source]]
    name = "mirrors"
    url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
    priority = "default"
    [tool.pytest.ini_options]
    markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
    [build-system]
    requires = ["poetry-core>=1.0.0"]
    build-backend = "poetry.core.masonry.api"

    TEI的Python侧参考适配代码调整了包导入顺序,并增加了Google风格的类与方法注释。为凸显重点,此部分修改并未在参考代码内容中加粗显示。

  • 为使TEI框架能够兼容不同后端(MindIE Torch与ATB)且灵活选择运行的NPU设备id,请在text-embeddings-inference/backends/python/server/text_embeddings_server/utils文件夹下新增env.py文件,通过读取环境变量确认TEI拉起模型的backend(默认使用mindietorch后端进行组图优化)与device_id(默认使用卡0加载模型),并设置必要的校验。
    # Copyright Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
    
    import os
    from dataclasses import dataclass
    from loguru import logger
    
    
    @dataclass
    class EnvVar:
        """
        Reads environment variables to configure the embedding or reranker model backend and device ID for use in
        other components of TEI-MindIE.
    
        Attributes:
            backend: The backend model type to be used for inference (e.g., 'atb' or 'mindietorch').
            device_id: The device id to load model on (between '0' and '7').
    
        Raises:
            ValueError: If the backend does not belong to ['atb', 'mindietorch'] or the device_id is invalid.
        """
    
        backend: str = os.getenv("TEI_NPU_BACKEND", 'mindietorch')
        device_id: str = os.getenv("TEI_NPU_DEVICE", '0') 
    
        def __post_init__(self):
            logger.info(self.dict())
            if self.backend not in ['atb', 'mindietorch']:
                raise ValueError("Your model backend is invalid.")
            if not self.device_id.isdigit():
                raise ValueError("Your device_id is invalid.")
            if int(self.device_id) < 0:
                raise ValueError("Your device_id is invalid.")
            if int(self.device_id) >= 8:
                logger.warning(f'Your device_id is {self.device_id}.')
    
        def dict(self):
            return self.__dict__
    
    ENV = EnvVar()
  • 基于TEI v1.2.3版本,根据embed.proto中定义的格式,请在text-embeddings-inference/backends/python/server/text_embeddings_server/server.py文件中的EmbeddingService类中修改健康检查接口并添加Embed_all以及Predict的异步接口,将请求内容转换为对应的batch(PaddedBatch)实例,将batch传入到模型的对应预测接口,将其输出值转换为对应服务的响应类实例(以下代码加粗部分)。
    import asyncio
    from pathlib import Path
    from typing import Optional
    
    from grpc import aio
    from loguru import logger
    from grpc_reflection.v1alpha import reflection
    import torch
    
    from text_embeddings_server.models import Model, get_model
    from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2
    from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor
    from text_embeddings_server.utils.interceptor import ExceptionInterceptor
    from .utils.env import ENV
    if ENV.backend == 'atb':
        import torch_npu
    else:
        import mindietorch
    
    
    class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer):
        """Handles gRPC requests for the text embeddings service.
    
        This class implements the EmbeddingServiceServicer interface, providing methods for embedding, prediction 
        and health check functionalities.
    
        Attributes:
            model (Model): The model used for generating embeddings and similarity scores.
            _inference_mode_raii_guard (torch._C._InferenceMode): A context manager to enforce inference mode.
        """
    
        def __init__(self, model: Model):
            """Initializes the EmbeddingService with a model.
    
            Args:
                model (Model): The model to use for embedding or reranker service.
            """
    
            self.model = model
            # Force inference mode for the lifetime of EmbeddingService
            self._inference_mode_raii_guard = torch._C._InferenceMode(True)
    
        async def Health(self, request, context):
            """Handles the health check request.
    
            Verifies if the model is able to run on the provided device.
    
            Args:
                request: The gRPC health check request.
                context: The gRPC context for the request.
    
            Returns:
                embed_pb2.HealthResponse: A response indicating the service health status.
            """
    
            if self.model.device.type == "npu":
                health_check_tensor = torch.zeros((2, 2))
                health_check_tensor.to(self.model.device)
            return embed_pb2.HealthResponse()
    
        async def Embed(self, request, context):
            """Handles the embed request.
    
            Processes the incoming request, generates embeddings using the provided model, and returns the hidden 
            state of the first token for each sequence in the batch.
    
            Args:
                request: The gRPC request containing a batch of data for embedding.
                context: The gRPC context for the request.
    
            Returns:
                embed_pb2.EmbedResponse: A response containing generated the sliced embeddings for the request.
            """
    
            batch = self.model.batch_type.from_pb(request, self.model.device)
    
            embeddings = self.model.embed(batch)
    
            return embed_pb2.EmbedResponse(embeddings=embeddings)
    
        async def Embed_all(self, request, context):
            """Handles the embed_all request.
    
            Processes the incoming request, generates embeddings using the provided model, and returns hidden states 
            of all tokens for each sequence in the batch.
    
            Args:
                request: The gRPC request containing a batch of data for embedding.
                context: The gRPC context for the request.
    
            Returns:
                embed_pb2.RawEmbedResponse: A response containing all the embeddings for the request.
            """
    
            batch = self.model.batch_type.from_pb(request, self.model.device)
    
            embeddings = self.model.embed_all(batch)
    
            return embed_pb2.RawEmbedResponse(allembeddings=embeddings)
    
        async def Predict(self, request, context):
            """Handles the predict request.
    
            Processes the incoming request, generates predictions using the provided model, and returns final scores.
    
            Args:
                request: The gRPC request containing a batch of data for scores prediction.
                context: The gRPC context for the request.
    
            Returns:
                embed_pb2.PredictResponse: A response containing the model's predictions for the request.
            """
    
            batch = self.model.batch_type.from_pb(request, self.model.device)
    
            predictions = self.model.predict(batch)
    
            return embed_pb2.PredictResponse(predictions=predictions)
    
    
    def serve(
        model_path: Path,
        dtype: Optional[str],
        uds_path: Path,
    ):
        """Starts the gRPC server and serves the text embedding service.
    
        This function initializes the model and starts a server that listens for incoming requests to generate embeddings, 
        predictions or check health.
    
        Args:
            model_path: Path to the model directory.
            dtype: Data type for model initialization.
            uds_path: Path to the Unix Domain Socket for the server to listen on.
        """
    
        async def serve_inner(
            model_path: Path,
            dtype: Optional[str] = None,
        ):
            """The inner asynchronous function to run the gRPC server.
    
            Args:
                model_path: Path to the model directory.
                dtype: Data type for model initialization.
    
            Raises:
                Exception: If there is an error during model initialization.
            """
    
            unix_socket = f"unix://{uds_path}"
    
            try:
                model = get_model(model_path, dtype)
            except Exception:
                logger.exception("Error when initializing model")
                raise
    
            server = aio.server(
                interceptors=[
                    ExceptionInterceptor(),
                    UDSOpenTelemetryAioServerInterceptor(),
                ]
            )
            embed_pb2_grpc.add_EmbeddingServiceServicer_to_server(
                EmbeddingService(model), server
            )
            SERVICE_NAMES = (
                embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name,
                reflection.SERVICE_NAME,
            )
            reflection.enable_server_reflection(SERVICE_NAMES, server)
            server.add_insecure_port(unix_socket)
    
            await server.start()
    
            logger.info(f"Server started at {unix_socket}")
    
            try:
                await server.wait_for_termination()
            except KeyboardInterrupt:
                logger.info("Signal received. Shutting down")
                await server.stop(0)
    
        asyncio.run(serve_inner(model_path, dtype))
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/__init__.py文件中的get_model方法内添加逻辑判断分支(以下代码加粗部分);当硬件平台为npu的情况下,根据环境变量TEI_NPU_BACKEND设置模型后端为MindIE Torch或ATB,根据环境变量TEI_NPU_DEVICE设置torch.device;根据模型的config.json中architectures列表的内容,判定调用default_model(文本嵌入模型类)还是rerank_model (重排序模型类)。
    from pathlib import Path
    from typing import Optional
    
    from loguru import logger
    from transformers import AutoConfig
    import torch
    
    from text_embeddings_server.models.model import Model
    from text_embeddings_server.models.default_model import DefaultModel
    from text_embeddings_server.models.rerank_model import RerankModel
    from ..utils.env import ENV
    if ENV.backend == 'atb':
        import torch_npu
    else:
        import mindietorch
    
    __all__ = ["Model"]
    
    # Disable gradients
    torch.set_grad_enabled(False)
    
    FLASH_ATTENTION = True
    try:
        from text_embeddings_server.models.flash_bert import FlashBert
    except ImportError as e:
        logger.warning(f"Could not import Flash Attention enabled models: {e}")
        FLASH_ATTENTION = False
    
    if FLASH_ATTENTION:
        __all__.append(FlashBert)
    
    
    def get_model(model_path: Path, dtype: Optional[str]):
        """Loads and returns the appropriate model class based on given path, data type and backend type.
    
        Args:
            model_path: The path to the pre-trained model.
            dtype: The data type of the model. If not specified or invalid, a RuntimeError will be raised.
    
        Returns:
            Model: An instance of a subclass of the `Model` class, which could either be:
                - RerankModel (if the architectures[0] in config.json ends with 'Classification')
                - FlashBert (if the model is based on BERT and CUDA is available)
                - DefaultModel (if all conditions above are not met)
    
        Raises:
            RuntimeError: If an unknown data type is provided for `dtype`.
            ValueError: If the device is CPU and the dtype is not `float32`.
        """
        if dtype == "float32":
            dtype = torch.float32
        elif dtype == "float16":
            dtype = torch.float16
        elif dtype == "bfloat16":
            dtype = torch.bfloat16
        else:
            raise RuntimeError(f"Unknown dtype {dtype}")
    
        if ENV.device_id:
            if ENV.backend == 'atb':
                torch.npu.set_compile_mode(jit_compile=False)
                option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceProd"}
                torch.npu.set_option(option)
                device = torch.device(f"npu:{int(ENV.device_id)}")
                torch.npu.set_device(torch.device(f"npu:{int(ENV.device_id)}"))
            else:
                mindietorch.set_device(int(ENV.device_id))
                device = torch.device(f"npu:{int(ENV.device_id)}")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            if dtype != torch.float32:
                raise ValueError("CPU device only supports float32 dtype")
            device = torch.device("cpu")
    
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
        if config.architectures[0].endswith("Classification"):
            return RerankModel(model_path, device, dtype)
        else:
            if (
                config.model_type == "bert"
                and device.type == "cuda"
                and config.position_embedding_type == "absolute"
                and dtype in [torch.float16, torch.bfloat16]
                and FLASH_ATTENTION
            ):
                return FlashBert(model_path, device, dtype)
            else:
                return DefaultModel(model_path, device, dtype)
        raise NotImplementedError
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/default_model.py文件中根据昇腾模型后端修改DefaultModel类的初始化方法和embed接口,并增加embed_all接口(以下代码加粗部分),返回所有token的hidden_states。
    from pathlib import Path
    from typing import Optional, Type, List
    
    from opentelemetry import trace
    from transformers import AutoModel, AutoConfig
    from loguru import logger
    import torch
    
    from text_embeddings_server.models import Model
    from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding
    from ..utils.env import ENV
    
    if ENV.backend == 'atb':
        import torch_npu
    else:
        import mindietorch
    
    tracer = trace.get_tracer(__name__)
    
    
    class DefaultModel(Model):
        """DefaultModel is a wrapper around a pre-trained embedding model that generates embeddings for input texts.
    
        It extends the Model class and adds functionality to provide embedding generation according to the batch of 
        input data.
    
        Attributes:
            config: The model configuration about the given pre-trained model.
            hidden_size: The size of the model's hidden states.
            is_causal: A flag indicating whether the embedding model is a causal model whose `architecture` ends with 
            "CausalLM". If True, embed function will return the embedding of the last token; otherwise, it returns the 
            embedding of the first token.
        """
    
        def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
            """ 
            Initializes the DefaultModel by loading the pre-trained model based on the given model path and backend 
            type, while also setting necessary member variables for embedding generation.
    
            Args: 
                model_path: Path to the pre-trained model.
                device: The device to load the model on (e.g., NPU).
                dtype: The data type of the model's parameters (e.g., torch.float16).
            """
    
            self.config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True)
            self.hidden_size = self.config.hidden_size
            self.is_causal = self.config.architectures[0].endswith('CausalLM')
    
            if ENV.backend == 'atb':
                model = AutoModel.from_pretrained(
                    model_path,
                    torch_dtype=dtype,
                    trust_remote_code=True
                ).to(device).eval()
            else:
                mindietorch.set_device(device.index)
                  # 使用MindIE Torch后端时,请确保model_path(模型权重路径)下仅存在一个编译优化后的pt文件,否则请根据实际情况修改正则表达式和pt文件命名
                model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device)
            super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device)
    
        @property
        def batch_type(self) -> Type[PaddedBatch]:
            """Returns the class type of the batch that the embedding model expects.
    
            Returns:
                Type[PaddedBatch]: The type of the batch class expected by the embedding model. All tensors in this 
                batch are padded according to the longgest input sequece.
            """
    
            return PaddedBatch
    
        @tracer.start_as_current_span("embed")
        def embed(self, batch: PaddedBatch) -> List[Embedding]:
            """
            Generates embeddings for a given batch of input data and returns the hidden state of the first or last 
            token for each input sequence in the batch.
    
            Args:
                batch: The input batch containing tokenized data (e.g., input_ids, attention_mask).
    
            Returns:
                List[Embedding]: A list of embeddings, where each embedding represents the hidden state of the first 
                or last token for each input sequence in the batch.
            """
    
            kwargs = {"input_ids": batch.input_ids.to(self.device), "attention_mask": batch.attention_mask.to(self.device)}
            if ENV.backend == 'atb':
                output = self.model(**kwargs, return_dict=True)
                embedding = output[0]
            else:
                output = self.model(kwargs["input_ids"], kwargs["attention_mask"])
                if isinstance(output, dict):
                    embedding = output['last_hidden_state'].to('cpu')
                else:
                    embedding = output[0].to('cpu')
    
            if self.is_causal:
                # For causal models, get the embedding of the last token
                embedding = embedding[:, -1]
            else:
                # For non-causal models, get the embedding of the first token
                embedding = embedding[:, 0]
    
            cpu_results = embedding.contiguous().view(-1).tolist()
            return [
                Embedding(
                    values=cpu_results[i * self.hidden_size: (i + 1) * self.hidden_size]
                )
                for i in range(len(batch))
            ]
    
        @tracer.start_as_current_span("embed_all")
        def embed_all(self, batch: PaddedBatch) -> List[TokenEmbedding]:
            """
            Generates embeddings for a given batch of input data and returns hidden states of all tokens for each 
            input sequence in the batch.
    
            Args:
                batch: The input batch containing tokenized data (e.g., input_ids, attention_mask).
    
            Returns:
                List[TokenEmbedding]: A list of token embeddings, where each token embedding represents hidden states 
                of all tokens for each input sequence in the batch.
            """
    
            kwargs = {"input_ids": batch.input_ids.to(self.device), "attention_mask": batch.attention_mask.to(self.device)}
            if ENV.backend == 'atb':
                output = self.model(**kwargs, return_dict=True)
                embedding = output[0]
            else:
                output = self.model(kwargs["input_ids"], kwargs["attention_mask"])
                if isinstance(output, dict):
                    embedding = output['last_hidden_state'].to('cpu')
                else:
                    embedding = output[0].to('cpu')
    
            cpu_results = embedding.contiguous().view(-1).tolist()
            embedding_result = []
            for i in range(len(batch)):
                base_index = i * batch.max_length * self.hidden_size
                tmp_embedding = [
                    Embedding(values=cpu_results[
                                     base_index + j * self.hidden_size: base_index + (j + 1) * self.hidden_size
                                     ])
                    for j in range(batch.input_ids.size(1))
                ]
                token_embeddings = TokenEmbedding(embeddings=tmp_embedding)
                embedding_result.append(token_embeddings)
    
            return embedding_result
    
        @tracer.start_as_current_span("predict")
        def predict(self, batch: PaddedBatch) -> Optional[List[Prediction]]:
            """Logs an error indicating that the embedding model does not support the predict function."""
            logger.error("embedding model does not support predict function")
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/model.py文件中的模型父类Model中添加embed_all和predict抽象接口定义(以下代码加粗部分),返回为NotImplementedError;此外,对基于ATB后端拉起的模型,根据服务器SoC信息对已加载的模型权重进行数据格式转换。
    from abc import ABC, abstractmethod
    from typing import List, TypeVar, Type
    
    import torch
    from loguru import logger
    
    from text_embeddings_server.models.types import Batch, Embedding, Prediction, TokenEmbedding
    from ..utils.env import ENV
    if ENV.backend == 'atb':
        import torch_npu
        from atb_llm.utils.initial import NPUSocInfo
    
    B = TypeVar("B", bound=Batch)
    
    
    class Model(ABC):
        """Abstract base class for all model types in TEI-MindIE.
    
        This class defines the common interface and properties for all models used in TEI-MindIE, including 
        methods for embedding generation and similarity score prediction.
    
        Attributes:
            model: The model object used for embedding generation or similarity score prediction.
            device: The device to load the model on (e.g., NPU).
            dtype: The data type of the model's parameters (e.g., torch.float16).
        """
    
        def __init__(
            self,
            model,
            dtype: torch.dtype,
            device: torch.device,
        ):
            """
            Initializes the Model class based on the given model object, device and dtype. For ATB backend,
            this function also execute npu_cormat_cast on the model's named_modules according to NPU's SoC 
            information.  
    
            Args:
                model: The model object used for embedding generation or similarity score prediction.
                device: The device to load the model on (e.g., NPU).
                dtype: The data type of the model's parameters (e.g., torch.float16).
            """
    
            self.model = model
            self.dtype = dtype
            self.device = device
            if ENV.backend == 'atb':
                def trans_data(model, soc_info):
                    if not soc_info.need_nz:
                        for _, module in model.named_modules():
                            if isinstance(module, torch.nn.Linear):
                                module.weight.data = torch_npu.npu_format_cast(module.weight.data, 2)
                        logger.info(f"soc info: {soc_info.soc_version}, support ND")
                    else:
                        for name, module in model.named_modules():
                            if isinstance(module, torch.nn.Linear):
                                if name == 'lm_head':
                                    # eliminate TransData op before lm_head calculation
                                    module.weight.data = torch.nn.parameter.Parameter(module.weight.data)
                                module.weight.data = torch_npu.npu_format_cast(module.weight.data, 29)
                        logger.info(f"soc info: {soc_info.soc_version}, support NZ")
    
                    for _, module in model.named_modules():
                        if isinstance(module, torch.nn.Embedding):
                            module.weight.data = torch_npu.npu_format_cast(module.weight.data, 2)
                self.soc_info = NPUSocInfo()
                trans_data(self.model, self.soc_info)
    
        @property
        @abstractmethod
        def batch_type(self) -> Type[B]:
            raise NotImplementedError
    
        @abstractmethod
        def embed(self, batch: B) -> List[Embedding]:
            raise NotImplementedError
    
        @abstractmethod
        def embed_all(self, batch: B) -> List[TokenEmbedding]:
            raise NotImplementedError
    
        @abstractmethod
        def predict(self, batch: B) -> List[Prediction]:
            raise NotImplementedError
  • 基于TEI v1.2.3版本,请自行创建text-embeddings-inference/backends/python/server/text_embeddings_server/models/rerank_model.py文件。新增该文件实现重排序模型的类,提供rerank接口,返回query和text的匹配分数。
    from pathlib import Path
    from typing import Optional, Type, List
    
    from opentelemetry import trace
    from transformers import AutoModelForSequenceClassification
    from loguru import logger
    import torch
    
    from text_embeddings_server.models import Model
    from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding
    from ..utils.env import ENV
    if ENV.backend == 'atb':
        import torch_npu
    else:
        import mindietorch
    
    tracer = trace.get_tracer(__name__)
    
    
    class RerankModel(Model):
        """
        RerankModel is a wrapper around a pre-trained reranker model that predicts similarity scores between given 
        query and texts.
    
        It extends the Model class and adds functionality to provide similrity score prediction according to the batch 
        of input data.
        """
    
        def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
            """Initializes the RerankModel by loading the pre-trained model based on the given model path and backend.
    
            Args: 
                model_path: Path to the pre-trained model.
                device: The device to load the model on (e.g., NPU).
                dtype: The data type of the model's parameters (e.g., torch.float16).
            """
    
            if ENV.backend == 'atb':
                model = AutoModelForSequenceClassification.from_pretrained(
                    model_path,
                    torch_dtype=dtype,
                    trust_remote_code=True
                ).to(device).eval()
            else:
                mindietorch.set_device(device.index)
                # 使用MindIE Torch后端时,请确保model_path(模型权重路径)下仅存在一个编译优化后的pt文件,否则请根据实际情况修改正则表达式和pt文件命名
                model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval()  
            super(RerankModel, self).__init__(model=model, dtype=dtype, device=device)
    
        @property
        def batch_type(self) -> Type[PaddedBatch]:
            """Returns the class type of the batch that the reranker model expects.
    
            Returns:
                Type[PaddedBatch]: The type of the batch class expected by the reranker model. All tensors in this 
                batch are padded according to the longgest input sequece.
            """
    
            return PaddedBatch
    
        @tracer.start_as_current_span("embed")
        def embed(self, batch: PaddedBatch) -> Optional[List[Embedding]]:
            """Logs an error indicating that the reranker model does not support the embed function."""
            logger.error("rerank model does not support embed function")
    
        @tracer.start_as_current_span("embed_all")
        def embed_all(self, batch: PaddedBatch) -> Optional[List[TokenEmbedding]]:
            """Logs an error indicating that the reranker model does not support the embed_all function."""
            logger.error("rerank model does not support embed_all function")
    
        @tracer.start_as_current_span("predict")
        def predict(self, batch: PaddedBatch) -> List[Prediction]:
            """Predicts the similarity score for each query-text pair in the input batch.
    
            Args:
                batch: The input batch containing tokenized data (e.g., input_ids, attention_mask).
    
            Returns:
                List[Prediction]: A list of predictions, where each prediction represents the similarity score for 
                each query-text pair in the input batch.
            """
    
            kwargs = {"input_ids": batch.input_ids.to(self.device), "attention_mask": batch.attention_mask.to(self.device)}
            if ENV.backend == 'atb':
                scores = self.model(**kwargs, return_dict=True).logits.view(-1, ).float().tolist()
                return [
                    Prediction(
                        values=[scores[i]]
                    )
                    for i in range(len(batch))
                ]
            else:
                scores = self.model(kwargs["input_ids"], kwargs["attention_mask"])[0].tolist()
                return [
                    Prediction(
                        values=scores[i]
                    )
                    for i in range(len(batch))
                ]
  • 基于TEI v1.2.3版本,请在text-embeddings-inference/backends/python/server/text_embeddings_server/models/types.py文件中为PaddedBatch类增加max_length字段并修改数据搬运以兼容MindIE Torch与ATB后端。
    from abc import ABC, abstractmethod
    from dataclasses import dataclass
    from opentelemetry import trace
    import torch
    
    from text_embeddings_server.pb import embed_pb2
    from text_embeddings_server.pb.embed_pb2 import Embedding, Prediction, TokenEmbedding
    
    from ..utils.env import ENV
    if ENV.backend == 'atb':
        import torch_npu
    else:
        import mindietorch
    
    tracer = trace.get_tracer(__name__)
    
    
    class Batch(ABC):
        """Abstract base class for batching input data for embedding and prediction.
    
        This class provides an interface for batching input data and converting it from protocol buffer
        format (EmbedRequest) into a format suitable for the model.
    
        Methods:
            from_pb(pb, device): Converts a protocol buffer format EmbedRequest to a Batch instance.
            __len__(): Returns the number of sequences in the batch.
        """
    
        @classmethod
        @abstractmethod
        def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch":
            """Converts an EmbedRequest protocol buffer format to a Batch object.
    
            Args:
                pb: The protocol buffer message containing input data.
                device: The device to allocate tensors.
    
            Returns:
                Batch: A new instance of a batch (either PaddedBatch or FlashBatch).
            """
            raise NotImplementedError
    
        @abstractmethod
        def __len__(self):
            """Returns the number of sequences in the batch."""
            raise NotImplementedError
    
    
    @dataclass
    class PaddedBatch(Batch):
        """Represents a batch of padded input data.
    
        This class holds a batch of input sequences, padding them to the same length. The data is represented
        as tensors, and all sequences are padded to the `max_length`.
    
        Attributes:
            input_ids: Tensor containing input token IDs for each sequence in the batch.
            token_type_ids: Tensor containing token type IDs for each sequence.
            position_ids: Tensor containing position IDs for each sequence.
            attention_mask: Tensor for the attention mask, indicating valid tokens in each sequence.
            max_length: The maximum sequence length for padding.
    
        Methods:
            from_pb(pb, device): Converts a protocol buffer format EmbedRequest to a PaddedBatch instance.
            __len__(): Returns the number of sequences in the batch.
        """
    
        input_ids: torch.Tensor
        token_type_ids: torch.Tensor
        position_ids: torch.Tensor
        attention_mask: torch.Tensor
        max_length: int
    
        @classmethod
        @tracer.start_as_current_span("from_pb")
        def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "PaddedBatch":
            """Converts an EmbedRequest protocol buffer format to a PaddedBatch instance.
    
            Args:
                pb: The protocol buffer message containing input data.
                device: The device to allocate tensors.
    
            Returns:
                PaddedBatch: A PaddedBatch instance containing padded input data.
            """
    
            # Allocate padded tensors all at once
            all_tensors = torch.zeros(
                [4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32, device='cpu'
            )
            max_length=pb.max_length
    
            for i, start_index in enumerate(pb.cu_seq_lengths[:-1]):
                end_index = pb.cu_seq_lengths[i + 1]
                input_length = end_index - start_index
    
                all_tensors[0, i, :input_length] = torch.tensor(
                    pb.input_ids[start_index:end_index], dtype=torch.int32
                )
                all_tensors[1, i, :input_length] = torch.tensor(
                    pb.token_type_ids[start_index:end_index], dtype=torch.int32
                )
                all_tensors[2, i, :input_length] = torch.tensor(
                    pb.position_ids[start_index:end_index], dtype=torch.int32
                )
                all_tensors[3, i, :input_length] = 1
            """
            # Move padded tensors all at once
            all_tensors = all_tensors.to(device)
            """
            return PaddedBatch(
                input_ids=all_tensors[0],
                token_type_ids=all_tensors[1],
                position_ids=all_tensors[2],
                attention_mask=all_tensors[3],
                max_length=max_length,
            )
    
        def __len__(self):
            """Returns the number of sequences in the batch."""
            return len(self.input_ids)
    
    
    @dataclass
    class FlashBatch(Batch):
        """Represents a batch of input data for flash inference.
    
        This class is used for models that support flash-based batching, where input data is packed into
        larger sequences for efficient processing.
    
        Attributes:
            input_ids: Tensor containing input token IDs for each sequence in the batch.
            token_type_ids: Tensor containing token type IDs for each sequence.
            position_ids: Tensor containing position IDs for each sequence.
            cu_seqlens: Tensor containing sequence lengths for each batch element.
            max_s: The maximum sequence length for this batch.
            size: The number of sequences in the batch.
    
        Methods:
            from_pb(pb, device): Converts a protocol buffer format EmbedRequest to a FlashBatch instance.
            __len__(): Returns the number of sequences in the batch.
        """
    
        input_ids: torch.Tensor
        token_type_ids: torch.Tensor
        position_ids: torch.Tensor
        cu_seqlens: torch.Tensor
        max_s: int
        size: int
    
        @classmethod
        @tracer.start_as_current_span("from_pb")
        def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch":
            """Converts an EmbedRequest protocol buffer format to a FlashBatch instance.
    
            Args:
                pb: The protocol buffer message containing input data.
                device: The device to allocate tensors.
    
            Returns:
                FlashBatch: A FlashBatch instance containing input data for a FlashBert model.
    
            Raises:
                RuntimeError: If the device is not 'cuda', FlashBatch is not supported.
            """
            if device.type != "cuda":
                raise RuntimeError(f"FlashBatch does not support device {device}")
    
            batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device)
            batch_token_type_ids = torch.tensor(
                pb.token_type_ids, dtype=torch.int32, device=device
            )
            batch_position_ids = torch.tensor(
                pb.position_ids, dtype=torch.int32, device=device
            )
    
            cu_seqlens = torch.tensor(pb.cu_seq_lengths, dtype=torch.int32, device=device)
    
            return FlashBatch(
                input_ids=batch_input_ids,
                token_type_ids=batch_token_type_ids,
                position_ids=batch_position_ids,
                cu_seqlens=cu_seqlens,
                max_s=pb.max_length,
                size=len(cu_seqlens) - 1,
            )
    
        def __len__(self):
            """Returns the number of sequences in the batch."""
            return self.size