Skip to main content

bge_m3_embedding_server/
binpack.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Bin-packing algorithm that groups tokenized sequences into `session.run()`
16//! calls that each fit within the per-worker workspace budget.
17//!
18//! The central type is [`CostModel`], which captures the quadratic memory
19//! scaling of BGE-M3 attention and is used by `bin_pack` to partition an
20//! incoming batch into chunks that are safe to run in a single ORT call.
21
22/// Quadratic-aware workspace cost model for ONNX attention inference.
23///
24/// BGE-M3 uses multi-head attention whose intermediate tensor footprint scales
25/// as `O(batch * seq^2)` (attention score matrix) plus `O(batch * seq)`
26/// (FFN intermediates, projection matrices). The total peak workspace is
27/// approximately:
28///
29/// ```text
30/// peak ≈ a * (batch * seq) + b * (batch * seq^2)
31/// ```
32///
33/// where `a` (bytes/token-position) captures the FFN / projection contribution
34/// and `b` (bytes/token-position^2) captures the attention contribution.
35///
36/// At sequence length 512 attention is small relative to FFN, so a linear
37/// approximation works. At 8192, `b * N^2` dominates by ~16×, so using only
38/// `a` would under-budget by that same factor.
39///
40/// Coefficients are derived at startup by [`crate::probe`] or set
41/// conservatively from compile-time defaults when measurement is unavailable.
42#[derive(Clone, Copy, Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct CostModel {
45    /// Bytes per token-position (linear term: FFN intermediates, projections).
46    pub a: f64,
47    /// Bytes per token-position-squared (quadratic term: attention scores).
48    pub b: f64,
49    /// Maximum workspace bytes available per worker for a single `session.run()` call.
50    pub max_workspace_bytes: usize,
51}
52
53impl CostModel {
54    /// Conservative static defaults calibrated so a `(16, 512)` chunk lands at
55    /// ~140 MB workspace — matching the old static budget at the previous default
56    /// `BGE_M3_ONNX_BATCH_SIZE = 16`, `MAX_SEQ_LENGTH = 512`.
57    ///
58    /// These are used when the probe cannot run (no ORT, no model, macOS without
59    /// cgroup support) or when `BGE_M3_DISABLE_AUTO_BUDGET` is set.
60    ///
61    /// Formula check: 16 KiB/token × 16 × 512 + 8 B/token² × 16 × 512²
62    ///   = 16384 × 8192 + 8 × 16 × 262144
63    ///   = 134 217 728 + 33 554 432
64    ///   = 167 772 160 ≈ 160 MB per chunk (workers run sequentially inside one worker).
65    pub const CONSERVATIVE_A: f64 = 16_384.0; // 16 KiB per token-position
66    /// Conservative quadratic coefficient (bytes per token-position squared).
67    pub const CONSERVATIVE_B: f64 = 8.0; // 8 bytes per token-position^2
68
69    /// Default maximum workspace per worker when memory cannot be detected.
70    ///
71    /// 2 GiB is conservatively safe for the Fargate 28 GiB task with 7 workers
72    /// (`28 GB * 0.7 safety / 7 workers ≈ 2.8 GB`); we round down for headroom.
73    pub const DEFAULT_MAX_WORKSPACE: usize = 2 * 1024 * 1024 * 1024; // 2 GiB
74
75    /// Constructs a `CostModel` with conservative defaults and the given workspace ceiling.
76    #[must_use]
77    pub fn conservative(max_workspace_bytes: usize) -> Self {
78        Self {
79            a: Self::CONSERVATIVE_A,
80            b: Self::CONSERVATIVE_B,
81            max_workspace_bytes,
82        }
83    }
84
85    /// Estimated peak workspace (bytes) for a single `session.run()` call with
86    /// `count` texts and `max_seq` as the padded sequence length.
87    ///
88    /// Uses saturating arithmetic on `u128` to avoid overflow at large inputs.
89    //
90    // cast_precision_loss: n is u128, but realistic values (batch ≤ 256, seq ≤ 8192)
91    //   keep n ≤ 2_097_152 — well within f64's 2^52 mantissa — so no bits are lost.
92    // cast_possible_truncation: f64 → u128 intentionally floors fractional bytes;
93    //   this is a memory *budget estimate*, not an exact byte count.
94    // cast_sign_loss: a and b are validated positive at construction, so the
95    //   products are always ≥ 0 before the cast.
96    #[must_use]
97    #[allow(
98        clippy::cast_precision_loss,
99        clippy::cast_possible_truncation,
100        clippy::cast_sign_loss
101    )]
102    pub fn chunk_cost(&self, count: usize, max_seq: usize) -> u128 {
103        let n = count as u128 * max_seq as u128;
104        let linear = (self.a * n as f64) as u128;
105        let quad = (self.b * n as f64 * max_seq as f64) as u128;
106        linear.saturating_add(quad)
107    }
108
109    /// Returns `true` if the chunk fits within the workspace budget.
110    #[must_use]
111    pub fn fits(&self, count: usize, max_seq: usize) -> bool {
112        self.chunk_cost(count, max_seq) <= self.max_workspace_bytes as u128
113    }
114}
115
116/// Length-sorted greedy bin-packer.
117///
118/// Partitions `seq_lengths` (indexed 0..n) into contiguous groups (chunks)
119/// where each chunk satisfies `cost_model.fits(chunk.len(), max_seq_in_chunk)`.
120///
121/// If a single text exceeds the budget on its own — which can happen when
122/// `max_workspace_bytes` is very small or the text is at `MAX_SEQ_LENGTH` and
123/// the budget is tighter than one text — it gets its own single-element chunk.
124/// The caller (ORT session) will either succeed or fail; we never silently
125/// truncate or discard inputs.
126///
127/// # Returns
128///
129/// `Vec<Vec<usize>>` where each inner `Vec` contains the **original indices**
130/// of texts in that chunk, sorted ascending by sequence length. The outer vec
131/// preserves the order chunks should be processed in. Callers scatter results
132/// back to the original positions using these indices.
133///
134/// # Complexity
135///
136/// `O(n log n)` sort + `O(n)` scan.
137pub(crate) fn bin_pack(seq_lengths: &[usize], cost_model: &CostModel) -> Vec<Vec<usize>> {
138    if seq_lengths.is_empty() {
139        return Vec::new();
140    }
141
142    // Sort indices by ascending sequence length so we can greedily pack
143    // short texts together. Long texts naturally form their own small chunks.
144    let mut sorted: Vec<usize> = (0..seq_lengths.len()).collect();
145    sorted.sort_unstable_by_key(|&i| seq_lengths[i]);
146
147    let mut chunks: Vec<Vec<usize>> = Vec::new();
148    let mut current: Vec<usize> = Vec::new();
149    let mut current_max_seq: usize = 0;
150
151    for idx in sorted {
152        let seq = seq_lengths[idx];
153        let new_max = current_max_seq.max(seq);
154        let new_count = current.len() + 1;
155
156        if current.is_empty() || cost_model.fits(new_count, new_max) {
157            // Adding this text keeps the chunk within budget.
158            current.push(idx);
159            current_max_seq = new_max;
160        } else {
161            // Flush the current chunk and start a new one.
162            tracing::debug!(
163                chunk_idx = chunks.len(),
164                batch = current.len(),
165                max_seq = current_max_seq,
166                estimated_workspace_mb =
167                    cost_model.chunk_cost(current.len(), current_max_seq) / (1024 * 1024),
168                "bin_pack chunk decided"
169            );
170            chunks.push(std::mem::take(&mut current));
171            current.push(idx);
172            current_max_seq = seq;
173        }
174    }
175
176    if !current.is_empty() {
177        tracing::debug!(
178            chunk_idx = chunks.len(),
179            batch = current.len(),
180            max_seq = current_max_seq,
181            estimated_workspace_mb =
182                cost_model.chunk_cost(current.len(), current_max_seq) / (1024 * 1024),
183            "bin_pack chunk decided"
184        );
185        chunks.push(current);
186    }
187
188    chunks
189}
190
191#[cfg(test)]
192mod tests;