Skip to main content

bge_m3_embedding_server/embedder/
worker.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Blocking worker thread, request dispatch, and probe wiring.
16
17use std::path::PathBuf;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21
22use anyhow::Result;
23use arc_swap::ArcSwap;
24use ort::value::TensorRef;
25use tokio::runtime::Handle;
26use tokio::sync::{mpsc, Mutex};
27use tracing::{info, info_span};
28
29use super::dense::embed_dense;
30use super::dual::embed_both;
31use super::error::ort_err;
32use super::session::load_models;
33use super::sparse::embed_sparse;
34use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
35use super::types::{EmbedRequest, ProbeResult};
36use crate::binpack::CostModel;
37use crate::config::ModelVariant;
38use crate::sysinfo;
39
40pub(super) struct WorkerGuard(pub Arc<AtomicUsize>);
41
42impl Drop for WorkerGuard {
43    fn drop(&mut self) {
44        let prev = self.0.fetch_sub(1, Ordering::AcqRel);
45        let live_after_drop = prev.saturating_sub(1);
46        if live_after_drop == 0 {
47            tracing::error!("All embedding workers have exited — pool is degraded");
48        } else {
49            tracing::warn!(live_after_drop, "Embedding worker exited");
50        }
51    }
52}
53
54/// Execution-policy configuration shared by all workers.
55///
56/// `cost_model` is an `Arc<ArcSwap<CostModel>>` so all workers share a single
57/// handle and the background probe can update the cost model atomically after
58/// fitting.  Each worker loads the current value lock-free at the start of
59/// every `session.run()` call via `config.cost_model.load()`.
60#[derive(Clone)]
61pub struct WorkerConfig {
62    /// Quadratic-aware workspace cost model and per-worker budget.
63    ///
64    /// Shared across all workers via `ArcSwap`.  The background probe updates
65    /// this handle once fitted coefficients are available; workers observe the
66    /// new model on their next request without any coordination or restart.
67    pub cost_model: Arc<ArcSwap<CostModel>>,
68    /// Duration of inactivity before workers unload their model instances.
69    pub idle_timeout: Option<Duration>,
70    /// ONNX model variant to load (FP32, FP16, or INT8).
71    pub model_variant: ModelVariant,
72    /// Maximum tokenized sequence length.
73    pub max_seq_length: usize,
74    /// Number of intra-op threads each ORT session may use for a single
75    /// `session.run()` call. Plumbed through to `load_session` at model load
76    /// time. See [`crate::config::Config::intra_threads`] for sizing guidance.
77    pub intra_threads: usize,
78}
79
80/// Runs a single `session.run()` for the probe, measuring RSS before and after.
81///
82/// The probe texts are already tokenized and padded to `pad_to` externally.
83/// This function just runs inference and returns RSS deltas so `probe.rs` can
84/// fit the cost model.
85pub(crate) fn probe_run_dense(
86    session: &mut ort::session::Session,
87    ids_array: &ndarray::Array2<i64>,
88    mask_array: &ndarray::Array2<i64>,
89) -> Result<ProbeResult> {
90    let rss_before = sysinfo::read_process_rss_bytes().unwrap_or(0);
91
92    let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
93    let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
94
95    // Run inference (output discarded — we only care about RSS).
96    let _outputs = session
97        .run(ort::inputs! {
98            "input_ids" => ids_tensor,
99            "attention_mask" => mask_tensor,
100        })
101        .map_err(ort_err)?;
102
103    let rss_after = sysinfo::read_process_rss_bytes().unwrap_or(rss_before);
104
105    Ok(ProbeResult {
106        rss_before,
107        rss_after,
108    })
109}
110
111/// Runs one probe batch: tokenize texts, build padded arrays, call `session.run()`,
112/// and return RSS deltas. Uses `embed_dense`'s no-pad tokenizer path.
113fn run_probe_batch(
114    session: &mut ort::session::Session,
115    tokenizer: &tokenizers::Tokenizer,
116    texts: &[String],
117) -> Result<ProbeResult> {
118    let encodings = tokenize_no_pad(tokenizer, texts)?;
119    let pad_to = encodings
120        .iter()
121        .map(|e| e.get_ids().len())
122        .max()
123        .unwrap_or(1)
124        .max(1);
125    let indices: Vec<usize> = (0..texts.len()).collect();
126    let (ids_array, mask_array) = build_chunk_arrays(&encodings, &indices, pad_to)?;
127    probe_run_dense(session, &ids_array, &mask_array)
128}
129
130#[allow(clippy::needless_pass_by_value, clippy::too_many_lines)]
131pub(super) fn run_worker(
132    id: usize,
133    cache_dir: PathBuf,
134    rx: Arc<Mutex<mpsc::Receiver<EmbedRequest>>>,
135    ready_tx: mpsc::Sender<Result<usize>>,
136    live_workers: Arc<AtomicUsize>,
137    loaded_workers: Arc<AtomicUsize>,
138    config: WorkerConfig,
139) -> Result<()> {
140    let _guard = WorkerGuard(Arc::clone(&live_workers));
141    let span = info_span!("worker", id = id);
142    let _span_guard = span.enter();
143
144    info!("Loading models (worker {id})...");
145    let load_start = std::time::Instant::now();
146    let rt = Handle::current();
147
148    // Measure RSS before loading so the delta accurately reflects this
149    // worker's model-weight + arena-baseline contribution, not accumulated
150    // OS noise.
151    let pre_load_rss = sysinfo::read_process_rss_bytes().unwrap_or(0);
152    let initial_models = match load_models(
153        &cache_dir,
154        id == 0,
155        config.model_variant,
156        config.max_seq_length,
157        config.intra_threads,
158    ) {
159        Ok(mut models) => {
160            // Prime the ORT session arena with a tiny session.run() BEFORE
161            // measuring post-load RSS. ORT lazily allocates ~1 GiB of arena
162            // bookkeeping on the first run() call regardless of input size;
163            // priming here folds that allocation into the per-worker model
164            // RSS measurement so the workspace-budget math on the main thread
165            // sees the realistic per-worker memory footprint, AND so the
166            // probe sweep's per-shape `rss_delta` readings reflect only the
167            // incremental workspace attributable to that shape.
168            //
169            // Without per-worker priming, the probe could dispatch shapes to
170            // workers that have not yet done a session.run(), and each such
171            // first-touch contributes ~1 GiB of arena init noise to its
172            // delta — which buries the per-shape workspace signal in the
173            // OLS fit.
174            let prime_ids = ndarray::Array2::<i64>::zeros((1, 8));
175            let prime_mask = ndarray::Array2::<i64>::ones((1, 8));
176            match probe_run_dense(&mut models.0, &prime_ids, &prime_mask) {
177                Ok(_) => {
178                    tracing::debug!("Worker {id} arena primed");
179                }
180                Err(e) => {
181                    tracing::warn!(
182                        error = %e,
183                        "Worker {id} arena prime failed; first probe shape on this \
184                         worker will include arena init delta"
185                    );
186                }
187            }
188
189            let post_load_rss = sysinfo::read_process_rss_bytes().unwrap_or(pre_load_rss);
190            tracing::info!(
191                elapsed_ms = load_start.elapsed().as_millis(),
192                rss_delta_mb = post_load_rss.saturating_sub(pre_load_rss) / (1024 * 1024),
193                "Models loaded (worker {id})"
194            );
195            models
196        }
197        Err(e) => {
198            let _ =
199                rt.block_on(ready_tx.send(Err(anyhow::anyhow!("Worker {id} failed to load: {e}"))));
200            return Err(e);
201        }
202    };
203
204    // Report the RSS delta so EmbedPool can derive the true per-worker
205    // model footprint for workspace-budget calculations.
206    let post_load_rss = sysinfo::read_process_rss_bytes().unwrap_or(pre_load_rss);
207    let rss_delta = post_load_rss.saturating_sub(pre_load_rss);
208    info!(
209        "Worker {id} models loaded — signaling ready (rss_delta_mb={})",
210        rss_delta / (1024 * 1024)
211    );
212    let _ = rt.block_on(ready_tx.send(Ok(rss_delta)));
213
214    let mut models: Option<(ort::session::Session, tokenizers::Tokenizer)> = Some(initial_models);
215
216    info!("Worker {id} entering request loop");
217    loop {
218        let msg = if let Some(timeout) = config.idle_timeout.filter(|_| models.is_some()) {
219            rt.block_on(async {
220                tokio::time::timeout(timeout, async { rx.lock().await.recv().await }).await
221            })
222        } else {
223            rt.block_on(async { Ok(rx.lock().await.recv().await) })
224        };
225
226        match msg {
227            Err(_elapsed) => {
228                models = None;
229                loaded_workers.fetch_sub(1, Ordering::AcqRel);
230                tracing::info!("Worker {id} unloaded models after idle timeout");
231            }
232            Ok(None) => {
233                if models.is_some() {
234                    loaded_workers.fetch_sub(1, Ordering::AcqRel);
235                }
236                info!("Worker {id} channel closed, shutting down");
237                break;
238            }
239            Ok(Some(request)) => {
240                if models.is_none() {
241                    tracing::info!("Worker {id} reloading models after idle...");
242                    let reload_start = std::time::Instant::now();
243                    match load_models(
244                        &cache_dir,
245                        false,
246                        config.model_variant,
247                        config.max_seq_length,
248                        config.intra_threads,
249                    ) {
250                        Ok(mut m) => {
251                            // Prime the freshly-loaded session arena so the
252                            // first incoming request after idle reload doesn't
253                            // pay the ~1 GiB lazy-arena-init cost. Same
254                            // rationale as the startup priming in the
255                            // load-models Ok arm above.
256                            let prime_ids = ndarray::Array2::<i64>::zeros((1, 8));
257                            let prime_mask = ndarray::Array2::<i64>::ones((1, 8));
258                            if let Err(e) = probe_run_dense(&mut m.0, &prime_ids, &prime_mask) {
259                                tracing::warn!(
260                                    error = %e,
261                                    "Worker {id} post-reload arena prime failed"
262                                );
263                            }
264                            models = Some(m);
265                            loaded_workers.fetch_add(1, Ordering::AcqRel);
266                            tracing::info!(
267                                elapsed_ms = reload_start.elapsed().as_millis(),
268                                "Worker {id} reloaded models"
269                            );
270                        }
271                        Err(e) => {
272                            tracing::error!(error = %e, "Worker {id} failed to reload models");
273                            let err = anyhow::anyhow!("Model reload failed: {e}");
274                            match request {
275                                EmbedRequest::Dense { reply, .. } => {
276                                    let _ = reply.send(Err(err));
277                                }
278                                EmbedRequest::Sparse { reply, .. } => {
279                                    let _ = reply.send(Err(err));
280                                }
281                                EmbedRequest::Both { reply, .. } => {
282                                    let _ = reply.send(Err(err));
283                                }
284                                EmbedRequest::Probe { reply, .. } => {
285                                    let _ = reply.send(Err(err));
286                                }
287                            }
288                            continue;
289                        }
290                    }
291                }
292
293                let (session, tokenizer) =
294                    models.as_mut().expect("models loaded after reload check");
295
296                match request {
297                    EmbedRequest::Dense { texts, reply } => {
298                        // Load the current cost model snapshot for this request.
299                        // ArcSwap::load() is lock-free; the guard keeps the Arc
300                        // alive for the duration of embed_dense.
301                        let cm_guard = config.cost_model.load();
302                        let result = embed_dense(
303                            session,
304                            tokenizer,
305                            &texts,
306                            &cm_guard,
307                            config.model_variant,
308                        )
309                        .map_err(|e| anyhow::anyhow!("Dense embed error: {e}"));
310                        if let Ok((_, ref stats)) = result {
311                            tracing::info!(
312                                worker_id = id,
313                                chunks = stats.chunks,
314                                max_chunk_seq = stats.max_chunk_seq,
315                                total_token_positions = stats.total_token_positions,
316                                tokenize_ms = stats.tokenize_ms,
317                                inference_ms = stats.inference_ms,
318                                "worker: dense embed complete"
319                            );
320                        }
321                        let _ = reply.send(result);
322                    }
323                    EmbedRequest::Sparse { texts, reply } => {
324                        let cm_guard = config.cost_model.load();
325                        let result = embed_sparse(
326                            session,
327                            tokenizer,
328                            &texts,
329                            &cm_guard,
330                            config.model_variant,
331                        )
332                        .map_err(|e| anyhow::anyhow!("Sparse embed error: {e}"));
333                        if let Ok((_, ref stats)) = result {
334                            tracing::info!(
335                                worker_id = id,
336                                chunks = stats.chunks,
337                                max_chunk_seq = stats.max_chunk_seq,
338                                total_token_positions = stats.total_token_positions,
339                                tokenize_ms = stats.tokenize_ms,
340                                inference_ms = stats.inference_ms,
341                                "worker: sparse embed complete"
342                            );
343                        }
344                        let _ = reply.send(result);
345                    }
346                    EmbedRequest::Both { texts, reply } => {
347                        let cm_guard = config.cost_model.load();
348                        let result =
349                            embed_both(session, tokenizer, &texts, &cm_guard, config.model_variant)
350                                .map_err(|e| anyhow::anyhow!("Dual embed error: {e}"));
351                        if let Ok((_, ref stats)) = result {
352                            tracing::info!(
353                                worker_id = id,
354                                chunks = stats.chunks,
355                                max_chunk_seq = stats.max_chunk_seq,
356                                total_token_positions = stats.total_token_positions,
357                                tokenize_ms = stats.tokenize_ms,
358                                inference_ms = stats.inference_ms,
359                                "worker: both embed complete"
360                            );
361                        }
362                        let _ = reply.send(result);
363                    }
364                    EmbedRequest::Probe { texts, reply } => {
365                        // Probe: tokenize once without padding, run dense inference
366                        // on a single flat batch at the chunk's natural max_seq.
367                        let result = run_probe_batch(session, tokenizer, &texts);
368                        let _ = reply.send(result);
369                    }
370                }
371            }
372        }
373    }
374
375    Ok(())
376}