bge_m3_embedding_server/probe/fit.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Least-squares cost-model fit for the startup probe.
16
17use tracing::warn;
18
19/// One measured data point from the probe sweep.
20#[derive(Debug, Clone, Copy)]
21pub(crate) struct DataPoint {
22 pub batch: usize,
23 pub seq: usize,
24 pub rss_delta: usize,
25}
26
27/// Fits `peak = a * (batch * seq) + b * (batch * seq^2)` via ordinary least
28/// squares (no intercept — workspace at batch=0 is 0 by definition).
29///
30/// The design matrix `X` has columns `[batch*seq, batch*seq^2]` and the
31/// response `y` is `rss_delta` for each observation.
32///
33/// **Normalization**: columns are scaled to `[0, 1]` before solving
34/// (`ξ1 = x1 / max(x1)`, `ξ2 = x2 / max(x2)`). Without this, `x2` at
35/// `max_seq=8192` exceeds `x1` by ~8000×, making the Gram matrix effectively
36/// rank-1 under the naïve det threshold and causing the fit to silently fall
37/// back to conservative defaults despite valid data.
38///
39/// Normal equations solved in normalized space via 2×2 matrix inverse
40/// (Cramer's rule), then unscaled: `a = α / x1_max`, `b = β / x2_max`.
41///
42/// Returns `None` when:
43/// - Fewer than 2 data points (under-determined system).
44/// - `x1_max` or `x2_max` is zero (degenerate data).
45/// - The normalized Gram matrix is nearly singular
46/// (det < 1e-6 of max diagonal²).
47/// - Either coefficient is negative (physically impossible workspace).
48//
49// cast_precision_loss: batch (≤ 16), seq (≤ 8192), and rss_delta (≤ ~28 GB) are
50// all well within f64's 2^52 mantissa (~4.5 PB). Coefficients are computed via
51// ordinary least squares where sub-integer precision in the inputs is irrelevant.
52#[allow(clippy::cast_precision_loss)]
53pub(crate) fn fit_cost_model(data: &[DataPoint]) -> Option<(f64, f64)> {
54 if data.len() < 2 {
55 return None;
56 }
57
58 // Compute scale factors so both design-matrix columns lie in [0, 1].
59 // Without normalization the x2 column (batch*seq²) at max_seq=8192 is
60 // ~8000× larger than x1 (batch*seq), making the Gram matrix near-singular
61 // under the det threshold even with 16 well-distributed data points.
62 let x1_max = data
63 .iter()
64 .map(|dp| (dp.batch * dp.seq) as f64)
65 .fold(0.0_f64, f64::max);
66 let x2_max = data
67 .iter()
68 .map(|dp| (dp.batch * dp.seq * dp.seq) as f64)
69 .fold(0.0_f64, f64::max);
70
71 if x1_max == 0.0 || x2_max == 0.0 {
72 return None;
73 }
74
75 // Build normalized Gram matrix: n1 = x1/x1_max, n2 = x2/x2_max ∈ [0,1].
76 // Variable names use single-letter prefixes to avoid clippy::similar_names
77 // on the longer accumulator names (g11, g12, g22, gy1, gy2).
78 let mut g11 = 0.0_f64; // sum(n1²)
79 let mut g12 = 0.0_f64; // sum(n1*n2)
80 let mut g22 = 0.0_f64; // sum(n2²)
81 let mut gy1 = 0.0_f64; // sum(n1*y)
82 let mut gy2 = 0.0_f64; // sum(n2*y)
83
84 for dp in data {
85 let n1 = (dp.batch * dp.seq) as f64 / x1_max;
86 let n2 = (dp.batch * dp.seq * dp.seq) as f64 / x2_max;
87 let y = dp.rss_delta as f64;
88
89 g11 += n1 * n1;
90 g12 += n1 * n2;
91 g22 += n2 * n2;
92 gy1 += n1 * y;
93 gy2 += n2 * y;
94 }
95
96 // 2×2 determinant in normalized space.
97 // With n1, n2 ∈ [0,1], max_diag ≤ N and det is directly comparable.
98 let det = g11 * g22 - g12 * g12;
99 let max_diag_sq = g11.max(g22).powi(2);
100 if det.abs() < 1e-6 * max_diag_sq {
101 // Nearly singular — likely all data points at the same shape or
102 // concentrated along one direction in design space.
103 return None;
104 }
105
106 // Cramer's rule in normalized space → normalized coefficients.
107 let alpha = (g22 * gy1 - g12 * gy2) / det; // coefficient of n1
108 let beta = (g11 * gy2 - g12 * gy1) / det; // coefficient of n2
109
110 // Unscale: a = alpha / x1_max, b = beta / x2_max.
111 let a_raw = alpha / x1_max;
112 let b_raw = beta / x2_max;
113
114 // ----- Negative-coefficient handling (rc8) -----
115 //
116 // OLS can return a negative `a_raw` when the data has a sharp
117 // discontinuity in y across the seq axis — for example, when ORT
118 // switches attention kernels between seq=2048 and seq=4096 (small
119 // memory-frugal fused kernel below the threshold; full O(N²) score
120 // matrix above). The two-coefficient quadratic model `y = a·N + b·N²`
121 // cannot describe a step function — the fitter has to drive `a` strongly
122 // negative to subtract the quadratic prediction back out at low seq
123 // where y ≈ 0.
124 //
125 // In that regime, `b_raw` is fine (the high-seq points fit a clean
126 // quadratic) but `a_raw` is non-physical. The rc7 production data was
127 // exactly this: `a_raw ≈ -109,000`, `b_raw ≈ 117`. Returning `None`
128 // and falling back to conservative defaults under-budgets ORT
129 // workspace by ~12× at high seq, which causes batched real-traffic
130 // OOMs (see CLAUDE.md gotcha "rc7 production capacity at max_seq=8192").
131 //
132 // Fix: when `a_raw` is negative, raise it to 0 and let the existing
133 // `.clamp(4_096.0, ...)` lower bound floor it at 4 KiB/token. That
134 // produces a fitted `b` that correctly predicts high-seq workspace
135 // (so the bin-packer rejects oversize batches) and an `a` that
136 // slightly over-predicts low-seq workspace (which is the safe
137 // direction — bin-packer might split low-seq batches more
138 // aggressively than ideal, but never accepts unsafe ones).
139 //
140 // A negative `b_raw` still fails fast: that would require a quadratic
141 // model to predict workspace *decreasing* as seq grows, which is
142 // genuinely non-physical and signals a measurement bug.
143 if b_raw < 0.0 {
144 return None;
145 }
146 let a_raw = a_raw.max(0.0);
147
148 // Clamp to sane operational ranges.
149 // a: [4 KiB, 256 KiB] per token-position
150 let a = a_raw.clamp(4_096.0, 262_144.0);
151 // b: [0.01, 50_000] bytes per token-position^2
152 let b = b_raw.clamp(0.01, 50_000.0);
153
154 // Log if clamping was significant.
155 let a_clamped = (a - a_raw).abs() > 0.01 * a_raw.abs();
156 let b_clamped = (b - b_raw).abs() > 0.01 * b_raw.abs();
157 if a_clamped || b_clamped {
158 warn!(
159 a_raw = format!("{a_raw:.0}"),
160 b_raw = format!("{b_raw:.4}"),
161 a_clamped = format!("{a:.0}"),
162 b_clamped = format!("{b:.4}"),
163 "Probe: fitted coefficients were clamped to sane range"
164 );
165 }
166
167 Some((a, b))
168}