适配昇腾环境及MindIE Torch与ATB后端的TEI代码目录结构与其Github代码仓完全一致(需要修改的代码文件加粗显示在下方代码结构树下,其它未显示文件或文件夹与TEI的Github代码仓保持一致),其目录结构如下所示。
text-embeddings-inference |____core | |____src | | |____ infer.rs |____backends | |____grpc-client | | |____src | | | |____ client.rs | |____proto | | |____ embed.proto | |____src | | |____ dtype.rs | |____python | | |____src | | | |____ lib.rs | | |____server | | | |____pyproject.toml | | | |____requirements.txt | | | |____text_embeddings_server | | | | |_____utils | | | | | |____ env.py | | | | |____ server.py | | | | |_____models | | | | | |____ __init__.py | | | | | |____ default_model.py | | | | | |____ model.py | | | | | |____ rerank_model.py | | | | | |____ types.py
不做修改也不会影响TEI的运行以及服务化功能,修改后性能会有明显提升。
use crate::queue::{Entry, Metadata, NextBatch, Queue}; use crate::tokenization::{EncodingInput, RawEncoding, Tokenization}; use crate::TextEmbeddingsError; use std::sync::Arc; use std::time::{Duration, Instant}; use text_embeddings_backend::{Backend, BackendError, Embedding, ModelType}; use tokio::sync::{mpsc, oneshot, watch, Notify, OwnedSemaphorePermit, Semaphore}; use tracing::instrument; /// Inference struct #[derive(Debug, Clone)] pub struct Infer { tokenization: Tokenization, queue: Queue, /// Shared notify notify_batching_task: Arc<Notify>, /// Inference limit limit_concurrent_requests: Arc<Semaphore>, backend: Backend, } impl Infer { #[allow(clippy::too_many_arguments)] pub fn new( tokenization: Tokenization, queue: Queue, max_concurrent_requests: usize, backend: Backend, ) -> Self { let notify_batching_task = Arc::new(Notify::new()); let (embed_sender, embed_receiver) = mpsc::unbounded_channel(); // Create only one batching task to prefetch batches tokio::spawn(batching_task( queue.clone(), notify_batching_task.clone(), embed_sender.clone(), )); /* 此处源码使用两个batching_task对请求进行队列管理,将此处注释后仅用一个batching_task性能更佳 tokio::spawn(batching_task( queue.clone(), notify_batching_task.clone(), embed_sender, )); */ // Create embed task to communicate with backend tokio::spawn(backend_task(backend.clone(), embed_receiver)); // Inference limit with a semaphore let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); Self { tokenization, queue, notify_batching_task, limit_concurrent_requests: semaphore, backend, } } #[instrument(skip(self))] pub async fn tokenize<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, add_special_tokens: bool, ) -> Result<RawEncoding, TextEmbeddingsError> { self.tokenization .tokenize(inputs.into(), add_special_tokens) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err }) } #[instrument(skip(self))] pub async fn decode( &self, ids: Vec<u32>, skip_special_tokens: bool, ) -> Result<String, TextEmbeddingsError> { self.tokenization .decode(ids, skip_special_tokens) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err }) } #[instrument(skip(self))] pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TextEmbeddingsError> { // Limit concurrent requests by acquiring a permit from the semaphore self.clone() .limit_concurrent_requests .try_acquire_owned() .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "overloaded"); tracing::error!("{err}"); TextEmbeddingsError::from(err) }) } #[instrument(skip(self))] pub async fn acquire_permit(&self) -> OwnedSemaphorePermit { // Limit concurrent requests by acquiring a permit from the semaphore self.clone() .limit_concurrent_requests .acquire_owned() .await .expect("Semaphore has been closed. This is a bug.") } #[instrument(skip(self, permit))] pub async fn embed_all<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, permit: OwnedSemaphorePermit, ) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if self.is_splade() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "`embed_all` is not available for SPLADE models".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, false, &start_time, permit) .await?; let InferResult::AllEmbedding(response) = results else { panic!("unexpected enum variant") }; // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self, permit))] pub async fn embed_sparse<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, permit: OwnedSemaphorePermit, ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if !self.is_splade() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not an embedding model with SPLADE pooling".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; let InferResult::PooledEmbedding(response) = results else { panic!("unexpected enum variant") }; // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self, permit))] pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, normalize: bool, permit: OwnedSemaphorePermit, ) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { let start_time = Instant::now(); if self.is_splade() && normalize { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "`normalize` is not available for SPLADE models".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let results = self .embed(inputs, truncate, true, &start_time, permit) .await?; let InferResult::PooledEmbedding(mut response) = results else { panic!("unexpected enum variant") }; if normalize { // Normalize embedding let scale = (1.0 / response .results .iter() .map(|v| { let v = *v as f64; v * v }) .sum::<f64>() .sqrt()) as f32; for v in response.results.iter_mut() { *v *= scale; } } // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_embed_success"); metrics::histogram!("te_embed_duration", total_time.as_secs_f64()); metrics::histogram!( "te_embed_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_embed_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_embed_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } async fn embed<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, pooling: bool, start_time: &Instant, _permit: OwnedSemaphorePermit, ) -> Result<InferResult, TextEmbeddingsError> { if self.is_classifier() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not an embedding model".to_string(); tracing::error!("{message}"); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } metrics::increment_counter!("te_embed_count"); // Tokenization let encoding = self .tokenization .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err })?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Append the request to the queue self.queue.append(Entry { metadata: Metadata { response_tx, tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), pooling, }, encoding, }); self.notify_batching_task.notify_one(); let response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", ) .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "inference"); tracing::error!("{err}"); err })?; Ok(response) } #[instrument(skip(self, _permit))] pub async fn predict<I: Into<EncodingInput> + std::fmt::Debug>( &self, inputs: I, truncate: bool, raw_scores: bool, _permit: OwnedSemaphorePermit, ) -> Result<ClassificationInferResponse, TextEmbeddingsError> { if !self.is_classifier() { metrics::increment_counter!("te_request_failure", "err" => "model_type"); let message = "Model is not a classifier model".to_string(); return Err(TextEmbeddingsError::Backend(BackendError::Inference( message, ))); } let start_time = Instant::now(); metrics::increment_counter!("te_predict_count"); // Tokenization let encoding = self .tokenization .encode(inputs.into(), truncate) .await .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "tokenization"); tracing::error!("{err}"); err })?; // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Append the request to the queue self.queue.append(Entry { metadata: Metadata { response_tx, tokenization: start_time.elapsed(), queue_time: Instant::now(), prompt_tokens: encoding.input_ids.len(), pooling: true, }, encoding, }); self.notify_batching_task.notify_one(); let response = response_rx .await .expect( "Infer batching task dropped the sender without sending a response. This is a bug.", ) .map_err(|err| { metrics::increment_counter!("te_request_failure", "err" => "inference"); tracing::error!("{err}"); err })?; let InferResult::Classification(mut response) = response else { panic!("unexpected enum variant") }; if !raw_scores { // Softmax if response.results.len() > 1 { let max = *response .results .iter() .max_by(|x, y| x.abs().partial_cmp(&y.abs()).unwrap()) .unwrap(); let mut den = 0.0; for v in response.results.iter_mut() { *v = (*v - max).exp(); den += *v; } for v in response.results.iter_mut() { *v /= den; } } // Sigmoid else { response.results[0] = 1.0 / (1.0 + (-response.results[0]).exp()); } } // Timings let total_time = start_time.elapsed(); // Metrics metrics::increment_counter!("te_predict_success"); metrics::histogram!("te_predict_duration", total_time.as_secs_f64()); metrics::histogram!( "te_predict_tokenization_duration", response.metadata.tokenization.as_secs_f64() ); metrics::histogram!( "te_predict_queue_duration", response.metadata.queue.as_secs_f64() ); metrics::histogram!( "te_predict_inference_duration", response.metadata.inference.as_secs_f64() ); Ok(response) } #[instrument(skip(self))] pub fn is_classifier(&self) -> bool { matches!(self.backend.model_type, ModelType::Classifier) } #[instrument(skip(self))] pub fn is_splade(&self) -> bool { matches!( self.backend.model_type, ModelType::Embedding(text_embeddings_backend::Pool::Splade) ) } #[instrument(skip(self))] pub async fn health(&self) -> bool { self.backend.health().await.is_ok() } #[instrument(skip(self))] pub fn health_watcher(&self) -> watch::Receiver<bool> { self.backend.health_watcher() } } #[instrument(skip_all)] async fn batching_task( queue: Queue, notify: Arc<Notify>, embed_sender: mpsc::UnboundedSender<(NextBatch, oneshot::Sender<()>)>, ) { loop { notify.notified().await; while let Some(next_batch) = queue.next_batch().await { let (callback_sender, callback_receiver) = oneshot::channel(); embed_sender .send((next_batch, callback_sender)) .expect("embed receiver was dropped. This is a bug."); let _ = callback_receiver.await; } } } #[instrument(skip_all)] async fn backend_task( backend: Backend, mut embed_receiver: mpsc::UnboundedReceiver<(NextBatch, oneshot::Sender<()>)>, ) { while let Some((batch, _callback)) = embed_receiver.recv().await { match &backend.model_type { ModelType::Classifier => { let results = backend.predict(batch.1).await; // Handle sending responses in another thread to avoid starving the backend std::thread::spawn(move || match results { Ok((mut predictions, inference_duration)) => { batch.0.into_iter().enumerate().for_each(|(i, m)| { let infer_metadata = InferMetadata { prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: m.queue_time.elapsed() - inference_duration, inference: inference_duration, }; let _ = m.response_tx.send(Ok(InferResult::Classification( ClassificationInferResponse { results: predictions.remove(&i).expect( "prediction not found in results. This is a backend bug.", ), metadata: infer_metadata, }, ))); }); } Err(err) => { batch.0.into_iter().for_each(|m| { let _ = m.response_tx.send(Err(err.clone())); }); } }); } ModelType::Embedding(_) => { let results = backend.embed(batch.1).await; // Handle sending responses in another thread to avoid starving the backend std::thread::spawn(move || match results { Ok((mut embeddings, inference_duration)) => { batch.0.into_iter().enumerate().for_each(|(i, m)| { let metadata = InferMetadata { prompt_tokens: m.prompt_tokens, tokenization: m.tokenization, queue: m.queue_time.elapsed() - inference_duration, inference: inference_duration, }; let results = match embeddings .remove(&i) .expect("embedding not found in results. This is a backend bug.") { Embedding::Pooled(e) => { InferResult::PooledEmbedding(PooledEmbeddingsInferResponse { results: e, metadata, }) } Embedding::All(e) => { InferResult::AllEmbedding(AllEmbeddingsInferResponse { results: e, metadata, }) } }; let _ = m.response_tx.send(Ok(results)); }) } Err(err) => { batch.0.into_iter().for_each(|m| { let _ = m.response_tx.send(Err(err.clone())); }); } }); } }; } } #[derive(Debug)] pub struct InferMetadata { pub prompt_tokens: usize, pub tokenization: Duration, pub queue: Duration, pub inference: Duration, } #[derive(Debug)] pub(crate) enum InferResult { Classification(ClassificationInferResponse), PooledEmbedding(PooledEmbeddingsInferResponse), AllEmbedding(AllEmbeddingsInferResponse), } #[derive(Debug)] pub struct ClassificationInferResponse { pub results: Vec<f32>, pub metadata: InferMetadata, } #[derive(Debug)] pub struct PooledEmbeddingsInferResponse { pub results: Vec<f32>, pub metadata: InferMetadata, } #[derive(Debug)] pub struct AllEmbeddingsInferResponse { pub results: Vec<Vec<f32>>, pub metadata: InferMetadata, }
/// Single shard Client use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient; use crate::pb::embedding::v1::*; use crate::Result; use grpc_metadata::InjectTelemetryContext; use tonic::transport::{Channel, Uri}; use tracing::instrument; /// Text Generation Inference gRPC client #[derive(Debug, Clone)] pub struct Client { stub: EmbeddingServiceClient<Channel>, } impl Client { /// Returns a client connected to the given url pub async fn connect(uri: Uri) -> Result<Self> { let channel = Channel::builder(uri).connect().await?; Ok(Self { stub: EmbeddingServiceClient::new(channel), }) } /// Returns a client connected to the given unix socket pub async fn connect_uds(path: String) -> Result<Self> { let channel = Channel::from_shared("http://[::]:50051".to_string()) .unwrap() .connect_with_connector(tower::service_fn(move |_: Uri| { tokio::net::UnixStream::connect(path.clone()) })) .await?; Ok(Self { stub: EmbeddingServiceClient::new(channel), }) } /// Get backend health #[instrument(skip(self))] pub async fn health(&mut self) -> Result<HealthResponse> { let request = tonic::Request::new(HealthRequest {}).inject_context(); let response = self.stub.health(request).await?.into_inner(); Ok(response) } #[instrument(skip_all)] pub async fn embed( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<Embedding>> { let request = tonic::Request::new(EmbedRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.embed(request).await?.into_inner(); Ok(response.embeddings) } #[instrument(skip_all)] pub async fn embed_all( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<TokenEmbedding>> { let request = tonic::Request::new(EmbedRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.embed_all(request).await?.into_inner(); Ok(response.allembeddings) } #[instrument(skip_all)] pub async fn predict( &mut self, input_ids: Vec<u32>, token_type_ids: Vec<u32>, position_ids: Vec<u32>, cu_seq_lengths: Vec<u32>, max_length: u32, ) -> Result<Vec<Prediction>> { let request = tonic::Request::new(PredictRequest { input_ids, token_type_ids, position_ids, max_length, cu_seq_lengths, }) .inject_context(); let response = self.stub.predict(request).await?.into_inner(); Ok(response.predictions) } }
syntax = "proto3"; package embedding.v1; service EmbeddingService { /// Decode token for a list of prefilled batches rpc Embed (EmbedRequest) returns (EmbedResponse); rpc Embed_all (EmbedRequest) returns (RawEmbedResponse); rpc Predict (PredictRequest) returns (PredictResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } message HealthRequest {} message HealthResponse {} message PredictRequest { repeated uint32 input_ids = 1; repeated uint32 token_type_ids = 2; repeated uint32 position_ids = 3; repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; } message Prediction { repeated float values = 1; } message PredictResponse { repeated Prediction predictions = 1; } message EmbedRequest { repeated uint32 input_ids = 1; repeated uint32 token_type_ids = 2; repeated uint32 position_ids = 3; repeated uint32 cu_seq_lengths = 4; /// Length of the longest request uint32 max_length = 5; } message Embedding { repeated float values = 1; } message EmbedResponse { repeated Embedding embeddings = 1; } message TokenEmbedding { repeated Embedding embeddings = 1; } message RawEmbedResponse { repeated TokenEmbedding allembeddings = 1; }
use std::fmt; #[cfg(feature = "clap")] use clap::ValueEnum; #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clap", derive(Clone, ValueEnum))] pub enum DType { // Float16 is not available on accelerate #[cfg(any( feature = "python", all(feature = "candle", not(feature = "accelerate")) ))] Float16, // Bfloat16 is available only when the "python" feature is enabled #[cfg(feature = "python")] Bfloat16, // Float32 is not available on candle cuda #[cfg(any(feature = "python", feature = "candle"))] Float32, // #[cfg(feature = "candle")] // Q6K, } impl fmt::Display for DType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { // Float16 is not available on accelerate #[cfg(any( feature = "python", all(feature = "candle", not(feature = "accelerate")) ))] DType::Float16 => write!(f, "float16"), // Bfloat16 is only available with the "python" feature #[cfg(feature = "python")] DType::Bfloat16 => write!(f, "bfloat16"), // Float32 is not available on candle cuda #[cfg(any(feature = "python", feature = "candle"))] DType::Float32 => write!(f, "float32"), // #[cfg(feature = "candle")] // DType::Q6K => write!(f, "q6k"), } } }
mod logging; mod management; use backend_grpc_client::Client; use nohash_hasher::BuildNoHashHasher; use std::collections::HashMap; use text_embeddings_backend_core::{ Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, }; use tokio::runtime::Runtime; pub struct PythonBackend { _backend_process: management::BackendProcess, tokio_runtime: Runtime, backend_client: Client, } impl PythonBackend { pub fn new( model_path: String, dtype: String, model_type: ModelType, uds_path: String, otlp_endpoint: Option<String>, ) -> Result<Self, BackendError> { match model_type { ModelType::Classifier => { let pool = Pool::Cls; pool } ModelType::Embedding(pool) => { if pool != Pool::Cls { return Err(BackendError::Start(format!("{pool:?} is not supported"))); } pool } }; let backend_process = management::BackendProcess::new(model_path, dtype, &uds_path, otlp_endpoint)?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; let backend_client = tokio_runtime .block_on(Client::connect_uds(uds_path)) .map_err(|err| { BackendError::Start(format!("Could not connect to backend process: {err}")) })?; Ok(Self { _backend_process: backend_process, tokio_runtime, backend_client, }) } } impl Backend for PythonBackend { fn health(&self) -> Result<(), BackendError> { if self .tokio_runtime .block_on(self.backend_client.clone().health()) .is_err() { return Err(BackendError::Unhealthy); } Ok(()) } fn is_padded(&self) -> bool { false } fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> { /* if !batch.raw_indices.is_empty() { return Err(BackendError::Inference( "raw embeddings are not supported for the Python backend.".to_string(), )); } */ let batch_size = batch.len(); let mut embeddings = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); if !batch.pooled_indices.is_empty() { let results = self .tokio_runtime .block_on(self.backend_client.clone().embed( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let pooled_embeddings: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); for (i, e) in pooled_embeddings.into_iter().enumerate() { embeddings.insert(i, Embedding::Pooled(e)); } } else if !batch.raw_indices.is_empty() { let results = self .tokio_runtime .block_on(self.backend_client.clone().embed_all( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let mut raw_embeddings = Vec::new(); for token_embedding in results { let mut two_dim_list = Vec::new(); for embeddings in token_embedding.embeddings { let values = embeddings.values.clone(); two_dim_list.push(values); } raw_embeddings.push(two_dim_list); } for (i, e) in raw_embeddings.into_iter().enumerate() { embeddings.insert(i, Embedding::All(e)); } } Ok(embeddings) } fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> { let batch_size = batch.len(); let results = self .tokio_runtime .block_on(self.backend_client.clone().predict( batch.input_ids, batch.token_type_ids, batch.position_ids, batch.cumulative_seq_lengths, batch.max_length, )) .map_err(|err| BackendError::Inference(err.to_string()))?; let predictions_result: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect(); let mut predictions = HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); for (i, r) in predictions_result.into_iter().enumerate() { predictions.insert(i, r); } Ok(predictions) } }
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.2 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" torch==2.1.0 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" poetry==1.8.3 ; python_version >= "3.9" and python_version < "3.13"
[tool.poetry] name = "text-embeddings-server" version = "0.1.0" description = "Text Embeddings Python gRPC Server" authors = ["Olivier Dehaene <olivier@huggingface.co>"] [tool.poetry.scripts] python-text-embeddings-server = 'text_embeddings_server.cli:app' [tool.poetry.dependencies] python = ">=3.9,<3.13" protobuf = "^4.21.7" grpcio = "^1.51.1" grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" safetensors = "^0.4.3" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" torch = { version = "^2.1.0" } [tool.poetry.extras] [tool.poetry.group.dev.dependencies] grpcio-tools = "^1.51.1" pytest = "^7.3.0" [[tool.poetry.source]] name = "mirrors" url = "https://pypi.tuna.tsinghua.edu.cn/simple/" priority = "default" [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api"
TEI的Python侧参考适配代码调整了包导入顺序,并增加了Google风格的类与方法注释。为凸显重点,此部分修改并未在参考代码内容中加粗显示。
# Copyright Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. import os from dataclasses import dataclass from loguru import logger @dataclass class EnvVar: """ Reads environment variables to configure the embedding or reranker model backend and device ID for use in other components of TEI-MindIE. Attributes: backend: The backend model type to be used for inference (e.g., 'atb' or 'mindietorch'). device_id: The device id to load model on (between '0' and '7'). Raises: ValueError: If the backend does not belong to ['atb', 'mindietorch'] or the device_id is invalid. """ backend: str = os.getenv("TEI_NPU_BACKEND", 'mindietorch') device_id: str = os.getenv("TEI_NPU_DEVICE", '0') def __post_init__(self): logger.info(self.dict()) if self.backend not in ['atb', 'mindietorch']: raise ValueError("Your model backend is invalid.") if not self.device_id.isdigit(): raise ValueError("Your device_id is invalid.") if int(self.device_id) < 0: raise ValueError("Your device_id is invalid.") if int(self.device_id) >= 8: logger.warning(f'Your device_id is {self.device_id}.') def dict(self): return self.__dict__ ENV = EnvVar()
import asyncio from pathlib import Path from typing import Optional from grpc import aio from loguru import logger from grpc_reflection.v1alpha import reflection import torch from text_embeddings_server.models import Model, get_model from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor from text_embeddings_server.utils.interceptor import ExceptionInterceptor from .utils.env import ENV if ENV.backend == 'atb': import torch_npu else: import mindietorch class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): """Handles gRPC requests for the text embeddings service. This class implements the EmbeddingServiceServicer interface, providing methods for embedding, prediction and health check functionalities. Attributes: model (Model): The model used for generating embeddings and similarity scores. _inference_mode_raii_guard (torch._C._InferenceMode): A context manager to enforce inference mode. """ def __init__(self, model: Model): """Initializes the EmbeddingService with a model. Args: model (Model): The model to use for embedding or reranker service. """ self.model = model # Force inference mode for the lifetime of EmbeddingService self._inference_mode_raii_guard = torch._C._InferenceMode(True) async def Health(self, request, context): """Handles the health check request. Verifies if the model is able to run on the provided device. Args: request: The gRPC health check request. context: The gRPC context for the request. Returns: embed_pb2.HealthResponse: A response indicating the service health status. """ if self.model.device.type == "npu": health_check_tensor = torch.zeros((2, 2)) health_check_tensor.to(self.model.device) return embed_pb2.HealthResponse() async def Embed(self, request, context): """Handles the embed request. Processes the incoming request, generates embeddings using the provided model, and returns the hidden state of the first token for each sequence in the batch. Args: request: The gRPC request containing a batch of data for embedding. context: The gRPC context for the request. Returns: embed_pb2.EmbedResponse: A response containing generated the sliced embeddings for the request. """ batch = self.model.batch_type.from_pb(request, self.model.device) embeddings = self.model.embed(batch) return embed_pb2.EmbedResponse(embeddings=embeddings) async def Embed_all(self, request, context): """Handles the embed_all request. Processes the incoming request, generates embeddings using the provided model, and returns hidden states of all tokens for each sequence in the batch. Args: request: The gRPC request containing a batch of data for embedding. context: The gRPC context for the request. Returns: embed_pb2.RawEmbedResponse: A response containing all the embeddings for the request. """ batch = self.model.batch_type.from_pb(request, self.model.device) embeddings = self.model.embed_all(batch) return embed_pb2.RawEmbedResponse(allembeddings=embeddings) async def Predict(self, request, context): """Handles the predict request. Processes the incoming request, generates predictions using the provided model, and returns final scores. Args: request: The gRPC request containing a batch of data for scores prediction. context: The gRPC context for the request. Returns: embed_pb2.PredictResponse: A response containing the model's predictions for the request. """ batch = self.model.batch_type.from_pb(request, self.model.device) predictions = self.model.predict(batch) return embed_pb2.PredictResponse(predictions=predictions) def serve( model_path: Path, dtype: Optional[str], uds_path: Path, ): """Starts the gRPC server and serves the text embedding service. This function initializes the model and starts a server that listens for incoming requests to generate embeddings, predictions or check health. Args: model_path: Path to the model directory. dtype: Data type for model initialization. uds_path: Path to the Unix Domain Socket for the server to listen on. """ async def serve_inner( model_path: Path, dtype: Optional[str] = None, ): """The inner asynchronous function to run the gRPC server. Args: model_path: Path to the model directory. dtype: Data type for model initialization. Raises: Exception: If there is an error during model initialization. """ unix_socket = f"unix://{uds_path}" try: model = get_model(model_path, dtype) except Exception: logger.exception("Error when initializing model") raise server = aio.server( interceptors=[ ExceptionInterceptor(), UDSOpenTelemetryAioServerInterceptor(), ] ) embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( EmbeddingService(model), server ) SERVICE_NAMES = ( embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) server.add_insecure_port(unix_socket) await server.start() logger.info(f"Server started at {unix_socket}") try: await server.wait_for_termination() except KeyboardInterrupt: logger.info("Signal received. Shutting down") await server.stop(0) asyncio.run(serve_inner(model_path, dtype))
from pathlib import Path from typing import Optional from loguru import logger from transformers import AutoConfig import torch from text_embeddings_server.models.model import Model from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.rerank_model import RerankModel from ..utils.env import ENV if ENV.backend == 'atb': import torch_npu else: import mindietorch __all__ = ["Model"] # Disable gradients torch.set_grad_enabled(False) FLASH_ATTENTION = True try: from text_embeddings_server.models.flash_bert import FlashBert except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) def get_model(model_path: Path, dtype: Optional[str]): """Loads and returns the appropriate model class based on given path, data type and backend type. Args: model_path: The path to the pre-trained model. dtype: The data type of the model. If not specified or invalid, a RuntimeError will be raised. Returns: Model: An instance of a subclass of the `Model` class, which could either be: - RerankModel (if the architectures[0] in config.json ends with 'Classification') - FlashBert (if the model is based on BERT and CUDA is available) - DefaultModel (if all conditions above are not met) Raises: RuntimeError: If an unknown data type is provided for `dtype`. ValueError: If the device is CPU and the dtype is not `float32`. """ if dtype == "float32": dtype = torch.float32 elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": dtype = torch.bfloat16 else: raise RuntimeError(f"Unknown dtype {dtype}") if ENV.device_id: if ENV.backend == 'atb': torch.npu.set_compile_mode(jit_compile=False) option = {"NPU_FUZZY_COMPILE_BLACKLIST": "ReduceProd"} torch.npu.set_option(option) device = torch.device(f"npu:{int(ENV.device_id)}") torch.npu.set_device(torch.device(f"npu:{int(ENV.device_id)}")) else: mindietorch.set_device(int(ENV.device_id)) device = torch.device(f"npu:{int(ENV.device_id)}") elif torch.cuda.is_available(): device = torch.device("cuda") else: if dtype != torch.float32: raise ValueError("CPU device only supports float32 dtype") device = torch.device("cpu") config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if config.architectures[0].endswith("Classification"): return RerankModel(model_path, device, dtype) else: if ( config.model_type == "bert" and device.type == "cuda" and config.position_embedding_type == "absolute" and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): return FlashBert(model_path, device, dtype) else: return DefaultModel(model_path, device, dtype) raise NotImplementedError
from pathlib import Path from typing import Optional, Type, List from opentelemetry import trace from transformers import AutoModel, AutoConfig from loguru import logger import torch from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding from ..utils.env import ENV if ENV.backend == 'atb': import torch_npu else: import mindietorch tracer = trace.get_tracer(__name__) class DefaultModel(Model): """DefaultModel is a wrapper around a pre-trained embedding model that generates embeddings for input texts. It extends the Model class and adds functionality to provide embedding generation according to the batch of input data. Attributes: config: The model configuration about the given pre-trained model. hidden_size: The size of the model's hidden states. is_causal: A flag indicating whether the embedding model is a causal model whose `architecture` ends with "CausalLM". If True, embed function will return the embedding of the last token; otherwise, it returns the embedding of the first token. """ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): """ Initializes the DefaultModel by loading the pre-trained model based on the given model path and backend type, while also setting necessary member variables for embedding generation. Args: model_path: Path to the pre-trained model. device: The device to load the model on (e.g., NPU). dtype: The data type of the model's parameters (e.g., torch.float16). """ self.config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=True) self.hidden_size = self.config.hidden_size self.is_causal = self.config.architectures[0].endswith('CausalLM') if ENV.backend == 'atb': model = AutoModel.from_pretrained( model_path, torch_dtype=dtype, trust_remote_code=True ).to(device).eval() else: mindietorch.set_device(device.index) # 使用MindIE Torch后端时,请确保model_path(模型权重路径)下仅存在一个编译优化后的pt文件,否则请根据实际情况修改正则表达式和pt文件命名 model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval().to(device) super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: """Returns the class type of the batch that the embedding model expects. Returns: Type[PaddedBatch]: The type of the batch class expected by the embedding model. All tensors in this batch are padded according to the longgest input sequece. """ return PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> List[Embedding]: """ Generates embeddings for a given batch of input data and returns the hidden state of the first or last token for each input sequence in the batch. Args: batch: The input batch containing tokenized data (e.g., input_ids, attention_mask). Returns: List[Embedding]: A list of embeddings, where each embedding represents the hidden state of the first or last token for each input sequence in the batch. """ kwargs = {"input_ids": batch.input_ids.to(self.device), "attention_mask": batch.attention_mask.to(self.device)} if ENV.backend == 'atb': output = self.model(**kwargs, return_dict=True) embedding = output[0] else: output = self.model(kwargs["input_ids"], kwargs["attention_mask"]) if isinstance(output, dict): embedding = output['last_hidden_state'].to('cpu') else: embedding = output[0].to('cpu') if self.is_causal: # For causal models, get the embedding of the last token embedding = embedding[:, -1] else: # For non-causal models, get the embedding of the first token embedding = embedding[:, 0] cpu_results = embedding.contiguous().view(-1).tolist() return [ Embedding( values=cpu_results[i * self.hidden_size: (i + 1) * self.hidden_size] ) for i in range(len(batch)) ] @tracer.start_as_current_span("embed_all") def embed_all(self, batch: PaddedBatch) -> List[TokenEmbedding]: """ Generates embeddings for a given batch of input data and returns hidden states of all tokens for each input sequence in the batch. Args: batch: The input batch containing tokenized data (e.g., input_ids, attention_mask). Returns: List[TokenEmbedding]: A list of token embeddings, where each token embedding represents hidden states of all tokens for each input sequence in the batch. """ kwargs = {"input_ids": batch.input_ids.to(self.device), "attention_mask": batch.attention_mask.to(self.device)} if ENV.backend == 'atb': output = self.model(**kwargs, return_dict=True) embedding = output[0] else: output = self.model(kwargs["input_ids"], kwargs["attention_mask"]) if isinstance(output, dict): embedding = output['last_hidden_state'].to('cpu') else: embedding = output[0].to('cpu') cpu_results = embedding.contiguous().view(-1).tolist() embedding_result = [] for i in range(len(batch)): base_index = i * batch.max_length * self.hidden_size tmp_embedding = [ Embedding(values=cpu_results[ base_index + j * self.hidden_size: base_index + (j + 1) * self.hidden_size ]) for j in range(batch.input_ids.size(1)) ] token_embeddings = TokenEmbedding(embeddings=tmp_embedding) embedding_result.append(token_embeddings) return embedding_result @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> Optional[List[Prediction]]: """Logs an error indicating that the embedding model does not support the predict function.""" logger.error("embedding model does not support predict function")
from abc import ABC, abstractmethod from typing import List, TypeVar, Type import torch from loguru import logger from text_embeddings_server.models.types import Batch, Embedding, Prediction, TokenEmbedding from ..utils.env import ENV if ENV.backend == 'atb': import torch_npu from atb_llm.utils.initial import NPUSocInfo B = TypeVar("B", bound=Batch) class Model(ABC): """Abstract base class for all model types in TEI-MindIE. This class defines the common interface and properties for all models used in TEI-MindIE, including methods for embedding generation and similarity score prediction. Attributes: model: The model object used for embedding generation or similarity score prediction. device: The device to load the model on (e.g., NPU). dtype: The data type of the model's parameters (e.g., torch.float16). """ def __init__( self, model, dtype: torch.dtype, device: torch.device, ): """ Initializes the Model class based on the given model object, device and dtype. For ATB backend, this function also execute npu_cormat_cast on the model's named_modules according to NPU's SoC information. Args: model: The model object used for embedding generation or similarity score prediction. device: The device to load the model on (e.g., NPU). dtype: The data type of the model's parameters (e.g., torch.float16). """ self.model = model self.dtype = dtype self.device = device if ENV.backend == 'atb': def trans_data(model, soc_info): if not soc_info.need_nz: for _, module in model.named_modules(): if isinstance(module, torch.nn.Linear): module.weight.data = torch_npu.npu_format_cast(module.weight.data, 2) logger.info(f"soc info: {soc_info.soc_version}, support ND") else: for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear): if name == 'lm_head': # eliminate TransData op before lm_head calculation module.weight.data = torch.nn.parameter.Parameter(module.weight.data) module.weight.data = torch_npu.npu_format_cast(module.weight.data, 29) logger.info(f"soc info: {soc_info.soc_version}, support NZ") for _, module in model.named_modules(): if isinstance(module, torch.nn.Embedding): module.weight.data = torch_npu.npu_format_cast(module.weight.data, 2) self.soc_info = NPUSocInfo() trans_data(self.model, self.soc_info) @property @abstractmethod def batch_type(self) -> Type[B]: raise NotImplementedError @abstractmethod def embed(self, batch: B) -> List[Embedding]: raise NotImplementedError @abstractmethod def embed_all(self, batch: B) -> List[TokenEmbedding]: raise NotImplementedError @abstractmethod def predict(self, batch: B) -> List[Prediction]: raise NotImplementedError
from pathlib import Path from typing import Optional, Type, List from opentelemetry import trace from transformers import AutoModelForSequenceClassification from loguru import logger import torch from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding, Prediction, TokenEmbedding from ..utils.env import ENV if ENV.backend == 'atb': import torch_npu else: import mindietorch tracer = trace.get_tracer(__name__) class RerankModel(Model): """ RerankModel is a wrapper around a pre-trained reranker model that predicts similarity scores between given query and texts. It extends the Model class and adds functionality to provide similrity score prediction according to the batch of input data. """ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): """Initializes the RerankModel by loading the pre-trained model based on the given model path and backend. Args: model_path: Path to the pre-trained model. device: The device to load the model on (e.g., NPU). dtype: The data type of the model's parameters (e.g., torch.float16). """ if ENV.backend == 'atb': model = AutoModelForSequenceClassification.from_pretrained( model_path, torch_dtype=dtype, trust_remote_code=True ).to(device).eval() else: mindietorch.set_device(device.index) # 使用MindIE Torch后端时,请确保model_path(模型权重路径)下仅存在一个编译优化后的pt文件,否则请根据实际情况修改正则表达式和pt文件命名 model = torch.jit.load(next(Path(model_path).rglob("*.pt"))).eval() super(RerankModel, self).__init__(model=model, dtype=dtype, device=device) @property def batch_type(self) -> Type[PaddedBatch]: """Returns the class type of the batch that the reranker model expects. Returns: Type[PaddedBatch]: The type of the batch class expected by the reranker model. All tensors in this batch are padded according to the longgest input sequece. """ return PaddedBatch @tracer.start_as_current_span("embed") def embed(self, batch: PaddedBatch) -> Optional[List[Embedding]]: """Logs an error indicating that the reranker model does not support the embed function.""" logger.error("rerank model does not support embed function") @tracer.start_as_current_span("embed_all") def embed_all(self, batch: PaddedBatch) -> Optional[List[TokenEmbedding]]: """Logs an error indicating that the reranker model does not support the embed_all function.""" logger.error("rerank model does not support embed_all function") @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Prediction]: """Predicts the similarity score for each query-text pair in the input batch. Args: batch: The input batch containing tokenized data (e.g., input_ids, attention_mask). Returns: List[Prediction]: A list of predictions, where each prediction represents the similarity score for each query-text pair in the input batch. """ kwargs = {"input_ids": batch.input_ids.to(self.device), "attention_mask": batch.attention_mask.to(self.device)} if ENV.backend == 'atb': scores = self.model(**kwargs, return_dict=True).logits.view(-1, ).float().tolist() return [ Prediction( values=[scores[i]] ) for i in range(len(batch)) ] else: scores = self.model(kwargs["input_ids"], kwargs["attention_mask"])[0].tolist() return [ Prediction( values=scores[i] ) for i in range(len(batch)) ]
from abc import ABC, abstractmethod from dataclasses import dataclass from opentelemetry import trace import torch from text_embeddings_server.pb import embed_pb2 from text_embeddings_server.pb.embed_pb2 import Embedding, Prediction, TokenEmbedding from ..utils.env import ENV if ENV.backend == 'atb': import torch_npu else: import mindietorch tracer = trace.get_tracer(__name__) class Batch(ABC): """Abstract base class for batching input data for embedding and prediction. This class provides an interface for batching input data and converting it from protocol buffer format (EmbedRequest) into a format suitable for the model. Methods: from_pb(pb, device): Converts a protocol buffer format EmbedRequest to a Batch instance. __len__(): Returns the number of sequences in the batch. """ @classmethod @abstractmethod def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch": """Converts an EmbedRequest protocol buffer format to a Batch object. Args: pb: The protocol buffer message containing input data. device: The device to allocate tensors. Returns: Batch: A new instance of a batch (either PaddedBatch or FlashBatch). """ raise NotImplementedError @abstractmethod def __len__(self): """Returns the number of sequences in the batch.""" raise NotImplementedError @dataclass class PaddedBatch(Batch): """Represents a batch of padded input data. This class holds a batch of input sequences, padding them to the same length. The data is represented as tensors, and all sequences are padded to the `max_length`. Attributes: input_ids: Tensor containing input token IDs for each sequence in the batch. token_type_ids: Tensor containing token type IDs for each sequence. position_ids: Tensor containing position IDs for each sequence. attention_mask: Tensor for the attention mask, indicating valid tokens in each sequence. max_length: The maximum sequence length for padding. Methods: from_pb(pb, device): Converts a protocol buffer format EmbedRequest to a PaddedBatch instance. __len__(): Returns the number of sequences in the batch. """ input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor max_length: int @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "PaddedBatch": """Converts an EmbedRequest protocol buffer format to a PaddedBatch instance. Args: pb: The protocol buffer message containing input data. device: The device to allocate tensors. Returns: PaddedBatch: A PaddedBatch instance containing padded input data. """ # Allocate padded tensors all at once all_tensors = torch.zeros( [4, len(pb.cu_seq_lengths) - 1, pb.max_length], dtype=torch.int32, device='cpu' ) max_length=pb.max_length for i, start_index in enumerate(pb.cu_seq_lengths[:-1]): end_index = pb.cu_seq_lengths[i + 1] input_length = end_index - start_index all_tensors[0, i, :input_length] = torch.tensor( pb.input_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[1, i, :input_length] = torch.tensor( pb.token_type_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[2, i, :input_length] = torch.tensor( pb.position_ids[start_index:end_index], dtype=torch.int32 ) all_tensors[3, i, :input_length] = 1 """ # Move padded tensors all at once all_tensors = all_tensors.to(device) """ return PaddedBatch( input_ids=all_tensors[0], token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], max_length=max_length, ) def __len__(self): """Returns the number of sequences in the batch.""" return len(self.input_ids) @dataclass class FlashBatch(Batch): """Represents a batch of input data for flash inference. This class is used for models that support flash-based batching, where input data is packed into larger sequences for efficient processing. Attributes: input_ids: Tensor containing input token IDs for each sequence in the batch. token_type_ids: Tensor containing token type IDs for each sequence. position_ids: Tensor containing position IDs for each sequence. cu_seqlens: Tensor containing sequence lengths for each batch element. max_s: The maximum sequence length for this batch. size: The number of sequences in the batch. Methods: from_pb(pb, device): Converts a protocol buffer format EmbedRequest to a FlashBatch instance. __len__(): Returns the number of sequences in the batch. """ input_ids: torch.Tensor token_type_ids: torch.Tensor position_ids: torch.Tensor cu_seqlens: torch.Tensor max_s: int size: int @classmethod @tracer.start_as_current_span("from_pb") def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "FlashBatch": """Converts an EmbedRequest protocol buffer format to a FlashBatch instance. Args: pb: The protocol buffer message containing input data. device: The device to allocate tensors. Returns: FlashBatch: A FlashBatch instance containing input data for a FlashBert model. Raises: RuntimeError: If the device is not 'cuda', FlashBatch is not supported. """ if device.type != "cuda": raise RuntimeError(f"FlashBatch does not support device {device}") batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device) batch_token_type_ids = torch.tensor( pb.token_type_ids, dtype=torch.int32, device=device ) batch_position_ids = torch.tensor( pb.position_ids, dtype=torch.int32, device=device ) cu_seqlens = torch.tensor(pb.cu_seq_lengths, dtype=torch.int32, device=device) return FlashBatch( input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, position_ids=batch_position_ids, cu_seqlens=cu_seqlens, max_s=pb.max_length, size=len(cu_seqlens) - 1, ) def __len__(self): """Returns the number of sequences in the batch.""" return self.size