Skip to main content

bge_m3_embedding_server/
weights.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Bundled BGE-M3 sparse-linear projection weights.
16//!
17//! The weights file (`sparse_linear.safetensors`) is embedded at compile time
18//! via `include_bytes!` and parsed once on first use. All workers share the
19//! same static reference via `sparse_linear`.
20
21use ndarray::Array1;
22use std::sync::OnceLock;
23
24/// Bundled sparse-linear projection weights for BGE-M3 sparse embedding.
25///
26/// # Provenance
27///
28/// Extracted from the fastembed-rs crate (v4) which bundles the same file
29/// from the BAAI/bge-m3 checkpoint. The weights implement the sparse-linear
30/// layer described in the BGE-M3 paper: a single linear projection
31/// `hidden_size → 1` that maps each token's 1024-d hidden state to a scalar
32/// relevance score, followed by `ReLU` activation and max-pooling by vocab ID.
33///
34/// - **Source checkpoint**: `BAAI/bge-m3` (HF commit `5617a9f61b02800`)
35/// - **Tensors**: `weight` shape `[1024]` (F32), `bias` scalar (F32)
36/// - **File SHA-256**: `a2601321f01abbb696d171a58a65ff35be1603d9cbc22c647dfe34b4568dd690`
37/// - **File size**: 4,236 bytes
38static WEIGHTS_BYTES: &[u8] = include_bytes!("sparse_linear.safetensors");
39
40static SPARSE_LINEAR: OnceLock<(Array1<f32>, f32)> = OnceLock::new();
41
42/// Returns the sparse-linear projection weights used by BGE-M3 sparse embedding.
43///
44/// The safetensors file contains a weight vector `[1024]` and a scalar bias.
45/// Parsed once on first call and cached for the lifetime of the process.
46pub(crate) fn sparse_linear() -> &'static (Array1<f32>, f32) {
47    SPARSE_LINEAR.get_or_init(|| {
48        let tensors = safetensors::SafeTensors::deserialize(WEIGHTS_BYTES)
49            .expect("embedded sparse_linear.safetensors must be valid");
50
51        let weight_view = tensors
52            .tensor("weight")
53            .expect("sparse_linear must contain 'weight' tensor");
54        let bias_view = tensors
55            .tensor("bias")
56            .expect("sparse_linear must contain 'bias' tensor");
57
58        let weight_data = weight_view.data();
59        assert_eq!(
60            weight_data.len() % 4,
61            0,
62            "weight tensor byte length must be a multiple of 4, got {}",
63            weight_data.len()
64        );
65        let weight: Vec<f32> = weight_data
66            .chunks_exact(4)
67            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
68            .collect();
69        let bias_data = bias_view.data();
70        assert_eq!(
71            bias_data.len(),
72            4,
73            "sparse_linear bias must be a scalar F32 (4 bytes), got {} bytes",
74            bias_data.len()
75        );
76        let bias = f32::from_le_bytes([bias_data[0], bias_data[1], bias_data[2], bias_data[3]]);
77
78        assert_eq!(weight.len(), 1024, "sparse_linear weight must be [1024]");
79        (Array1::from(weight), bias)
80    })
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn sparse_linear_loads_correct_shape() {
89        let (weight, bias) = sparse_linear();
90        assert_eq!(weight.len(), 1024);
91        // Known bias value from BAAI/bge-m3 sparse_linear.safetensors
92        assert!(
93            (*bias - 0.045_196_53).abs() < 1e-6,
94            "bias should be ~0.04520, got {bias}"
95        );
96        assert!(bias.is_finite(), "bias must be finite");
97        // TST-6: verify all weights are finite and not all-zero
98        assert!(
99            weight.iter().all(|w| w.is_finite()),
100            "all weight elements must be finite"
101        );
102        assert!(
103            weight.iter().any(|&w| w != 0.0),
104            "weight vector must not be all-zero"
105        );
106    }
107
108    #[test]
109    fn sparse_linear_is_idempotent() {
110        let a = sparse_linear();
111        let b = sparse_linear();
112        assert!(std::ptr::eq(a, b), "should return the same cached ref");
113    }
114
115    #[test]
116    fn bundled_file_is_valid_safetensors() {
117        // Verify the embedded bytes parse without panic and contain expected tensors.
118        let tensors = safetensors::SafeTensors::deserialize(WEIGHTS_BYTES)
119            .expect("WEIGHTS_BYTES must be valid safetensors");
120        assert!(tensors.tensor("weight").is_ok(), "must contain 'weight'");
121        assert!(tensors.tensor("bias").is_ok(), "must contain 'bias'");
122    }
123
124    #[test]
125    fn bundled_file_size_matches() {
126        // Size pinned to detect accidental replacement or corruption.
127        assert_eq!(WEIGHTS_BYTES.len(), 4236, "expected 4,236 bytes");
128    }
129
130    #[test]
131    fn bundled_file_sha256_matches() {
132        use sha2::Digest;
133        use std::fmt::Write;
134        // Documented provenance hash — any change to the bundled file must update this.
135        const EXPECTED_SHA256: &str =
136            "a2601321f01abbb696d171a58a65ff35be1603d9cbc22c647dfe34b4568dd690";
137        let digest = {
138            let mut hasher = sha2::Sha256::new();
139            hasher.update(WEIGHTS_BYTES);
140            hasher.finalize()
141        };
142        let hex = digest.iter().fold(String::new(), |mut s, b| {
143            write!(s, "{b:02x}").expect("hex write");
144            s
145        });
146        assert_eq!(
147            hex, EXPECTED_SHA256,
148            "bundled sparse_linear.safetensors SHA-256 mismatch"
149        );
150    }
151}