Skip to main content

bge_m3_embedding_server/embedder/
session.rs

1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! ORT execution-provider configuration and session loading.
16
17use std::path::Path;
18
19use anyhow::Result;
20
21use super::error::ort_err;
22use super::model_files::download_model_files;
23use super::tokenize::load_tokenizer;
24use crate::config::ModelVariant;
25
26/// Returns the execution providers to use for this platform.
27///
28/// On macOS: uses the `CoreML` EP with `MLProgram` format and `FastPrediction`
29/// specialisation strategy (overridable via `BGE_M3_COREML_STRATEGY=default`).
30/// On all other platforms: returns an empty list, so ORT falls back to MLAS (CPU).
31pub(super) fn execution_providers(cache_dir: &Path) -> Vec<ort::ep::ExecutionProviderDispatch> {
32    #[cfg(target_os = "macos")]
33    {
34        let coreml_cache = cache_dir.join("coreml");
35        let strategy = match std::env::var("BGE_M3_COREML_STRATEGY").ok().as_deref() {
36            Some("default") => ort::ep::coreml::SpecializationStrategy::Default,
37            _ => ort::ep::coreml::SpecializationStrategy::FastPrediction,
38        };
39        let builder = ort::ep::CoreML::default()
40            .with_model_format(ort::ep::coreml::ModelFormat::MLProgram)
41            .with_specialization_strategy(strategy)
42            .with_model_cache_dir(coreml_cache.display().to_string());
43        #[cfg(feature = "coreml-profile")]
44        let builder = builder.with_profile_compute_plan(true);
45        vec![builder.build()]
46    }
47    #[cfg(not(target_os = "macos"))]
48    {
49        let _ = cache_dir;
50        vec![]
51    }
52}
53
54/// Builds an ORT session from the ONNX model file with the given execution providers.
55///
56/// `intra_threads` controls intra-op parallelism for matmul / attention kernels
57/// inside a single `session.run()` call. The default (`1`) keeps per-worker RSS
58/// predictable for the workspace probe; raise it to `floor(num_cpus / workers)`
59/// on under-utilized hosts to recover CPU headroom. See
60/// [`crate::config::Config::intra_threads`] for the operator-facing knob.
61pub(super) fn load_session(
62    model_path: &Path,
63    eps: Vec<ort::ep::ExecutionProviderDispatch>,
64    intra_threads: usize,
65) -> Result<ort::session::Session> {
66    let mut builder = ort::session::Session::builder().map_err(ort_err)?;
67    if !eps.is_empty() {
68        builder = builder.with_execution_providers(eps).map_err(ort_err)?;
69    }
70    let session = builder
71        .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)
72        .map_err(ort_err)?
73        .with_intra_threads(intra_threads.max(1))
74        .map_err(ort_err)?
75        .commit_from_file(model_path)
76        .map_err(ort_err)?;
77    Ok(session)
78}
79
80/// Downloads (if not already cached) and loads both the ORT session and the
81/// tokenizer for the given model variant, returning them as a pair.
82pub(super) fn load_models(
83    cache_dir: &Path,
84    show_download_progress: bool,
85    model_variant: ModelVariant,
86    max_seq_length: usize,
87    intra_threads: usize,
88) -> Result<(ort::session::Session, tokenizers::Tokenizer)> {
89    let files = download_model_files(cache_dir, show_download_progress, model_variant)?;
90    let tokenizer = load_tokenizer(&files.tokenizer_path, max_seq_length)?;
91    let eps = execution_providers(cache_dir);
92    let session = load_session(&files.onnx_path, eps, intra_threads)?;
93    Ok((session, tokenizer))
94}