bge_m3_embedding_server/bootstrap/probe_task.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Background probe task spawned when the cost model has not been overridden
16//! and the EFS cache is empty (or disabled).
17
18use std::path::PathBuf;
19use std::sync::atomic::Ordering;
20use std::sync::Arc;
21
22use tracing::info;
23
24use super::readiness::run_readiness_checks_and_open;
25use crate::binpack::CostModel;
26use crate::probe;
27use crate::state::{AppState, ProbeStatus};
28
29/// Spawns the background probe task with proper permit ownership.
30///
31/// # Serialisation protocol
32///
33/// 1. Set `probe_status = Running`.
34/// 2. Acquire `cfg_workers - 1` permits via `acquire_many_owned` — combined
35/// with the 1 permit already reserved at startup, this drains the
36/// semaphore to 0 so all incoming `/v1/embeddings*` requests queue
37/// behind the gate while the probe is in flight.
38/// 3. Move the [`tokio::sync::OwnedSemaphorePermit`] into the spawned
39/// task. Its destructor is invoked just before `add_permits(cfg_workers)`
40/// at the end of the task, restoring full traffic concurrency.
41///
42/// **Rationale for `acquire_many_owned`:** `tokio::spawn` returns
43/// synchronously before the spawned task starts executing. A permit bound to
44/// a local variable in the parent function would be dropped immediately at
45/// the end of that function — before the probe begins — leaving the semaphore
46/// un-drained and allowing real traffic to contaminate per-shape RSS
47/// measurements. `acquire_many_owned` returns an `OwnedSemaphorePermit`
48/// independent of the source `Semaphore` lifetime, so it survives the move
49/// into the async closure and is held for the full duration of the probe.
50#[allow(clippy::too_many_arguments)]
51pub(super) async fn spawn_probe_task(
52 state: Arc<AppState>,
53 cfg_workers: usize,
54 cfg_max_seq: usize,
55 per_worker_workspace: usize,
56 cgroup_limit_bytes: usize,
57 cache_dir: PathBuf,
58 model_variant_str: String,
59 save_cache: bool,
60) {
61 state
62 .probe_status
63 .store(ProbeStatus::Running as u8, Ordering::Release);
64
65 // Drain all remaining permits. The semaphore starts with
66 // `max(cfg_workers - 1, 1)` permits at startup (one slot reserved for
67 // the probe worker); we acquire the remaining `cfg_workers - 1` here
68 // so the count drops to 0 for the duration of the probe.
69 //
70 // `acquire_many_owned` returns an `OwnedSemaphorePermit` that we move
71 // into the spawned task closure. The permit's drop handler returns the
72 // permits to the semaphore — we manually call `add_permits(cfg_workers)`
73 // in the task to also release the originally-reserved probe slot.
74 let probe_permit = Arc::clone(&state.request_permits)
75 .acquire_many_owned(u32::try_from(cfg_workers.saturating_sub(1)).unwrap_or(u32::MAX))
76 .await
77 .ok();
78
79 tokio::spawn(async move {
80 // Forget the OwnedSemaphorePermit at the end; we manually
81 // add_permits(cfg_workers) below so the count goes from 0
82 // straight to cfg_workers (releasing both the drained permits
83 // and the originally-reserved probe slot in one operation).
84 if let Some(p) = probe_permit {
85 p.forget();
86 }
87
88 let (a, b) = probe::run_probe(
89 &state.pool,
90 cfg_max_seq,
91 per_worker_workspace,
92 cgroup_limit_bytes,
93 )
94 .await;
95 let cm = CostModel {
96 a,
97 b,
98 max_workspace_bytes: per_worker_workspace,
99 };
100 info!(
101 a = cm.a,
102 b = cm.b,
103 max_workspace_mb = cm.max_workspace_bytes / (1024 * 1024),
104 "Probe complete — updating cost model"
105 );
106 state.cost_model.store(Arc::new(cm));
107 // Distinguish real fit from conservative fallback.
108 let status = if (a - CostModel::CONSERVATIVE_A).abs() < f64::EPSILON
109 && (b - CostModel::CONSERVATIVE_B).abs() < f64::EPSILON
110 {
111 ProbeStatus::Failed
112 } else {
113 if save_cache {
114 probe::save_probe_cache(&cache_dir, &model_variant_str, cfg_max_seq, a, b);
115 }
116 ProbeStatus::Complete
117 };
118 state.probe_status.store(status as u8, Ordering::Release);
119 info!(probe_status = status.as_str(), "Probe status updated");
120
121 // Readiness checks run inside the probe task so they do not
122 // contaminate the probe's RSS measurements.
123 if let Err(e) = run_readiness_checks_and_open(&state).await {
124 tracing::error!(error = %e, "Post-probe readiness check failed");
125 }
126 // Release the drained permits AND the originally-reserved probe
127 // slot in one operation. Net effect: semaphore count goes from 0
128 // back to cfg_workers, opening traffic at full concurrency.
129 state.request_permits.add_permits(cfg_workers);
130 });
131}