bge_m3_embedding_server/embedder/
tokenize.rs1use std::path::Path;
18
19use anyhow::Result;
20
21pub(super) fn load_tokenizer(
24 tokenizer_path: &Path,
25 max_seq_length: usize,
26) -> Result<tokenizers::Tokenizer> {
27 let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
28 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
29
30 tokenizer
31 .with_truncation(Some(tokenizers::TruncationParams {
32 max_length: max_seq_length,
33 strategy: tokenizers::TruncationStrategy::LongestFirst,
34 ..Default::default()
35 }))
36 .map_err(|e| anyhow::anyhow!("Failed to set truncation: {e}"))?;
37
38 tokenizer.with_padding(None);
41
42 Ok(tokenizer)
43}
44
45pub(super) fn tokenize_no_pad(
48 tokenizer: &tokenizers::Tokenizer,
49 texts: &[String],
50) -> Result<Vec<tokenizers::Encoding>> {
51 let str_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
52 let encodings = tokenizer
53 .encode_batch_fast(str_refs, true)
54 .map_err(|e| anyhow::anyhow!("Tokenization failed: {e}"))?;
55 Ok(encodings)
56}
57
58#[allow(clippy::cast_possible_truncation)]
64pub(super) fn build_chunk_arrays(
65 all_encodings: &[tokenizers::Encoding],
66 indices: &[usize],
67 pad_to: usize,
68) -> Result<(ndarray::Array2<i64>, ndarray::Array2<i64>)> {
69 let batch = indices.len();
70 let mut ids_flat: Vec<i64> = Vec::with_capacity(batch * pad_to);
71 let mut mask_flat: Vec<i64> = Vec::with_capacity(batch * pad_to);
72
73 for &idx in indices {
74 let enc = &all_encodings[idx];
75 let token_ids = enc.get_ids();
76 let attn_mask = enc.get_attention_mask();
77 let seq_len = token_ids.len();
78
79 ids_flat.extend(token_ids.iter().map(|&id| i64::from(id)));
81 mask_flat.extend(attn_mask.iter().map(|&m| i64::from(m)));
82
83 let pad = pad_to.saturating_sub(seq_len);
85 ids_flat.extend(std::iter::repeat_n(1i64, pad));
86 mask_flat.extend(std::iter::repeat_n(0i64, pad));
87 }
88
89 let ids_array = ndarray::Array2::from_shape_vec((batch, pad_to), ids_flat)?;
90 let mask_array = ndarray::Array2::from_shape_vec((batch, pad_to), mask_flat)?;
91
92 Ok((ids_array, mask_array))
93}