bge_m3_embedding_server/bootstrap/readiness.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Foreground readiness probe.
16//!
17//! Waits for the worker pool's init handle, computes the per-worker workspace
18//! budget, writes static [`crate::state::TuningInfo`], resolves the cost model
19//! (override → EFS cache hit → background probe), and finally flips
20//! `state.ready` once the dense + sparse readiness checks succeed.
21
22use std::path::PathBuf;
23use std::sync::atomic::Ordering;
24use std::sync::Arc;
25
26use tracing::info;
27
28use super::budget::compute_workspace_budget;
29use super::probe_task::spawn_probe_task;
30use crate::binpack::CostModel;
31use crate::probe;
32use crate::state::{AppState, ProbeStatus, TuningInfo};
33use crate::sysinfo;
34
35/// Runs after all workers finish loading their model instances.
36///
37/// # Sequence
38///
39/// 1. Wait for worker pool initialisation to finish.
40/// 2. Read `pool.model_rss_per_worker_bytes()` — the median RSS delta measured
41/// inside each worker's `spawn_blocking` closure around `load_models()`.
42/// Workers load sequentially (one at a time), so each delta reflects only
43/// that worker's ORT session allocation with no parallel-load contamination.
44/// 3. Detect available memory; compute `per_worker_workspace` via
45/// `compute_workspace_budget`. Fail fast if the budget is below the
46/// physics-based floor (cannot fit even one text at `max_seq_length`).
47/// 4. Write static [`TuningInfo`] to `OnceLock`.
48/// 5. Resolve the cost model — one of three paths:
49/// - cost-model override set: apply immediately, `probe_status = Disabled`.
50/// - EFS cache hit: apply cached `(a, b)` via `ArcSwap`, `probe_status = CacheHit`.
51/// - cache miss: set `probe_status = Running`, launch background probe task.
52/// 6. Run dense + sparse readiness calls to confirm the worker pool is healthy.
53/// 7. Flip `state.ready = true` — `/health` returns `200 ok` from this point on.
54/// If the probe is still running in the background, the bin-packer uses
55/// conservative defaults until the `ArcSwap` is updated (typically ~120 s).
56///
57/// # Errors
58///
59/// - Worker pool init panicked (`JoinError`) or returned an error from model loading.
60/// - Per-worker workspace budget falls below the physics floor (cannot fit even one text
61/// at `max_seq_length` — container is restarted by the orchestrator).
62///
63// cast_possible_truncation: physics_floor is a u128 workspace estimate; truncating
64// to usize is safe because per_worker_workspace is itself bounded by available_bytes
65// which fits comfortably in usize on any 64-bit target.
66// cast_precision_loss / cast_sign_loss: delegated to compute_workspace_budget.
67#[allow(
68 clippy::cast_precision_loss,
69 clippy::cast_possible_truncation,
70 clippy::cast_sign_loss,
71 clippy::too_many_arguments,
72 clippy::too_many_lines
73)]
74pub async fn run_readiness_probe(
75 init_handle: tokio::task::JoinHandle<anyhow::Result<()>>,
76 state: Arc<AppState>,
77 cfg_max_seq: usize,
78 cfg_workers: usize,
79 cfg_safety: f64,
80 cost_model_override: Option<CostModel>,
81 cache_dir: PathBuf,
82 model_variant_str: String,
83 disable_probe_cache: bool,
84) -> anyhow::Result<()> {
85 init_handle
86 .await
87 .map_err(|e| anyhow::anyhow!("Worker pool task panicked: {e}"))?
88 .map_err(|e| anyhow::anyhow!("Worker pool initialization failed: {e}"))?;
89
90 // --- Memory detection ---
91 let mem = sysinfo::detect_available_memory();
92 info!(
93 available_bytes = mem.available_bytes,
94 source = %mem.source,
95 "Memory detected"
96 );
97
98 // Per-worker model RSS is the median of per-worker deltas collected by
99 // EmbedPool::spawn. Workers load sequentially (one at a time) so each
100 // delta reflects only that worker's ORT session allocation. The median
101 // is robust to one outlier from page-cache settling or ORT arena jitter.
102 let model_rss_per_worker = state.pool.model_rss_per_worker_bytes();
103 info!(
104 model_rss_per_worker_mb = model_rss_per_worker / (1024 * 1024),
105 "Measured model RSS per worker (median across all workers)"
106 );
107
108 // Compute per-worker workspace ceiling.
109 let (per_worker_workspace, worst_case_peak, utilization_pct) = compute_workspace_budget(
110 mem.available_bytes,
111 cfg_workers,
112 model_rss_per_worker,
113 cfg_safety,
114 );
115
116 info!(
117 worst_case_peak_mb = worst_case_peak / (1024 * 1024),
118 available_mb = mem.available_bytes / (1024 * 1024),
119 utilization_pct = format!("{utilization_pct:.1}"),
120 per_worker_workspace_mb = per_worker_workspace / (1024 * 1024),
121 "Workspace budget computed (worst-case all-workers-peak)"
122 );
123 if utilization_pct > 90.0 {
124 tracing::warn!(
125 utilization_pct = format!("{utilization_pct:.1}"),
126 "Worst-case workspace peak exceeds 90% of available memory; \
127 consider lowering BGE_M3_MEMORY_SAFETY_FACTOR or BGE_M3_WORKERS"
128 );
129 }
130
131 // Physics-based safety floor: the minimum workspace required to run a
132 // single text at the configured max sequence length under conservative
133 // cost-model coefficients. If the computed per_worker_workspace falls
134 // below this floor, the measurement upstream is broken (e.g. inflated
135 // model_rss_per_worker driving total_workspace to zero via saturating_sub).
136 // Continuing in this state degrades bin_pack to batch=1 and produces
137 // silent throughput collapse — fail fast instead so ECS restarts the task
138 // and the operator sees a clear error rather than a degraded service.
139 let physics_floor = CostModel::conservative(0).chunk_cost(1, cfg_max_seq) as usize;
140 if per_worker_workspace < physics_floor {
141 return Err(anyhow::anyhow!(
142 "Computed per_worker_workspace ({per_worker_workspace} B = {} MiB) is below the \
143 physics-based minimum ({physics_floor} B = {} MiB) needed to run one text at \
144 max_seq_length={cfg_max_seq}. Likely causes: model_rss_per_worker ({} MiB) is \
145 over-estimated (parallel-load contamination), BGE_M3_MEMORY_SAFETY_FACTOR too low \
146 ({cfg_safety}), BGE_M3_WORKERS too high ({cfg_workers}) for available memory \
147 ({} MiB), or BGE_M3_AVAILABLE_MEMORY_BYTES override too small.",
148 per_worker_workspace / (1024 * 1024),
149 physics_floor / (1024 * 1024),
150 model_rss_per_worker / (1024 * 1024),
151 mem.available_bytes / (1024 * 1024),
152 ));
153 }
154
155 // Write static memory + budget info now so /health always shows these fields
156 // even while the background probe is still running.
157 let _ = state.tuning.set(TuningInfo::new(
158 &mem,
159 model_rss_per_worker,
160 worst_case_peak,
161 utilization_pct,
162 ));
163
164 // The cgroup-limit byte count (the actual kernel ceiling, not the
165 // safety-discounted budget) is threaded into run_probe so the per-shape
166 // RSS guard can compare against the real ceiling rather than the
167 // discounted per_worker_workspace value.
168 let cgroup_limit_bytes = mem.available_bytes;
169
170 // --- Cost model resolution ---
171 if let Some(cm) = cost_model_override {
172 info!(
173 a = cm.a,
174 b = cm.b,
175 max_workspace_mb = cm.max_workspace_bytes / (1024 * 1024),
176 "Using pre-configured cost model (probe skipped)"
177 );
178 state.cost_model.store(Arc::new(cm));
179 state
180 .probe_status
181 .store(ProbeStatus::Disabled as u8, Ordering::Release);
182 // No probe — run readiness checks inline and open traffic.
183 run_readiness_checks_and_open(&state).await?;
184 } else if !disable_probe_cache {
185 // Try to load cached coefficients from EFS.
186 if let Some((a, b)) =
187 probe::try_load_probe_cache(&cache_dir, &model_variant_str, cfg_max_seq)
188 {
189 let cm = CostModel {
190 a,
191 b,
192 max_workspace_bytes: per_worker_workspace,
193 };
194 info!(
195 a,
196 b,
197 max_workspace_mb = cm.max_workspace_bytes / (1024 * 1024),
198 "Cost model loaded from EFS cache"
199 );
200 state.cost_model.store(Arc::new(cm));
201 state
202 .probe_status
203 .store(ProbeStatus::CacheHit as u8, Ordering::Release);
204 // Cache hit — run readiness checks inline and open traffic.
205 run_readiness_checks_and_open(&state).await?;
206 } else {
207 // Cache miss — probe must run. See `spawn_probe_task` for the
208 // serialisation protocol that holds all cfg_workers permits across
209 // the probe + readiness window.
210 spawn_probe_task(
211 Arc::clone(&state),
212 cfg_workers,
213 cfg_max_seq,
214 per_worker_workspace,
215 cgroup_limit_bytes,
216 cache_dir,
217 model_variant_str,
218 /* save_cache = */ true,
219 )
220 .await;
221 return Ok(());
222 }
223 } else {
224 // BGE_M3_DISABLE_PROBE_CACHE=1 but no override — run probe without caching.
225 spawn_probe_task(
226 Arc::clone(&state),
227 cfg_workers,
228 cfg_max_seq,
229 per_worker_workspace,
230 cgroup_limit_bytes,
231 cache_dir,
232 model_variant_str,
233 /* save_cache = */ false,
234 )
235 .await;
236 return Ok(());
237 }
238
239 Ok(())
240}
241
242/// Runs the dense + sparse readiness calls and flips `state.ready`.
243///
244/// Called from the override/cache-hit paths (inline, before returning from
245/// `run_readiness_probe`) and from the background probe task (after the probe
246/// completes) so that readiness checks never run concurrently with the probe.
247pub(super) async fn run_readiness_checks_and_open(state: &AppState) -> anyhow::Result<()> {
248 state
249 .pool
250 .dense(vec!["ready".into()])
251 .await
252 .map_err(|e| anyhow::anyhow!("Dense readiness probe failed: {e}"))?;
253
254 state
255 .pool
256 .sparse(vec!["ready".into()])
257 .await
258 .map_err(|e| anyhow::anyhow!("Sparse readiness probe failed: {e}"))?;
259
260 state
261 .ready
262 .store(true, std::sync::atomic::Ordering::Release);
263 tracing::info!("Models ready — accepting requests");
264 Ok(())
265}