Skip to main content

bge_m3_embedding_server/embedder/
tokenize.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Tokenizer load + no-pad tokenization + chunk-array build helpers.
16
17use std::path::Path;
18
19use anyhow::Result;
20
21/// Loads and configures the BGE-M3 tokenizer with truncation at `max_seq_length`
22/// but **no** padding. Padding is applied per-chunk in [`build_chunk_arrays`].
23pub(super) fn load_tokenizer(
24    tokenizer_path: &Path,
25    max_seq_length: usize,
26) -> Result<tokenizers::Tokenizer> {
27    let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
28        .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {e}"))?;
29
30    tokenizer
31        .with_truncation(Some(tokenizers::TruncationParams {
32            max_length: max_seq_length,
33            strategy: tokenizers::TruncationStrategy::LongestFirst,
34            ..Default::default()
35        }))
36        .map_err(|e| anyhow::anyhow!("Failed to set truncation: {e}"))?;
37
38    // No BatchLongest padding here — we pad manually in build_chunk_arrays
39    // so each chunk only pads to its own longest sequence.
40    tokenizer.with_padding(None);
41
42    Ok(tokenizer)
43}
44
45/// Tokenizes `texts` without applying any padding. Returns one `Encoding` per text,
46/// each truncated to the tokenizer's configured `max_length`.
47pub(super) fn tokenize_no_pad(
48    tokenizer: &tokenizers::Tokenizer,
49    texts: &[String],
50) -> Result<Vec<tokenizers::Encoding>> {
51    let str_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
52    let encodings = tokenizer
53        .encode_batch_fast(str_refs, true)
54        .map_err(|e| anyhow::anyhow!("Tokenization failed: {e}"))?;
55    Ok(encodings)
56}
57
58/// Builds `input_ids` and `attention_mask` arrays for a single chunk.
59///
60/// `indices` selects which encodings from `all_encodings` belong to this chunk.
61/// `pad_to` is the chunk-local maximum sequence length; all sequences are
62/// right-padded with `pad_id = 1` (XLM-RoBERTa `<pad>` token).
63#[allow(clippy::cast_possible_truncation)]
64pub(super) fn build_chunk_arrays(
65    all_encodings: &[tokenizers::Encoding],
66    indices: &[usize],
67    pad_to: usize,
68) -> Result<(ndarray::Array2<i64>, ndarray::Array2<i64>)> {
69    let batch = indices.len();
70    let mut ids_flat: Vec<i64> = Vec::with_capacity(batch * pad_to);
71    let mut mask_flat: Vec<i64> = Vec::with_capacity(batch * pad_to);
72
73    for &idx in indices {
74        let enc = &all_encodings[idx];
75        let token_ids = enc.get_ids();
76        let attn_mask = enc.get_attention_mask();
77        let seq_len = token_ids.len();
78
79        // Copy token ids and mask
80        ids_flat.extend(token_ids.iter().map(|&id| i64::from(id)));
81        mask_flat.extend(attn_mask.iter().map(|&m| i64::from(m)));
82
83        // Right-pad with pad_id=1 / mask=0
84        let pad = pad_to.saturating_sub(seq_len);
85        ids_flat.extend(std::iter::repeat_n(1i64, pad));
86        mask_flat.extend(std::iter::repeat_n(0i64, pad));
87    }
88
89    let ids_array = ndarray::Array2::from_shape_vec((batch, pad_to), ids_flat)?;
90    let mask_array = ndarray::Array2::from_shape_vec((batch, pad_to), mask_flat)?;
91
92    Ok((ids_array, mask_array))
93}