Skip to main content

bge_m3_embedding_server/bootstrap/
probe_task.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Background probe task spawned when the cost model has not been overridden
16//! and the EFS cache is empty (or disabled).
17
18use std::path::PathBuf;
19use std::sync::atomic::Ordering;
20use std::sync::Arc;
21
22use tracing::info;
23
24use super::readiness::run_readiness_checks_and_open;
25use crate::binpack::CostModel;
26use crate::probe;
27use crate::state::{AppState, ProbeStatus};
28
29/// Spawns the background probe task with proper permit ownership.
30///
31/// # Serialisation protocol
32///
33/// 1. Set `probe_status = Running`.
34/// 2. Acquire `cfg_workers - 1` permits via `acquire_many_owned` — combined
35///    with the 1 permit already reserved at startup, this drains the
36///    semaphore to 0 so all incoming `/v1/embeddings*` requests queue
37///    behind the gate while the probe is in flight.
38/// 3. Move the [`tokio::sync::OwnedSemaphorePermit`] into the spawned
39///    task. Its destructor is invoked just before `add_permits(cfg_workers)`
40///    at the end of the task, restoring full traffic concurrency.
41///
42/// **Rationale for `acquire_many_owned`:** `tokio::spawn` returns
43/// synchronously before the spawned task starts executing. A permit bound to
44/// a local variable in the parent function would be dropped immediately at
45/// the end of that function — before the probe begins — leaving the semaphore
46/// un-drained and allowing real traffic to contaminate per-shape RSS
47/// measurements. `acquire_many_owned` returns an `OwnedSemaphorePermit`
48/// independent of the source `Semaphore` lifetime, so it survives the move
49/// into the async closure and is held for the full duration of the probe.
50#[allow(clippy::too_many_arguments)]
51pub(super) async fn spawn_probe_task(
52    state: Arc<AppState>,
53    cfg_workers: usize,
54    cfg_max_seq: usize,
55    per_worker_workspace: usize,
56    cgroup_limit_bytes: usize,
57    cache_dir: PathBuf,
58    model_variant_str: String,
59    save_cache: bool,
60) {
61    state
62        .probe_status
63        .store(ProbeStatus::Running as u8, Ordering::Release);
64
65    // Drain all remaining permits. The semaphore starts with
66    // `max(cfg_workers - 1, 1)` permits at startup (one slot reserved for
67    // the probe worker); we acquire the remaining `cfg_workers - 1` here
68    // so the count drops to 0 for the duration of the probe.
69    //
70    // `acquire_many_owned` returns an `OwnedSemaphorePermit` that we move
71    // into the spawned task closure. The permit's drop handler returns the
72    // permits to the semaphore — we manually call `add_permits(cfg_workers)`
73    // in the task to also release the originally-reserved probe slot.
74    let probe_permit = Arc::clone(&state.request_permits)
75        .acquire_many_owned(u32::try_from(cfg_workers.saturating_sub(1)).unwrap_or(u32::MAX))
76        .await
77        .ok();
78
79    tokio::spawn(async move {
80        // Forget the OwnedSemaphorePermit at the end; we manually
81        // add_permits(cfg_workers) below so the count goes from 0
82        // straight to cfg_workers (releasing both the drained permits
83        // and the originally-reserved probe slot in one operation).
84        if let Some(p) = probe_permit {
85            p.forget();
86        }
87
88        let (a, b) = probe::run_probe(
89            &state.pool,
90            cfg_max_seq,
91            per_worker_workspace,
92            cgroup_limit_bytes,
93        )
94        .await;
95        let cm = CostModel {
96            a,
97            b,
98            max_workspace_bytes: per_worker_workspace,
99        };
100        info!(
101            a = cm.a,
102            b = cm.b,
103            max_workspace_mb = cm.max_workspace_bytes / (1024 * 1024),
104            "Probe complete — updating cost model"
105        );
106        state.cost_model.store(Arc::new(cm));
107        // Distinguish real fit from conservative fallback.
108        let status = if (a - CostModel::CONSERVATIVE_A).abs() < f64::EPSILON
109            && (b - CostModel::CONSERVATIVE_B).abs() < f64::EPSILON
110        {
111            ProbeStatus::Failed
112        } else {
113            if save_cache {
114                probe::save_probe_cache(&cache_dir, &model_variant_str, cfg_max_seq, a, b);
115            }
116            ProbeStatus::Complete
117        };
118        state.probe_status.store(status as u8, Ordering::Release);
119        info!(probe_status = status.as_str(), "Probe status updated");
120
121        // Readiness checks run inside the probe task so they do not
122        // contaminate the probe's RSS measurements.
123        if let Err(e) = run_readiness_checks_and_open(&state).await {
124            tracing::error!(error = %e, "Post-probe readiness check failed");
125        }
126        // Release the drained permits AND the originally-reserved probe
127        // slot in one operation. Net effect: semaphore count goes from 0
128        // back to cfg_workers, opening traffic at full concurrency.
129        state.request_permits.add_permits(cfg_workers);
130    });
131}