bge_m3_embedding_server/
lib.rs1#![warn(missing_docs)]
23#![warn(rustdoc::missing_crate_level_docs)]
24#![warn(rustdoc::unescaped_backticks)]
25#![deny(rustdoc::broken_intra_doc_links)]
26#![deny(rustdoc::invalid_html_tags)]
27#![deny(rustdoc::bare_urls)]
28#![warn(rustdoc::redundant_explicit_links)]
29#![warn(rustdoc::private_doc_tests)]
30
31pub mod binpack;
32pub mod bootstrap;
33pub mod config;
34pub mod embedder;
35pub mod error;
36pub mod handler;
37pub mod models;
38pub mod probe;
39pub mod state;
40pub mod sysinfo;
41pub mod weights;
42
43use std::path::PathBuf;
44use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
45use std::sync::Arc;
46use std::time::Duration;
47
48use arc_swap::ArcSwap;
49use tokio::sync::Semaphore;
50use tracing::info;
51
52use crate::binpack::CostModel;
53use crate::bootstrap::{build_router, run_readiness_probe};
54use crate::config::Config;
55use crate::embedder::{EmbedPool, WorkerConfig};
56use crate::state::{AppState, ProbeStatus};
57
58#[allow(clippy::too_many_lines)]
68pub async fn run() -> anyhow::Result<()> {
69 info!(
70 version = env!("CARGO_PKG_VERSION"),
71 git_sha = env!("BGE_M3_GIT_SHA"),
72 target_arch = std::env::consts::ARCH,
73 target_os = std::env::consts::OS,
74 profile = if cfg!(debug_assertions) {
75 "debug"
76 } else {
77 "release"
78 },
79 "bge-m3-embedding-server build info"
80 );
81
82 let cfg = Config::from_env();
83
84 let disable_probe_cache = std::env::var("BGE_M3_DISABLE_PROBE_CACHE")
85 .is_ok_and(|v| matches!(v.as_str(), "1" | "true" | "yes"));
86
87 info!(
88 bind = %cfg.bind_addr,
89 workers = cfg.workers,
90 max_batch = cfg.max_batch,
91 max_seq_length = cfg.max_seq_length,
92 cache_dir = %cfg.cache_dir,
93 idle_timeout_secs = cfg.idle_timeout.map(|d| d.as_secs()),
94 model_variant = ?cfg.model_variant,
95 memory_safety_factor = cfg.memory_safety_factor,
96 auto_budget = cfg.cost_model_override.is_none(),
97 disable_probe_cache,
98 "Starting bge-m3-embedding-server"
99 );
100
101 let initial_cost_model = cfg
106 .cost_model_override
107 .unwrap_or_else(|| CostModel::conservative(CostModel::DEFAULT_MAX_WORKSPACE));
108 let cost_model_handle = Arc::new(ArcSwap::from_pointee(initial_cost_model));
109
110 let initial_permits = cfg.workers.saturating_sub(1).max(1);
116 let request_permits = Arc::new(Semaphore::new(initial_permits));
117
118 let (pool, init_handle) = EmbedPool::spawn(
119 cfg.workers,
120 PathBuf::from(&cfg.cache_dir),
121 WorkerConfig {
122 cost_model: Arc::clone(&cost_model_handle),
123 idle_timeout: cfg.idle_timeout,
124 model_variant: cfg.model_variant,
125 max_seq_length: cfg.max_seq_length,
126 intra_threads: cfg.intra_threads,
127 },
128 );
129
130 let state = Arc::new(AppState {
131 pool,
132 ready: AtomicBool::new(false),
133 max_batch: cfg.max_batch,
134 total_workers: cfg.workers,
135 max_seq_length: cfg.max_seq_length,
136 tuning: std::sync::OnceLock::new(),
137 cost_model: cost_model_handle,
138 probe_status: AtomicU8::new(ProbeStatus::Running as u8),
139 request_permits,
140 });
141
142 let app = build_router(Arc::clone(&state));
143
144 let listener = tokio::net::TcpListener::bind(&cfg.bind_addr).await?;
145 info!(bind = %cfg.bind_addr, "Listening");
146
147 let state_for_readiness = Arc::clone(&state);
148 let cfg_max_seq = cfg.max_seq_length;
149 let cfg_workers = cfg.workers;
150 let cfg_safety = cfg.memory_safety_factor;
151 let cost_model_override = cfg.cost_model_override;
152 let cache_dir = PathBuf::from(&cfg.cache_dir);
153 let model_variant_str = cfg.model_variant.to_string();
154
155 tokio::spawn(async move {
156 if let Err(e) = run_readiness_probe(
157 init_handle,
158 state_for_readiness,
159 cfg_max_seq,
160 cfg_workers,
161 cfg_safety,
162 cost_model_override,
163 cache_dir,
164 model_variant_str,
165 disable_probe_cache,
166 )
167 .await
168 {
169 tracing::error!("{e}");
170 std::process::exit(1);
171 }
172 });
173
174 let heartbeat_secs = cfg.heartbeat_secs;
177 if heartbeat_secs > 0 {
178 let state_hb = Arc::clone(&state);
179 tokio::spawn(async move {
180 let mut tick = tokio::time::interval(Duration::from_secs(heartbeat_secs));
181 tick.tick().await;
184 loop {
185 tick.tick().await;
186 let rss_mb = sysinfo::read_process_rss_bytes().unwrap_or(0) / (1024 * 1024);
187 info!(
188 rss_mb,
189 live_workers = state_hb.pool.live_worker_count(),
190 loaded_workers = state_hb.pool.loaded_worker_count(),
191 queue_depth = state_hb.pool.queue_depth(),
192 available_permits = state_hb.request_permits.available_permits(),
193 probe_status =
194 ProbeStatus::from_u8(state_hb.probe_status.load(Ordering::Acquire))
195 .as_str(),
196 "heartbeat"
197 );
198 }
199 });
200 }
201
202 axum::serve(listener, app).await?;
203 Ok(())
204}