Skip to main content

bge_m3_embedding_server/probe/
runner.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Probe shape-sweep driver.
16//!
17//! Orchestrates the `(batch, seq)` shape sweep on the leader worker, applies
18//! the conservative-`fits()` gate and the absolute-RSS guard, and feeds the
19//! collected `DataPoint`s to [`super::fit::fit_cost_model`] for OLS fitting.
20
21use tracing::{info, warn};
22
23use super::corpus::{load_probe_texts, synthesize_texts};
24use super::fit::{fit_cost_model, DataPoint};
25use super::validate::validate_max_seq_shape;
26use crate::binpack::CostModel;
27use crate::embedder::EmbedPool;
28
29/// Probe shape: `(batch_count, seq_length)`.
30pub(super) type Shape = (usize, usize);
31
32/// Shapes swept by the probe.
33///
34/// 6 static shapes plus a dynamic `(1, max_seq)` shape added at runtime for
35/// the quadratic anchor at the configured upper bound:
36///
37/// - `(1, 64)` and `(1, 256)` anchor the linear term at low seq.
38/// - `(4, 64)` shares `x1 = batch*seq = 256` with `(1, 256)` but has a
39///   different `x2 = batch*seq² = 16384` vs `65536`, giving a near-direct
40///   measurement of `b` independent of `a`.
41/// - `(1, 1024)` and `(1, 2048)` provide mid-range leverage.
42/// - `(1, 4096)` anchors the quadratic regime.
43///
44/// ## Safety against OOM
45///
46/// ORT's memory arena retains pages across `session.run()` calls, so
47/// cumulative process RSS grows with each successive probe shape. Three
48/// independent mechanisms keep the sweep within the container's cgroup limit:
49///
50/// 1. **Arena warm-up** at the start of `run_probe` runs a `(1, 64)`
51///    `session.run()` BEFORE the sweep, so the lazy ORT arena initialisation
52///    does not appear as a ~1 GB constant offset on every per-shape delta.
53/// 2. **Conservative `fits()` gate** rejects any shape whose per-call
54///    workspace estimate exceeds `rss_ceiling` (the safety-discounted budget).
55/// 3. **Absolute-RSS guard** rejects any shape whose projected arena growth
56///    would push process RSS above 87.5% of the cgroup ceiling, regardless
57///    of the conservative model's estimate.
58///
59/// The dynamic `(1, max_seq)` shape is added at runtime by `run_probe`. If
60/// the model variant cannot run at `max_seq`, the shape is skipped and the
61/// error surfaces on the first real embedding request.
62///
63/// Estimated probe time: ~120 s on aarch64 MLAS fp16 at `max_seq=8192`.
64pub(super) const PROBE_SHAPES: &[Shape] = &[
65    (1, 64),   // linear anchor
66    (4, 64),   // pairs with (1,256) for direct b isolation
67    (1, 256),  // linear anchor
68    (1, 1024), // mid-range
69    (1, 2048), // mid-range, anchors quadratic stability
70    (1, 4096), // quadratic anchor
71               // (1, max_seq) is added dynamically based on configured max.
72];
73
74/// Runs the startup probe on the already-warmed leader worker.
75///
76/// # Arguments
77///
78/// - `pool`: the `EmbedPool` whose leader worker has already loaded models.
79/// - `max_seq`: the configured `BGE_M3_MAX_SEQ_LENGTH` (determines the topmost
80///   probe shape).  The dynamic `(1, max_seq)` capability check has been
81///   removed — see `trim_probe_shapes` in the change log.
82/// - `rss_ceiling`: the per-worker workspace budget computed from sysinfo.
83///   Shapes estimated to exceed this are skipped to avoid OOM mid-probe
84///   (the conservative-model guard, unchanged).
85/// - `cgroup_limit_bytes`: the **actual kernel memory ceiling** (cgroup limit
86///   or host RAM, whichever was detected first). Used by the absolute-RSS
87///   guard: before each shape the current process RSS is measured and the
88///   shape is skipped if `rss + 4 × estimated_cost > cgroup_limit × 87.5%`.
89///   This prevents ORT session-arena retention from accumulating past the
90///   kernel ceiling across successive probe shapes.
91///
92/// # Returns
93///
94/// `(a, b)` where `a` and `b` are the fitted cost-model coefficients.
95/// Returns conservative defaults and logs a warning on any failure.
96#[allow(clippy::too_many_lines)]
97pub(crate) async fn run_probe(
98    pool: &EmbedPool,
99    max_seq: usize,
100    rss_ceiling: usize,
101    cgroup_limit_bytes: usize,
102) -> (f64, f64) {
103    info!(
104        max_seq,
105        rss_ceiling_mb = rss_ceiling / (1024 * 1024),
106        cgroup_limit_mb = cgroup_limit_bytes / (1024 * 1024),
107        "Starting memory probe"
108    );
109
110    // Validate that the model can accept inputs at max_seq without running
111    // attention. This checks tokenizer + ndarray shape construction only —
112    // no `session.run()` call, so no ORT arena allocation.
113    validate_max_seq_shape(max_seq);
114
115    // Build shape list from the static set + dynamic max_seq capability anchor.
116    let mut shapes: Vec<Shape> = PROBE_SHAPES.to_vec();
117    // Add a (1, max_seq) shape if max_seq is larger than any static shape.
118    // This anchors the quadratic coefficient at the configured upper bound.
119    // If the model cannot run at max_seq, the per-shape error path skips it
120    // (no fail-fast — the failure surfaces as an ORT error on the first real
121    // request, which is more actionable than a startup OOM).
122    if !shapes.iter().any(|&(_, s)| s == max_seq) {
123        shapes.push((1, max_seq));
124    }
125    // Remove any shapes whose seq > max_seq (out of range for this model).
126    shapes.retain(|&(_, s)| s <= max_seq);
127    // Sort by ascending total token-positions so we grow load gradually.
128    shapes.sort_by_key(|&(b, s)| b * s);
129
130    let probe_start = std::time::Instant::now();
131    let mut data: Vec<DataPoint> = Vec::with_capacity(shapes.len());
132    let conservative = CostModel::conservative(rss_ceiling);
133
134    // Per-shape outcome counters for precise diagnostics when data is empty.
135    let mut shapes_skipped: usize = 0;
136    let mut shapes_errored: usize = 0;
137    let total_shapes = shapes.len();
138
139    // Synthesize probe texts from corpus (already curated and pinned).
140    let corpus_texts = load_probe_texts();
141
142    // ----- Arena warm-up -----
143    //
144    // ORT lazily allocates its session arena on the first `session.run()`.
145    // The first call therefore reads as a ~1 GB RSS delta even at tiny
146    // shapes — that delta is arena bookkeeping, not per-call workspace, and
147    // it pollutes the cost-model fit because it appears as constant noise
148    // across all subsequent shapes.
149    //
150    // The warm-up runs a small `(1, 64)` `session.run()` BEFORE the actual
151    // sweep starts and discards the result. After the warm-up, subsequent
152    // per-shape `rss_delta` readings reflect only the incremental allocation
153    // attributable to that shape, giving the OLS fitter a meaningful signal.
154    //
155    // The warm-up is gated by the same RSS guard that protects the sweep —
156    // if `current_rss + 4 × chunk_cost(1, 64)` would breach the cgroup limit
157    // we skip the warm-up and continue with conservative defaults.
158    let warmup_texts = synthesize_texts(&corpus_texts, 1, 64);
159    let warmup_start = std::time::Instant::now();
160    match pool.probe(warmup_texts).await {
161        Ok(result) => {
162            let warmup_delta = result.rss_after.saturating_sub(result.rss_before);
163            let elapsed_ms = u64::try_from(warmup_start.elapsed().as_millis()).unwrap_or(u64::MAX);
164            info!(
165                warmup_delta_mb = warmup_delta / (1024 * 1024),
166                rss_after_mb = result.rss_after / (1024 * 1024),
167                elapsed_ms,
168                "Probe: arena warm-up complete (delta excluded from fit)"
169            );
170        }
171        Err(e) => {
172            warn!(
173                error = %e,
174                "Probe: warm-up failed — proceeding without warm-up; first shape's \
175                 rss_delta will include arena initialisation overhead"
176            );
177        }
178    }
179
180    for (batch, seq) in &shapes {
181        let batch = *batch;
182        let seq = *seq;
183
184        // Skip shapes estimated to exceed the rss_ceiling by more than
185        // conservative cost model says (avoids OOM mid-probe).
186        if !conservative.fits(batch, seq) {
187            info!(
188                batch,
189                seq,
190                rss_ceiling_mb = rss_ceiling / (1024 * 1024),
191                "Probe: skipping shape (estimated to exceed rss_ceiling)"
192            );
193            shapes_skipped += 1;
194            continue;
195        }
196
197        // Absolute-RSS guard: ORT session-arena retention accumulates across
198        // probe shapes — each `session.run()` grows the arena and retains the
199        // pages for subsequent calls. The conservative `fits()` check above
200        // only looks at per-call workspace, not cumulative process RSS, so it
201        // cannot protect against gradual exhaustion.
202        //
203        // Before each shape we read the live process RSS and project the
204        // additional arena growth at 4× the conservative per-call estimate
205        // (empirically observed ratio on aarch64 MLAS fp16 at shapes where
206        // arena retention dominates). If the projected total would consume
207        // more than 87.5% of the cgroup ceiling we skip the shape.
208        //
209        // This guard fires only when `cgroup_limit_bytes > 0` so it is a
210        // no-op when memory detection fell back to the 4 GiB constant or was
211        // overridden to 0 in tests.
212        if cgroup_limit_bytes > 0 {
213            let current_rss = crate::sysinfo::read_process_rss_bytes().unwrap_or(0);
214            let estimated_cost = conservative.chunk_cost(batch, seq) as usize;
215            // 12.5% safety headroom below the cgroup ceiling.
216            let headroom = cgroup_limit_bytes / 8;
217            let rss_limit = cgroup_limit_bytes.saturating_sub(headroom);
218            // 4× multiplier on the conservative estimate to account for arena
219            // retention observed across successive probe shapes.
220            let projected = current_rss.saturating_add(estimated_cost.saturating_mul(4));
221            if projected > rss_limit {
222                info!(
223                    batch,
224                    seq,
225                    current_rss_mb = current_rss / (1024 * 1024),
226                    projected_mb = projected / (1024 * 1024),
227                    rss_limit_mb = rss_limit / (1024 * 1024),
228                    "Probe: skipping shape (current RSS + estimated arena growth \
229                     would breach cgroup limit)"
230                );
231                shapes_skipped += 1;
232                continue;
233            }
234        }
235
236        // Synthesize texts of approximately `seq` tokens by repeating corpus
237        // texts and trimming. We can't tokenize here (no tokenizer), so we
238        // approximate: one word ≈ 1.3 tokens, one char ≈ 0.25 tokens.
239        // At a rough 4 chars/token, a `seq`-token input is ~4*seq characters.
240        let texts = synthesize_texts(&corpus_texts, batch, seq);
241
242        let shape_start = std::time::Instant::now();
243        match pool.probe(texts).await {
244            Ok(result) => {
245                let elapsed_ms =
246                    u64::try_from(shape_start.elapsed().as_millis()).unwrap_or(u64::MAX);
247                let delta = result.rss_after.saturating_sub(result.rss_before);
248                info!(
249                    batch,
250                    seq,
251                    rss_delta_mb = delta / (1024 * 1024),
252                    elapsed_ms,
253                    "Probe shape measured"
254                );
255                data.push(DataPoint {
256                    batch,
257                    seq,
258                    rss_delta: delta,
259                });
260            }
261            Err(e) => {
262                let elapsed_ms =
263                    u64::try_from(shape_start.elapsed().as_millis()).unwrap_or(u64::MAX);
264                warn!(batch, seq, elapsed_ms, error = %e, "Probe shape failed; skipping");
265                shapes_errored += 1;
266            }
267        }
268    }
269
270    if data.is_empty() {
271        // Emit a specific diagnostic based on what actually happened so the
272        // operator can distinguish between a broken budget (rss_ceiling=0),
273        // ORT/model errors, and a non-Linux platform where RSS is unavailable.
274        if shapes_skipped == total_shapes {
275            warn!(
276                rss_ceiling_mb = rss_ceiling / (1024 * 1024),
277                total_shapes,
278                "Probe: all shapes skipped because rss_ceiling is too small to fit even \
279                 (batch=1, seq=64); per_worker_workspace upstream is likely broken — \
280                 check model_rss_per_worker measurement and memory detection"
281            );
282        } else if shapes_errored == total_shapes {
283            warn!(
284                total_shapes,
285                "Probe: all shapes errored — check ORT session and model logs above"
286            );
287        } else {
288            warn!(
289                shapes_skipped,
290                shapes_errored,
291                total_shapes,
292                "Probe collected no usable data points (RSS measurement unavailable on \
293                 non-Linux platforms, or all shapes were skipped/errored); \
294                 using conservative defaults"
295            );
296        }
297        return (CostModel::CONSERVATIVE_A, CostModel::CONSERVATIVE_B);
298    }
299
300    // If all measured deltas are zero, RSS is unavailable (non-Linux).
301    if data.iter().all(|dp| dp.rss_delta == 0) {
302        warn!(
303            data_points = data.len(),
304            "Probe: all RSS deltas are zero — read_process_rss_bytes() returned 0; \
305             auto-budget requires Linux /proc/self/statm; using conservative defaults"
306        );
307        return (CostModel::CONSERVATIVE_A, CostModel::CONSERVATIVE_B);
308    }
309
310    if let Some((a, b)) = fit_cost_model(&data) {
311        let total_elapsed_ms = u64::try_from(probe_start.elapsed().as_millis()).unwrap_or(u64::MAX);
312        info!(
313            a = format!("{a:.0}"),
314            b = format!("{b:.4}"),
315            data_points = data.len(),
316            total_elapsed_ms,
317            "Probe: fitted cost model"
318        );
319        (a, b)
320    } else {
321        warn!(
322            data_points = data.len(),
323            "Probe: least-squares fit failed or produced invalid coefficients; \
324             using conservative defaults"
325        );
326        (CostModel::CONSERVATIVE_A, CostModel::CONSERVATIVE_B)
327    }
328}