Skip to main content

bge_m3_embedding_server/embedder/
dual.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Paired dense + sparse embedding pipeline (one forward pass per chunk).
16
17use anyhow::Result;
18use ort::value::TensorRef;
19
20use super::error::ort_err;
21use super::math::{normalize_l2, sparse_maxpool, sparse_project};
22use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
23use super::types::{DualEmbedding, EmbedStats, SparseEmbedding};
24use crate::binpack::{bin_pack, CostModel};
25use crate::config::ModelVariant;
26
27/// Produces paired dense + sparse embeddings using **one** `session.run()` per chunk.
28///
29/// Both projections are derived from the same forward pass:
30/// - **FP32**: extracts both `sentence_embedding` (dense) and `token_embeddings`
31///   (sparse base) from the model's dual outputs.
32/// - **FP16/INT8**: extracts dense from the CLS token (position 0) of
33///   `last_hidden_state`, and sparse from the full hidden states of the same
34///   tensor. This avoids a second forward pass.
35///
36/// Numerically equivalent to calling [`super::dense::embed_dense`] and
37/// [`super::sparse::embed_sparse`] separately, within FP rounding tolerance.
38#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
39pub(super) fn embed_both(
40    session: &mut ort::session::Session,
41    tokenizer: &tokenizers::Tokenizer,
42    texts: &[String],
43    cost_model: &CostModel,
44    model_variant: ModelVariant,
45) -> Result<(Vec<DualEmbedding>, EmbedStats)> {
46    let (weight, bias) = crate::weights::sparse_linear();
47    let weight_view = weight.view();
48
49    let tokenize_start = std::time::Instant::now();
50    let encodings = tokenize_no_pad(tokenizer, texts)?;
51    let seq_lens: Vec<usize> = encodings.iter().map(|e| e.get_ids().len()).collect();
52    let tokenize_ms = u64::try_from(tokenize_start.elapsed().as_millis()).unwrap_or(u64::MAX);
53
54    let total_token_positions: usize = seq_lens.iter().sum();
55    let chunks = bin_pack(&seq_lens, cost_model);
56
57    let mut all_dual: Vec<Option<DualEmbedding>> = (0..texts.len()).map(|_| None).collect();
58
59    let mut max_chunk_seq: usize = 0;
60    let mut inference_ms: u64 = 0;
61
62    for (chunk_idx, chunk_indices) in chunks.iter().enumerate() {
63        let chunk_max = chunk_indices
64            .iter()
65            .map(|&i| seq_lens[i])
66            .max()
67            .unwrap_or(1)
68            .max(1);
69
70        max_chunk_seq = max_chunk_seq.max(chunk_max);
71
72        let (ids_array, mask_array) = build_chunk_arrays(&encodings, chunk_indices, chunk_max)?;
73
74        let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
75        let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
76
77        let chunk_start = std::time::Instant::now();
78        let outputs = {
79            let _span = tracing::debug_span!(
80                "chunk",
81                chunk_idx,
82                batch = chunk_indices.len(),
83                max_seq = chunk_max
84            )
85            .entered();
86            session
87                .run(ort::inputs! {
88                    "input_ids" => ids_tensor,
89                    "attention_mask" => mask_tensor,
90                })
91                .map_err(ort_err)?
92        };
93        let chunk_ms = u64::try_from(chunk_start.elapsed().as_millis()).unwrap_or(u64::MAX);
94        inference_ms = inference_ms.saturating_add(chunk_ms);
95        tracing::debug!(
96            chunk_idx,
97            batch = chunk_indices.len(),
98            max_seq = chunk_max,
99            elapsed_ms = chunk_ms,
100            "both chunk inference complete"
101        );
102
103        // Extract dense + token-level hidden states from the same outputs.
104        // FP32: separate sentence_embedding + token_embeddings outputs.
105        // FP16/INT8: derive dense (CLS) and sparse-base from last_hidden_state.
106        let (dense_emb, token_emb) = match model_variant {
107            ModelVariant::Fp32 => {
108                let dense = outputs["sentence_embedding"]
109                    .try_extract_array::<f32>()
110                    .map_err(ort_err)?
111                    .to_owned();
112                let tokens = outputs["token_embeddings"]
113                    .try_extract_array::<f32>()
114                    .map_err(ort_err)?
115                    .to_owned();
116                (dense, tokens)
117            }
118            ModelVariant::Fp16 | ModelVariant::Int8 => {
119                let lhs = outputs["last_hidden_state"]
120                    .try_extract_array::<f32>()
121                    .map_err(ort_err)?;
122                let dense = lhs.index_axis(ndarray::Axis(1), 0).to_owned();
123                let tokens = lhs.to_owned();
124                (dense, tokens)
125            }
126        };
127
128        for (chunk_pos, &orig_idx) in chunk_indices.iter().enumerate() {
129            // Dense: CLS row, L2-normalized.
130            let dense_row = dense_emb.index_axis(ndarray::Axis(0), chunk_pos);
131            let mut dense_vec = dense_row
132                .as_slice()
133                .expect("dense embedding should be contiguous")
134                .to_vec();
135            normalize_l2(&mut dense_vec);
136
137            // Sparse: project each token's hidden state, then max-pool.
138            let enc = &encodings[orig_idx];
139            let ids = enc.get_ids();
140            let mask = enc.get_attention_mask();
141            let batch_hidden = token_emb.index_axis(ndarray::Axis(0), chunk_pos);
142
143            let scores: Vec<f32> = (0..ids.len())
144                .map(|j| {
145                    let hidden = batch_hidden.index_axis(ndarray::Axis(0), j);
146                    let hidden_slice = hidden
147                        .as_slice()
148                        .expect("hidden state should be contiguous");
149                    sparse_project(hidden_slice, &weight_view, *bias)
150                })
151                .collect();
152
153            let (indices, values) = sparse_maxpool(ids, mask, &scores);
154
155            all_dual[orig_idx] = Some(DualEmbedding {
156                dense: dense_vec,
157                sparse: SparseEmbedding { indices, values },
158            });
159        }
160    }
161
162    let stats = EmbedStats {
163        chunks: chunks.len(),
164        max_chunk_seq,
165        total_token_positions,
166        tokenize_ms,
167        inference_ms,
168    };
169
170    Ok((
171        all_dual
172            .into_iter()
173            .map(|d| d.expect("every slot must be filled"))
174            .collect(),
175        stats,
176    ))
177}