Skip to main content

bge_m3_embedding_server/probe/
fit.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Least-squares cost-model fit for the startup probe.
16
17use tracing::warn;
18
19/// One measured data point from the probe sweep.
20#[derive(Debug, Clone, Copy)]
21pub(crate) struct DataPoint {
22    pub batch: usize,
23    pub seq: usize,
24    pub rss_delta: usize,
25}
26
27/// Fits `peak = a * (batch * seq) + b * (batch * seq^2)` via ordinary least
28/// squares (no intercept — workspace at batch=0 is 0 by definition).
29///
30/// The design matrix `X` has columns `[batch*seq, batch*seq^2]` and the
31/// response `y` is `rss_delta` for each observation.
32///
33/// **Normalization**: columns are scaled to `[0, 1]` before solving
34/// (`ξ1 = x1 / max(x1)`, `ξ2 = x2 / max(x2)`).  Without this, `x2` at
35/// `max_seq=8192` exceeds `x1` by ~8000×, making the Gram matrix effectively
36/// rank-1 under the naïve det threshold and causing the fit to silently fall
37/// back to conservative defaults despite valid data.
38///
39/// Normal equations solved in normalized space via 2×2 matrix inverse
40/// (Cramer's rule), then unscaled: `a = α / x1_max`, `b = β / x2_max`.
41///
42/// Returns `None` when:
43/// - Fewer than 2 data points (under-determined system).
44/// - `x1_max` or `x2_max` is zero (degenerate data).
45/// - The normalized Gram matrix is nearly singular
46///   (det < 1e-6 of max diagonal²).
47/// - Either coefficient is negative (physically impossible workspace).
48//
49// cast_precision_loss: batch (≤ 16), seq (≤ 8192), and rss_delta (≤ ~28 GB) are
50//   all well within f64's 2^52 mantissa (~4.5 PB). Coefficients are computed via
51//   ordinary least squares where sub-integer precision in the inputs is irrelevant.
52#[allow(clippy::cast_precision_loss)]
53pub(crate) fn fit_cost_model(data: &[DataPoint]) -> Option<(f64, f64)> {
54    if data.len() < 2 {
55        return None;
56    }
57
58    // Compute scale factors so both design-matrix columns lie in [0, 1].
59    // Without normalization the x2 column (batch*seq²) at max_seq=8192 is
60    // ~8000× larger than x1 (batch*seq), making the Gram matrix near-singular
61    // under the det threshold even with 16 well-distributed data points.
62    let x1_max = data
63        .iter()
64        .map(|dp| (dp.batch * dp.seq) as f64)
65        .fold(0.0_f64, f64::max);
66    let x2_max = data
67        .iter()
68        .map(|dp| (dp.batch * dp.seq * dp.seq) as f64)
69        .fold(0.0_f64, f64::max);
70
71    if x1_max == 0.0 || x2_max == 0.0 {
72        return None;
73    }
74
75    // Build normalized Gram matrix: n1 = x1/x1_max, n2 = x2/x2_max ∈ [0,1].
76    // Variable names use single-letter prefixes to avoid clippy::similar_names
77    // on the longer accumulator names (g11, g12, g22, gy1, gy2).
78    let mut g11 = 0.0_f64; // sum(n1²)
79    let mut g12 = 0.0_f64; // sum(n1*n2)
80    let mut g22 = 0.0_f64; // sum(n2²)
81    let mut gy1 = 0.0_f64; // sum(n1*y)
82    let mut gy2 = 0.0_f64; // sum(n2*y)
83
84    for dp in data {
85        let n1 = (dp.batch * dp.seq) as f64 / x1_max;
86        let n2 = (dp.batch * dp.seq * dp.seq) as f64 / x2_max;
87        let y = dp.rss_delta as f64;
88
89        g11 += n1 * n1;
90        g12 += n1 * n2;
91        g22 += n2 * n2;
92        gy1 += n1 * y;
93        gy2 += n2 * y;
94    }
95
96    // 2×2 determinant in normalized space.
97    // With n1, n2 ∈ [0,1], max_diag ≤ N and det is directly comparable.
98    let det = g11 * g22 - g12 * g12;
99    let max_diag_sq = g11.max(g22).powi(2);
100    if det.abs() < 1e-6 * max_diag_sq {
101        // Nearly singular — likely all data points at the same shape or
102        // concentrated along one direction in design space.
103        return None;
104    }
105
106    // Cramer's rule in normalized space → normalized coefficients.
107    let alpha = (g22 * gy1 - g12 * gy2) / det; // coefficient of n1
108    let beta = (g11 * gy2 - g12 * gy1) / det; // coefficient of n2
109
110    // Unscale: a = alpha / x1_max, b = beta / x2_max.
111    let a_raw = alpha / x1_max;
112    let b_raw = beta / x2_max;
113
114    // ----- Negative-coefficient handling (rc8) -----
115    //
116    // OLS can return a negative `a_raw` when the data has a sharp
117    // discontinuity in y across the seq axis — for example, when ORT
118    // switches attention kernels between seq=2048 and seq=4096 (small
119    // memory-frugal fused kernel below the threshold; full O(N²) score
120    // matrix above). The two-coefficient quadratic model `y = a·N + b·N²`
121    // cannot describe a step function — the fitter has to drive `a` strongly
122    // negative to subtract the quadratic prediction back out at low seq
123    // where y ≈ 0.
124    //
125    // In that regime, `b_raw` is fine (the high-seq points fit a clean
126    // quadratic) but `a_raw` is non-physical. The rc7 production data was
127    // exactly this: `a_raw ≈ -109,000`, `b_raw ≈ 117`. Returning `None`
128    // and falling back to conservative defaults under-budgets ORT
129    // workspace by ~12× at high seq, which causes batched real-traffic
130    // OOMs (see CLAUDE.md gotcha "rc7 production capacity at max_seq=8192").
131    //
132    // Fix: when `a_raw` is negative, raise it to 0 and let the existing
133    // `.clamp(4_096.0, ...)` lower bound floor it at 4 KiB/token. That
134    // produces a fitted `b` that correctly predicts high-seq workspace
135    // (so the bin-packer rejects oversize batches) and an `a` that
136    // slightly over-predicts low-seq workspace (which is the safe
137    // direction — bin-packer might split low-seq batches more
138    // aggressively than ideal, but never accepts unsafe ones).
139    //
140    // A negative `b_raw` still fails fast: that would require a quadratic
141    // model to predict workspace *decreasing* as seq grows, which is
142    // genuinely non-physical and signals a measurement bug.
143    if b_raw < 0.0 {
144        return None;
145    }
146    let a_raw = a_raw.max(0.0);
147
148    // Clamp to sane operational ranges.
149    // a: [4 KiB, 256 KiB] per token-position
150    let a = a_raw.clamp(4_096.0, 262_144.0);
151    // b: [0.01, 50_000] bytes per token-position^2
152    let b = b_raw.clamp(0.01, 50_000.0);
153
154    // Log if clamping was significant.
155    let a_clamped = (a - a_raw).abs() > 0.01 * a_raw.abs();
156    let b_clamped = (b - b_raw).abs() > 0.01 * b_raw.abs();
157    if a_clamped || b_clamped {
158        warn!(
159            a_raw = format!("{a_raw:.0}"),
160            b_raw = format!("{b_raw:.4}"),
161            a_clamped = format!("{a:.0}"),
162            b_clamped = format!("{b:.4}"),
163            "Probe: fitted coefficients were clamped to sane range"
164        );
165    }
166
167    Some((a, b))
168}