Skip to main content

bge_m3_embedding_server/bootstrap/
readiness.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Foreground readiness probe.
16//!
17//! Waits for the worker pool's init handle, computes the per-worker workspace
18//! budget, writes static [`crate::state::TuningInfo`], resolves the cost model
19//! (override → EFS cache hit → background probe), and finally flips
20//! `state.ready` once the dense + sparse readiness checks succeed.
21
22use std::path::PathBuf;
23use std::sync::atomic::Ordering;
24use std::sync::Arc;
25
26use tracing::info;
27
28use super::budget::compute_workspace_budget;
29use super::probe_task::spawn_probe_task;
30use crate::binpack::CostModel;
31use crate::probe;
32use crate::state::{AppState, ProbeStatus, TuningInfo};
33use crate::sysinfo;
34
35/// Runs after all workers finish loading their model instances.
36///
37/// # Sequence
38///
39/// 1. Wait for worker pool initialisation to finish.
40/// 2. Read `pool.model_rss_per_worker_bytes()` — the median RSS delta measured
41///    inside each worker's `spawn_blocking` closure around `load_models()`.
42///    Workers load sequentially (one at a time), so each delta reflects only
43///    that worker's ORT session allocation with no parallel-load contamination.
44/// 3. Detect available memory; compute `per_worker_workspace` via
45///    `compute_workspace_budget`. Fail fast if the budget is below the
46///    physics-based floor (cannot fit even one text at `max_seq_length`).
47/// 4. Write static [`TuningInfo`] to `OnceLock`.
48/// 5. Resolve the cost model — one of three paths:
49///    - cost-model override set: apply immediately, `probe_status = Disabled`.
50///    - EFS cache hit: apply cached `(a, b)` via `ArcSwap`, `probe_status = CacheHit`.
51///    - cache miss: set `probe_status = Running`, launch background probe task.
52/// 6. Run dense + sparse readiness calls to confirm the worker pool is healthy.
53/// 7. Flip `state.ready = true` — `/health` returns `200 ok` from this point on.
54///    If the probe is still running in the background, the bin-packer uses
55///    conservative defaults until the `ArcSwap` is updated (typically ~120 s).
56///
57/// # Errors
58///
59/// - Worker pool init panicked (`JoinError`) or returned an error from model loading.
60/// - Per-worker workspace budget falls below the physics floor (cannot fit even one text
61///   at `max_seq_length` — container is restarted by the orchestrator).
62///
63// cast_possible_truncation: physics_floor is a u128 workspace estimate; truncating
64//   to usize is safe because per_worker_workspace is itself bounded by available_bytes
65//   which fits comfortably in usize on any 64-bit target.
66// cast_precision_loss / cast_sign_loss: delegated to compute_workspace_budget.
67#[allow(
68    clippy::cast_precision_loss,
69    clippy::cast_possible_truncation,
70    clippy::cast_sign_loss,
71    clippy::too_many_arguments,
72    clippy::too_many_lines
73)]
74pub async fn run_readiness_probe(
75    init_handle: tokio::task::JoinHandle<anyhow::Result<()>>,
76    state: Arc<AppState>,
77    cfg_max_seq: usize,
78    cfg_workers: usize,
79    cfg_safety: f64,
80    cost_model_override: Option<CostModel>,
81    cache_dir: PathBuf,
82    model_variant_str: String,
83    disable_probe_cache: bool,
84) -> anyhow::Result<()> {
85    init_handle
86        .await
87        .map_err(|e| anyhow::anyhow!("Worker pool task panicked: {e}"))?
88        .map_err(|e| anyhow::anyhow!("Worker pool initialization failed: {e}"))?;
89
90    // --- Memory detection ---
91    let mem = sysinfo::detect_available_memory();
92    info!(
93        available_bytes = mem.available_bytes,
94        source = %mem.source,
95        "Memory detected"
96    );
97
98    // Per-worker model RSS is the median of per-worker deltas collected by
99    // EmbedPool::spawn. Workers load sequentially (one at a time) so each
100    // delta reflects only that worker's ORT session allocation. The median
101    // is robust to one outlier from page-cache settling or ORT arena jitter.
102    let model_rss_per_worker = state.pool.model_rss_per_worker_bytes();
103    info!(
104        model_rss_per_worker_mb = model_rss_per_worker / (1024 * 1024),
105        "Measured model RSS per worker (median across all workers)"
106    );
107
108    // Compute per-worker workspace ceiling.
109    let (per_worker_workspace, worst_case_peak, utilization_pct) = compute_workspace_budget(
110        mem.available_bytes,
111        cfg_workers,
112        model_rss_per_worker,
113        cfg_safety,
114    );
115
116    info!(
117        worst_case_peak_mb = worst_case_peak / (1024 * 1024),
118        available_mb = mem.available_bytes / (1024 * 1024),
119        utilization_pct = format!("{utilization_pct:.1}"),
120        per_worker_workspace_mb = per_worker_workspace / (1024 * 1024),
121        "Workspace budget computed (worst-case all-workers-peak)"
122    );
123    if utilization_pct > 90.0 {
124        tracing::warn!(
125            utilization_pct = format!("{utilization_pct:.1}"),
126            "Worst-case workspace peak exceeds 90% of available memory; \
127             consider lowering BGE_M3_MEMORY_SAFETY_FACTOR or BGE_M3_WORKERS"
128        );
129    }
130
131    // Physics-based safety floor: the minimum workspace required to run a
132    // single text at the configured max sequence length under conservative
133    // cost-model coefficients. If the computed per_worker_workspace falls
134    // below this floor, the measurement upstream is broken (e.g. inflated
135    // model_rss_per_worker driving total_workspace to zero via saturating_sub).
136    // Continuing in this state degrades bin_pack to batch=1 and produces
137    // silent throughput collapse — fail fast instead so ECS restarts the task
138    // and the operator sees a clear error rather than a degraded service.
139    let physics_floor = CostModel::conservative(0).chunk_cost(1, cfg_max_seq) as usize;
140    if per_worker_workspace < physics_floor {
141        return Err(anyhow::anyhow!(
142            "Computed per_worker_workspace ({per_worker_workspace} B = {} MiB) is below the \
143             physics-based minimum ({physics_floor} B = {} MiB) needed to run one text at \
144             max_seq_length={cfg_max_seq}. Likely causes: model_rss_per_worker ({} MiB) is \
145             over-estimated (parallel-load contamination), BGE_M3_MEMORY_SAFETY_FACTOR too low \
146             ({cfg_safety}), BGE_M3_WORKERS too high ({cfg_workers}) for available memory \
147             ({} MiB), or BGE_M3_AVAILABLE_MEMORY_BYTES override too small.",
148            per_worker_workspace / (1024 * 1024),
149            physics_floor / (1024 * 1024),
150            model_rss_per_worker / (1024 * 1024),
151            mem.available_bytes / (1024 * 1024),
152        ));
153    }
154
155    // Write static memory + budget info now so /health always shows these fields
156    // even while the background probe is still running.
157    let _ = state.tuning.set(TuningInfo::new(
158        &mem,
159        model_rss_per_worker,
160        worst_case_peak,
161        utilization_pct,
162    ));
163
164    // The cgroup-limit byte count (the actual kernel ceiling, not the
165    // safety-discounted budget) is threaded into run_probe so the per-shape
166    // RSS guard can compare against the real ceiling rather than the
167    // discounted per_worker_workspace value.
168    let cgroup_limit_bytes = mem.available_bytes;
169
170    // --- Cost model resolution ---
171    if let Some(cm) = cost_model_override {
172        info!(
173            a = cm.a,
174            b = cm.b,
175            max_workspace_mb = cm.max_workspace_bytes / (1024 * 1024),
176            "Using pre-configured cost model (probe skipped)"
177        );
178        state.cost_model.store(Arc::new(cm));
179        state
180            .probe_status
181            .store(ProbeStatus::Disabled as u8, Ordering::Release);
182        // No probe — run readiness checks inline and open traffic.
183        run_readiness_checks_and_open(&state).await?;
184    } else if !disable_probe_cache {
185        // Try to load cached coefficients from EFS.
186        if let Some((a, b)) =
187            probe::try_load_probe_cache(&cache_dir, &model_variant_str, cfg_max_seq)
188        {
189            let cm = CostModel {
190                a,
191                b,
192                max_workspace_bytes: per_worker_workspace,
193            };
194            info!(
195                a,
196                b,
197                max_workspace_mb = cm.max_workspace_bytes / (1024 * 1024),
198                "Cost model loaded from EFS cache"
199            );
200            state.cost_model.store(Arc::new(cm));
201            state
202                .probe_status
203                .store(ProbeStatus::CacheHit as u8, Ordering::Release);
204            // Cache hit — run readiness checks inline and open traffic.
205            run_readiness_checks_and_open(&state).await?;
206        } else {
207            // Cache miss — probe must run. See `spawn_probe_task` for the
208            // serialisation protocol that holds all cfg_workers permits across
209            // the probe + readiness window.
210            spawn_probe_task(
211                Arc::clone(&state),
212                cfg_workers,
213                cfg_max_seq,
214                per_worker_workspace,
215                cgroup_limit_bytes,
216                cache_dir,
217                model_variant_str,
218                /* save_cache = */ true,
219            )
220            .await;
221            return Ok(());
222        }
223    } else {
224        // BGE_M3_DISABLE_PROBE_CACHE=1 but no override — run probe without caching.
225        spawn_probe_task(
226            Arc::clone(&state),
227            cfg_workers,
228            cfg_max_seq,
229            per_worker_workspace,
230            cgroup_limit_bytes,
231            cache_dir,
232            model_variant_str,
233            /* save_cache = */ false,
234        )
235        .await;
236        return Ok(());
237    }
238
239    Ok(())
240}
241
242/// Runs the dense + sparse readiness calls and flips `state.ready`.
243///
244/// Called from the override/cache-hit paths (inline, before returning from
245/// `run_readiness_probe`) and from the background probe task (after the probe
246/// completes) so that readiness checks never run concurrently with the probe.
247pub(super) async fn run_readiness_checks_and_open(state: &AppState) -> anyhow::Result<()> {
248    state
249        .pool
250        .dense(vec!["ready".into()])
251        .await
252        .map_err(|e| anyhow::anyhow!("Dense readiness probe failed: {e}"))?;
253
254    state
255        .pool
256        .sparse(vec!["ready".into()])
257        .await
258        .map_err(|e| anyhow::anyhow!("Sparse readiness probe failed: {e}"))?;
259
260    state
261        .ready
262        .store(true, std::sync::atomic::Ordering::Release);
263    tracing::info!("Models ready — accepting requests");
264    Ok(())
265}