bge_m3_embedding_server/embedder/math.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Pure dense/sparse math helpers (testable without ORT).
16
17use std::collections::HashMap;
18
19use ndarray::ArrayView1;
20
21/// CLS, PAD, SEP/EOS, UNK — excluded from sparse output.
22pub(super) const SPECIAL_TOKENS: [u32; 4] = [0, 1, 2, 3];
23
24/// L2-normalizes `vec` in place. If the norm is zero, leaves the vector unchanged.
25pub(super) fn normalize_l2(vec: &mut [f32]) {
26 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
27 if norm > 0.0 {
28 for x in vec.iter_mut() {
29 *x /= norm;
30 }
31 }
32}
33
34/// Projects a single token's hidden state through the sparse-linear layer.
35///
36/// Returns `max(0, dot(hidden, weight) + bias)` (ReLU-gated score).
37pub(super) fn sparse_project(hidden: &[f32], weight: &ArrayView1<f32>, bias: f32) -> f32 {
38 let hidden_view = ArrayView1::from(hidden);
39 (hidden_view.dot(weight) + bias).max(0.0)
40}
41
42/// Max-pools sparse scores by vocabulary token ID, excluding special tokens
43/// and tokens masked by the attention mask.
44///
45/// Returns sorted `(indices, values)` vectors suitable for `SparseEmbedding`.
46pub(super) fn sparse_maxpool(ids: &[u32], mask: &[u32], scores: &[f32]) -> (Vec<usize>, Vec<f32>) {
47 let mut token_weights: HashMap<usize, f32> = HashMap::new();
48
49 for (j, &token_id) in ids.iter().enumerate() {
50 if mask[j] == 0 {
51 continue;
52 }
53 if SPECIAL_TOKENS.contains(&token_id) {
54 continue;
55 }
56 let score = scores[j];
57 if score > 0.0 {
58 token_weights
59 .entry(token_id as usize)
60 .and_modify(|w| *w = w.max(score))
61 .or_insert(score);
62 }
63 }
64
65 let mut indices: Vec<usize> = token_weights.keys().copied().collect();
66 indices.sort_unstable();
67 let values: Vec<f32> = indices.iter().map(|k| token_weights[k]).collect();
68 (indices, values)
69}
70
71/// Computes the median of a `Vec<usize>` in-place (sorts the slice).
72///
73/// Returns `0` for empty input. For even-length inputs returns the lower
74/// of the two middle elements (no floating-point required).
75pub(super) fn median_usize(values: &mut [usize]) -> usize {
76 if values.is_empty() {
77 return 0;
78 }
79 values.sort_unstable();
80 values[values.len() / 2]
81}