bge_m3_embedding_server/bootstrap/router.rs
1// Copyright (c) 2026 J. Patrick Fulton
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Axum router construction with the request-id, tracing, and body-limit
16//! layers attached.
17
18use std::sync::Arc;
19
20use axum::extract::DefaultBodyLimit;
21use axum::http::HeaderValue;
22use axum::{routing::get, routing::post, Router};
23use tower_http::request_id::{
24 MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer,
25};
26use tower_http::trace::{DefaultOnFailure, DefaultOnResponse, MakeSpan, TraceLayer};
27use tracing::Level;
28
29use crate::handler;
30use crate::state::AppState;
31
32/// [`MakeRequestId`] implementation that assigns a random UUID v4 to every
33/// incoming request, attached as the `x-request-id` header.
34#[derive(Clone, Default)]
35pub(super) struct UuidRequestId;
36
37impl MakeRequestId for UuidRequestId {
38 fn make_request_id<B>(&mut self, _request: &axum::http::Request<B>) -> Option<RequestId> {
39 let id = uuid::Uuid::new_v4().to_string();
40 HeaderValue::from_str(&id).ok().map(RequestId::new)
41 }
42}
43
44/// Selects the tracing level for HTTP spans based on path.
45///
46/// `/health` and `/v1/models` are polled frequently by load balancers and the
47/// Docker `HEALTHCHECK`. Logging them at DEBUG rather than INFO keeps
48/// `CloudWatch` free of ~8,640 health-check records per container per day.
49#[derive(Clone)]
50pub(super) struct RouteAwareSpan;
51
52impl<B> MakeSpan<B> for RouteAwareSpan {
53 fn make_span(&mut self, request: &axum::http::Request<B>) -> tracing::Span {
54 let path = request.uri().path();
55 let is_noisy = matches!(path, "/health" | "/v1/models");
56 let method = request.method().as_str();
57 if is_noisy {
58 tracing::debug_span!(
59 "http_request",
60 method = method,
61 uri = %request.uri(),
62 version = ?request.version(),
63 )
64 } else {
65 tracing::info_span!(
66 "http_request",
67 method = method,
68 uri = %request.uri(),
69 version = ?request.version(),
70 )
71 }
72 }
73}
74
75/// Builds the Axum [`Router`] with all embedding, health, and fleet-discovery
76/// routes, a 2 MiB body limit, request-id propagation, and structured tracing.
77pub fn build_router(state: Arc<AppState>) -> Router {
78 Router::new()
79 .route("/v1/embeddings", post(handler::dense_embeddings))
80 .route("/v1/sparse-embeddings", post(handler::sparse_embeddings))
81 // The colon in `/v1/embeddings:both` is a valid `pchar` per RFC 3986
82 // ยง3.3, but some HTTP clients (and URI builders) percent-encode it
83 // anyway when it appears in a path segment. The router is built on
84 // `matchit`, which matches the raw URI path byte-for-byte, so the
85 // encoded forms are registered as alias routes pointing at the same
86 // handler. RFC 3986 percent-encoding is case-insensitive, hence both
87 // upper- and lowercase aliases.
88 .route("/v1/embeddings:both", post(handler::both_embeddings))
89 .route("/v1/embeddings%3Aboth", post(handler::both_embeddings))
90 .route("/v1/embeddings%3aboth", post(handler::both_embeddings))
91 .route("/v1/models", get(handler::models))
92 .route("/health", get(handler::health))
93 .layer(DefaultBodyLimit::max(2_097_152))
94 .layer(PropagateRequestIdLayer::x_request_id())
95 .layer(
96 TraceLayer::new_for_http()
97 .make_span_with(RouteAwareSpan)
98 .on_response(
99 DefaultOnResponse::new()
100 .level(Level::INFO)
101 .latency_unit(tower_http::LatencyUnit::Millis),
102 )
103 .on_failure(DefaultOnFailure::new().level(Level::ERROR)),
104 )
105 .layer(SetRequestIdLayer::x_request_id(UuidRequestId))
106 .with_state(state)
107}