Skip to main content

bge_m3_embedding_server/embedder/
dense.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Dense embedding pipeline.
16
17use anyhow::Result;
18use ort::value::TensorRef;
19
20use super::error::ort_err;
21use super::math::normalize_l2;
22use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
23use super::types::EmbedStats;
24use crate::binpack::{bin_pack, CostModel};
25use crate::config::ModelVariant;
26
27/// Produces L2-normalized dense embeddings.
28///
29/// Tokenizes once, then uses the cost model to bin-pack into chunks that fit
30/// within the workspace budget. Results are scattered back to the original
31/// input order.
32#[allow(clippy::cast_possible_truncation)]
33pub(super) fn embed_dense(
34    session: &mut ort::session::Session,
35    tokenizer: &tokenizers::Tokenizer,
36    texts: &[String],
37    cost_model: &CostModel,
38    model_variant: ModelVariant,
39) -> Result<(Vec<Vec<f32>>, EmbedStats)> {
40    let tokenize_start = std::time::Instant::now();
41    let encodings = tokenize_no_pad(tokenizer, texts)?;
42    let seq_lens: Vec<usize> = encodings.iter().map(|e| e.get_ids().len()).collect();
43    let tokenize_ms = u64::try_from(tokenize_start.elapsed().as_millis()).unwrap_or(u64::MAX);
44
45    let total_token_positions: usize = seq_lens.iter().sum();
46    let chunks = bin_pack(&seq_lens, cost_model);
47
48    // Pre-allocate output slots (one per input text, filled below).
49    let mut all_embeddings: Vec<Vec<f32>> = (0..texts.len()).map(|_| Vec::new()).collect();
50
51    let mut max_chunk_seq: usize = 0;
52    let mut inference_ms: u64 = 0;
53
54    for (chunk_idx, chunk_indices) in chunks.iter().enumerate() {
55        let chunk_max = chunk_indices
56            .iter()
57            .map(|&i| seq_lens[i])
58            .max()
59            .unwrap_or(1)
60            .max(1); // guard: at least 1 to avoid 0-dim tensors
61
62        max_chunk_seq = max_chunk_seq.max(chunk_max);
63
64        let (ids_array, mask_array) = build_chunk_arrays(&encodings, chunk_indices, chunk_max)?;
65        let batch_len = ids_array.nrows();
66
67        let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
68        let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
69
70        let chunk_start = std::time::Instant::now();
71        let outputs = {
72            let _span = tracing::debug_span!(
73                "chunk",
74                chunk_idx,
75                batch = chunk_indices.len(),
76                max_seq = chunk_max
77            )
78            .entered();
79            session
80                .run(ort::inputs! {
81                    "input_ids" => ids_tensor,
82                    "attention_mask" => mask_tensor,
83                })
84                .map_err(ort_err)?
85        };
86        let chunk_ms = u64::try_from(chunk_start.elapsed().as_millis()).unwrap_or(u64::MAX);
87        inference_ms = inference_ms.saturating_add(chunk_ms);
88        tracing::debug!(
89            chunk_idx,
90            batch = chunk_indices.len(),
91            max_seq = chunk_max,
92            elapsed_ms = chunk_ms,
93            "dense chunk inference complete"
94        );
95
96        // FP32: sentence_embedding [batch, 1024] — pre-pooled CLS output.
97        // FP16/INT8: last_hidden_state [batch, seq, 1024] — CLS token at position 0.
98        let emb: ndarray::ArrayD<f32> = match model_variant {
99            ModelVariant::Fp32 => outputs["sentence_embedding"]
100                .try_extract_array::<f32>()
101                .map_err(ort_err)?
102                .to_owned(),
103            ModelVariant::Fp16 | ModelVariant::Int8 => {
104                let lhs = outputs["last_hidden_state"]
105                    .try_extract_array::<f32>()
106                    .map_err(ort_err)?;
107                lhs.index_axis(ndarray::Axis(1), 0).to_owned()
108            }
109        };
110
111        for (chunk_pos, &orig_idx) in chunk_indices.iter().enumerate() {
112            debug_assert!(chunk_pos < batch_len, "chunk_pos must be within batch");
113            let row = emb.index_axis(ndarray::Axis(0), chunk_pos);
114            let mut vec = row
115                .as_slice()
116                .expect("embedding should be contiguous")
117                .to_vec();
118            normalize_l2(&mut vec);
119            all_embeddings[orig_idx] = vec;
120        }
121    }
122
123    let stats = EmbedStats {
124        chunks: chunks.len(),
125        max_chunk_seq,
126        total_token_positions,
127        tokenize_ms,
128        inference_ms,
129    };
130
131    Ok((all_embeddings, stats))
132}