Skip to main content

bge_m3_embedding_server/embedder/
pool.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! `EmbedPool` async wrapper around the worker thread pool.
16
17use std::path::PathBuf;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20
21use anyhow::Result;
22use tokio::sync::{mpsc, oneshot, Mutex};
23use tokio::task::JoinHandle;
24use tracing::{info, info_span, Instrument};
25
26use super::math::median_usize;
27use super::types::{DualEmbedding, EmbedRequest, EmbedStats, ProbeResult, SparseEmbedding};
28use super::worker::{run_worker, WorkerConfig};
29
30/// Async handle to the embedding worker thread pool.
31///
32/// Wraps a bounded `mpsc` channel shared by `n` `spawn_blocking` worker threads.
33/// Each worker owns its own ORT session and tokenizer; the pool dispatches
34/// `EmbedRequest` variants to whichever worker is free next.
35///
36/// Clone is cheap — the underlying channel sender and atomic counters are
37/// reference-counted.
38#[derive(Clone)]
39pub struct EmbedPool {
40    tx: mpsc::Sender<EmbedRequest>,
41    live_workers: Arc<AtomicUsize>,
42    /// Number of workers that currently have model instances loaded in memory.
43    loaded_workers: Arc<AtomicUsize>,
44    /// Median RSS delta (bytes) measured across all workers during sequential
45    /// model load.
46    ///
47    /// Workers load one at a time (leader first, then followers in sequence).
48    /// Each reports its own RSS before/after `load_models()` via `ready_tx`.
49    /// The pool stores the median once all workers have signaled ready — robust
50    /// to one outlier from page-cache settling or ORT arena init jitter.
51    ///
52    /// Used by `run_readiness_probe` to correctly deduct the model-weight
53    /// footprint from the available workspace before computing per-worker
54    /// budget. Returns `0` on non-Linux targets where RSS measurement is
55    /// unavailable, or before the init task has completed.
56    model_rss_per_worker_bytes: Arc<AtomicUsize>,
57}
58
59impl EmbedPool {
60    /// Spawns `n` embedding worker threads and returns the pool plus an init
61    /// handle that resolves once all workers have finished loading their models.
62    pub fn spawn(
63        n: usize,
64        cache_dir: PathBuf,
65        config: WorkerConfig,
66    ) -> (Self, JoinHandle<Result<()>>) {
67        let capacity = n * 4;
68        let (tx, rx) = mpsc::channel::<EmbedRequest>(capacity);
69        let rx = Arc::new(Mutex::new(rx));
70
71        // Channel carries Result<usize> where the Ok variant is the RSS delta
72        // (bytes) measured by each worker around load_models().
73        let (ready_tx, mut ready_rx) = mpsc::channel::<Result<usize>>(n);
74
75        let live_workers = Arc::new(AtomicUsize::new(n));
76        let loaded_workers = Arc::new(AtomicUsize::new(0));
77        let model_rss_per_worker_bytes = Arc::new(AtomicUsize::new(0));
78        let live_workers_for_init = Arc::clone(&live_workers);
79        let loaded_workers_for_init = Arc::clone(&loaded_workers);
80        let model_rss_for_init = Arc::clone(&model_rss_per_worker_bytes);
81
82        let init_handle = tokio::task::spawn(
83            async move {
84                let mut worker_handles = Vec::with_capacity(n);
85
86                let spawn_worker = |id: usize,
87                                    ready_tx_clone: mpsc::Sender<Result<usize>>,
88                                    worker_config: WorkerConfig|
89                 -> JoinHandle<Result<()>> {
90                    let rx_clone = Arc::clone(&rx);
91                    let cache_dir_clone = cache_dir.clone();
92                    let live_for_worker = Arc::clone(&live_workers_for_init);
93                    let loaded_for_worker = Arc::clone(&loaded_workers_for_init);
94                    tokio::task::spawn_blocking(move || {
95                        run_worker(
96                            id,
97                            cache_dir_clone,
98                            rx_clone,
99                            ready_tx_clone,
100                            live_for_worker,
101                            loaded_for_worker,
102                            worker_config,
103                        )
104                    })
105                };
106
107                // Collect per-worker RSS deltas for median aggregation.
108                // Median is robust to one outlier from transient kernel snapshot
109                // quirk (page-cache settling, ORT arena init jitter) while still
110                // using all N independent measurements.
111                let mut rss_deltas: Vec<usize> = Vec::with_capacity(n);
112
113                // --- Phase 1: spawn leader worker (may download models) ---
114                worker_handles.push(spawn_worker(0, ready_tx.clone(), config.clone()));
115
116                match ready_rx.recv().await {
117                    Some(Ok(delta)) => {
118                        loaded_workers_for_init.fetch_add(1, Ordering::AcqRel);
119                        rss_deltas.push(delta);
120                        info!(
121                            rss_delta_mb = delta / (1024 * 1024),
122                            "Leader worker ready, model cache warm (1/{n})"
123                        );
124                    }
125                    Some(Err(e)) => {
126                        return Err(anyhow::anyhow!("Leader worker failed to load models: {e}"));
127                    }
128                    None => {
129                        return Err(anyhow::anyhow!(
130                            "Leader worker exited before signaling readiness"
131                        ));
132                    }
133                }
134
135                // --- Phase 2: spawn follower workers one at a time.
136                //
137                // Workers load sequentially: spawn one, await its ready signal,
138                // then spawn the next. This ensures each worker's RSS delta
139                // (pre/post load_models) reflects only that worker's ORT session
140                // allocation — not the cumulative effect of other workers loading
141                // in parallel. Parallel loading caused the 2026-05-09 measurement
142                // contamination bug: all followers read post_load_rss after most
143                // other sessions had already mmap'd, producing an inflated
144                // rss_delta ≈ N × model_size and driving per_worker_workspace to 0.
145                //
146                // Startup cost: ~4-6s per worker × 6 followers ≈ 24-36s total,
147                // well within the configured startPeriod (300s).
148                for id in 1..n {
149                    worker_handles.push(spawn_worker(id, ready_tx.clone(), config.clone()));
150
151                    match ready_rx.recv().await {
152                        Some(Ok(delta)) => {
153                            loaded_workers_for_init.fetch_add(1, Ordering::AcqRel);
154                            rss_deltas.push(delta);
155                            info!(
156                                rss_delta_mb = delta / (1024 * 1024),
157                                "Follower worker signaled ready ({}/{n})",
158                                id + 1
159                            );
160                        }
161                        Some(Err(e)) => {
162                            return Err(anyhow::anyhow!("Worker {id} failed to load models: {e}"));
163                        }
164                        None => {
165                            return Err(anyhow::anyhow!(
166                                "Worker {id} exited before signaling readiness ({id}/{n})"
167                            ));
168                        }
169                    }
170                }
171
172                drop(ready_tx);
173                drop(worker_handles);
174
175                // Compute and store the median delta as the per-worker model footprint.
176                let median = median_usize(&mut rss_deltas);
177                model_rss_for_init.store(median, Ordering::Release);
178                info!(
179                    median_rss_mb = median / (1024 * 1024),
180                    samples = rss_deltas.len(),
181                    "All workers ready — per-worker model RSS median computed"
182                );
183
184                Ok(())
185            }
186            .instrument(info_span!("embed_pool")),
187        );
188
189        (
190            Self {
191                tx,
192                live_workers,
193                loaded_workers,
194                model_rss_per_worker_bytes,
195            },
196            init_handle,
197        )
198    }
199
200    /// Runs dense (float32) embedding inference on `texts`.
201    ///
202    /// # Errors
203    ///
204    /// - Returns `Err` if the worker channel has closed (pool shut down).
205    /// - Returns `Err` if the worker drops the reply sender before responding.
206    /// - Returns `Err` if the ORT session fails during inference.
207    pub async fn dense(&self, texts: Vec<String>) -> Result<(Vec<Vec<f32>>, EmbedStats)> {
208        let (reply_tx, reply_rx) = oneshot::channel();
209        self.tx
210            .send(EmbedRequest::Dense {
211                texts,
212                reply: reply_tx,
213            })
214            .await
215            .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
216        reply_rx
217            .await
218            .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
219    }
220
221    /// Runs sparse (SPLADE-style) embedding inference on `texts`.
222    ///
223    /// # Errors
224    ///
225    /// - Returns `Err` if the worker channel has closed (pool shut down).
226    /// - Returns `Err` if the worker drops the reply sender before responding.
227    /// - Returns `Err` if the ORT session fails during inference.
228    pub async fn sparse(&self, texts: Vec<String>) -> Result<(Vec<SparseEmbedding>, EmbedStats)> {
229        let (reply_tx, reply_rx) = oneshot::channel();
230        self.tx
231            .send(EmbedRequest::Sparse {
232                texts,
233                reply: reply_tx,
234            })
235            .await
236            .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
237        reply_rx
238            .await
239            .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
240    }
241
242    /// Runs a single forward pass that yields both dense and sparse embeddings.
243    ///
244    /// Equivalent to calling [`Self::dense`] and [`Self::sparse`] back-to-back,
245    /// but uses one `session.run()` per chunk instead of two — at near-zero
246    /// marginal GPU cost.
247    ///
248    /// # Errors
249    ///
250    /// - Returns `Err` if the worker channel has closed (pool shut down).
251    /// - Returns `Err` if the worker drops the reply sender before responding.
252    /// - Returns `Err` if the ORT session fails during inference.
253    pub async fn both(&self, texts: Vec<String>) -> Result<(Vec<DualEmbedding>, EmbedStats)> {
254        let (reply_tx, reply_rx) = oneshot::channel();
255        self.tx
256            .send(EmbedRequest::Both {
257                texts,
258                reply: reply_tx,
259            })
260            .await
261            .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
262        reply_rx
263            .await
264            .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
265    }
266
267    /// Sends a probe request to a single worker and returns the result.
268    /// Only called during init before `ready` is set.
269    pub(crate) async fn probe(&self, texts: Vec<String>) -> Result<ProbeResult> {
270        let (reply_tx, reply_rx) = oneshot::channel();
271        self.tx
272            .send(EmbedRequest::Probe {
273                texts,
274                reply: reply_tx,
275            })
276            .await
277            .map_err(|_| anyhow::anyhow!("EmbedPool channel closed"))?;
278        reply_rx
279            .await
280            .map_err(|_| anyhow::anyhow!("Worker dropped reply sender"))?
281    }
282
283    #[must_use]
284    /// Returns the number of worker threads currently alive (not yet exited).
285    pub fn live_worker_count(&self) -> usize {
286        self.live_workers.load(Ordering::Acquire)
287    }
288
289    #[must_use]
290    /// Returns the number of workers that currently have model instances loaded in memory.
291    ///
292    /// A worker transitions from loaded to unloaded after the [`crate::config::Config::idle_timeout`]
293    /// elapses with no incoming requests, and back to loaded on the next request.
294    pub fn loaded_worker_count(&self) -> usize {
295        self.loaded_workers.load(Ordering::Acquire)
296    }
297
298    /// Returns the number of requests currently queued but not yet picked up
299    /// by a worker. Uses the channel's current vs max capacity.
300    #[must_use]
301    pub fn queue_depth(&self) -> usize {
302        self.tx.max_capacity().saturating_sub(self.tx.capacity())
303    }
304
305    /// Returns the median RSS delta (bytes) measured across all workers during
306    /// sequential model load.
307    ///
308    /// This is the per-worker model-weight footprint used by
309    /// `run_readiness_probe` to compute the per-worker workspace budget.
310    /// Returns `0` on non-Linux targets where RSS measurement is unavailable,
311    /// or before the init task has completed.
312    #[must_use]
313    pub fn model_rss_per_worker_bytes(&self) -> usize {
314        self.model_rss_per_worker_bytes.load(Ordering::Acquire)
315    }
316}
317
318// ---------------------------------------------------------------------------
319// Test helpers (cfg(test)-gated)
320// ---------------------------------------------------------------------------
321
322#[cfg(test)]
323impl EmbedPool {
324    pub(crate) fn closed_for_test() -> Self {
325        let (tx, rx) = mpsc::channel::<EmbedRequest>(1);
326        drop(rx);
327        Self {
328            tx,
329            live_workers: Arc::new(AtomicUsize::new(0)),
330            loaded_workers: Arc::new(AtomicUsize::new(0)),
331            model_rss_per_worker_bytes: Arc::new(AtomicUsize::new(0)),
332        }
333    }
334
335    pub(crate) fn with_fixed_responses(
336        dense_fixture: Vec<Vec<f32>>,
337        sparse_fixture: Vec<SparseEmbedding>,
338    ) -> Self {
339        let (tx, mut rx) = mpsc::channel::<EmbedRequest>(8);
340        let dense = Arc::new(dense_fixture);
341        let sparse = Arc::new(sparse_fixture);
342        let dense_both = Arc::clone(&dense);
343        let sparse_both = Arc::clone(&sparse);
344        tokio::spawn(async move {
345            while let Some(req) = rx.recv().await {
346                match req {
347                    EmbedRequest::Dense { reply, .. } => {
348                        let _ = reply.send(Ok(((*dense).clone(), EmbedStats::default())));
349                    }
350                    EmbedRequest::Sparse { reply, .. } => {
351                        let _ = reply.send(Ok(((*sparse).clone(), EmbedStats::default())));
352                    }
353                    EmbedRequest::Both { reply, .. } => {
354                        // Pair dense_fixture[i] with sparse_fixture[i] elementwise.
355                        // Truncate to the shorter of the two so the test fixture is
356                        // self-consistent.
357                        let pairs: Vec<DualEmbedding> = dense_both
358                            .iter()
359                            .zip(sparse_both.iter())
360                            .map(|(d, s)| DualEmbedding {
361                                dense: d.clone(),
362                                sparse: s.clone(),
363                            })
364                            .collect();
365                        let _ = reply.send(Ok((pairs, EmbedStats::default())));
366                    }
367                    EmbedRequest::Probe { reply, .. } => {
368                        let _ = reply.send(Ok(ProbeResult {
369                            rss_before: 0,
370                            rss_after: 0,
371                        }));
372                    }
373                }
374            }
375        });
376        Self {
377            tx,
378            live_workers: Arc::new(AtomicUsize::new(1)),
379            loaded_workers: Arc::new(AtomicUsize::new(1)),
380            model_rss_per_worker_bytes: Arc::new(AtomicUsize::new(0)),
381        }
382    }
383
384    pub(crate) fn idle_for_test() -> Self {
385        let (tx, _rx) = mpsc::channel::<EmbedRequest>(1);
386        Self {
387            tx,
388            live_workers: Arc::new(AtomicUsize::new(1)),
389            loaded_workers: Arc::new(AtomicUsize::new(0)),
390            model_rss_per_worker_bytes: Arc::new(AtomicUsize::new(0)),
391        }
392    }
393
394    /// Returns the raw `Arc<AtomicUsize>` backing `model_rss_per_worker_bytes`.
395    ///
396    /// Test-only; allows injecting a specific value to assert aggregation logic
397    /// without running actual model loads.
398    pub(crate) fn model_rss_per_worker_bytes_atomic(&self) -> Arc<AtomicUsize> {
399        Arc::clone(&self.model_rss_per_worker_bytes)
400    }
401}
402
403#[cfg(test)]
404mod tests;