适配昇腾环境及MindIE Torch组件的TEI代码目录结构与其Github代码仓完全一致(需要修改的代码文件加粗显示在下方代码结构树下,其它未显示文件或文件夹与TEI的Github代码仓保持一致),其目录结构如下所示。
text-embeddings-inference |____core | |____src | | |____ infer.rs |____backends | |____grpc-client | | |____src | | | |____ client.rs | |____proto | | |____ embed.proto | |____python | | |____src | | | |____ lib.rs | | |____server | | | |____pyproject.toml | | | |____requirements.txt | | | |____text_embeddings_server | | | | |____ server.py | | | | |_____models | | | | | |____ __init__.py | | | | | |____ default_model.py | | | | | |____ model.py | | | | | |____ rerank_model.py | | | | | |____ types.py
不做修改也不会影响TEI的运行以及服务化功能,修改后性能会有明显提升。
use crate::queue::{Entry, Metadata, NextBatch, Queue}; use crate::tokenization::{EncodingInput, RawEncoding, Tokenization}; use crate::TextEmbeddingsError; use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType}; use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::instrument; /// Inference struct #[derive(Debug, Clone)] pub struct Infer { tokenization: Tokenization, queue: Queue, /// Shared notify notify_batching_task: Arc<Notify>, /// Inference limit limit_concurrent_requests: Arc<Semaphore>, backend: Backend, } impl Infer { #[allow(clippy::too_many_arguments)] pub fn new( tokenization: Tokenization, queue: Queue, max_concurrent_requests: usize, backend: Backend, ) -> Self { let notify_batching_task = Arc::new(Notify::new()); let (embed_sender, embed_receiver) = mpsc::unbounded_channel(); // Create only one batching task to prefetch batches tokio::spawn(batching_task( queue.clone(), notify_batching_task.clone(), embed_sender.clone(), )); /* 此处源码使用两个batching_task对请求进行队列管理,将此处注释后仅用一个batching_task性能更佳 tokio::spawn(batching_task( queue.clone(), notify_batching_task.clone(), embed_sender, )); */ // Create embed task to communicate with backend tokio::spawn(backend_task(backend.clone(), embed_receiver)); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); Self { tokenization, queue, notify_batching_task, limit_concurrent_requests: semaphore, backend, } } #[instrument(skip(self))] pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, add_special_tokens: bool, ) -> Result<RawEncoding, TextEmbeddingsError> { self.tokenization .tokenize(inputs.into(), add_special_tokens) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err }) } #[instrument(skip(self))] pub async fn decode( &self, ids: Vec<u32>, skip_special_tokens: bool, ) -> Result<String, TextEmbeddingsError> { self.tokenization .decode(ids, skip_special_tokens) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err }) } #[instrument(skip(self))] pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> { // Limit concurrent requests by acquiring a permit from the semaphore self.clone() .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "overloaded"); tracing::error!("{err}"); TextEmbeddingsError::from(err) }) } #[instrument(skip(self))] pub async fn acquire_permit(&self) -> OwnedSemaphorePermit { // Limit concurrent requests by acquiring a permit from the semaphore self.clone() .limit_concurrent_requests .acquire_owned() .await .expect("Semaphore has been closed. This is a bug.") } #[instrument(skip(self, permit))] pub async fn embed_all<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, permit: OwnedSemaphorePermit, ) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if self.is_splade() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "`embed_all` is not available for SPLADE models".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, false, &start_time, permit) .await?; let InferResult::AllEmbedding(response) = results else { panic!("unexpected enum variant") }; // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self, permit))] pub async fn embed_sparse<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, permit: OwnedSemaphorePermit, ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if !self.is_splade() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not an embedding model with SPLADE pooling".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; let InferResult::PooledEmbedding(response) = results else { panic!("unexpected enum variant") }; // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self, permit))] pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, normalize: bool, permit: OwnedSemaphorePermit, ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if self.is_splade() && normalize { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "`normalize` is not available for SPLADE models".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; let InferResult::PooledEmbedding(mut response) = results else { panic!("unexpected enum variant") }; if normalize { // Normalize embedding let scale = (1.0 / response .results .iter() .map(|v| { let v = *v as f64; v * v }) .sum::<f64>() .sqrt()) as f32; for v in response.results.iter_mut() { *v *= scale; } } // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } async fn embed<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, pooling: bool, start_time: &Instant, _permit: OwnedSemaphorePermit, ) -> Result<InferResult, TextEmbeddingsError> { if self.is_classifier() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not an embedding model".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } metrics::increment_counter!("te_embed_count"); // Tokenization let encoding = self .tokenization .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err })?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Append the request to the queue self.queue.append(Entry { metadata: Metadata { response_tx, tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), pooling, }, encoding, }); self.notify_batching_task.notify_one(); let response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", ) .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "inference"); tracing::error!("{err}"); err })?; Ok(response) } #[instrument(skip(self, _permit))] pub async fn predict<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, raw_scores: bool, _permit: OwnedSemaphorePermit, ) -> Result<ClassificationInferResponse, TextEmbeddingsError> { if !self.is_classifier() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not a classifier model".to_string(); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let start_time = Instant::now(); metrics::increment_counter!("te_predict_count"); // Tokenization let encoding = self .tokenization .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err })?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Append the request to the queue self.queue.append(Entry { metadata: Metadata { response_tx, tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), pooling: true, }, encoding, }); self.notify_batching_task.notify_one(); let response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", ) .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "inference"); tracing::error!("{err}"); err })?; let InferResult::Classification(mut response) = response else { panic!("unexpected enum variant") }; if !raw_scores { // Softmax if response.results.len() > 1 { let max = *response .results .iter() .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap()) .unwrap(); let mut den = 0.0; for v in response.results.iter_mut() { *v = (*v - max).exp(); den += *v; } for v in response.results.iter_mut() { *v /= den; } } // Sigmoid else { response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp()); } } // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_predict_success"); metrics::histogram!("te_predict_duration", total_time.as_secs_f64()); metrics::histogram!( "te_predict_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_predict_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_predict_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self))] pub fn is_classifier(&self) -> bool { matches!(self.backend.model_type, ModelType::Classifier) } #[instrument(skip(self))] pub fn is_splade(&self) -> bool { matches!( self.backend.model_type, ModelType::Embedding(text_embeddings_backend::Pool::Splade) ) } #[instrument(skip(self))] pub async fn health(&self) -> bool { self.backend.health().await.is_ok() } #[instrument(skip(self))] pub fn health_watcher(&self) -> watch::Receiver<bool> { self.backend.health_watcher() } } #[instrument(skip_all)] async fn batching_task( queue: Queue, notify: Arc<Notify>, embed_sender: mpsc::UnboundedSender<(NextBatch, oneshot::Sender<()>)>, ) { loop { notify.notified().await; while let Some(next_batch) = queue.next_batch().await { let (callback_sender, callback_receiver) = oneshot::channel(); embed_sender .send((next_batch, callback_sender)) .expect("embed receiver was dropped. This is a bug."); let _ = callback_receiver.await; } } } #[instrument(skip_all)] async fn backend_task( backend: Backend, mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>, ) { while let Some((batch, _callback)) = embed_receiver.recv().await { match &backend.model_type { ModelType::Classifier => { let results = backend.predict(batch.1).await; // Handle sending responses in another thread to avoid starving the backend std::thread::spawn(move || match results { Ok((mut predictions, inference_duration)) => { batch.0.into_iter().enumerate().for_each(|(i, m)| { let infer_metadata = InferMetadata { prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: m.queue_time.elapsed() - inference_duration, inference: inference_duration, }; let _ = m.response_tx.send(Ok(InferResult::Classification( ClassificationInferResponse { results: predictions.remove(&i).expect( "prediction not found in results. This is a backend bug.", ), metadata: infer_metadata, }, ))); }); } Err(err) => { batch.0.into_iter().for_each(|m| { let _ = m.response_tx.send(Err(err.clone())); }); } }); } ModelType::Embedding(_) => { let results = backend.embed(batch.1).await; // Handle sending responses in another thread to avoid starving the backend std::thread::spawn(move || match results { Ok((mut embeddings, inference_duration)) => { batch.0.into_iter().enumerate().for_each(|(i, m)| { let metadata = InferMetadata { prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: m.queue_time.elapsed() - inference_duration, inference: inference_duration, }; let results = match embeddings .remove(&i) .expect("embedding not found in results. This is a backend bug.") { Embedding::Pooled(e) => { InferResult::PooledEmbedding(PooledEmbeddingsInferResponse { results: e, metadata, }) } Embedding::All(e) => { InferResult::AllEmbedding(AllEmbeddingsInferResponse { results: e, metadata, }) } }; let _ = m.response_tx.send(Ok(results)); }) } Err(err) => { batch.0.into_iter().for_each(|m| { let _ = m.response_tx.send(Err(err.clone())); }); } }); } }; } } #[derive(Debug)] pub struct InferMetadata { pub prompt_tokens: usize, pub tokenization: Duration, pub queue: Duration, pub inference: Duration, } #[derive(Debug)] pub(crate) enum InferResult { Classification(ClassificationInferResponse), PooledEmbedding(PooledEmbeddingsInferResponse), AllEmbedding(AllEmbeddingsInferResponse), } #[derive(Debug)] pub struct ClassificationInferResponse { pub results: Vec<f32>, pub metadata: InferMetadata, } #[derive(Debug)] pub struct PooledEmbeddingsInferResponse { pub results: Vec<f32>, pub metadata: InferMetadata, } #[derive(Debug)] pub struct AllEmbeddingsInferResponse { pub results: Vec<Vec<f32>>, pub metadata: InferMetadata, }
/// Single shard Client use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient; use crate::pb::embedding::v1::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; use tonic::transport::{Channel, Uri}; use tracing::instrument; /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { stub: EmbeddingServiceClient<Channel>, } impl Client { /// Returns a client connected to the given url pub async fn connect(uri: Uri) -> Result<Self> { let channel = Channel::builder(uri).connect().await?; Ok(Self { stub: EmbeddingServiceClient::new(channel), }) } /// Returns a client connected to the given unix socket pub async fn connect_uds(path: String) -> Result<Self> { let channel = Channel::from_shared("http://[::]:50051".to_string()) .unwrap() .connect_with_connector(tower::service_fn(move |_: Uri| { tokio::net::UnixStream::connect(path.clone()) })) .await?; Ok(Self { stub: EmbeddingServiceClient::new(channel), }) } /// Get backend health #[instrument(skip(self))] pub async fn health(&mut self) -> Result<HealthResponse> { let request = tonic::Request::new(HealthRequest {}).inject_context(); let response = self.stub.health(request).await?.into_inner(); Ok(response) } #[instrument(skip_all)] pub async fn embed( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<Embedding>> { let request = tonic::Request::new(EmbedRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); Ok(response.embeddings) } #[instrument(skip_all)] pub async fn embed_all( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<TokenEmbedding>> { let request = tonic::Request::new(EmbedRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.embed_all(request).await?.into_inner(); Ok(response.allembeddings) } #[instrument(skip_all)] pub async fn predict( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<Prediction>> { let request = tonic::Request::new(PredictRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); Ok(response.predictions) } }
syntax = "proto3"; package embedding.v1; service EmbeddingService { /// Decode token for a list of prefilled batches rpc Embed (EmbedRequest) returns (EmbedResponse); rpc Embed_all (EmbedRequest) returns (RawEmbedResponse); rpc Predict (PredictRequest) returns (PredictResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } message HealthRequest {} message HealthResponse {} message PredictRequest { repeated uint32 input_ids = 1; repeated uint32 token_type_ids = 2; repeated uint32 position_ids = 3; repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; } message Prediction { repeated float values = 1; } message PredictResponse { repeated Prediction predictions = 1; } message EmbedRequest { repeated uint32 input_ids = 1; repeated uint32 token_type_ids = 2; repeated uint32 position_ids = 3; repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; } message Embedding { repeated float values = 1; } message EmbedResponse { repeated Embedding embeddings = 1; } message TokenEmbedding { repeated Embedding embeddings = 1; } message RawEmbedResponse { repeated TokenEmbedding allembeddings = 1; }
mod logging; mod management; use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, }; use tokio::runtime::Runtime; pub struct PythonBackend { _backend_process: management::BackendProcess, tokio_runtime: Runtime, backend_client: Client, } impl PythonBackend { pub fn new( model_path: String, dtype: String, model_type: ModelType, uds_path: String, otlp_endpoint: Option<String>, ) -> Result<Self, BackendError> { match model_type { ModelType::Classifier => { let pool = Pool::Cls; pool } ModelType::Embedding(pool) => { if pool != Pool::Cls { return Err(BackendError::Start(format!("{pool:?} is not supported"))); } pool } }; let backend_process = management::BackendProcess::new(model_path, dtype, &uds_path, otlp_endpoint)?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; let backend_client = tokio_runtime .block_on(Client::connect_uds(uds_path)) .map_err(|err| { BackendError::Start(format!("Could not connect to backend process: {err}")) })?; Ok(Self { _backend_process: backend_process, tokio_runtime, backend_client, }) } } impl Backend for PythonBackend { fn health(&self) -> Result<(), BackendError> { if self .tokio_runtime .block_on(self.backend_client.clone().health()) .is_err() { return Err(BackendError::Unhealthy); } Ok(()) } fn is_padded(&self) -> bool { false } fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> { /* if !batch.raw_indices.is_empty() { return Err(BackendError::Inference( "raw embeddings are not supported for the Python backend.".to_string(), )); } */ let batch_size = batch.len(); let mut embeddings = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); if !batch.pooled_indices.is_empty() { let results = self .tokio_runtime .block_on(self.backend_client.clone().embed( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); for (i, e) in pooled_embeddings.into_iter().enumerate() { embeddings.insert(i, Embedding::Pooled(e)); } } else if !batch.raw_indices.is_empty() { let results = self .tokio_runtime .block_on(self.backend_client.clone().embed_all( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let mut raw_embeddings = Vec::new(); for token_embedding in results { let mut two_dim_list = Vec::new(); for embeddings in token_embedding.embeddings { let values = embeddings.values.clone(); two_dim_list.push(values); } raw_embeddings.push(two_dim_list); } for (i, e) in raw_embeddings.into_iter().enumerate() { embeddings.insert(i, Embedding::All(e)); } } Ok(embeddings) } fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> { let batch_size = batch.len(); let results = self .tokio_runtime .block_on(self.backend_client.clone().predict( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let predictions_result: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in predictions_result.into_iter().enumerate() { predictions.insert(i, r); } Ok(predictions) } }
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.2 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" poetry==1.8.3 ; python_version >= "3.9" and python_version < "3.13"
[tool.poetry] name = "text-embeddings-server" version = "0.1.0" description = "Text Embeddings Python gRPC Server" authors = ["Olivier Dehaene <olivier@huggingface.co>"] [tool.poetry.scripts] python-text-embeddings-server = 'text_embeddings_server.cli:app' [tool.poetry.dependencies] python = ">=3.9,<3.13" protobuf = "^4.21.7" grpcio = "^1.51.1" grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" safetensors = "^0.4.3" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" torch = { version = "^2.1.0" } [tool.poetry.extras] [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.51.1" pytest = "^7.3.0" [[tool.poetry.source]] name = "mirrors" url = "https://pypi.tuna.tsinghua.edu.cn/simple/" priority = "default" [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api"
import asyncio import torch import mindietorch from grpc import aio from loguru import logger from grpc_reflection.v1alpha import reflection from pathlib import Path from typing import Optional from text_embeddings_server.models import Model, get_model from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor from text_embeddings_server.utils.interceptor import ExceptionInterceptor class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): def __init__(self, model: Model): self.model = model # Force inference mode for the lifetime of EmbeddingService self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def Health(self, request, context): if self.model.device.type == "cuda": torch.zeros((2, 2), device="cuda") return embed_pb2.HealthResponse() async def Embed(self, request, context): batch = self.model.batch_type.from_pb(request, self.model.device) embeddings = self.model.embed(batch) return embed_pb2.EmbedResponse(embeddings=embeddings) async def Embed_all(self, request, context): batch = self.model.batch_type.from_pb(request, self.model.device) embeddings = self.model.embed_all(batch) return embed_pb2.RawEmbedResponse(allembeddings=embeddings) async def Predict(self, request, context): batch = self.model.batch_type.from_pb(request, self.model.device) predictions = self.model.predict(batch) return embed_pb2.PredictResponse(predictions=predictions) def serve( model_path: Path, dtype: Optional[str], uds_path: Path, ): async def serve_inner( model_path: Path, dtype: Optional[str] = None, ): unix_socket = f"unix://{uds_path}" try: model = get_model(model_path, dtype) except Exception: logger.exception("Error when initializing model") raise server = aio.server( interceptors=[ ExceptionInterceptor(), UDSOpenTelemetryAioServerInterceptor(), ] ) embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( EmbeddingService(model), server ) SERVICE_NAMES = ( embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) server.add_insecure_port(unix_socket) await server.start() logger.info(f"Server started at {unix_socket}") try: await server.wait_for_termination() except KeyboardInterrupt: logger.info("Signal received. Shutting down") await server.stop(0) asyncio.run(serve_inner(model_path, dtype))
import os import torch import mindietorch from loguru import logger from pathlib import Path from typing import Optional from transformers import AutoConfig from text_embeddings_server.models.model import Model from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.rerank_model import RerankModel __all__ = ["Model"] # Disable gradients torch.set_grad_enabled(False) FLASH_ATTENTION = True try: from text_embeddings_server.models.flash_bert import FlashBert except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) def get_model(model_path: Path, dtype: Optional[str]): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") deviceIdx = os.environ.get('TEI_NPU_DEVICE', '0') if deviceIdx != None and deviceIdx.isdigit() and int(deviceIdx) >= 0 and int(deviceIdx) <= 7: mindietorch.set_device(int(deviceIdx)) device = torch.device(f"npu:{int(deviceIdx)}") config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if config.architectures[0].endswith("Classification"): return RerankModel(model_path, device, dtype) else: if ( config.model_type == "bert" and device.type == "cuda" and config.position_embedding_type == "absolute" and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): return FlashBert(model_path, device, dtype) else: return DefaultModel(model_path, device, dtype) raise NotImplementedError
import torch import mindietorch from pathlib import Path from typing import Type, List from opentelemetry import trace from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, AutoConfig from loguru import logger from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding tracer = trace.get_tracer(__name__) class DefaultModel(Model): def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): mindietorch.set_device(device.index) model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device) self.model_path = str(model_path) self.hidden_size = AutoConfig.from_pretrained(model_path, trust_remote_code=True).hidden_size super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: return PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} output = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device)) if isinstance(output, dict): embedding = output['last_hidden_state'].to('cpu') else: embedding = output[0].to('cpu') embedding = embedding[:, 0].contiguous() cpu_results = embedding.view(-1).tolist() return [ Embedding( values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] ) for i in range(len(batch)) ] @tracer.start_as_current_span("embed_all") def embed_all(self, batch: PaddedBatch): kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} output = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device)) if isinstance(output, dict): embedding = output['last_hidden_state'].to('cpu').contiguous() else: embedding = output[0].to('cpu').contiguous() cpu_results = embedding.view(-1).tolist() embedding_result=[] for i in range(len(batch)): embedding_tmp=[ Embedding(values=cpu_results[(j+i*batch.max_length) * self.hidden_size : (j + 1 + i*batch.max_length) * self.hidden_size]) for j in range(batch.input_ids.size()[1]) ] tokenembeddings=TokenEmbedding(embeddings=embedding_tmp) embedding_result.append(tokenembeddings) return embedding_result @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Prediction]: print("embedding model does not support predict function")
import torch from abc import ABC, abstractmethod from typing import List, TypeVar, Type from text_embeddings_server.models.types import Batch, Embedding, Prediction, TokenEmbedding B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( self, model, dtype: torch.dtype, device: torch.device, ): self.model = model self.dtype = dtype self.device = device @property @abstractmethod def batch_type(self) -> Type[B]: raise NotImplementedError @abstractmethod def embed(self, batch: B) -> List[Embedding]: raise NotImplementedError @abstractmethod def embed_all(self, batch: B) -> List[TokenEmbedding]: raise NotImplementedError @abstractmethod def predict(self, batch: B) -> List[Prediction]: raise NotImplementedError
import torch import mindietorch from pathlib import Path from typing import Type, List from opentelemetry import trace from loguru import logger from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding tracer = trace.get_tracer(__name__) class RerankModel(Model): def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): mindietorch.set_device(device.index) model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device) super(RerankModel, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: return PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> List[Embedding]: print("rerank model does not support embed function") @tracer.start_as_current_span("embed_all") def embed_all(self, batch: PaddedBatch) -> List[TokenEmbedding]: print("rerank model does not support embed_all function") @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Prediction]: kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} scores = self.model(kwargs["input_ids"].to(self.device), kwargs["attention_mask"].to(self.device))[0].tolist() return [ Prediction( values=scores[i] ) for i in range(len(batch)) ]
import torch import mindietorch from abc import ABC, abstractmethod from dataclasses import dataclass from opentelemetry import trace from text_embeddings_server.pb import embed_pb2 from text_embeddings_server.pb.embed_pb2 import Embedding, Prediction, TokenEmbedding tracer = trace.get_tracer(__name__) class Batch(ABC): @classmethod @abstractmethod def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch": raise NotImplementedError @abstractmethod def __len__(self): raise NotImplementedError @dataclass class PaddedBatch(Batch): input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor max_length: int @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "PaddedBatch": # Allocate padded tensors all at once all_tensors = torch.zeros( [4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32, device='cpu' ) max_length=pb.max_length for i, start_index in enumerate(pb.cu_seq_lengths[:-1]): end_index = pb.cu_seq_lengths[i + 1] input_length = end_index - start_index all_tensors[0, i, :input_length] = torch.tensor( pb.input_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[1, i, :input_length] = torch.tensor( pb.token_type_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[2, i, :input_length] = torch.tensor( pb.position_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[3, i, :input_length] = 1 """ # Move padded tensors all at once all_tensors = all_tensors.to(device) """ return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], max_length=max_length, ) def __len__(self): return len(self.input_ids) @dataclass class FlashBatch(Batch): input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor cu_seqlens: torch.Tensor max_s: int size: int @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch": if device.type != "cuda": raise RuntimeError(f"FlashBatch does not support device {device}") batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device) batch_token_type_ids = torch.tensor( pb.token_type_ids, dtype=torch.int32, device=device ) batch_position_ids = torch.tensor( pb.position_ids, dtype=torch.int32, device=device ) cu_seqlens = torch.tensor(pb.cu_seq_lengths, dtype=torch.int32, device=device) return FlashBatch( input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, position_ids=batch_position_ids, cu_seqlens=cu_seqlens, max_s=pb.max_length, size=len(cu_seqlens) - 1, ) def __len__(self): return self.size