Skip to main content

bge_m3_embedding_server/
lib.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Library crate for the bge-m3 embedding server.
16//!
17//! `main.rs` is a 20–30 line entry point that calls [`run`]; all real
18//! orchestration logic lives here so it can be unit-tested and reused from
19//! integration tests without spawning the binary.
20
21// Rustdoc lints — enforce documentation quality
22#![warn(missing_docs)]
23#![warn(rustdoc::missing_crate_level_docs)]
24#![warn(rustdoc::unescaped_backticks)]
25#![deny(rustdoc::broken_intra_doc_links)]
26#![deny(rustdoc::invalid_html_tags)]
27#![deny(rustdoc::bare_urls)]
28#![warn(rustdoc::redundant_explicit_links)]
29#![warn(rustdoc::private_doc_tests)]
30
31pub mod binpack;
32pub mod bootstrap;
33pub mod config;
34pub mod embedder;
35pub mod error;
36pub mod handler;
37pub mod models;
38pub mod probe;
39pub mod state;
40pub mod sysinfo;
41pub mod weights;
42
43use std::path::PathBuf;
44use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
45use std::sync::Arc;
46use std::time::Duration;
47
48use arc_swap::ArcSwap;
49use tokio::sync::Semaphore;
50use tracing::info;
51
52use crate::binpack::CostModel;
53use crate::bootstrap::{build_router, run_readiness_probe};
54use crate::config::Config;
55use crate::embedder::{EmbedPool, WorkerConfig};
56use crate::state::{AppState, ProbeStatus};
57
58/// Runs the embedding server end-to-end: load config, spawn the worker pool,
59/// install the readiness probe, start the heartbeat, and serve HTTP traffic.
60///
61/// Background tasks log and call `process::exit(1)` on their own unrecoverable
62/// failures so the container is restarted by the orchestrator.
63///
64/// # Errors
65///
66/// Returns `Err` if the TCP listener cannot bind to the configured address.
67#[allow(clippy::too_many_lines)]
68pub async fn run() -> anyhow::Result<()> {
69    info!(
70        version = env!("CARGO_PKG_VERSION"),
71        git_sha = env!("BGE_M3_GIT_SHA"),
72        target_arch = std::env::consts::ARCH,
73        target_os = std::env::consts::OS,
74        profile = if cfg!(debug_assertions) {
75            "debug"
76        } else {
77            "release"
78        },
79        "bge-m3-embedding-server build info"
80    );
81
82    let cfg = Config::from_env();
83
84    let disable_probe_cache = std::env::var("BGE_M3_DISABLE_PROBE_CACHE")
85        .is_ok_and(|v| matches!(v.as_str(), "1" | "true" | "yes"));
86
87    info!(
88        bind = %cfg.bind_addr,
89        workers = cfg.workers,
90        max_batch = cfg.max_batch,
91        max_seq_length = cfg.max_seq_length,
92        cache_dir = %cfg.cache_dir,
93        idle_timeout_secs = cfg.idle_timeout.map(|d| d.as_secs()),
94        model_variant = ?cfg.model_variant,
95        memory_safety_factor = cfg.memory_safety_factor,
96        auto_budget = cfg.cost_model_override.is_none(),
97        disable_probe_cache,
98        "Starting bge-m3-embedding-server"
99    );
100
101    // Allocate one shared cost-model handle.  Conservative defaults are used
102    // until the background probe (or cache hit) updates the handle via ArcSwap.
103    // All workers share the same Arc<ArcSwap<CostModel>> so a single store()
104    // call in the probe task is immediately visible to every worker.
105    let initial_cost_model = cfg
106        .cost_model_override
107        .unwrap_or_else(|| CostModel::conservative(CostModel::DEFAULT_MAX_WORKSPACE));
108    let cost_model_handle = Arc::new(ArcSwap::from_pointee(initial_cost_model));
109
110    // Request concurrency limiter.  Start with cfg_workers - 1 permits so the
111    // background probe always has a worker slot free.  The probe (or any terminal
112    // probe bypass) calls add_permits(1) to raise to cfg_workers once the probe
113    // lifecycle ends.  Minimum is 1 so a single-worker deployment always accepts
114    // at least one concurrent request (at the cost of a shared probe slot).
115    let initial_permits = cfg.workers.saturating_sub(1).max(1);
116    let request_permits = Arc::new(Semaphore::new(initial_permits));
117
118    let (pool, init_handle) = EmbedPool::spawn(
119        cfg.workers,
120        PathBuf::from(&cfg.cache_dir),
121        WorkerConfig {
122            cost_model: Arc::clone(&cost_model_handle),
123            idle_timeout: cfg.idle_timeout,
124            model_variant: cfg.model_variant,
125            max_seq_length: cfg.max_seq_length,
126            intra_threads: cfg.intra_threads,
127        },
128    );
129
130    let state = Arc::new(AppState {
131        pool,
132        ready: AtomicBool::new(false),
133        max_batch: cfg.max_batch,
134        total_workers: cfg.workers,
135        max_seq_length: cfg.max_seq_length,
136        tuning: std::sync::OnceLock::new(),
137        cost_model: cost_model_handle,
138        probe_status: AtomicU8::new(ProbeStatus::Running as u8),
139        request_permits,
140    });
141
142    let app = build_router(Arc::clone(&state));
143
144    let listener = tokio::net::TcpListener::bind(&cfg.bind_addr).await?;
145    info!(bind = %cfg.bind_addr, "Listening");
146
147    let state_for_readiness = Arc::clone(&state);
148    let cfg_max_seq = cfg.max_seq_length;
149    let cfg_workers = cfg.workers;
150    let cfg_safety = cfg.memory_safety_factor;
151    let cost_model_override = cfg.cost_model_override;
152    let cache_dir = PathBuf::from(&cfg.cache_dir);
153    let model_variant_str = cfg.model_variant.to_string();
154
155    tokio::spawn(async move {
156        if let Err(e) = run_readiness_probe(
157            init_handle,
158            state_for_readiness,
159            cfg_max_seq,
160            cfg_workers,
161            cfg_safety,
162            cost_model_override,
163            cache_dir,
164            model_variant_str,
165            disable_probe_cache,
166        )
167        .await
168        {
169            tracing::error!("{e}");
170            std::process::exit(1);
171        }
172    });
173
174    // Periodic heartbeat — logs RSS, worker counts, queue depth, and permits
175    // at a fixed interval so dashboards can detect slow leaks or saturation.
176    let heartbeat_secs = cfg.heartbeat_secs;
177    if heartbeat_secs > 0 {
178        let state_hb = Arc::clone(&state);
179        tokio::spawn(async move {
180            let mut tick = tokio::time::interval(Duration::from_secs(heartbeat_secs));
181            // Skip the first (immediate) tick so we don't log at t=0 before
182            // the server has finished starting up.
183            tick.tick().await;
184            loop {
185                tick.tick().await;
186                let rss_mb = sysinfo::read_process_rss_bytes().unwrap_or(0) / (1024 * 1024);
187                info!(
188                    rss_mb,
189                    live_workers = state_hb.pool.live_worker_count(),
190                    loaded_workers = state_hb.pool.loaded_worker_count(),
191                    queue_depth = state_hb.pool.queue_depth(),
192                    available_permits = state_hb.request_permits.available_permits(),
193                    probe_status =
194                        ProbeStatus::from_u8(state_hb.probe_status.load(Ordering::Acquire))
195                            .as_str(),
196                    "heartbeat"
197                );
198            }
199        });
200    }
201
202    axum::serve(listener, app).await?;
203    Ok(())
204}