bge_m3_embedding_server/embedder/
sparse.rs1use anyhow::Result;
18use ort::value::TensorRef;
19
20use super::error::ort_err;
21use super::math::{sparse_maxpool, sparse_project};
22use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
23use super::types::{EmbedStats, SparseEmbedding};
24use crate::binpack::{bin_pack, CostModel};
25use crate::config::ModelVariant;
26
27#[allow(clippy::cast_possible_truncation)]
32pub(super) fn embed_sparse(
33 session: &mut ort::session::Session,
34 tokenizer: &tokenizers::Tokenizer,
35 texts: &[String],
36 cost_model: &CostModel,
37 model_variant: ModelVariant,
38) -> Result<(Vec<SparseEmbedding>, EmbedStats)> {
39 let (weight, bias) = crate::weights::sparse_linear();
40 let weight_view = weight.view();
41
42 let tokenize_start = std::time::Instant::now();
43 let encodings = tokenize_no_pad(tokenizer, texts)?;
44 let seq_lens: Vec<usize> = encodings.iter().map(|e| e.get_ids().len()).collect();
45 let tokenize_ms = u64::try_from(tokenize_start.elapsed().as_millis()).unwrap_or(u64::MAX);
46
47 let total_token_positions: usize = seq_lens.iter().sum();
48 let chunks = bin_pack(&seq_lens, cost_model);
49
50 let mut all_sparse: Vec<Option<SparseEmbedding>> = (0..texts.len()).map(|_| None).collect();
51
52 let mut max_chunk_seq: usize = 0;
53 let mut inference_ms: u64 = 0;
54
55 for (chunk_idx, chunk_indices) in chunks.iter().enumerate() {
56 let chunk_max = chunk_indices
57 .iter()
58 .map(|&i| seq_lens[i])
59 .max()
60 .unwrap_or(1)
61 .max(1);
62
63 max_chunk_seq = max_chunk_seq.max(chunk_max);
64
65 let (ids_array, mask_array) = build_chunk_arrays(&encodings, chunk_indices, chunk_max)?;
66
67 let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
68 let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
69
70 let chunk_start = std::time::Instant::now();
71 let outputs = {
72 let _span = tracing::debug_span!(
73 "chunk",
74 chunk_idx,
75 batch = chunk_indices.len(),
76 max_seq = chunk_max
77 )
78 .entered();
79 session
80 .run(ort::inputs! {
81 "input_ids" => ids_tensor,
82 "attention_mask" => mask_tensor,
83 })
84 .map_err(ort_err)?
85 };
86 let chunk_ms = u64::try_from(chunk_start.elapsed().as_millis()).unwrap_or(u64::MAX);
87 inference_ms = inference_ms.saturating_add(chunk_ms);
88 tracing::debug!(
89 chunk_idx,
90 batch = chunk_indices.len(),
91 max_seq = chunk_max,
92 elapsed_ms = chunk_ms,
93 "sparse chunk inference complete"
94 );
95
96 let token_emb = match model_variant {
99 ModelVariant::Fp32 => outputs["token_embeddings"]
100 .try_extract_array::<f32>()
101 .map_err(ort_err)?,
102 ModelVariant::Fp16 | ModelVariant::Int8 => outputs["last_hidden_state"]
103 .try_extract_array::<f32>()
104 .map_err(ort_err)?,
105 };
106
107 for (chunk_pos, &orig_idx) in chunk_indices.iter().enumerate() {
108 let enc = &encodings[orig_idx];
109 let ids = enc.get_ids();
110 let mask = enc.get_attention_mask();
111 let batch_hidden = token_emb.index_axis(ndarray::Axis(0), chunk_pos);
112
113 let scores: Vec<f32> = (0..ids.len())
114 .map(|j| {
115 let hidden = batch_hidden.index_axis(ndarray::Axis(0), j);
116 let hidden_slice = hidden
117 .as_slice()
118 .expect("hidden state should be contiguous");
119 sparse_project(hidden_slice, &weight_view, *bias)
120 })
121 .collect();
122
123 let (indices, values) = sparse_maxpool(ids, mask, &scores);
124 all_sparse[orig_idx] = Some(SparseEmbedding { indices, values });
125 }
126 }
127
128 let stats = EmbedStats {
129 chunks: chunks.len(),
130 max_chunk_seq,
131 total_token_positions,
132 tokenize_ms,
133 inference_ms,
134 };
135
136 Ok((
137 all_sparse
138 .into_iter()
139 .map(|s| s.expect("every slot must be filled"))
140 .collect(),
141 stats,
142 ))
143}