Skip to main content

bge_m3_embedding_server/embedder/
sparse.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! BGE-M3 SPLADE-style sparse embedding pipeline.
16
17use anyhow::Result;
18use ort::value::TensorRef;
19
20use super::error::ort_err;
21use super::math::{sparse_maxpool, sparse_project};
22use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
23use super::types::{EmbedStats, SparseEmbedding};
24use crate::binpack::{bin_pack, CostModel};
25use crate::config::ModelVariant;
26
27/// Produces sparse embeddings via the BGE-M3 sparse-linear projection.
28///
29/// Tokenizes once, then uses the cost model to bin-pack into chunks. Results
30/// are scattered back to the original input order.
31#[allow(clippy::cast_possible_truncation)]
32pub(super) fn embed_sparse(
33    session: &mut ort::session::Session,
34    tokenizer: &tokenizers::Tokenizer,
35    texts: &[String],
36    cost_model: &CostModel,
37    model_variant: ModelVariant,
38) -> Result<(Vec<SparseEmbedding>, EmbedStats)> {
39    let (weight, bias) = crate::weights::sparse_linear();
40    let weight_view = weight.view();
41
42    let tokenize_start = std::time::Instant::now();
43    let encodings = tokenize_no_pad(tokenizer, texts)?;
44    let seq_lens: Vec<usize> = encodings.iter().map(|e| e.get_ids().len()).collect();
45    let tokenize_ms = u64::try_from(tokenize_start.elapsed().as_millis()).unwrap_or(u64::MAX);
46
47    let total_token_positions: usize = seq_lens.iter().sum();
48    let chunks = bin_pack(&seq_lens, cost_model);
49
50    let mut all_sparse: Vec<Option<SparseEmbedding>> = (0..texts.len()).map(|_| None).collect();
51
52    let mut max_chunk_seq: usize = 0;
53    let mut inference_ms: u64 = 0;
54
55    for (chunk_idx, chunk_indices) in chunks.iter().enumerate() {
56        let chunk_max = chunk_indices
57            .iter()
58            .map(|&i| seq_lens[i])
59            .max()
60            .unwrap_or(1)
61            .max(1);
62
63        max_chunk_seq = max_chunk_seq.max(chunk_max);
64
65        let (ids_array, mask_array) = build_chunk_arrays(&encodings, chunk_indices, chunk_max)?;
66
67        let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
68        let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
69
70        let chunk_start = std::time::Instant::now();
71        let outputs = {
72            let _span = tracing::debug_span!(
73                "chunk",
74                chunk_idx,
75                batch = chunk_indices.len(),
76                max_seq = chunk_max
77            )
78            .entered();
79            session
80                .run(ort::inputs! {
81                    "input_ids" => ids_tensor,
82                    "attention_mask" => mask_tensor,
83                })
84                .map_err(ort_err)?
85        };
86        let chunk_ms = u64::try_from(chunk_start.elapsed().as_millis()).unwrap_or(u64::MAX);
87        inference_ms = inference_ms.saturating_add(chunk_ms);
88        tracing::debug!(
89            chunk_idx,
90            batch = chunk_indices.len(),
91            max_seq = chunk_max,
92            elapsed_ms = chunk_ms,
93            "sparse chunk inference complete"
94        );
95
96        // FP32: token_embeddings [batch, seq, 1024].
97        // FP16/INT8: last_hidden_state [batch, seq, 1024] — same shape, different key.
98        let token_emb = match model_variant {
99            ModelVariant::Fp32 => outputs["token_embeddings"]
100                .try_extract_array::<f32>()
101                .map_err(ort_err)?,
102            ModelVariant::Fp16 | ModelVariant::Int8 => outputs["last_hidden_state"]
103                .try_extract_array::<f32>()
104                .map_err(ort_err)?,
105        };
106
107        for (chunk_pos, &orig_idx) in chunk_indices.iter().enumerate() {
108            let enc = &encodings[orig_idx];
109            let ids = enc.get_ids();
110            let mask = enc.get_attention_mask();
111            let batch_hidden = token_emb.index_axis(ndarray::Axis(0), chunk_pos);
112
113            let scores: Vec<f32> = (0..ids.len())
114                .map(|j| {
115                    let hidden = batch_hidden.index_axis(ndarray::Axis(0), j);
116                    let hidden_slice = hidden
117                        .as_slice()
118                        .expect("hidden state should be contiguous");
119                    sparse_project(hidden_slice, &weight_view, *bias)
120                })
121                .collect();
122
123            let (indices, values) = sparse_maxpool(ids, mask, &scores);
124            all_sparse[orig_idx] = Some(SparseEmbedding { indices, values });
125        }
126    }
127
128    let stats = EmbedStats {
129        chunks: chunks.len(),
130        max_chunk_seq,
131        total_token_positions,
132        tokenize_ms,
133        inference_ms,
134    };
135
136    Ok((
137        all_sparse
138            .into_iter()
139            .map(|s| s.expect("every slot must be filled"))
140            .collect(),
141        stats,
142    ))
143}