bge_m3_embedding_server/embedder/pool.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! `EmbedPool` async wrapper around the worker thread pool.
16
17use std::path::PathBuf;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use anyhow::Result;
22use tokio::sync::{mpsc, oneshot, Mutex};
23use tokio::task::JoinHandle;
24use tracing::{info, info_span, Instrument};
25
26use super::math::median_usize;
27use super::types::{DualEmbedding, EmbedRequest, EmbedStats, ProbeResult, SparseEmbedding};
28use super::worker::{run_worker, WorkerConfig};
29
30/// Async handle to the embedding worker thread pool.
31///
32/// Wraps a bounded `mpsc` channel shared by `n` `spawn_blocking` worker threads.
33/// Each worker owns its own ORT session and tokenizer; the pool dispatches
34/// `EmbedRequest` variants to whichever worker is free next.
35///
36/// Clone is cheap — the underlying channel sender and atomic counters are
37/// reference-counted.
38#[derive(Clone)]
39pub struct EmbedPool {
40 tx: mpsc::Sender<EmbedRequest>,
41 live_workers: Arc<AtomicUsize>,
42 /// Number of workers that currently have model instances loaded in memory.
43 loaded_workers: Arc<AtomicUsize>,
44 /// Median RSS delta (bytes) measured across all workers during sequential
45 /// model load.
46 ///
47 /// Workers load one at a time (leader first, then followers in sequence).
48 /// Each reports its own RSS before/after `load_models()` via `ready_tx`.
49 /// The pool stores the median once all workers have signaled ready — robust
50 /// to one outlier from page-cache settling or ORT arena init jitter.
51 ///
52 /// Used by `run_readiness_probe` to correctly deduct the model-weight
53 /// footprint from the available workspace before computing per-worker
54 /// budget. Returns `0` on non-Linux targets where RSS measurement is
55 /// unavailable, or before the init task has completed.
56 model_rss_per_worker_bytes: Arc<AtomicUsize>,
57}
58
59impl EmbedPool {
60 /// Spawns `n` embedding worker threads and returns the pool plus an init
61 /// handle that resolves once all workers have finished loading their models.
62 pub fn spawn(
63 n: usize,
64 cache_dir: PathBuf,
65 config: WorkerConfig,
66 ) -> (Self, JoinHandle<Result<()>>) {
67 let capacity = n * 4;
68 let (tx, rx) = mpsc::channel::<EmbedRequest>(capacity);
69 let rx = Arc::new(Mutex::new(rx));
70
71 // Channel carries Result<usize> where the Ok variant is the RSS delta
72 // (bytes) measured by each worker around load_models().
73 let (ready_tx, mut ready_rx) = mpsc::channel::<Result<usize>>(n);
74
75 let live_workers = Arc::new(AtomicUsize::new(n));
76 let loaded_workers = Arc::new(AtomicUsize::new(0));
77 let model_rss_per_worker_bytes = Arc::new(AtomicUsize::new(0));
78 let live_workers_for_init = Arc::clone(&live_workers);
79 let loaded_workers_for_init = Arc::clone(&loaded_workers);
80 let model_rss_for_init = Arc::clone(&model_rss_per_worker_bytes);
81
82 let init_handle = tokio::task::spawn(
83 async move {
84 let mut worker_handles = Vec::with_capacity(n);
85
86 let spawn_worker = |id: usize,
87 ready_tx_clone: mpsc::Sender<Result<usize>>,
88 worker_config: WorkerConfig|
89 -> JoinHandle<Result<()>> {
90 let rx_clone = Arc::clone(&rx);
91 let cache_dir_clone = cache_dir.clone();
92 let live_for_worker = Arc::clone(&live_workers_for_init);
93 let loaded_for_worker = Arc::clone(&loaded_workers_for_init);
94 tokio::task::spawn_blocking(move || {
95 run_worker(
96 id,
97 cache_dir_clone,
98 rx_clone,
99 ready_tx_clone,
100 live_for_worker,
101 loaded_for_worker,
102 worker_config,
103 )
104 })
105 };
106
107 // Collect per-worker RSS deltas for median aggregation.
108 // Median is robust to one outlier from transient kernel snapshot
109 // quirk (page-cache settling, ORT arena init jitter) while still
110 // using all N independent measurements.
111 let mut rss_deltas: Vec<usize> = Vec::with_capacity(n);
112
113 // --- Phase 1: spawn leader worker (may download models) ---
114 worker_handles.push(spawn_worker(0, ready_tx.clone(), config.clone()));
115
116 match ready_rx.recv().await {
117 Some(Ok(delta)) => {
118 loaded_workers_for_init.fetch_add(1, Ordering::AcqRel);
119 rss_deltas.push(delta);
120 info!(
121 rss_delta_mb = delta / (1024 * 1024),
122 "Leader worker ready, model cache warm (1/{n})"
123 );
124 }
125 Some(Err(e)) => {
126 return Err(anyhow::anyhow!("Leader worker failed to load models: {e}"));
127 }
128 None => {
129 return Err(anyhow::anyhow!(
130 "Leader worker exited before signaling readiness"
131 ));
132 }
133 }
134
135 // --- Phase 2: spawn follower workers one at a time.
136 //
137 // Workers load sequentially: spawn one, await its ready signal,
138 // then spawn the next. This ensures each worker's RSS delta
139 // (pre/post load_models) reflects only that worker's ORT session
140 // allocation — not the cumulative effect of other workers loading
141 // in parallel. Parallel loading caused the 2026-05-09 measurement
142 // contamination bug: all followers read post_load_rss after most
143 // other sessions had already mmap'd, producing an inflated
144 // rss_delta ≈ N × model_size and driving per_worker_workspace to 0.
145 //
146 // Startup cost: ~4-6s per worker × 6 followers ≈ 24-36s total,
147 // well within the configured startPeriod (300s).
148 for id in 1..n {
149 worker_handles.push(spawn_worker(id, ready_tx.clone(), config.clone()));
150
151 match ready_rx.recv().await {
152 Some(Ok(delta)) => {
153 loaded_workers_for_init.fetch_add(1, Ordering::AcqRel);
154 rss_deltas.push(delta);
155 info!(
156 rss_delta_mb = delta / (1024 * 1024),
157 "Follower worker signaled ready ({}/{n})",
158 id + 1
159 );
160 }
161 Some(Err(e)) => {
162 return Err(anyhow::anyhow!("Worker {id} failed to load models: {e}"));
163 }
164 None => {
165 return Err(anyhow::anyhow!(
166 "Worker {id} exited before signaling readiness ({id}/{n})"
167 ));
168 }
169 }
170 }
171
172 drop(ready_tx);
173 drop(worker_handles);
174
175 // Compute and store the median delta as the per-worker model footprint.
176 let median = median_usize(&mut rss_deltas);
177 model_rss_for_init.store(median, Ordering::Release);
178 info!(
179 median_rss_mb = median / (1024 * 1024),
180 samples = rss_deltas.len(),
181 "All workers ready — per-worker model RSS median computed"
182 );
183
184 Ok(())
185 }
186 .instrument(info_span!("embed_pool")),
187 );
188
189 (
190 Self {
191 tx,
192 live_workers,
193 loaded_workers,
194 model_rss_per_worker_bytes,
195 },
196 init_handle,
197 )
198 }
199
200 /// Runs dense (float32) embedding inference on `texts`.
201 ///
202 /// # Errors
203 ///
204 /// - Returns `Err` if the worker channel has closed (pool shut down).
205 /// - Returns `Err` if the worker drops the reply sender before responding.
206 /// - Returns `Err` if the ORT session fails during inference.
207 pub async fn dense(&self, texts: Vec<String>) -> Result<(Vec<Vec<f32>>, EmbedStats)> {
208 let (reply_tx, reply_rx) = oneshot::channel();
209 self.tx
210 .send(EmbedRequest::Dense {
211 texts,
212 reply: reply_tx,
213 })
214 .await
215 .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
216 reply_rx
217 .await
218 .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
219 }
220
221 /// Runs sparse (SPLADE-style) embedding inference on `texts`.
222 ///
223 /// # Errors
224 ///
225 /// - Returns `Err` if the worker channel has closed (pool shut down).
226 /// - Returns `Err` if the worker drops the reply sender before responding.
227 /// - Returns `Err` if the ORT session fails during inference.
228 pub async fn sparse(&self, texts: Vec<String>) -> Result<(Vec<SparseEmbedding>, EmbedStats)> {
229 let (reply_tx, reply_rx) = oneshot::channel();
230 self.tx
231 .send(EmbedRequest::Sparse {
232 texts,
233 reply: reply_tx,
234 })
235 .await
236 .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
237 reply_rx
238 .await
239 .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
240 }
241
242 /// Runs a single forward pass that yields both dense and sparse embeddings.
243 ///
244 /// Equivalent to calling [`Self::dense`] and [`Self::sparse`] back-to-back,
245 /// but uses one `session.run()` per chunk instead of two — at near-zero
246 /// marginal GPU cost.
247 ///
248 /// # Errors
249 ///
250 /// - Returns `Err` if the worker channel has closed (pool shut down).
251 /// - Returns `Err` if the worker drops the reply sender before responding.
252 /// - Returns `Err` if the ORT session fails during inference.
253 pub async fn both(&self, texts: Vec<String>) -> Result<(Vec<DualEmbedding>, EmbedStats)> {
254 let (reply_tx, reply_rx) = oneshot::channel();
255 self.tx
256 .send(EmbedRequest::Both {
257 texts,
258 reply: reply_tx,
259 })
260 .await
261 .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
262 reply_rx
263 .await
264 .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
265 }
266
267 /// Sends a probe request to a single worker and returns the result.
268 /// Only called during init before `ready` is set.
269 pub(crate) async fn probe(&self, texts: Vec<String>) -> Result<ProbeResult> {
270 let (reply_tx, reply_rx) = oneshot::channel();
271 self.tx
272 .send(EmbedRequest::Probe {
273 texts,
274 reply: reply_tx,
275 })
276 .await
277 .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
278 reply_rx
279 .await
280 .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
281 }
282
283 #[must_use]
284 /// Returns the number of worker threads currently alive (not yet exited).
285 pub fn live_worker_count(&self) -> usize {
286 self.live_workers.load(Ordering::Acquire)
287 }
288
289 #[must_use]
290 /// Returns the number of workers that currently have model instances loaded in memory.
291 ///
292 /// A worker transitions from loaded to unloaded after the [`crate::config::Config::idle_timeout`]
293 /// elapses with no incoming requests, and back to loaded on the next request.
294 pub fn loaded_worker_count(&self) -> usize {
295 self.loaded_workers.load(Ordering::Acquire)
296 }
297
298 /// Returns the number of requests currently queued but not yet picked up
299 /// by a worker. Uses the channel's current vs max capacity.
300 #[must_use]
301 pub fn queue_depth(&self) -> usize {
302 self.tx.max_capacity().saturating_sub(self.tx.capacity())
303 }
304
305 /// Returns the median RSS delta (bytes) measured across all workers during
306 /// sequential model load.
307 ///
308 /// This is the per-worker model-weight footprint used by
309 /// `run_readiness_probe` to compute the per-worker workspace budget.
310 /// Returns `0` on non-Linux targets where RSS measurement is unavailable,
311 /// or before the init task has completed.
312 #[must_use]
313 pub fn model_rss_per_worker_bytes(&self) -> usize {
314 self.model_rss_per_worker_bytes.load(Ordering::Acquire)
315 }
316}
317
318// ---------------------------------------------------------------------------
319// Test helpers (cfg(test)-gated)
320// ---------------------------------------------------------------------------
321
322#[cfg(test)]
323impl EmbedPool {
324 pub(crate) fn closed_for_test() -> Self {
325 let (tx, rx) = mpsc::channel::<EmbedRequest>(1);
326 drop(rx);
327 Self {
328 tx,
329 live_workers: Arc::new(AtomicUsize::new(0)),
330 loaded_workers: Arc::new(AtomicUsize::new(0)),
331 model_rss_per_worker_bytes: Arc::new(AtomicUsize::new(0)),
332 }
333 }
334
335 pub(crate) fn with_fixed_responses(
336 dense_fixture: Vec<Vec<f32>>,
337 sparse_fixture: Vec<SparseEmbedding>,
338 ) -> Self {
339 let (tx, mut rx) = mpsc::channel::<EmbedRequest>(8);
340 let dense = Arc::new(dense_fixture);
341 let sparse = Arc::new(sparse_fixture);
342 let dense_both = Arc::clone(&dense);
343 let sparse_both = Arc::clone(&sparse);
344 tokio::spawn(async move {
345 while let Some(req) = rx.recv().await {
346 match req {
347 EmbedRequest::Dense { reply, .. } => {
348 let _ = reply.send(Ok(((*dense).clone(), EmbedStats::default())));
349 }
350 EmbedRequest::Sparse { reply, .. } => {
351 let _ = reply.send(Ok(((*sparse).clone(), EmbedStats::default())));
352 }
353 EmbedRequest::Both { reply, .. } => {
354 // Pair dense_fixture[i] with sparse_fixture[i] elementwise.
355 // Truncate to the shorter of the two so the test fixture is
356 // self-consistent.
357 let pairs: Vec<DualEmbedding> = dense_both
358 .iter()
359 .zip(sparse_both.iter())
360 .map(|(d, s)| DualEmbedding {
361 dense: d.clone(),
362 sparse: s.clone(),
363 })
364 .collect();
365 let _ = reply.send(Ok((pairs, EmbedStats::default())));
366 }
367 EmbedRequest::Probe { reply, .. } => {
368 let _ = reply.send(Ok(ProbeResult {
369 rss_before: 0,
370 rss_after: 0,
371 }));
372 }
373 }
374 }
375 });
376 Self {
377 tx,
378 live_workers: Arc::new(AtomicUsize::new(1)),
379 loaded_workers: Arc::new(AtomicUsize::new(1)),
380 model_rss_per_worker_bytes: Arc::new(AtomicUsize::new(0)),
381 }
382 }
383
384 pub(crate) fn idle_for_test() -> Self {
385 let (tx, _rx) = mpsc::channel::<EmbedRequest>(1);
386 Self {
387 tx,
388 live_workers: Arc::new(AtomicUsize::new(1)),
389 loaded_workers: Arc::new(AtomicUsize::new(0)),
390 model_rss_per_worker_bytes: Arc::new(AtomicUsize::new(0)),
391 }
392 }
393
394 /// Returns the raw `Arc<AtomicUsize>` backing `model_rss_per_worker_bytes`.
395 ///
396 /// Test-only; allows injecting a specific value to assert aggregation logic
397 /// without running actual model loads.
398 pub(crate) fn model_rss_per_worker_bytes_atomic(&self) -> Arc<AtomicUsize> {
399 Arc::clone(&self.model_rss_per_worker_bytes)
400 }
401}
402
403#[cfg(test)]
404mod tests;