bge_m3_embedding_server/embedder/
session.rs1use std::path::Path;
18
19use anyhow::Result;
20
21use super::error::ort_err;
22use super::model_files::download_model_files;
23use super::tokenize::load_tokenizer;
24use crate::config::ModelVariant;
25
26pub(super) fn execution_providers(cache_dir: &Path) -> Vec<ort::ep::ExecutionProviderDispatch> {
32 #[cfg(target_os = "macos")]
33 {
34 let coreml_cache = cache_dir.join("coreml");
35 let strategy = match std::env::var("BGE_M3_COREML_STRATEGY").ok().as_deref() {
36 Some("default") => ort::ep::coreml::SpecializationStrategy::Default,
37 _ => ort::ep::coreml::SpecializationStrategy::FastPrediction,
38 };
39 let builder = ort::ep::CoreML::default()
40 .with_model_format(ort::ep::coreml::ModelFormat::MLProgram)
41 .with_specialization_strategy(strategy)
42 .with_model_cache_dir(coreml_cache.display().to_string());
43 #[cfg(feature = "coreml-profile")]
44 let builder = builder.with_profile_compute_plan(true);
45 vec![builder.build()]
46 }
47 #[cfg(not(target_os = "macos"))]
48 {
49 let _ = cache_dir;
50 vec![]
51 }
52}
53
54pub(super) fn load_session(
62 model_path: &Path,
63 eps: Vec<ort::ep::ExecutionProviderDispatch>,
64 intra_threads: usize,
65) -> Result<ort::session::Session> {
66 let mut builder = ort::session::Session::builder().map_err(ort_err)?;
67 if !eps.is_empty() {
68 builder = builder.with_execution_providers(eps).map_err(ort_err)?;
69 }
70 let session = builder
71 .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)
72 .map_err(ort_err)?
73 .with_intra_threads(intra_threads.max(1))
74 .map_err(ort_err)?
75 .commit_from_file(model_path)
76 .map_err(ort_err)?;
77 Ok(session)
78}
79
80pub(super) fn load_models(
83 cache_dir: &Path,
84 show_download_progress: bool,
85 model_variant: ModelVariant,
86 max_seq_length: usize,
87 intra_threads: usize,
88) -> Result<(ort::session::Session, tokenizers::Tokenizer)> {
89 let files = download_model_files(cache_dir, show_download_progress, model_variant)?;
90 let tokenizer = load_tokenizer(&files.tokenizer_path, max_seq_length)?;
91 let eps = execution_providers(cache_dir);
92 let session = load_session(&files.onnx_path, eps, intra_threads)?;
93 Ok((session, tokenizer))
94}