//! Routing between [`Service`]s. use self::future::{BoxRouteFuture, EmptyRouterFuture, RouteFuture}; use crate::{ body::{box_body, BoxBody}, buffer::MpscBuffer, extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, response::IntoResponse, service::HandleErrorFromRouter, util::ByteStr, }; use async_trait::async_trait; use bytes::Bytes; use http::{Method, Request, Response, StatusCode, Uri}; use regex::Regex; use std::{ borrow::Cow, convert::Infallible, fmt, marker::PhantomData, sync::Arc, task::{Context, Poll}, }; use tower::{ util::{BoxService, ServiceExt}, BoxError, Layer, Service, ServiceBuilder, }; use tower_http::map_response_body::MapResponseBodyLayer; pub mod future; /// A filter that matches one or more HTTP methods. #[derive(Debug, Copy, Clone)] pub enum MethodFilter { /// Match any method. Any, /// Match `CONNECT` requests. Connect, /// Match `DELETE` requests. Delete, /// Match `GET` requests. Get, /// Match `HEAD` requests. Head, /// Match `OPTIONS` requests. Options, /// Match `PATCH` requests. Patch, /// Match `POST` requests. Post, /// Match `PUT` requests. Put, /// Match `TRACE` requests. Trace, } impl MethodFilter { #[allow(clippy::match_like_matches_macro)] pub(crate) fn matches(self, method: &Method) -> bool { match (self, method) { (MethodFilter::Any, _) | (MethodFilter::Connect, &Method::CONNECT) | (MethodFilter::Delete, &Method::DELETE) | (MethodFilter::Get, &Method::GET) | (MethodFilter::Head, &Method::HEAD) | (MethodFilter::Options, &Method::OPTIONS) | (MethodFilter::Patch, &Method::PATCH) | (MethodFilter::Post, &Method::POST) | (MethodFilter::Put, &Method::PUT) | (MethodFilter::Trace, &Method::TRACE) => true, _ => false, } } } /// A route that sends requests to one of two [`Service`]s depending on the /// path. /// /// Created with [`route`](crate::route). See that function for more details. #[derive(Debug, Clone)] pub struct Route { pub(crate) pattern: PathPattern, pub(crate) svc: S, pub(crate) fallback: F, } /// Trait for building routers. #[async_trait] pub trait RoutingDsl: crate::sealed::Sealed + Sized { /// Add another route to the router. /// /// # Example /// /// ```rust /// use axum::prelude::*; /// /// async fn first_handler() { /* ... */ } /// /// async fn second_handler() { /* ... */ } /// /// async fn third_handler() { /* ... */ } /// /// // `GET /` goes to `first_handler`, `POST /` goes to `second_handler`, /// // and `GET /foo` goes to third_handler. /// let app = route("/", get(first_handler).post(second_handler)) /// .route("/foo", get(third_handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` fn route(self, description: &str, svc: T) -> Route where T: Service> + Clone, { Route { pattern: PathPattern::new(description), svc, fallback: self, } } /// Nest another service inside this router at the given path. /// /// See [`nest`] for more details. fn nest(self, description: &str, svc: T) -> Nested where T: Service> + Clone, { Nested { pattern: PathPattern::new(description), svc, fallback: self, } } /// Create a boxed route trait object. /// /// This makes it easier to name the types of routers to, for example, /// return them from functions: /// /// ```rust /// use axum::{routing::BoxRoute, body::Body, prelude::*}; /// /// async fn first_handler() { /* ... */ } /// /// async fn second_handler() { /* ... */ } /// /// async fn third_handler() { /* ... */ } /// /// fn app() -> BoxRoute { /// route("/", get(first_handler).post(second_handler)) /// .route("/foo", get(third_handler)) /// .boxed() /// } /// ``` /// /// It also helps with compile times when you have a very large number of /// routes. fn boxed(self) -> BoxRoute where Self: Service, Response = Response> + Send + 'static, >>::Error: Into + Send + Sync, >>::Future: Send, ReqBody: http_body::Body + Send + Sync + 'static, ReqBody::Error: Into + Send + Sync + 'static, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync + 'static, { ServiceBuilder::new() .layer_fn(BoxRoute) .layer_fn(MpscBuffer::new) .layer(BoxService::layer()) .layer(MapResponseBodyLayer::new(box_body)) .service(self) } /// Apply a [`tower::Layer`] to the router. /// /// All requests to the router will be processed by the layer's /// corresponding middleware. /// /// This can be used to add additional processing to a request for a group /// of routes. /// /// Note this differs from [`handler::Layered`](crate::handler::Layered) /// which adds a middleware to a single handler. /// /// # Example /// /// Adding the [`tower::limit::ConcurrencyLimit`] middleware to a group of /// routes can be done like so: /// /// ```rust /// use axum::prelude::*; /// use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit}; /// /// async fn first_handler() { /* ... */ } /// /// async fn second_handler() { /* ... */ } /// /// async fn third_handler() { /* ... */ } /// /// // All requests to `handler` and `other_handler` will be sent through /// // `ConcurrencyLimit` /// let app = route("/", get(first_handler)) /// .route("/foo", get(second_handler)) /// .layer(ConcurrencyLimitLayer::new(64)) /// // Request to `GET /bar` will go directly to `third_handler` and /// // wont be sent through `ConcurrencyLimit` /// .route("/bar", get(third_handler)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// This is commonly used to add middleware such as tracing/logging to your /// entire app: /// /// ```rust /// use axum::prelude::*; /// use tower_http::trace::TraceLayer; /// /// async fn first_handler() { /* ... */ } /// /// async fn second_handler() { /* ... */ } /// /// async fn third_handler() { /* ... */ } /// /// let app = route("/", get(first_handler)) /// .route("/foo", get(second_handler)) /// .route("/bar", get(third_handler)) /// .layer(TraceLayer::new_for_http()); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` fn layer(self, layer: L) -> Layered where L: Layer, { Layered::new(layer.layer(self)) } /// Convert this router into a [`MakeService`], that is a [`Service`] who's /// response is another service. /// /// This is useful when running your application with hyper's /// [`Server`](hyper::server::Server): /// /// ``` /// use axum::prelude::*; /// /// let app = route("/", get(|| async { "Hi!" })); /// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve(app.into_make_service()) /// .await /// .expect("server failed"); /// # }; /// ``` /// /// [`MakeService`]: tower::make::MakeService fn into_make_service(self) -> tower::make::Shared where Self: Clone, { tower::make::Shared::new(self) } /// Convert this router into a [`MakeService`], that will store `C`'s /// associated `ConnectInfo` in a request extension such that [`ConnectInfo`] /// can extract it. /// /// This enables extracting things like the client's remote address. /// /// Extracting [`std::net::SocketAddr`] is supported out of the box: /// /// ``` /// use axum::{prelude::*, extract::ConnectInfo}; /// use std::net::SocketAddr; /// /// let app = route("/", get(handler)); /// /// async fn handler(ConnectInfo(addr): ConnectInfo) -> String { /// format!("Hello {}", addr) /// } /// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve( /// app.into_make_service_with_connect_info::() /// ) /// .await /// .expect("server failed"); /// # }; /// ``` /// /// You can implement custom a [`Connected`] like so: /// /// ``` /// use axum::{ /// prelude::*, /// extract::connect_info::{ConnectInfo, Connected}, /// }; /// use hyper::server::conn::AddrStream; /// /// let app = route("/", get(handler)); /// /// async fn handler( /// ConnectInfo(my_connect_info): ConnectInfo, /// ) -> String { /// format!("Hello {:?}", my_connect_info) /// } /// /// #[derive(Clone, Debug)] /// struct MyConnectInfo { /// // ... /// } /// /// impl Connected<&AddrStream> for MyConnectInfo { /// type ConnectInfo = MyConnectInfo; /// /// fn connect_info(target: &AddrStream) -> Self::ConnectInfo { /// MyConnectInfo { /// // ... /// } /// } /// } /// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve( /// app.into_make_service_with_connect_info::() /// ) /// .await /// .expect("server failed"); /// # }; /// ``` /// /// See the [unix domain socket example][uds] for an example of how to use /// this to collect UDS connection info. /// /// [`MakeService`]: tower::make::MakeService /// [`Connected`]: crate::extract::connect_info::Connected /// [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo /// [uds]: https://github.com/tokio-rs/axum/blob/main/examples/unix_domain_socket.rs fn into_make_service_with_connect_info( self, ) -> IntoMakeServiceWithConnectInfo where Self: Clone, C: Connected, { IntoMakeServiceWithConnectInfo::new(self) } } impl RoutingDsl for Route {} impl crate::sealed::Sealed for Route {} impl Service> for Route where S: Service, Response = Response> + Clone, F: Service, Response = Response, Error = S::Error> + Clone, { type Response = Response; type Error = S::Error; type Future = RouteFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { if let Some(captures) = self.pattern.full_match(req.uri().path()) { insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) } else { let fut = self.fallback.clone().oneshot(req); RouteFuture::b(fut) } } } #[derive(Debug)] pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>); fn insert_url_params(req: &mut Request, params: Vec<(String, String)>) { let params = params .into_iter() .map(|(k, v)| (ByteStr::new(k), ByteStr::new(v))); if let Some(current) = req.extensions_mut().get_mut::>() { let mut current = current.take().unwrap(); current.0.extend(params); req.extensions_mut().insert(Some(current)); } else { req.extensions_mut() .insert(Some(UrlParams(params.collect()))); } } /// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed` /// to all requests. /// /// This is used as the bottom service in a router stack. You shouldn't have to /// use to manually. pub struct EmptyRouter { status: StatusCode, _marker: PhantomData E>, } impl EmptyRouter { pub(crate) fn not_found() -> Self { Self { status: StatusCode::NOT_FOUND, _marker: PhantomData, } } pub(crate) fn method_not_allowed() -> Self { Self { status: StatusCode::METHOD_NOT_ALLOWED, _marker: PhantomData, } } } impl Clone for EmptyRouter { fn clone(&self) -> Self { Self { status: self.status, _marker: PhantomData, } } } impl fmt::Debug for EmptyRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("EmptyRouter").finish() } } impl RoutingDsl for EmptyRouter {} impl crate::sealed::Sealed for EmptyRouter {} impl Service> for EmptyRouter { type Response = Response; type Error = E; type Future = EmptyRouterFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, _req: Request) -> Self::Future { let mut res = Response::new(crate::body::empty()); *res.status_mut() = self.status; EmptyRouterFuture { future: futures_util::future::ok(res), } } } #[derive(Debug, Clone)] pub(crate) struct PathPattern(Arc); #[derive(Debug)] struct Inner { full_path_regex: Regex, capture_group_names: Box<[Bytes]>, } impl PathPattern { pub(crate) fn new(pattern: &str) -> Self { assert!( pattern.starts_with('/'), "Route description must start with a `/`" ); let mut capture_group_names = Vec::new(); let pattern = pattern .split('/') .map(|part| { if let Some(key) = part.strip_prefix(':') { capture_group_names.push(Bytes::copy_from_slice(key.as_bytes())); Cow::Owned(format!("(?P<{}>[^/]*)", key)) } else { Cow::Borrowed(part) } }) .collect::>() .join("/"); let full_path_regex = Regex::new(&format!("^{}", pattern)).expect("invalid regex generated from route"); Self(Arc::new(Inner { full_path_regex, capture_group_names: capture_group_names.into(), })) } pub(crate) fn full_match(&self, path: &str) -> Option { self.do_match(path).and_then(|match_| { if match_.full_match { Some(match_.captures) } else { None } }) } pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> { self.do_match(path) .map(|match_| (match_.matched, match_.captures)) } fn do_match<'a>(&self, path: &'a str) -> Option> { self.0.full_path_regex.captures(path).map(|captures| { let matched = captures.get(0).unwrap(); let full_match = matched.as_str() == path; let captures = self .0 .capture_group_names .iter() .map(|bytes| { std::str::from_utf8(bytes) .expect("bytes were created from str so is valid utf-8") }) .filter_map(|name| captures.name(name).map(|value| (name, value.as_str()))) .map(|(key, value)| (key.to_string(), value.to_string())) .collect::>(); Match { captures, full_match, matched: matched.as_str(), } }) } } struct Match<'a> { captures: Captures, // true if regex matched whole path, false if it only matched a prefix full_match: bool, matched: &'a str, } type Captures = Vec<(String, String)>; /// A boxed route trait object. /// /// See [`RoutingDsl::boxed`] for more details. pub struct BoxRoute( MpscBuffer, Response, E>, Request>, ); impl Clone for BoxRoute { fn clone(&self) -> Self { Self(self.0.clone()) } } impl fmt::Debug for BoxRoute { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxRoute").finish() } } impl RoutingDsl for BoxRoute {} impl crate::sealed::Sealed for BoxRoute {} impl Service> for BoxRoute where E: Into, { type Response = Response; type Error = E; type Future = BoxRouteFuture; #[inline] fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } #[inline] fn call(&mut self, req: Request) -> Self::Future { BoxRouteFuture { inner: self.0.clone().oneshot(req), } } } /// A [`Service`] created from a router by applying a Tower middleware. /// /// Created with [`RoutingDsl::layer`]. See that method for more details. pub struct Layered { inner: S, } impl Layered { fn new(inner: S) -> Self { Self { inner } } } impl Clone for Layered where S: Clone, { fn clone(&self) -> Self { Self::new(self.inner.clone()) } } impl fmt::Debug for Layered where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Layered") .field("inner", &self.inner) .finish() } } impl RoutingDsl for Layered {} impl crate::sealed::Sealed for Layered {} impl Layered { /// Create a new [`Layered`] service where errors will be handled using the /// given closure. /// /// This is used to convert errors to responses rather than simply /// terminating the connection. /// /// That can be done using `handle_error` like so: /// /// ```rust /// use axum::prelude::*; /// use http::StatusCode; /// use tower::{BoxError, timeout::TimeoutLayer}; /// use std::{convert::Infallible, time::Duration}; /// /// async fn handler() { /* ... */ } /// /// // `Timeout` will fail with `BoxError` if the timeout elapses... /// let layered_app = route("/", get(handler)) /// .layer(TimeoutLayer::new(Duration::from_secs(30))); /// /// // ...so we should handle that error /// let with_errors_handled = layered_app.handle_error(|error: BoxError| { /// if error.is::() { /// Ok::<_, Infallible>(( /// StatusCode::REQUEST_TIMEOUT, /// "request took too long".to_string(), /// )) /// } else { /// Ok::<_, Infallible>(( /// StatusCode::INTERNAL_SERVER_ERROR, /// format!("Unhandled internal error: {}", error), /// )) /// } /// }); /// # async { /// # axum::Server::bind(&"".parse().unwrap()) /// # .serve(with_errors_handled.into_make_service()) /// # .await /// # .unwrap(); /// # }; /// ``` /// /// The closure must return `Result` where `T` implements [`IntoResponse`]. /// /// You can also return `Err(_)` if you don't wish to handle the error: /// /// ```rust /// use axum::prelude::*; /// use http::StatusCode; /// use tower::{BoxError, timeout::TimeoutLayer}; /// use std::time::Duration; /// /// async fn handler() { /* ... */ } /// /// let layered_app = route("/", get(handler)) /// .layer(TimeoutLayer::new(Duration::from_secs(30))); /// /// let with_errors_handled = layered_app.handle_error(|error: BoxError| { /// if error.is::() { /// Ok(( /// StatusCode::REQUEST_TIMEOUT, /// "request took too long".to_string(), /// )) /// } else { /// // keep the error as is /// Err(error) /// } /// }); /// # async { /// # axum::Server::bind(&"".parse().unwrap()) /// # .serve(with_errors_handled.into_make_service()) /// # .await /// # .unwrap(); /// # }; /// ``` pub fn handle_error( self, f: F, ) -> crate::service::HandleError where S: Service, Response = Response> + Clone, F: FnOnce(S::Error) -> Result, Res: IntoResponse, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync + 'static, { crate::service::HandleError::new(self.inner, f) } } impl Service for Layered where S: Service, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } #[inline] fn call(&mut self, req: R) -> Self::Future { self.inner.call(req) } } /// Nest a group of routes (or a [`Service`]) at some path. /// /// This allows you to break your application into smaller pieces and compose /// them together. This will strip the matching prefix from the URL so the /// nested route will only see the part of URL: /// /// ``` /// use axum::{routing::nest, prelude::*}; /// use http::Uri; /// /// async fn users_get(uri: Uri) { /// // `users_get` doesn't see the whole URL. `nest` will strip the matching /// // `/api` prefix. /// assert_eq!(uri.path(), "/users"); /// } /// /// async fn users_post() {} /// /// async fn careers() {} /// /// let users_api = route("/users", get(users_get).post(users_post)); /// /// let app = nest("/api", users_api).route("/careers", get(careers)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// Take care when using `nest` together with dynamic routes as nesting also /// captures from the outer routes: /// /// ``` /// use axum::{routing::nest, prelude::*}; /// /// async fn users_get(params: extract::UrlParamsMap) { /// // Both `version` and `id` were captured even though `users_api` only /// // explicitly captures `id`. /// let version = params.get("version"); /// let id = params.get("id"); /// } /// /// let users_api = route("/users/:id", get(users_get)); /// /// let app = nest("/:version/api", users_api); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// `nest` also accepts any [`Service`]. This can for example be used with /// [`tower_http::services::ServeDir`] to serve static files from a directory: /// /// ``` /// use axum::{ /// routing::nest, service::{get, ServiceExt}, prelude::*, /// }; /// use tower_http::services::ServeDir; /// /// // Serves files inside the `public` directory at `GET /public/*` /// let serve_dir_service = ServeDir::new("public"); /// /// let app = nest("/public", get(serve_dir_service)); /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// If necessary you can use [`RoutingDsl::boxed`] to box a group of routes /// making the type easier to name. This is sometimes useful when working with /// `nest`. pub fn nest(description: &str, svc: S) -> Nested> where S: Service> + Clone, { Nested { pattern: PathPattern::new(description), svc, fallback: EmptyRouter::not_found(), } } /// A [`Service`] that has been nested inside a router at some path. /// /// Created with [`nest`] or [`RoutingDsl::nest`]. #[derive(Debug, Clone)] pub struct Nested { pattern: PathPattern, svc: S, fallback: F, } impl RoutingDsl for Nested {} impl crate::sealed::Sealed for Nested {} impl Service> for Nested where S: Service, Response = Response> + Clone, F: Service, Response = Response, Error = S::Error> + Clone, { type Response = Response; type Error = S::Error; type Future = RouteFuture; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, mut req: Request) -> Self::Future { if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { let without_prefix = strip_prefix(req.uri(), prefix); *req.uri_mut() = without_prefix; insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) } else { let fut = self.fallback.clone().oneshot(req); RouteFuture::b(fut) } } } fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { let path_and_query = if let Some(path_and_query) = uri.path_and_query() { let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { path } else { path_and_query.path() }; let new_path = if new_path.starts_with('/') { Cow::Borrowed(new_path) } else { Cow::Owned(format!("/{}", new_path)) }; if let Some(query) = path_and_query.query() { Some( format!("{}?{}", new_path, query) .parse::() .unwrap(), ) } else { Some(new_path.parse().unwrap()) } } else { None }; let mut parts = http::uri::Parts::default(); parts.scheme = uri.scheme().cloned(); parts.authority = uri.authority().cloned(); parts.path_and_query = path_and_query; Uri::from_parts(parts).unwrap() } #[cfg(test)] mod tests { use super::*; #[test] fn test_routing() { assert_match("/", "/"); assert_match("/foo", "/foo"); assert_match("/foo/", "/foo/"); refute_match("/foo", "/foo/"); refute_match("/foo/", "/foo"); assert_match("/foo/bar", "/foo/bar"); refute_match("/foo/bar/", "/foo/bar"); refute_match("/foo/bar", "/foo/bar/"); assert_match("/:value", "/foo"); assert_match("/users/:id", "/users/1"); assert_match("/users/:id/action", "/users/42/action"); refute_match("/users/:id/action", "/users/42"); refute_match("/users/:id", "/users/42/action"); } fn assert_match(route_spec: &'static str, path: &'static str) { let route = PathPattern::new(route_spec); assert!( route.full_match(path).is_some(), "`{}` doesn't match `{}`", path, route_spec ); } fn refute_match(route_spec: &'static str, path: &'static str) { let route = PathPattern::new(route_spec); assert!( route.full_match(path).is_none(), "`{}` did match `{}` (but shouldn't)", path, route_spec ); } }