Skip to main content

bge_m3_embedding_server/embedder/
math.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Pure dense/sparse math helpers (testable without ORT).
16
17use std::collections::HashMap;
18
19use ndarray::ArrayView1;
20
21/// CLS, PAD, SEP/EOS, UNK — excluded from sparse output.
22pub(super) const SPECIAL_TOKENS: [u32; 4] = [0, 1, 2, 3];
23
24/// L2-normalizes `vec` in place. If the norm is zero, leaves the vector unchanged.
25pub(super) fn normalize_l2(vec: &mut [f32]) {
26    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
27    if norm > 0.0 {
28        for x in vec.iter_mut() {
29            *x /= norm;
30        }
31    }
32}
33
34/// Projects a single token's hidden state through the sparse-linear layer.
35///
36/// Returns `max(0, dot(hidden, weight) + bias)` (ReLU-gated score).
37pub(super) fn sparse_project(hidden: &[f32], weight: &ArrayView1<f32>, bias: f32) -> f32 {
38    let hidden_view = ArrayView1::from(hidden);
39    (hidden_view.dot(weight) + bias).max(0.0)
40}
41
42/// Max-pools sparse scores by vocabulary token ID, excluding special tokens
43/// and tokens masked by the attention mask.
44///
45/// Returns sorted `(indices, values)` vectors suitable for `SparseEmbedding`.
46pub(super) fn sparse_maxpool(ids: &[u32], mask: &[u32], scores: &[f32]) -> (Vec<usize>, Vec<f32>) {
47    let mut token_weights: HashMap<usize, f32> = HashMap::new();
48
49    for (j, &token_id) in ids.iter().enumerate() {
50        if mask[j] == 0 {
51            continue;
52        }
53        if SPECIAL_TOKENS.contains(&token_id) {
54            continue;
55        }
56        let score = scores[j];
57        if score > 0.0 {
58            token_weights
59                .entry(token_id as usize)
60                .and_modify(|w| *w = w.max(score))
61                .or_insert(score);
62        }
63    }
64
65    let mut indices: Vec<usize> = token_weights.keys().copied().collect();
66    indices.sort_unstable();
67    let values: Vec<f32> = indices.iter().map(|k| token_weights[k]).collect();
68    (indices, values)
69}
70
71/// Computes the median of a `Vec<usize>` in-place (sorts the slice).
72///
73/// Returns `0` for empty input. For even-length inputs returns the lower
74/// of the two middle elements (no floating-point required).
75pub(super) fn median_usize(values: &mut [usize]) -> usize {
76    if values.is_empty() {
77        return 0;
78    }
79    values.sort_unstable();
80    values[values.len() / 2]
81}