Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 91 additions & 73 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,123 +1,141 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef
WORKDIR /usr/src
# Dockerfile for TEI with Python backend and CUDA support
# Supports: L40s (sm_89), RTX 3090 (sm_86)

# =============================================================================
# Stage 1: Rust Builder
# =============================================================================
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 AS rust-builder

ENV SCCACHE=0.10.0
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
ENV PATH="/root/.cargo/bin:${PATH}"
ENV CARGO_CHEF=0.1.71

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
libssl-dev \
pkg-config \
protobuf-compiler \
&& rm -rf /var/lib/apt/lists/*

# Donwload, configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
RUN cargo install cargo-chef --version $CARGO_CHEF --locked

# =============================================================================
# Stage 2: Recipe Planner
# =============================================================================
FROM rust-builder AS planner

WORKDIR /usr/src

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json
RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder
# =============================================================================
# Stage 3: Dependency Builder
# =============================================================================
FROM rust-builder AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG SCCACHE_GHA_ENABLED

RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \
tee /etc/apt/sources.list.d/oneAPI.list

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
intel-oneapi-mkl-devel=2024.0.0-49656 \
build-essential \
&& rm -rf /var/lib/apt/lists/*

RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \
gcc -shared -fPIC -o libfakeintel.so fakeintel.c
WORKDIR /usr/src

COPY --from=planner /usr/src/recipe.json recipe.json

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo chef cook --release --features ort,candle,mkl,static-linking --no-default-features --recipe-path recipe.json && sccache -s
RUN cargo chef cook --release --features python --features http --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

FROM builder AS http-builder
RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

# =============================================================================
# Stage 4: Python Environment
# =============================================================================
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS python-builder

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,http --no-default-features && sccache -s
ENV DEBIAN_FRONTEND=noninteractive

FROM builder AS grpc-builder
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.10 \
python3.10-dev \
python3-pip \
git \
&& rm -rf /var/lib/apt/lists/*

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
ln -sf /usr/bin/python3.10 /usr/bin/python3

COPY proto proto
RUN pip install --no-cache-dir --upgrade pip setuptools wheel

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router --features ort,candle,mkl,static-linking,grpc --no-default-features && sccache -s
WORKDIR /opt/server

FROM debian:bookworm-slim AS base
COPY backends/proto /opt/proto
COPY backends/python/server /opt/server

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80 \
MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \
RAYON_NUM_THREADS=8 \
LD_PRELOAD=/usr/local/libfakeintel.so \
LD_LIBRARY_PATH=/usr/local/lib
RUN pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir && \
mkdir -p text_embeddings_server/pb && \
python -m grpc_tools.protoc -I/opt/proto --python_out=text_embeddings_server/pb \
--grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb /opt/proto/embed.proto && \
find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; && \
touch text_embeddings_server/pb/__init__.py

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
libomp-dev \
RUN pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu124

RUN pip install --no-cache-dir -r requirements.txt

RUN pip install --no-cache-dir .

# =============================================================================
# Stage 5: Final Image
# =============================================================================
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV HUGGINGFACE_HUB_CACHE=/data
ENV PORT=80
ENV TQDM_DISABLE=1

RUN apt-get update && apt-get install -y --no-install-recommends \
python3.10 \
python3-pip \
ca-certificates \
libssl-dev \
curl \
&& rm -rf /var/lib/apt/lists/*

# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch...
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2
COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2
COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so

FROM base AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
ln -sf /usr/bin/python3.10 /usr/bin/python3

FROM base AS http
COPY --from=python-builder /usr/local/lib/python3.10/dist-packages /usr/local/lib/python3.10/dist-packages
COPY --from=python-builder /usr/local/bin/python-text-embeddings-server /usr/local/bin/python-text-embeddings-server
COPY --from=python-builder /opt/server /opt/server

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
COPY --from=builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

# Amazon SageMaker compatible image
FROM http AS sagemaker
COPY --chmod=775 sagemaker-entrypoint.sh entrypoint.sh
ENV PATH="/usr/local/bin:${PATH}"
ENV PYTHONPATH="/opt/server:${PYTHONPATH}"

ENTRYPOINT ["./entrypoint.sh"]
# Download spacy model in final image (ensures it's available at runtime)
# This is needed because spacy models may not be fully copied from builder stage
RUN pip install --no-cache-dir spacy>=3.7.0 && \
python -m spacy download xx_sent_ud_sm && \
python -c "import spacy; spacy.load('xx_sent_ud_sm')" && \
echo "Spacy model verified successfully"

# Default image
FROM http
WORKDIR /opt/server

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
3 changes: 0 additions & 3 deletions assets/bs1-lat.png

This file was deleted.

3 changes: 0 additions & 3 deletions assets/bs1-tp.png

This file was deleted.

3 changes: 0 additions & 3 deletions assets/bs32-lat.png

This file was deleted.

3 changes: 0 additions & 3 deletions assets/bs32-tp.png

This file was deleted.

7 changes: 5 additions & 2 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::{de::Deserializer, Deserialize};
use std::collections::HashMap;
use std::path::Path;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Prediction, Predictions,
};

#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -653,7 +653,10 @@ impl Backend for CandleBackend {
let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in results.into_iter().enumerate() {
predictions.insert(i, r);
predictions.insert(i, Prediction {
scores: r,
pruned_text: None,
});
}

Ok(predictions)
Expand Down
15 changes: 14 additions & 1 deletion backends/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ pub struct Batch {
pub max_length: u32,
pub pooled_indices: Vec<u32>,
pub raw_indices: Vec<u32>,
/// XProvence: raw query texts for context pruning
pub raw_queries: Vec<Option<String>>,
/// XProvence: raw context texts for context pruning
pub raw_texts: Vec<Option<String>>,
}

impl Batch {
Expand All @@ -32,7 +36,16 @@ pub enum Embedding {
}

pub type Embeddings = IntMap<usize, Embedding>;
pub type Predictions = IntMap<usize, Vec<f32>>;

/// XProvence: Prediction result containing scores and optional pruned text
#[derive(Debug, Clone)]
pub struct Prediction {
pub scores: Vec<f32>,
/// XProvence: pruned context text after removing irrelevant sentences
pub pruned_text: Option<String>,
}

pub type Predictions = IntMap<usize, Prediction>;

pub trait Backend {
fn health(&self) -> Result<(), BackendError>;
Expand Down
6 changes: 6 additions & 0 deletions backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ impl Client {
position_ids,
max_length,
cu_seq_lengths,
raw_queries: vec![],
raw_texts: vec![],
})
.inject_context();
let response = self.stub.embed(request).await?.into_inner();
Expand All @@ -73,13 +75,17 @@ impl Client {
position_ids: Vec<u32>,
cu_seq_lengths: Vec<u32>,
max_length: u32,
raw_queries: Vec<String>,
raw_texts: Vec<String>,
) -> Result<Vec<Score>> {
let request = tonic::Request::new(EmbedRequest {
input_ids,
token_type_ids,
position_ids,
max_length,
cu_seq_lengths,
raw_queries,
raw_texts,
})
.inject_context();
let response = self.stub.predict(request).await?.into_inner();
Expand Down
7 changes: 5 additions & 2 deletions backends/ort/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::ops::{Div, Mul};
use std::path::Path;
use std::sync::Mutex;
use text_embeddings_backend_core::{
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Prediction, Predictions,
};

#[derive(Debug, Clone, Deserialize)]
Expand Down Expand Up @@ -679,7 +679,10 @@ impl Backend for OrtBackend {
let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in outputs.rows().into_iter().enumerate() {
predictions.insert(i, r.to_vec());
predictions.insert(i, Prediction {
scores: r.to_vec(),
pruned_text: None,
});
}

Ok(predictions)
Expand Down
6 changes: 6 additions & 0 deletions backends/proto/embed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ message EmbedRequest {
repeated uint32 cu_seq_lengths = 4;
/// Length of the longest request
uint32 max_length = 5;
/// XProvence: raw query texts for context pruning (one per batch item)
repeated string raw_queries = 6;
/// XProvence: raw context texts for context pruning (one per batch item)
repeated string raw_texts = 7;
}

message Embedding {
Expand All @@ -33,6 +37,8 @@ message EmbedResponse {

message Score {
repeated float values = 1;
/// XProvence: pruned context text after removing irrelevant sentences
optional string pruned_text = 2;
}

message PredictResponse {
Expand Down
1 change: 1 addition & 0 deletions backends/python/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
spacy>=3.7.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.6.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
Expand Down
Loading