bge_m3_embedding_server/handler/
sparse.rs1use std::sync::Arc;
18use std::time::Instant;
19
20use axum::{extract::State, Json};
21
22use super::common::{check_ready, validate_input};
23use crate::error::AppError;
24use crate::models::{SparseEmbeddingData, SparseRequest, SparseResponse, SparseValues};
25use crate::state::AppState;
26
27#[allow(clippy::cast_possible_truncation)]
40#[tracing::instrument(
41 skip(state, req),
42 fields(batch_size, chunks, max_chunk_seq, tokenize_ms, inference_ms, total_ms)
43)]
44pub async fn sparse_embeddings(
45 State(state): State<Arc<AppState>>,
46 Json(req): Json<SparseRequest>,
47) -> Result<Json<SparseResponse>, AppError> {
48 check_ready(&state)?;
49 let texts = req.input.0;
50 validate_input(&texts, state.max_batch)?;
51 let batch_size = texts.len();
52 tracing::Span::current().record("batch_size", batch_size);
53
54 let t0 = Instant::now();
55
56 let _permit = Arc::clone(&state.request_permits)
57 .acquire_owned()
58 .await
59 .expect("request semaphore is never closed");
60
61 let (embeddings, embed_stats) = state.pool.sparse(texts).await?;
62
63 let total_ms = u64::try_from(t0.elapsed().as_millis()).unwrap_or(u64::MAX);
64 tracing::Span::current()
65 .record("chunks", embed_stats.chunks)
66 .record("max_chunk_seq", embed_stats.max_chunk_seq)
67 .record("tokenize_ms", embed_stats.tokenize_ms)
68 .record("inference_ms", embed_stats.inference_ms)
69 .record("total_ms", total_ms);
70 tracing::info!(
71 route = "sparse",
72 batch_size,
73 chunks = embed_stats.chunks,
74 max_chunk_seq = embed_stats.max_chunk_seq,
75 total_token_positions = embed_stats.total_token_positions,
76 tokenize_ms = embed_stats.tokenize_ms,
77 inference_ms = embed_stats.inference_ms,
78 total_ms,
79 "embedding request complete"
80 );
81
82 let data = embeddings
83 .into_iter()
84 .enumerate()
85 .map(|(index, emb)| SparseEmbeddingData {
86 index,
87 sparse_values: SparseValues {
88 indices: emb.indices.iter().map(|i| *i as u32).collect(),
89 values: emb.values,
90 },
91 })
92 .collect();
93
94 Ok(Json(SparseResponse { data }))
95}