Skip to main content

bge_m3_embedding_server/handler/
both.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! `POST /v1/embeddings:both` handler — dense + sparse embeddings in one pass.
16
17use std::sync::Arc;
18use std::time::Instant;
19
20use axum::{extract::State, Json};
21
22use super::common::{check_ready, validate_input};
23use crate::error::AppError;
24use crate::models::{DualEmbeddingData, DualRequest, DualResponse, SparseValues, Usage};
25use crate::state::AppState;
26
27/// Handles `POST /v1/embeddings:both` — returns dense and sparse embeddings in one pass.
28///
29/// # Errors
30///
31/// - [`AppError::ServiceUnavailable`] if the model is not ready or no workers are live.
32/// - [`AppError::InvalidRequest`] if the batch is empty, exceeds `max_batch`, or any
33///   text exceeds the per-string character limit.
34/// - [`AppError::Internal`] if the embedding pool returns an inference error.
35///
36/// # Panics
37///
38/// Panics if the request semaphore has been closed — should not occur in normal operation.
39#[allow(clippy::cast_possible_truncation)]
40#[tracing::instrument(
41    skip(state, req),
42    fields(
43        batch_size,
44        prompt_tokens,
45        chunks,
46        max_chunk_seq,
47        tokenize_ms,
48        inference_ms,
49        total_ms
50    )
51)]
52pub async fn both_embeddings(
53    State(state): State<Arc<AppState>>,
54    Json(req): Json<DualRequest>,
55) -> Result<Json<DualResponse>, AppError> {
56    check_ready(&state)?;
57    let texts = req.input.0;
58    drop(req.model);
59    validate_input(&texts, state.max_batch)?;
60    let batch_size = texts.len();
61    tracing::Span::current().record("batch_size", batch_size);
62
63    let prompt_tokens: usize = texts.iter().map(|t| t.chars().count() / 4 + 1).sum();
64    tracing::Span::current().record("prompt_tokens", prompt_tokens);
65
66    let t0 = Instant::now();
67
68    let _permit = Arc::clone(&state.request_permits)
69        .acquire_owned()
70        .await
71        .expect("request semaphore is never closed");
72
73    let (pairs, embed_stats) = state.pool.both(texts).await?;
74
75    let total_ms = u64::try_from(t0.elapsed().as_millis()).unwrap_or(u64::MAX);
76    tracing::Span::current()
77        .record("chunks", embed_stats.chunks)
78        .record("max_chunk_seq", embed_stats.max_chunk_seq)
79        .record("tokenize_ms", embed_stats.tokenize_ms)
80        .record("inference_ms", embed_stats.inference_ms)
81        .record("total_ms", total_ms);
82    tracing::info!(
83        route = "both",
84        batch_size,
85        prompt_tokens,
86        chunks = embed_stats.chunks,
87        max_chunk_seq = embed_stats.max_chunk_seq,
88        total_token_positions = embed_stats.total_token_positions,
89        tokenize_ms = embed_stats.tokenize_ms,
90        inference_ms = embed_stats.inference_ms,
91        total_ms,
92        "embedding request complete"
93    );
94
95    let data = pairs
96        .into_iter()
97        .enumerate()
98        .map(|(index, pair)| DualEmbeddingData {
99            index,
100            embedding: pair.dense,
101            sparse_values: SparseValues {
102                indices: pair.sparse.indices.iter().map(|i| *i as u32).collect(),
103                values: pair.sparse.values,
104            },
105        })
106        .collect();
107
108    Ok(Json(DualResponse {
109        object: "list",
110        model: "bge-m3",
111        data,
112        usage: Usage {
113            prompt_tokens,
114            total_tokens: prompt_tokens,
115        },
116    }))
117}