bge_m3_embedding_server/embedder/
dual.rs1use anyhow::Result;
18use ort::value::TensorRef;
19
20use super::error::ort_err;
21use super::math::{normalize_l2, sparse_maxpool, sparse_project};
22use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
23use super::types::{DualEmbedding, EmbedStats, SparseEmbedding};
24use crate::binpack::{bin_pack, CostModel};
25use crate::config::ModelVariant;
26
27#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
39pub(super) fn embed_both(
40 session: &mut ort::session::Session,
41 tokenizer: &tokenizers::Tokenizer,
42 texts: &[String],
43 cost_model: &CostModel,
44 model_variant: ModelVariant,
45) -> Result<(Vec<DualEmbedding>, EmbedStats)> {
46 let (weight, bias) = crate::weights::sparse_linear();
47 let weight_view = weight.view();
48
49 let tokenize_start = std::time::Instant::now();
50 let encodings = tokenize_no_pad(tokenizer, texts)?;
51 let seq_lens: Vec<usize> = encodings.iter().map(|e| e.get_ids().len()).collect();
52 let tokenize_ms = u64::try_from(tokenize_start.elapsed().as_millis()).unwrap_or(u64::MAX);
53
54 let total_token_positions: usize = seq_lens.iter().sum();
55 let chunks = bin_pack(&seq_lens, cost_model);
56
57 let mut all_dual: Vec<Option<DualEmbedding>> = (0..texts.len()).map(|_| None).collect();
58
59 let mut max_chunk_seq: usize = 0;
60 let mut inference_ms: u64 = 0;
61
62 for (chunk_idx, chunk_indices) in chunks.iter().enumerate() {
63 let chunk_max = chunk_indices
64 .iter()
65 .map(|&i| seq_lens[i])
66 .max()
67 .unwrap_or(1)
68 .max(1);
69
70 max_chunk_seq = max_chunk_seq.max(chunk_max);
71
72 let (ids_array, mask_array) = build_chunk_arrays(&encodings, chunk_indices, chunk_max)?;
73
74 let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
75 let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
76
77 let chunk_start = std::time::Instant::now();
78 let outputs = {
79 let _span = tracing::debug_span!(
80 "chunk",
81 chunk_idx,
82 batch = chunk_indices.len(),
83 max_seq = chunk_max
84 )
85 .entered();
86 session
87 .run(ort::inputs! {
88 "input_ids" => ids_tensor,
89 "attention_mask" => mask_tensor,
90 })
91 .map_err(ort_err)?
92 };
93 let chunk_ms = u64::try_from(chunk_start.elapsed().as_millis()).unwrap_or(u64::MAX);
94 inference_ms = inference_ms.saturating_add(chunk_ms);
95 tracing::debug!(
96 chunk_idx,
97 batch = chunk_indices.len(),
98 max_seq = chunk_max,
99 elapsed_ms = chunk_ms,
100 "both chunk inference complete"
101 );
102
103 let (dense_emb, token_emb) = match model_variant {
107 ModelVariant::Fp32 => {
108 let dense = outputs["sentence_embedding"]
109 .try_extract_array::<f32>()
110 .map_err(ort_err)?
111 .to_owned();
112 let tokens = outputs["token_embeddings"]
113 .try_extract_array::<f32>()
114 .map_err(ort_err)?
115 .to_owned();
116 (dense, tokens)
117 }
118 ModelVariant::Fp16 | ModelVariant::Int8 => {
119 let lhs = outputs["last_hidden_state"]
120 .try_extract_array::<f32>()
121 .map_err(ort_err)?;
122 let dense = lhs.index_axis(ndarray::Axis(1), 0).to_owned();
123 let tokens = lhs.to_owned();
124 (dense, tokens)
125 }
126 };
127
128 for (chunk_pos, &orig_idx) in chunk_indices.iter().enumerate() {
129 let dense_row = dense_emb.index_axis(ndarray::Axis(0), chunk_pos);
131 let mut dense_vec = dense_row
132 .as_slice()
133 .expect("dense embedding should be contiguous")
134 .to_vec();
135 normalize_l2(&mut dense_vec);
136
137 let enc = &encodings[orig_idx];
139 let ids = enc.get_ids();
140 let mask = enc.get_attention_mask();
141 let batch_hidden = token_emb.index_axis(ndarray::Axis(0), chunk_pos);
142
143 let scores: Vec<f32> = (0..ids.len())
144 .map(|j| {
145 let hidden = batch_hidden.index_axis(ndarray::Axis(0), j);
146 let hidden_slice = hidden
147 .as_slice()
148 .expect("hidden state should be contiguous");
149 sparse_project(hidden_slice, &weight_view, *bias)
150 })
151 .collect();
152
153 let (indices, values) = sparse_maxpool(ids, mask, &scores);
154
155 all_dual[orig_idx] = Some(DualEmbedding {
156 dense: dense_vec,
157 sparse: SparseEmbedding { indices, values },
158 });
159 }
160 }
161
162 let stats = EmbedStats {
163 chunks: chunks.len(),
164 max_chunk_seq,
165 total_token_positions,
166 tokenize_ms,
167 inference_ms,
168 };
169
170 Ok((
171 all_dual
172 .into_iter()
173 .map(|d| d.expect("every slot must be filled"))
174 .collect(),
175 stats,
176 ))
177}