bge_m3_embedding_server/embedder/
model_files.rs1use std::path::{Path, PathBuf};
18
19use anyhow::Result;
20use tracing::info;
21
22use crate::config::ModelVariant;
23
24const REPO_ID: &str = "BAAI/bge-m3";
25const REPO_REVISION: &str = "5617a9f61b028005a4858fdac845db406aefb181";
29
30const XENOVA_REPO_ID: &str = "Xenova/bge-m3";
31const XENOVA_REPO_REVISION: &str = "4de13258303883538bd53b696b452bf8099f0858";
34
35pub(super) struct ModelFiles {
37 pub onnx_path: PathBuf,
39 pub tokenizer_path: PathBuf,
41}
42
43fn is_model_cached(cache_dir: &Path, repo_id: &str, revision: &str, onnx_filename: &str) -> bool {
54 let repo_dir = format!("models--{}", repo_id.replace('/', "--"));
55 cache_dir
56 .join(repo_dir)
57 .join("snapshots")
58 .join(revision)
59 .join(onnx_filename)
60 .exists()
61}
62
63pub(super) fn download_model_files(
69 cache_dir: &Path,
70 show_progress: bool,
71 variant: ModelVariant,
72) -> Result<ModelFiles> {
73 let (repo_id, repo_revision) = match variant {
74 ModelVariant::Fp32 => (REPO_ID, REPO_REVISION),
75 ModelVariant::Fp16 | ModelVariant::Int8 => (XENOVA_REPO_ID, XENOVA_REPO_REVISION),
76 };
77
78 let onnx_filename = match variant {
82 ModelVariant::Fp32 => "onnx/model.onnx",
83 ModelVariant::Fp16 => "onnx/model_fp16.onnx",
84 ModelVariant::Int8 => "onnx/model_int8.onnx",
85 };
86 let cached = is_model_cached(cache_dir, repo_id, repo_revision, onnx_filename);
87 if cached {
88 info!(
89 repo_id,
90 revision = repo_revision,
91 model_variant = %variant,
92 "Model files found in local cache — no download needed"
93 );
94 } else {
95 info!(
96 repo_id,
97 revision = repo_revision,
98 model_variant = %variant,
99 "Model files not in local cache — downloading from HuggingFace Hub"
100 );
101 }
102
103 let api = hf_hub::api::sync::ApiBuilder::new()
104 .with_cache_dir(cache_dir.to_path_buf())
105 .with_progress(show_progress)
106 .build()
107 .map_err(|e| anyhow::anyhow!("Failed to build hf-hub API: {e}"))?;
108
109 let repo = api.repo(hf_hub::Repo::with_revision(
110 repo_id.to_string(),
111 hf_hub::RepoType::Model,
112 repo_revision.to_string(),
113 ));
114
115 let onnx_path = match variant {
116 ModelVariant::Fp32 => {
117 let path = repo
118 .get("onnx/model.onnx")
119 .map_err(|e| anyhow::anyhow!("Failed to get onnx/model.onnx: {e}"))?;
120 repo.get("onnx/model.onnx_data")
121 .map_err(|e| anyhow::anyhow!("Failed to get onnx/model.onnx_data: {e}"))?;
122 repo.get("onnx/Constant_7_attr__value")
123 .map_err(|e| anyhow::anyhow!("Failed to get onnx/Constant_7_attr__value: {e}"))?;
124 path
125 }
126 ModelVariant::Fp16 => repo
127 .get("onnx/model_fp16.onnx")
128 .map_err(|e| anyhow::anyhow!("Failed to get onnx/model_fp16.onnx: {e}"))?,
129 ModelVariant::Int8 => repo
130 .get("onnx/model_int8.onnx")
131 .map_err(|e| anyhow::anyhow!("Failed to get onnx/model_int8.onnx: {e}"))?,
132 };
133
134 let tokenizer_path = repo
135 .get("tokenizer.json")
136 .map_err(|e| anyhow::anyhow!("Failed to get tokenizer.json: {e}"))?;
137
138 Ok(ModelFiles {
139 onnx_path,
140 tokenizer_path,
141 })
142}