bge_m3_embedding_server/binpack.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Bin-packing algorithm that groups tokenized sequences into `session.run()`
16//! calls that each fit within the per-worker workspace budget.
17//!
18//! The central type is [`CostModel`], which captures the quadratic memory
19//! scaling of BGE-M3 attention and is used by `bin_pack` to partition an
20//! incoming batch into chunks that are safe to run in a single ORT call.
21
22/// Quadratic-aware workspace cost model for ONNX attention inference.
23///
24/// BGE-M3 uses multi-head attention whose intermediate tensor footprint scales
25/// as `O(batch * seq^2)` (attention score matrix) plus `O(batch * seq)`
26/// (FFN intermediates, projection matrices). The total peak workspace is
27/// approximately:
28///
29/// ```text
30/// peak ≈ a * (batch * seq) + b * (batch * seq^2)
31/// ```
32///
33/// where `a` (bytes/token-position) captures the FFN / projection contribution
34/// and `b` (bytes/token-position^2) captures the attention contribution.
35///
36/// At sequence length 512 attention is small relative to FFN, so a linear
37/// approximation works. At 8192, `b * N^2` dominates by ~16×, so using only
38/// `a` would under-budget by that same factor.
39///
40/// Coefficients are derived at startup by [`crate::probe`] or set
41/// conservatively from compile-time defaults when measurement is unavailable.
42#[derive(Clone, Copy, Debug)]
43#[cfg_attr(test, derive(PartialEq))]
44pub struct CostModel {
45 /// Bytes per token-position (linear term: FFN intermediates, projections).
46 pub a: f64,
47 /// Bytes per token-position-squared (quadratic term: attention scores).
48 pub b: f64,
49 /// Maximum workspace bytes available per worker for a single `session.run()` call.
50 pub max_workspace_bytes: usize,
51}
52
53impl CostModel {
54 /// Conservative static defaults calibrated so a `(16, 512)` chunk lands at
55 /// ~140 MB workspace — matching the old static budget at the previous default
56 /// `BGE_M3_ONNX_BATCH_SIZE = 16`, `MAX_SEQ_LENGTH = 512`.
57 ///
58 /// These are used when the probe cannot run (no ORT, no model, macOS without
59 /// cgroup support) or when `BGE_M3_DISABLE_AUTO_BUDGET` is set.
60 ///
61 /// Formula check: 16 KiB/token × 16 × 512 + 8 B/token² × 16 × 512²
62 /// = 16384 × 8192 + 8 × 16 × 262144
63 /// = 134 217 728 + 33 554 432
64 /// = 167 772 160 ≈ 160 MB per chunk (workers run sequentially inside one worker).
65 pub const CONSERVATIVE_A: f64 = 16_384.0; // 16 KiB per token-position
66 /// Conservative quadratic coefficient (bytes per token-position squared).
67 pub const CONSERVATIVE_B: f64 = 8.0; // 8 bytes per token-position^2
68
69 /// Default maximum workspace per worker when memory cannot be detected.
70 ///
71 /// 2 GiB is conservatively safe for the Fargate 28 GiB task with 7 workers
72 /// (`28 GB * 0.7 safety / 7 workers ≈ 2.8 GB`); we round down for headroom.
73 pub const DEFAULT_MAX_WORKSPACE: usize = 2 * 1024 * 1024 * 1024; // 2 GiB
74
75 /// Constructs a `CostModel` with conservative defaults and the given workspace ceiling.
76 #[must_use]
77 pub fn conservative(max_workspace_bytes: usize) -> Self {
78 Self {
79 a: Self::CONSERVATIVE_A,
80 b: Self::CONSERVATIVE_B,
81 max_workspace_bytes,
82 }
83 }
84
85 /// Estimated peak workspace (bytes) for a single `session.run()` call with
86 /// `count` texts and `max_seq` as the padded sequence length.
87 ///
88 /// Uses saturating arithmetic on `u128` to avoid overflow at large inputs.
89 //
90 // cast_precision_loss: n is u128, but realistic values (batch ≤ 256, seq ≤ 8192)
91 // keep n ≤ 2_097_152 — well within f64's 2^52 mantissa — so no bits are lost.
92 // cast_possible_truncation: f64 → u128 intentionally floors fractional bytes;
93 // this is a memory *budget estimate*, not an exact byte count.
94 // cast_sign_loss: a and b are validated positive at construction, so the
95 // products are always ≥ 0 before the cast.
96 #[must_use]
97 #[allow(
98 clippy::cast_precision_loss,
99 clippy::cast_possible_truncation,
100 clippy::cast_sign_loss
101 )]
102 pub fn chunk_cost(&self, count: usize, max_seq: usize) -> u128 {
103 let n = count as u128 * max_seq as u128;
104 let linear = (self.a * n as f64) as u128;
105 let quad = (self.b * n as f64 * max_seq as f64) as u128;
106 linear.saturating_add(quad)
107 }
108
109 /// Returns `true` if the chunk fits within the workspace budget.
110 #[must_use]
111 pub fn fits(&self, count: usize, max_seq: usize) -> bool {
112 self.chunk_cost(count, max_seq) <= self.max_workspace_bytes as u128
113 }
114}
115
116/// Length-sorted greedy bin-packer.
117///
118/// Partitions `seq_lengths` (indexed 0..n) into contiguous groups (chunks)
119/// where each chunk satisfies `cost_model.fits(chunk.len(), max_seq_in_chunk)`.
120///
121/// If a single text exceeds the budget on its own — which can happen when
122/// `max_workspace_bytes` is very small or the text is at `MAX_SEQ_LENGTH` and
123/// the budget is tighter than one text — it gets its own single-element chunk.
124/// The caller (ORT session) will either succeed or fail; we never silently
125/// truncate or discard inputs.
126///
127/// # Returns
128///
129/// `Vec<Vec<usize>>` where each inner `Vec` contains the **original indices**
130/// of texts in that chunk, sorted ascending by sequence length. The outer vec
131/// preserves the order chunks should be processed in. Callers scatter results
132/// back to the original positions using these indices.
133///
134/// # Complexity
135///
136/// `O(n log n)` sort + `O(n)` scan.
137pub(crate) fn bin_pack(seq_lengths: &[usize], cost_model: &CostModel) -> Vec<Vec<usize>> {
138 if seq_lengths.is_empty() {
139 return Vec::new();
140 }
141
142 // Sort indices by ascending sequence length so we can greedily pack
143 // short texts together. Long texts naturally form their own small chunks.
144 let mut sorted: Vec<usize> = (0..seq_lengths.len()).collect();
145 sorted.sort_unstable_by_key(|&i| seq_lengths[i]);
146
147 let mut chunks: Vec<Vec<usize>> = Vec::new();
148 let mut current: Vec<usize> = Vec::new();
149 let mut current_max_seq: usize = 0;
150
151 for idx in sorted {
152 let seq = seq_lengths[idx];
153 let new_max = current_max_seq.max(seq);
154 let new_count = current.len() + 1;
155
156 if current.is_empty() || cost_model.fits(new_count, new_max) {
157 // Adding this text keeps the chunk within budget.
158 current.push(idx);
159 current_max_seq = new_max;
160 } else {
161 // Flush the current chunk and start a new one.
162 tracing::debug!(
163 chunk_idx = chunks.len(),
164 batch = current.len(),
165 max_seq = current_max_seq,
166 estimated_workspace_mb =
167 cost_model.chunk_cost(current.len(), current_max_seq) / (1024 * 1024),
168 "bin_pack chunk decided"
169 );
170 chunks.push(std::mem::take(&mut current));
171 current.push(idx);
172 current_max_seq = seq;
173 }
174 }
175
176 if !current.is_empty() {
177 tracing::debug!(
178 chunk_idx = chunks.len(),
179 batch = current.len(),
180 max_seq = current_max_seq,
181 estimated_workspace_mb =
182 cost_model.chunk_cost(current.len(), current_max_seq) / (1024 * 1024),
183 "bin_pack chunk decided"
184 );
185 chunks.push(current);
186 }
187
188 chunks
189}
190
191#[cfg(test)]
192mod tests;