bge_m3_embedding_server/handler/
both.rs1use std::sync::Arc;
18use std::time::Instant;
19
20use axum::{extract::State, Json};
21
22use super::common::{check_ready, validate_input};
23use crate::error::AppError;
24use crate::models::{DualEmbeddingData, DualRequest, DualResponse, SparseValues, Usage};
25use crate::state::AppState;
26
27#[allow(clippy::cast_possible_truncation)]
40#[tracing::instrument(
41 skip(state, req),
42 fields(
43 batch_size,
44 prompt_tokens,
45 chunks,
46 max_chunk_seq,
47 tokenize_ms,
48 inference_ms,
49 total_ms
50 )
51)]
52pub async fn both_embeddings(
53 State(state): State<Arc<AppState>>,
54 Json(req): Json<DualRequest>,
55) -> Result<Json<DualResponse>, AppError> {
56 check_ready(&state)?;
57 let texts = req.input.0;
58 drop(req.model);
59 validate_input(&texts, state.max_batch)?;
60 let batch_size = texts.len();
61 tracing::Span::current().record("batch_size", batch_size);
62
63 let prompt_tokens: usize = texts.iter().map(|t| t.chars().count() / 4 + 1).sum();
64 tracing::Span::current().record("prompt_tokens", prompt_tokens);
65
66 let t0 = Instant::now();
67
68 let _permit = Arc::clone(&state.request_permits)
69 .acquire_owned()
70 .await
71 .expect("request semaphore is never closed");
72
73 let (pairs, embed_stats) = state.pool.both(texts).await?;
74
75 let total_ms = u64::try_from(t0.elapsed().as_millis()).unwrap_or(u64::MAX);
76 tracing::Span::current()
77 .record("chunks", embed_stats.chunks)
78 .record("max_chunk_seq", embed_stats.max_chunk_seq)
79 .record("tokenize_ms", embed_stats.tokenize_ms)
80 .record("inference_ms", embed_stats.inference_ms)
81 .record("total_ms", total_ms);
82 tracing::info!(
83 route = "both",
84 batch_size,
85 prompt_tokens,
86 chunks = embed_stats.chunks,
87 max_chunk_seq = embed_stats.max_chunk_seq,
88 total_token_positions = embed_stats.total_token_positions,
89 tokenize_ms = embed_stats.tokenize_ms,
90 inference_ms = embed_stats.inference_ms,
91 total_ms,
92 "embedding request complete"
93 );
94
95 let data = pairs
96 .into_iter()
97 .enumerate()
98 .map(|(index, pair)| DualEmbeddingData {
99 index,
100 embedding: pair.dense,
101 sparse_values: SparseValues {
102 indices: pair.sparse.indices.iter().map(|i| *i as u32).collect(),
103 values: pair.sparse.values,
104 },
105 })
106 .collect();
107
108 Ok(Json(DualResponse {
109 object: "list",
110 model: "bge-m3",
111 data,
112 usage: Usage {
113 prompt_tokens,
114 total_tokens: prompt_tokens,
115 },
116 }))
117}