bge_m3_embedding_server/probe/runner.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Probe shape-sweep driver.
16//!
17//! Orchestrates the `(batch, seq)` shape sweep on the leader worker, applies
18//! the conservative-`fits()` gate and the absolute-RSS guard, and feeds the
19//! collected `DataPoint`s to [`super::fit::fit_cost_model`] for OLS fitting.
20
21use tracing::{info, warn};
22
23use super::corpus::{load_probe_texts, synthesize_texts};
24use super::fit::{fit_cost_model, DataPoint};
25use super::validate::validate_max_seq_shape;
26use crate::binpack::CostModel;
27use crate::embedder::EmbedPool;
28
29/// Probe shape: `(batch_count, seq_length)`.
30pub(super) type Shape = (usize, usize);
31
32/// Shapes swept by the probe.
33///
34/// 6 static shapes plus a dynamic `(1, max_seq)` shape added at runtime for
35/// the quadratic anchor at the configured upper bound:
36///
37/// - `(1, 64)` and `(1, 256)` anchor the linear term at low seq.
38/// - `(4, 64)` shares `x1 = batch*seq = 256` with `(1, 256)` but has a
39/// different `x2 = batch*seq² = 16384` vs `65536`, giving a near-direct
40/// measurement of `b` independent of `a`.
41/// - `(1, 1024)` and `(1, 2048)` provide mid-range leverage.
42/// - `(1, 4096)` anchors the quadratic regime.
43///
44/// ## Safety against OOM
45///
46/// ORT's memory arena retains pages across `session.run()` calls, so
47/// cumulative process RSS grows with each successive probe shape. Three
48/// independent mechanisms keep the sweep within the container's cgroup limit:
49///
50/// 1. **Arena warm-up** at the start of `run_probe` runs a `(1, 64)`
51/// `session.run()` BEFORE the sweep, so the lazy ORT arena initialisation
52/// does not appear as a ~1 GB constant offset on every per-shape delta.
53/// 2. **Conservative `fits()` gate** rejects any shape whose per-call
54/// workspace estimate exceeds `rss_ceiling` (the safety-discounted budget).
55/// 3. **Absolute-RSS guard** rejects any shape whose projected arena growth
56/// would push process RSS above 87.5% of the cgroup ceiling, regardless
57/// of the conservative model's estimate.
58///
59/// The dynamic `(1, max_seq)` shape is added at runtime by `run_probe`. If
60/// the model variant cannot run at `max_seq`, the shape is skipped and the
61/// error surfaces on the first real embedding request.
62///
63/// Estimated probe time: ~120 s on aarch64 MLAS fp16 at `max_seq=8192`.
64pub(super) const PROBE_SHAPES: &[Shape] = &[
65 (1, 64), // linear anchor
66 (4, 64), // pairs with (1,256) for direct b isolation
67 (1, 256), // linear anchor
68 (1, 1024), // mid-range
69 (1, 2048), // mid-range, anchors quadratic stability
70 (1, 4096), // quadratic anchor
71 // (1, max_seq) is added dynamically based on configured max.
72];
73
74/// Runs the startup probe on the already-warmed leader worker.
75///
76/// # Arguments
77///
78/// - `pool`: the `EmbedPool` whose leader worker has already loaded models.
79/// - `max_seq`: the configured `BGE_M3_MAX_SEQ_LENGTH` (determines the topmost
80/// probe shape). The dynamic `(1, max_seq)` capability check has been
81/// removed — see `trim_probe_shapes` in the change log.
82/// - `rss_ceiling`: the per-worker workspace budget computed from sysinfo.
83/// Shapes estimated to exceed this are skipped to avoid OOM mid-probe
84/// (the conservative-model guard, unchanged).
85/// - `cgroup_limit_bytes`: the **actual kernel memory ceiling** (cgroup limit
86/// or host RAM, whichever was detected first). Used by the absolute-RSS
87/// guard: before each shape the current process RSS is measured and the
88/// shape is skipped if `rss + 4 × estimated_cost > cgroup_limit × 87.5%`.
89/// This prevents ORT session-arena retention from accumulating past the
90/// kernel ceiling across successive probe shapes.
91///
92/// # Returns
93///
94/// `(a, b)` where `a` and `b` are the fitted cost-model coefficients.
95/// Returns conservative defaults and logs a warning on any failure.
96#[allow(clippy::too_many_lines)]
97pub(crate) async fn run_probe(
98 pool: &EmbedPool,
99 max_seq: usize,
100 rss_ceiling: usize,
101 cgroup_limit_bytes: usize,
102) -> (f64, f64) {
103 info!(
104 max_seq,
105 rss_ceiling_mb = rss_ceiling / (1024 * 1024),
106 cgroup_limit_mb = cgroup_limit_bytes / (1024 * 1024),
107 "Starting memory probe"
108 );
109
110 // Validate that the model can accept inputs at max_seq without running
111 // attention. This checks tokenizer + ndarray shape construction only —
112 // no `session.run()` call, so no ORT arena allocation.
113 validate_max_seq_shape(max_seq);
114
115 // Build shape list from the static set + dynamic max_seq capability anchor.
116 let mut shapes: Vec<Shape> = PROBE_SHAPES.to_vec();
117 // Add a (1, max_seq) shape if max_seq is larger than any static shape.
118 // This anchors the quadratic coefficient at the configured upper bound.
119 // If the model cannot run at max_seq, the per-shape error path skips it
120 // (no fail-fast — the failure surfaces as an ORT error on the first real
121 // request, which is more actionable than a startup OOM).
122 if !shapes.iter().any(|&(_, s)| s == max_seq) {
123 shapes.push((1, max_seq));
124 }
125 // Remove any shapes whose seq > max_seq (out of range for this model).
126 shapes.retain(|&(_, s)| s <= max_seq);
127 // Sort by ascending total token-positions so we grow load gradually.
128 shapes.sort_by_key(|&(b, s)| b * s);
129
130 let probe_start = std::time::Instant::now();
131 let mut data: Vec<DataPoint> = Vec::with_capacity(shapes.len());
132 let conservative = CostModel::conservative(rss_ceiling);
133
134 // Per-shape outcome counters for precise diagnostics when data is empty.
135 let mut shapes_skipped: usize = 0;
136 let mut shapes_errored: usize = 0;
137 let total_shapes = shapes.len();
138
139 // Synthesize probe texts from corpus (already curated and pinned).
140 let corpus_texts = load_probe_texts();
141
142 // ----- Arena warm-up -----
143 //
144 // ORT lazily allocates its session arena on the first `session.run()`.
145 // The first call therefore reads as a ~1 GB RSS delta even at tiny
146 // shapes — that delta is arena bookkeeping, not per-call workspace, and
147 // it pollutes the cost-model fit because it appears as constant noise
148 // across all subsequent shapes.
149 //
150 // The warm-up runs a small `(1, 64)` `session.run()` BEFORE the actual
151 // sweep starts and discards the result. After the warm-up, subsequent
152 // per-shape `rss_delta` readings reflect only the incremental allocation
153 // attributable to that shape, giving the OLS fitter a meaningful signal.
154 //
155 // The warm-up is gated by the same RSS guard that protects the sweep —
156 // if `current_rss + 4 × chunk_cost(1, 64)` would breach the cgroup limit
157 // we skip the warm-up and continue with conservative defaults.
158 let warmup_texts = synthesize_texts(&corpus_texts, 1, 64);
159 let warmup_start = std::time::Instant::now();
160 match pool.probe(warmup_texts).await {
161 Ok(result) => {
162 let warmup_delta = result.rss_after.saturating_sub(result.rss_before);
163 let elapsed_ms = u64::try_from(warmup_start.elapsed().as_millis()).unwrap_or(u64::MAX);
164 info!(
165 warmup_delta_mb = warmup_delta / (1024 * 1024),
166 rss_after_mb = result.rss_after / (1024 * 1024),
167 elapsed_ms,
168 "Probe: arena warm-up complete (delta excluded from fit)"
169 );
170 }
171 Err(e) => {
172 warn!(
173 error = %e,
174 "Probe: warm-up failed — proceeding without warm-up; first shape's \
175 rss_delta will include arena initialisation overhead"
176 );
177 }
178 }
179
180 for (batch, seq) in &shapes {
181 let batch = *batch;
182 let seq = *seq;
183
184 // Skip shapes estimated to exceed the rss_ceiling by more than
185 // conservative cost model says (avoids OOM mid-probe).
186 if !conservative.fits(batch, seq) {
187 info!(
188 batch,
189 seq,
190 rss_ceiling_mb = rss_ceiling / (1024 * 1024),
191 "Probe: skipping shape (estimated to exceed rss_ceiling)"
192 );
193 shapes_skipped += 1;
194 continue;
195 }
196
197 // Absolute-RSS guard: ORT session-arena retention accumulates across
198 // probe shapes — each `session.run()` grows the arena and retains the
199 // pages for subsequent calls. The conservative `fits()` check above
200 // only looks at per-call workspace, not cumulative process RSS, so it
201 // cannot protect against gradual exhaustion.
202 //
203 // Before each shape we read the live process RSS and project the
204 // additional arena growth at 4× the conservative per-call estimate
205 // (empirically observed ratio on aarch64 MLAS fp16 at shapes where
206 // arena retention dominates). If the projected total would consume
207 // more than 87.5% of the cgroup ceiling we skip the shape.
208 //
209 // This guard fires only when `cgroup_limit_bytes > 0` so it is a
210 // no-op when memory detection fell back to the 4 GiB constant or was
211 // overridden to 0 in tests.
212 if cgroup_limit_bytes > 0 {
213 let current_rss = crate::sysinfo::read_process_rss_bytes().unwrap_or(0);
214 let estimated_cost = conservative.chunk_cost(batch, seq) as usize;
215 // 12.5% safety headroom below the cgroup ceiling.
216 let headroom = cgroup_limit_bytes / 8;
217 let rss_limit = cgroup_limit_bytes.saturating_sub(headroom);
218 // 4× multiplier on the conservative estimate to account for arena
219 // retention observed across successive probe shapes.
220 let projected = current_rss.saturating_add(estimated_cost.saturating_mul(4));
221 if projected > rss_limit {
222 info!(
223 batch,
224 seq,
225 current_rss_mb = current_rss / (1024 * 1024),
226 projected_mb = projected / (1024 * 1024),
227 rss_limit_mb = rss_limit / (1024 * 1024),
228 "Probe: skipping shape (current RSS + estimated arena growth \
229 would breach cgroup limit)"
230 );
231 shapes_skipped += 1;
232 continue;
233 }
234 }
235
236 // Synthesize texts of approximately `seq` tokens by repeating corpus
237 // texts and trimming. We can't tokenize here (no tokenizer), so we
238 // approximate: one word ≈ 1.3 tokens, one char ≈ 0.25 tokens.
239 // At a rough 4 chars/token, a `seq`-token input is ~4*seq characters.
240 let texts = synthesize_texts(&corpus_texts, batch, seq);
241
242 let shape_start = std::time::Instant::now();
243 match pool.probe(texts).await {
244 Ok(result) => {
245 let elapsed_ms =
246 u64::try_from(shape_start.elapsed().as_millis()).unwrap_or(u64::MAX);
247 let delta = result.rss_after.saturating_sub(result.rss_before);
248 info!(
249 batch,
250 seq,
251 rss_delta_mb = delta / (1024 * 1024),
252 elapsed_ms,
253 "Probe shape measured"
254 );
255 data.push(DataPoint {
256 batch,
257 seq,
258 rss_delta: delta,
259 });
260 }
261 Err(e) => {
262 let elapsed_ms =
263 u64::try_from(shape_start.elapsed().as_millis()).unwrap_or(u64::MAX);
264 warn!(batch, seq, elapsed_ms, error = %e, "Probe shape failed; skipping");
265 shapes_errored += 1;
266 }
267 }
268 }
269
270 if data.is_empty() {
271 // Emit a specific diagnostic based on what actually happened so the
272 // operator can distinguish between a broken budget (rss_ceiling=0),
273 // ORT/model errors, and a non-Linux platform where RSS is unavailable.
274 if shapes_skipped == total_shapes {
275 warn!(
276 rss_ceiling_mb = rss_ceiling / (1024 * 1024),
277 total_shapes,
278 "Probe: all shapes skipped because rss_ceiling is too small to fit even \
279 (batch=1, seq=64); per_worker_workspace upstream is likely broken — \
280 check model_rss_per_worker measurement and memory detection"
281 );
282 } else if shapes_errored == total_shapes {
283 warn!(
284 total_shapes,
285 "Probe: all shapes errored — check ORT session and model logs above"
286 );
287 } else {
288 warn!(
289 shapes_skipped,
290 shapes_errored,
291 total_shapes,
292 "Probe collected no usable data points (RSS measurement unavailable on \
293 non-Linux platforms, or all shapes were skipped/errored); \
294 using conservative defaults"
295 );
296 }
297 return (CostModel::CONSERVATIVE_A, CostModel::CONSERVATIVE_B);
298 }
299
300 // If all measured deltas are zero, RSS is unavailable (non-Linux).
301 if data.iter().all(|dp| dp.rss_delta == 0) {
302 warn!(
303 data_points = data.len(),
304 "Probe: all RSS deltas are zero — read_process_rss_bytes() returned 0; \
305 auto-budget requires Linux /proc/self/statm; using conservative defaults"
306 );
307 return (CostModel::CONSERVATIVE_A, CostModel::CONSERVATIVE_B);
308 }
309
310 if let Some((a, b)) = fit_cost_model(&data) {
311 let total_elapsed_ms = u64::try_from(probe_start.elapsed().as_millis()).unwrap_or(u64::MAX);
312 info!(
313 a = format!("{a:.0}"),
314 b = format!("{b:.4}"),
315 data_points = data.len(),
316 total_elapsed_ms,
317 "Probe: fitted cost model"
318 );
319 (a, b)
320 } else {
321 warn!(
322 data_points = data.len(),
323 "Probe: least-squares fit failed or produced invalid coefficients; \
324 using conservative defaults"
325 );
326 (CostModel::CONSERVATIVE_A, CostModel::CONSERVATIVE_B)
327 }
328}