bge_m3_embedding_server/embedder/
worker.rs1use std::path::PathBuf;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21
22use anyhow::Result;
23use arc_swap::ArcSwap;
24use ort::value::TensorRef;
25use tokio::runtime::Handle;
26use tokio::sync::{mpsc, Mutex};
27use tracing::{info, info_span};
28
29use super::dense::embed_dense;
30use super::dual::embed_both;
31use super::error::ort_err;
32use super::session::load_models;
33use super::sparse::embed_sparse;
34use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
35use super::types::{EmbedRequest, ProbeResult};
36use crate::binpack::CostModel;
37use crate::config::ModelVariant;
38use crate::sysinfo;
39
40pub(super) struct WorkerGuard(pub Arc<AtomicUsize>);
41
42impl Drop for WorkerGuard {
43 fn drop(&mut self) {
44 let prev = self.0.fetch_sub(1, Ordering::AcqRel);
45 let live_after_drop = prev.saturating_sub(1);
46 if live_after_drop == 0 {
47 tracing::error!("All embedding workers have exited — pool is degraded");
48 } else {
49 tracing::warn!(live_after_drop, "Embedding worker exited");
50 }
51 }
52}
53
54#[derive(Clone)]
61pub struct WorkerConfig {
62 pub cost_model: Arc<ArcSwap<CostModel>>,
68 pub idle_timeout: Option<Duration>,
70 pub model_variant: ModelVariant,
72 pub max_seq_length: usize,
74 pub intra_threads: usize,
78}
79
80pub(crate) fn probe_run_dense(
86 session: &mut ort::session::Session,
87 ids_array: &ndarray::Array2<i64>,
88 mask_array: &ndarray::Array2<i64>,
89) -> Result<ProbeResult> {
90 let rss_before = sysinfo::read_process_rss_bytes().unwrap_or(0);
91
92 let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
93 let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
94
95 let _outputs = session
97 .run(ort::inputs! {
98 "input_ids" => ids_tensor,
99 "attention_mask" => mask_tensor,
100 })
101 .map_err(ort_err)?;
102
103 let rss_after = sysinfo::read_process_rss_bytes().unwrap_or(rss_before);
104
105 Ok(ProbeResult {
106 rss_before,
107 rss_after,
108 })
109}
110
111fn run_probe_batch(
114 session: &mut ort::session::Session,
115 tokenizer: &tokenizers::Tokenizer,
116 texts: &[String],
117) -> Result<ProbeResult> {
118 let encodings = tokenize_no_pad(tokenizer, texts)?;
119 let pad_to = encodings
120 .iter()
121 .map(|e| e.get_ids().len())
122 .max()
123 .unwrap_or(1)
124 .max(1);
125 let indices: Vec<usize> = (0..texts.len()).collect();
126 let (ids_array, mask_array) = build_chunk_arrays(&encodings, &indices, pad_to)?;
127 probe_run_dense(session, &ids_array, &mask_array)
128}
129
130#[allow(clippy::needless_pass_by_value, clippy::too_many_lines)]
131pub(super) fn run_worker(
132 id: usize,
133 cache_dir: PathBuf,
134 rx: Arc<Mutex<mpsc::Receiver<EmbedRequest>>>,
135 ready_tx: mpsc::Sender<Result<usize>>,
136 live_workers: Arc<AtomicUsize>,
137 loaded_workers: Arc<AtomicUsize>,
138 config: WorkerConfig,
139) -> Result<()> {
140 let _guard = WorkerGuard(Arc::clone(&live_workers));
141 let span = info_span!("worker", id = id);
142 let _span_guard = span.enter();
143
144 info!("Loading models (worker {id})...");
145 let load_start = std::time::Instant::now();
146 let rt = Handle::current();
147
148 let pre_load_rss = sysinfo::read_process_rss_bytes().unwrap_or(0);
152 let initial_models = match load_models(
153 &cache_dir,
154 id == 0,
155 config.model_variant,
156 config.max_seq_length,
157 config.intra_threads,
158 ) {
159 Ok(mut models) => {
160 let prime_ids = ndarray::Array2::<i64>::zeros((1, 8));
175 let prime_mask = ndarray::Array2::<i64>::ones((1, 8));
176 match probe_run_dense(&mut models.0, &prime_ids, &prime_mask) {
177 Ok(_) => {
178 tracing::debug!("Worker {id} arena primed");
179 }
180 Err(e) => {
181 tracing::warn!(
182 error = %e,
183 "Worker {id} arena prime failed; first probe shape on this \
184 worker will include arena init delta"
185 );
186 }
187 }
188
189 let post_load_rss = sysinfo::read_process_rss_bytes().unwrap_or(pre_load_rss);
190 tracing::info!(
191 elapsed_ms = load_start.elapsed().as_millis(),
192 rss_delta_mb = post_load_rss.saturating_sub(pre_load_rss) / (1024 * 1024),
193 "Models loaded (worker {id})"
194 );
195 models
196 }
197 Err(e) => {
198 let _ =
199 rt.block_on(ready_tx.send(Err(anyhow::anyhow!("Worker {id} failed to load: {e}"))));
200 return Err(e);
201 }
202 };
203
204 let post_load_rss = sysinfo::read_process_rss_bytes().unwrap_or(pre_load_rss);
207 let rss_delta = post_load_rss.saturating_sub(pre_load_rss);
208 info!(
209 "Worker {id} models loaded — signaling ready (rss_delta_mb={})",
210 rss_delta / (1024 * 1024)
211 );
212 let _ = rt.block_on(ready_tx.send(Ok(rss_delta)));
213
214 let mut models: Option<(ort::session::Session, tokenizers::Tokenizer)> = Some(initial_models);
215
216 info!("Worker {id} entering request loop");
217 loop {
218 let msg = if let Some(timeout) = config.idle_timeout.filter(|_| models.is_some()) {
219 rt.block_on(async {
220 tokio::time::timeout(timeout, async { rx.lock().await.recv().await }).await
221 })
222 } else {
223 rt.block_on(async { Ok(rx.lock().await.recv().await) })
224 };
225
226 match msg {
227 Err(_elapsed) => {
228 models = None;
229 loaded_workers.fetch_sub(1, Ordering::AcqRel);
230 tracing::info!("Worker {id} unloaded models after idle timeout");
231 }
232 Ok(None) => {
233 if models.is_some() {
234 loaded_workers.fetch_sub(1, Ordering::AcqRel);
235 }
236 info!("Worker {id} channel closed, shutting down");
237 break;
238 }
239 Ok(Some(request)) => {
240 if models.is_none() {
241 tracing::info!("Worker {id} reloading models after idle...");
242 let reload_start = std::time::Instant::now();
243 match load_models(
244 &cache_dir,
245 false,
246 config.model_variant,
247 config.max_seq_length,
248 config.intra_threads,
249 ) {
250 Ok(mut m) => {
251 let prime_ids = ndarray::Array2::<i64>::zeros((1, 8));
257 let prime_mask = ndarray::Array2::<i64>::ones((1, 8));
258 if let Err(e) = probe_run_dense(&mut m.0, &prime_ids, &prime_mask) {
259 tracing::warn!(
260 error = %e,
261 "Worker {id} post-reload arena prime failed"
262 );
263 }
264 models = Some(m);
265 loaded_workers.fetch_add(1, Ordering::AcqRel);
266 tracing::info!(
267 elapsed_ms = reload_start.elapsed().as_millis(),
268 "Worker {id} reloaded models"
269 );
270 }
271 Err(e) => {
272 tracing::error!(error = %e, "Worker {id} failed to reload models");
273 let err = anyhow::anyhow!("Model reload failed: {e}");
274 match request {
275 EmbedRequest::Dense { reply, .. } => {
276 let _ = reply.send(Err(err));
277 }
278 EmbedRequest::Sparse { reply, .. } => {
279 let _ = reply.send(Err(err));
280 }
281 EmbedRequest::Both { reply, .. } => {
282 let _ = reply.send(Err(err));
283 }
284 EmbedRequest::Probe { reply, .. } => {
285 let _ = reply.send(Err(err));
286 }
287 }
288 continue;
289 }
290 }
291 }
292
293 let (session, tokenizer) =
294 models.as_mut().expect("models loaded after reload check");
295
296 match request {
297 EmbedRequest::Dense { texts, reply } => {
298 let cm_guard = config.cost_model.load();
302 let result = embed_dense(
303 session,
304 tokenizer,
305 &texts,
306 &cm_guard,
307 config.model_variant,
308 )
309 .map_err(|e| anyhow::anyhow!("Dense embed error: {e}"));
310 if let Ok((_, ref stats)) = result {
311 tracing::info!(
312 worker_id = id,
313 chunks = stats.chunks,
314 max_chunk_seq = stats.max_chunk_seq,
315 total_token_positions = stats.total_token_positions,
316 tokenize_ms = stats.tokenize_ms,
317 inference_ms = stats.inference_ms,
318 "worker: dense embed complete"
319 );
320 }
321 let _ = reply.send(result);
322 }
323 EmbedRequest::Sparse { texts, reply } => {
324 let cm_guard = config.cost_model.load();
325 let result = embed_sparse(
326 session,
327 tokenizer,
328 &texts,
329 &cm_guard,
330 config.model_variant,
331 )
332 .map_err(|e| anyhow::anyhow!("Sparse embed error: {e}"));
333 if let Ok((_, ref stats)) = result {
334 tracing::info!(
335 worker_id = id,
336 chunks = stats.chunks,
337 max_chunk_seq = stats.max_chunk_seq,
338 total_token_positions = stats.total_token_positions,
339 tokenize_ms = stats.tokenize_ms,
340 inference_ms = stats.inference_ms,
341 "worker: sparse embed complete"
342 );
343 }
344 let _ = reply.send(result);
345 }
346 EmbedRequest::Both { texts, reply } => {
347 let cm_guard = config.cost_model.load();
348 let result =
349 embed_both(session, tokenizer, &texts, &cm_guard, config.model_variant)
350 .map_err(|e| anyhow::anyhow!("Dual embed error: {e}"));
351 if let Ok((_, ref stats)) = result {
352 tracing::info!(
353 worker_id = id,
354 chunks = stats.chunks,
355 max_chunk_seq = stats.max_chunk_seq,
356 total_token_positions = stats.total_token_positions,
357 tokenize_ms = stats.tokenize_ms,
358 inference_ms = stats.inference_ms,
359 "worker: both embed complete"
360 );
361 }
362 let _ = reply.send(result);
363 }
364 EmbedRequest::Probe { texts, reply } => {
365 let result = run_probe_batch(session, tokenizer, &texts);
368 let _ = reply.send(result);
369 }
370 }
371 }
372 }
373 }
374
375 Ok(())
376}