bge_m3_embedding_server/
weights.rs1use ndarray::Array1;
22use std::sync::OnceLock;
23
24static WEIGHTS_BYTES: &[u8] = include_bytes!("sparse_linear.safetensors");
39
40static SPARSE_LINEAR: OnceLock<(Array1<f32>, f32)> = OnceLock::new();
41
42pub(crate) fn sparse_linear() -> &'static (Array1<f32>, f32) {
47 SPARSE_LINEAR.get_or_init(|| {
48 let tensors = safetensors::SafeTensors::deserialize(WEIGHTS_BYTES)
49 .expect("embedded sparse_linear.safetensors must be valid");
50
51 let weight_view = tensors
52 .tensor("weight")
53 .expect("sparse_linear must contain 'weight' tensor");
54 let bias_view = tensors
55 .tensor("bias")
56 .expect("sparse_linear must contain 'bias' tensor");
57
58 let weight_data = weight_view.data();
59 assert_eq!(
60 weight_data.len() % 4,
61 0,
62 "weight tensor byte length must be a multiple of 4, got {}",
63 weight_data.len()
64 );
65 let weight: Vec<f32> = weight_data
66 .chunks_exact(4)
67 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
68 .collect();
69 let bias_data = bias_view.data();
70 assert_eq!(
71 bias_data.len(),
72 4,
73 "sparse_linear bias must be a scalar F32 (4 bytes), got {} bytes",
74 bias_data.len()
75 );
76 let bias = f32::from_le_bytes([bias_data[0], bias_data[1], bias_data[2], bias_data[3]]);
77
78 assert_eq!(weight.len(), 1024, "sparse_linear weight must be [1024]");
79 (Array1::from(weight), bias)
80 })
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn sparse_linear_loads_correct_shape() {
89 let (weight, bias) = sparse_linear();
90 assert_eq!(weight.len(), 1024);
91 assert!(
93 (*bias - 0.045_196_53).abs() < 1e-6,
94 "bias should be ~0.04520, got {bias}"
95 );
96 assert!(bias.is_finite(), "bias must be finite");
97 assert!(
99 weight.iter().all(|w| w.is_finite()),
100 "all weight elements must be finite"
101 );
102 assert!(
103 weight.iter().any(|&w| w != 0.0),
104 "weight vector must not be all-zero"
105 );
106 }
107
108 #[test]
109 fn sparse_linear_is_idempotent() {
110 let a = sparse_linear();
111 let b = sparse_linear();
112 assert!(std::ptr::eq(a, b), "should return the same cached ref");
113 }
114
115 #[test]
116 fn bundled_file_is_valid_safetensors() {
117 let tensors = safetensors::SafeTensors::deserialize(WEIGHTS_BYTES)
119 .expect("WEIGHTS_BYTES must be valid safetensors");
120 assert!(tensors.tensor("weight").is_ok(), "must contain 'weight'");
121 assert!(tensors.tensor("bias").is_ok(), "must contain 'bias'");
122 }
123
124 #[test]
125 fn bundled_file_size_matches() {
126 assert_eq!(WEIGHTS_BYTES.len(), 4236, "expected 4,236 bytes");
128 }
129
130 #[test]
131 fn bundled_file_sha256_matches() {
132 use sha2::Digest;
133 use std::fmt::Write;
134 const EXPECTED_SHA256: &str =
136 "a2601321f01abbb696d171a58a65ff35be1603d9cbc22c647dfe34b4568dd690";
137 let digest = {
138 let mut hasher = sha2::Sha256::new();
139 hasher.update(WEIGHTS_BYTES);
140 hasher.finalize()
141 };
142 let hex = digest.iter().fold(String::new(), |mut s, b| {
143 write!(s, "{b:02x}").expect("hex write");
144 s
145 });
146 assert_eq!(
147 hex, EXPECTED_SHA256,
148 "bundled sparse_linear.safetensors SHA-256 mismatch"
149 );
150 }
151}