bge_m3_embedding_server/embedder/
dense.rs1use anyhow::Result;
18use ort::value::TensorRef;
19
20use super::error::ort_err;
21use super::math::normalize_l2;
22use super::tokenize::{build_chunk_arrays, tokenize_no_pad};
23use super::types::EmbedStats;
24use crate::binpack::{bin_pack, CostModel};
25use crate::config::ModelVariant;
26
27#[allow(clippy::cast_possible_truncation)]
33pub(super) fn embed_dense(
34 session: &mut ort::session::Session,
35 tokenizer: &tokenizers::Tokenizer,
36 texts: &[String],
37 cost_model: &CostModel,
38 model_variant: ModelVariant,
39) -> Result<(Vec<Vec<f32>>, EmbedStats)> {
40 let tokenize_start = std::time::Instant::now();
41 let encodings = tokenize_no_pad(tokenizer, texts)?;
42 let seq_lens: Vec<usize> = encodings.iter().map(|e| e.get_ids().len()).collect();
43 let tokenize_ms = u64::try_from(tokenize_start.elapsed().as_millis()).unwrap_or(u64::MAX);
44
45 let total_token_positions: usize = seq_lens.iter().sum();
46 let chunks = bin_pack(&seq_lens, cost_model);
47
48 let mut all_embeddings: Vec<Vec<f32>> = (0..texts.len()).map(|_| Vec::new()).collect();
50
51 let mut max_chunk_seq: usize = 0;
52 let mut inference_ms: u64 = 0;
53
54 for (chunk_idx, chunk_indices) in chunks.iter().enumerate() {
55 let chunk_max = chunk_indices
56 .iter()
57 .map(|&i| seq_lens[i])
58 .max()
59 .unwrap_or(1)
60 .max(1); max_chunk_seq = max_chunk_seq.max(chunk_max);
63
64 let (ids_array, mask_array) = build_chunk_arrays(&encodings, chunk_indices, chunk_max)?;
65 let batch_len = ids_array.nrows();
66
67 let ids_tensor = TensorRef::from_array_view(ids_array.view()).map_err(ort_err)?;
68 let mask_tensor = TensorRef::from_array_view(mask_array.view()).map_err(ort_err)?;
69
70 let chunk_start = std::time::Instant::now();
71 let outputs = {
72 let _span = tracing::debug_span!(
73 "chunk",
74 chunk_idx,
75 batch = chunk_indices.len(),
76 max_seq = chunk_max
77 )
78 .entered();
79 session
80 .run(ort::inputs! {
81 "input_ids" => ids_tensor,
82 "attention_mask" => mask_tensor,
83 })
84 .map_err(ort_err)?
85 };
86 let chunk_ms = u64::try_from(chunk_start.elapsed().as_millis()).unwrap_or(u64::MAX);
87 inference_ms = inference_ms.saturating_add(chunk_ms);
88 tracing::debug!(
89 chunk_idx,
90 batch = chunk_indices.len(),
91 max_seq = chunk_max,
92 elapsed_ms = chunk_ms,
93 "dense chunk inference complete"
94 );
95
96 let emb: ndarray::ArrayD<f32> = match model_variant {
99 ModelVariant::Fp32 => outputs["sentence_embedding"]
100 .try_extract_array::<f32>()
101 .map_err(ort_err)?
102 .to_owned(),
103 ModelVariant::Fp16 | ModelVariant::Int8 => {
104 let lhs = outputs["last_hidden_state"]
105 .try_extract_array::<f32>()
106 .map_err(ort_err)?;
107 lhs.index_axis(ndarray::Axis(1), 0).to_owned()
108 }
109 };
110
111 for (chunk_pos, &orig_idx) in chunk_indices.iter().enumerate() {
112 debug_assert!(chunk_pos < batch_len, "chunk_pos must be within batch");
113 let row = emb.index_axis(ndarray::Axis(0), chunk_pos);
114 let mut vec = row
115 .as_slice()
116 .expect("embedding should be contiguous")
117 .to_vec();
118 normalize_l2(&mut vec);
119 all_embeddings[orig_idx] = vec;
120 }
121 }
122
123 let stats = EmbedStats {
124 chunks: chunks.len(),
125 max_chunk_seq,
126 total_token_positions,
127 tokenize_ms,
128 inference_ms,
129 };
130
131 Ok((all_embeddings, stats))
132}