mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 15:24:54 +00:00

As described in https://github.com/tokio-rs/axum/pull/108#issuecomment-892811637, a `HandleError` created from `axum::ServiceExt::handle_error` should _not_ implement `RoutingDsl` as that leads to confusing routing behavior. The technique used here of adding another type parameter to `HandleError` isn't very clean, I think. But the alternative is duplicating `HandleError` and having two versions, which I think is less desirable.
953 lines
28 KiB
Rust
953 lines
28 KiB
Rust
//! Routing between [`Service`]s.
|
|
|
|
use self::future::{BoxRouteFuture, EmptyRouterFuture, RouteFuture};
|
|
use crate::{
|
|
body::{box_body, BoxBody},
|
|
buffer::MpscBuffer,
|
|
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
|
response::IntoResponse,
|
|
service::HandleErrorFromRouter,
|
|
util::ByteStr,
|
|
};
|
|
use async_trait::async_trait;
|
|
use bytes::Bytes;
|
|
use http::{Method, Request, Response, StatusCode, Uri};
|
|
use regex::Regex;
|
|
use std::{
|
|
borrow::Cow,
|
|
convert::Infallible,
|
|
fmt,
|
|
marker::PhantomData,
|
|
sync::Arc,
|
|
task::{Context, Poll},
|
|
};
|
|
use tower::{
|
|
util::{BoxService, ServiceExt},
|
|
BoxError, Layer, Service, ServiceBuilder,
|
|
};
|
|
use tower_http::map_response_body::MapResponseBodyLayer;
|
|
|
|
pub mod future;
|
|
|
|
/// A filter that matches one or more HTTP methods.
|
|
#[derive(Debug, Copy, Clone)]
|
|
pub enum MethodFilter {
|
|
/// Match any method.
|
|
Any,
|
|
/// Match `CONNECT` requests.
|
|
Connect,
|
|
/// Match `DELETE` requests.
|
|
Delete,
|
|
/// Match `GET` requests.
|
|
Get,
|
|
/// Match `HEAD` requests.
|
|
Head,
|
|
/// Match `OPTIONS` requests.
|
|
Options,
|
|
/// Match `PATCH` requests.
|
|
Patch,
|
|
/// Match `POST` requests.
|
|
Post,
|
|
/// Match `PUT` requests.
|
|
Put,
|
|
/// Match `TRACE` requests.
|
|
Trace,
|
|
}
|
|
|
|
impl MethodFilter {
|
|
#[allow(clippy::match_like_matches_macro)]
|
|
pub(crate) fn matches(self, method: &Method) -> bool {
|
|
match (self, method) {
|
|
(MethodFilter::Any, _)
|
|
| (MethodFilter::Connect, &Method::CONNECT)
|
|
| (MethodFilter::Delete, &Method::DELETE)
|
|
| (MethodFilter::Get, &Method::GET)
|
|
| (MethodFilter::Head, &Method::HEAD)
|
|
| (MethodFilter::Options, &Method::OPTIONS)
|
|
| (MethodFilter::Patch, &Method::PATCH)
|
|
| (MethodFilter::Post, &Method::POST)
|
|
| (MethodFilter::Put, &Method::PUT)
|
|
| (MethodFilter::Trace, &Method::TRACE) => true,
|
|
_ => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A route that sends requests to one of two [`Service`]s depending on the
|
|
/// path.
|
|
///
|
|
/// Created with [`route`](crate::route). See that function for more details.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Route<S, F> {
|
|
pub(crate) pattern: PathPattern,
|
|
pub(crate) svc: S,
|
|
pub(crate) fallback: F,
|
|
}
|
|
|
|
/// Trait for building routers.
|
|
#[async_trait]
|
|
pub trait RoutingDsl: crate::sealed::Sealed + Sized {
|
|
/// Add another route to the router.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```rust
|
|
/// use axum::prelude::*;
|
|
///
|
|
/// async fn first_handler() { /* ... */ }
|
|
///
|
|
/// async fn second_handler() { /* ... */ }
|
|
///
|
|
/// async fn third_handler() { /* ... */ }
|
|
///
|
|
/// // `GET /` goes to `first_handler`, `POST /` goes to `second_handler`,
|
|
/// // and `GET /foo` goes to third_handler.
|
|
/// let app = route("/", get(first_handler).post(second_handler))
|
|
/// .route("/foo", get(third_handler));
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
/// # };
|
|
/// ```
|
|
fn route<T, B>(self, description: &str, svc: T) -> Route<T, Self>
|
|
where
|
|
T: Service<Request<B>> + Clone,
|
|
{
|
|
Route {
|
|
pattern: PathPattern::new(description),
|
|
svc,
|
|
fallback: self,
|
|
}
|
|
}
|
|
|
|
/// Nest another service inside this router at the given path.
|
|
///
|
|
/// See [`nest`] for more details.
|
|
fn nest<T, B>(self, description: &str, svc: T) -> Nested<T, Self>
|
|
where
|
|
T: Service<Request<B>> + Clone,
|
|
{
|
|
Nested {
|
|
pattern: PathPattern::new(description),
|
|
svc,
|
|
fallback: self,
|
|
}
|
|
}
|
|
|
|
/// Create a boxed route trait object.
|
|
///
|
|
/// This makes it easier to name the types of routers to, for example,
|
|
/// return them from functions:
|
|
///
|
|
/// ```rust
|
|
/// use axum::{routing::BoxRoute, body::Body, prelude::*};
|
|
///
|
|
/// async fn first_handler() { /* ... */ }
|
|
///
|
|
/// async fn second_handler() { /* ... */ }
|
|
///
|
|
/// async fn third_handler() { /* ... */ }
|
|
///
|
|
/// fn app() -> BoxRoute<Body> {
|
|
/// route("/", get(first_handler).post(second_handler))
|
|
/// .route("/foo", get(third_handler))
|
|
/// .boxed()
|
|
/// }
|
|
/// ```
|
|
///
|
|
/// It also helps with compile times when you have a very large number of
|
|
/// routes.
|
|
fn boxed<ReqBody, ResBody>(self) -> BoxRoute<ReqBody, Self::Error>
|
|
where
|
|
Self: Service<Request<ReqBody>, Response = Response<ResBody>> + Send + 'static,
|
|
<Self as Service<Request<ReqBody>>>::Error: Into<BoxError> + Send + Sync,
|
|
<Self as Service<Request<ReqBody>>>::Future: Send,
|
|
ReqBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
|
ReqBody::Error: Into<BoxError> + Send + Sync + 'static,
|
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
|
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
|
{
|
|
ServiceBuilder::new()
|
|
.layer_fn(BoxRoute)
|
|
.layer_fn(MpscBuffer::new)
|
|
.layer(BoxService::layer())
|
|
.layer(MapResponseBodyLayer::new(box_body))
|
|
.service(self)
|
|
}
|
|
|
|
/// Apply a [`tower::Layer`] to the router.
|
|
///
|
|
/// All requests to the router will be processed by the layer's
|
|
/// corresponding middleware.
|
|
///
|
|
/// This can be used to add additional processing to a request for a group
|
|
/// of routes.
|
|
///
|
|
/// Note this differs from [`handler::Layered`](crate::handler::Layered)
|
|
/// which adds a middleware to a single handler.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// Adding the [`tower::limit::ConcurrencyLimit`] middleware to a group of
|
|
/// routes can be done like so:
|
|
///
|
|
/// ```rust
|
|
/// use axum::prelude::*;
|
|
/// use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit};
|
|
///
|
|
/// async fn first_handler() { /* ... */ }
|
|
///
|
|
/// async fn second_handler() { /* ... */ }
|
|
///
|
|
/// async fn third_handler() { /* ... */ }
|
|
///
|
|
/// // All requests to `handler` and `other_handler` will be sent through
|
|
/// // `ConcurrencyLimit`
|
|
/// let app = route("/", get(first_handler))
|
|
/// .route("/foo", get(second_handler))
|
|
/// .layer(ConcurrencyLimitLayer::new(64))
|
|
/// // Request to `GET /bar` will go directly to `third_handler` and
|
|
/// // wont be sent through `ConcurrencyLimit`
|
|
/// .route("/bar", get(third_handler));
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// This is commonly used to add middleware such as tracing/logging to your
|
|
/// entire app:
|
|
///
|
|
/// ```rust
|
|
/// use axum::prelude::*;
|
|
/// use tower_http::trace::TraceLayer;
|
|
///
|
|
/// async fn first_handler() { /* ... */ }
|
|
///
|
|
/// async fn second_handler() { /* ... */ }
|
|
///
|
|
/// async fn third_handler() { /* ... */ }
|
|
///
|
|
/// let app = route("/", get(first_handler))
|
|
/// .route("/foo", get(second_handler))
|
|
/// .route("/bar", get(third_handler))
|
|
/// .layer(TraceLayer::new_for_http());
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
/// # };
|
|
/// ```
|
|
fn layer<L>(self, layer: L) -> Layered<L::Service>
|
|
where
|
|
L: Layer<Self>,
|
|
{
|
|
Layered::new(layer.layer(self))
|
|
}
|
|
|
|
/// Convert this router into a [`MakeService`], that is a [`Service`] who's
|
|
/// response is another service.
|
|
///
|
|
/// This is useful when running your application with hyper's
|
|
/// [`Server`](hyper::server::Server):
|
|
///
|
|
/// ```
|
|
/// use axum::prelude::*;
|
|
///
|
|
/// let app = route("/", get(|| async { "Hi!" }));
|
|
///
|
|
/// # async {
|
|
/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
|
|
/// .serve(app.into_make_service())
|
|
/// .await
|
|
/// .expect("server failed");
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// [`MakeService`]: tower::make::MakeService
|
|
fn into_make_service(self) -> tower::make::Shared<Self>
|
|
where
|
|
Self: Clone,
|
|
{
|
|
tower::make::Shared::new(self)
|
|
}
|
|
|
|
/// Convert this router into a [`MakeService`], that will store `C`'s
|
|
/// associated `ConnectInfo` in a request extension such that [`ConnectInfo`]
|
|
/// can extract it.
|
|
///
|
|
/// This enables extracting things like the client's remote address.
|
|
///
|
|
/// Extracting [`std::net::SocketAddr`] is supported out of the box:
|
|
///
|
|
/// ```
|
|
/// use axum::{prelude::*, extract::ConnectInfo};
|
|
/// use std::net::SocketAddr;
|
|
///
|
|
/// let app = route("/", get(handler));
|
|
///
|
|
/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
|
|
/// format!("Hello {}", addr)
|
|
/// }
|
|
///
|
|
/// # async {
|
|
/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
|
|
/// .serve(
|
|
/// app.into_make_service_with_connect_info::<SocketAddr, _>()
|
|
/// )
|
|
/// .await
|
|
/// .expect("server failed");
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// You can implement custom a [`Connected`] like so:
|
|
///
|
|
/// ```
|
|
/// use axum::{
|
|
/// prelude::*,
|
|
/// extract::connect_info::{ConnectInfo, Connected},
|
|
/// };
|
|
/// use hyper::server::conn::AddrStream;
|
|
///
|
|
/// let app = route("/", get(handler));
|
|
///
|
|
/// async fn handler(
|
|
/// ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
|
|
/// ) -> String {
|
|
/// format!("Hello {:?}", my_connect_info)
|
|
/// }
|
|
///
|
|
/// #[derive(Clone, Debug)]
|
|
/// struct MyConnectInfo {
|
|
/// // ...
|
|
/// }
|
|
///
|
|
/// impl Connected<&AddrStream> for MyConnectInfo {
|
|
/// type ConnectInfo = MyConnectInfo;
|
|
///
|
|
/// fn connect_info(target: &AddrStream) -> Self::ConnectInfo {
|
|
/// MyConnectInfo {
|
|
/// // ...
|
|
/// }
|
|
/// }
|
|
/// }
|
|
///
|
|
/// # async {
|
|
/// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
|
|
/// .serve(
|
|
/// app.into_make_service_with_connect_info::<MyConnectInfo, _>()
|
|
/// )
|
|
/// .await
|
|
/// .expect("server failed");
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// See the [unix domain socket example][uds] for an example of how to use
|
|
/// this to collect UDS connection info.
|
|
///
|
|
/// [`MakeService`]: tower::make::MakeService
|
|
/// [`Connected`]: crate::extract::connect_info::Connected
|
|
/// [`ConnectInfo`]: crate::extract::connect_info::ConnectInfo
|
|
/// [uds]: https://github.com/tokio-rs/axum/blob/main/examples/unix_domain_socket.rs
|
|
fn into_make_service_with_connect_info<C, Target>(
|
|
self,
|
|
) -> IntoMakeServiceWithConnectInfo<Self, C>
|
|
where
|
|
Self: Clone,
|
|
C: Connected<Target>,
|
|
{
|
|
IntoMakeServiceWithConnectInfo::new(self)
|
|
}
|
|
}
|
|
|
|
impl<S, F> RoutingDsl for Route<S, F> {}
|
|
|
|
impl<S, F> crate::sealed::Sealed for Route<S, F> {}
|
|
|
|
impl<S, F, B> Service<Request<B>> for Route<S, F>
|
|
where
|
|
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
|
|
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
|
|
{
|
|
type Response = Response<BoxBody>;
|
|
type Error = S::Error;
|
|
type Future = RouteFuture<S, F, B>;
|
|
|
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
|
if let Some(captures) = self.pattern.full_match(req.uri().path()) {
|
|
insert_url_params(&mut req, captures);
|
|
let fut = self.svc.clone().oneshot(req);
|
|
RouteFuture::a(fut)
|
|
} else {
|
|
let fut = self.fallback.clone().oneshot(req);
|
|
RouteFuture::b(fut)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>);
|
|
|
|
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
|
|
let params = params
|
|
.into_iter()
|
|
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v)));
|
|
|
|
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
|
|
let mut current = current.take().unwrap();
|
|
current.0.extend(params);
|
|
req.extensions_mut().insert(Some(current));
|
|
} else {
|
|
req.extensions_mut()
|
|
.insert(Some(UrlParams(params.collect())));
|
|
}
|
|
}
|
|
|
|
/// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed`
|
|
/// to all requests.
|
|
///
|
|
/// This is used as the bottom service in a router stack. You shouldn't have to
|
|
/// use to manually.
|
|
pub struct EmptyRouter<E = Infallible> {
|
|
status: StatusCode,
|
|
_marker: PhantomData<fn() -> E>,
|
|
}
|
|
|
|
impl<E> EmptyRouter<E> {
|
|
pub(crate) fn not_found() -> Self {
|
|
Self {
|
|
status: StatusCode::NOT_FOUND,
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn method_not_allowed() -> Self {
|
|
Self {
|
|
status: StatusCode::METHOD_NOT_ALLOWED,
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<E> Clone for EmptyRouter<E> {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
status: self.status,
|
|
_marker: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<E> fmt::Debug for EmptyRouter<E> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_tuple("EmptyRouter").finish()
|
|
}
|
|
}
|
|
|
|
impl<E> RoutingDsl for EmptyRouter<E> {}
|
|
|
|
impl<E> crate::sealed::Sealed for EmptyRouter<E> {}
|
|
|
|
impl<B, E> Service<Request<B>> for EmptyRouter<E> {
|
|
type Response = Response<BoxBody>;
|
|
type Error = E;
|
|
type Future = EmptyRouterFuture<E>;
|
|
|
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn call(&mut self, _req: Request<B>) -> Self::Future {
|
|
let mut res = Response::new(crate::body::empty());
|
|
*res.status_mut() = self.status;
|
|
EmptyRouterFuture {
|
|
future: futures_util::future::ok(res),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub(crate) struct PathPattern(Arc<Inner>);
|
|
|
|
#[derive(Debug)]
|
|
struct Inner {
|
|
full_path_regex: Regex,
|
|
capture_group_names: Box<[Bytes]>,
|
|
}
|
|
|
|
impl PathPattern {
|
|
pub(crate) fn new(pattern: &str) -> Self {
|
|
assert!(
|
|
pattern.starts_with('/'),
|
|
"Route description must start with a `/`"
|
|
);
|
|
|
|
let mut capture_group_names = Vec::new();
|
|
|
|
let pattern = pattern
|
|
.split('/')
|
|
.map(|part| {
|
|
if let Some(key) = part.strip_prefix(':') {
|
|
capture_group_names.push(Bytes::copy_from_slice(key.as_bytes()));
|
|
|
|
Cow::Owned(format!("(?P<{}>[^/]*)", key))
|
|
} else {
|
|
Cow::Borrowed(part)
|
|
}
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("/");
|
|
|
|
let full_path_regex =
|
|
Regex::new(&format!("^{}", pattern)).expect("invalid regex generated from route");
|
|
|
|
Self(Arc::new(Inner {
|
|
full_path_regex,
|
|
capture_group_names: capture_group_names.into(),
|
|
}))
|
|
}
|
|
|
|
pub(crate) fn full_match(&self, path: &str) -> Option<Captures> {
|
|
self.do_match(path).and_then(|match_| {
|
|
if match_.full_match {
|
|
Some(match_.captures)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
}
|
|
|
|
pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> {
|
|
self.do_match(path)
|
|
.map(|match_| (match_.matched, match_.captures))
|
|
}
|
|
|
|
fn do_match<'a>(&self, path: &'a str) -> Option<Match<'a>> {
|
|
self.0.full_path_regex.captures(path).map(|captures| {
|
|
let matched = captures.get(0).unwrap();
|
|
let full_match = matched.as_str() == path;
|
|
|
|
let captures = self
|
|
.0
|
|
.capture_group_names
|
|
.iter()
|
|
.map(|bytes| {
|
|
std::str::from_utf8(bytes)
|
|
.expect("bytes were created from str so is valid utf-8")
|
|
})
|
|
.filter_map(|name| captures.name(name).map(|value| (name, value.as_str())))
|
|
.map(|(key, value)| (key.to_string(), value.to_string()))
|
|
.collect::<Vec<_>>();
|
|
|
|
Match {
|
|
captures,
|
|
full_match,
|
|
matched: matched.as_str(),
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
struct Match<'a> {
|
|
captures: Captures,
|
|
// true if regex matched whole path, false if it only matched a prefix
|
|
full_match: bool,
|
|
matched: &'a str,
|
|
}
|
|
|
|
type Captures = Vec<(String, String)>;
|
|
|
|
/// A boxed route trait object.
|
|
///
|
|
/// See [`RoutingDsl::boxed`] for more details.
|
|
pub struct BoxRoute<B, E = Infallible>(
|
|
MpscBuffer<BoxService<Request<B>, Response<BoxBody>, E>, Request<B>>,
|
|
);
|
|
|
|
impl<B, E> Clone for BoxRoute<B, E> {
|
|
fn clone(&self) -> Self {
|
|
Self(self.0.clone())
|
|
}
|
|
}
|
|
|
|
impl<B, E> fmt::Debug for BoxRoute<B, E> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("BoxRoute").finish()
|
|
}
|
|
}
|
|
|
|
impl<B, E> RoutingDsl for BoxRoute<B, E> {}
|
|
|
|
impl<B, E> crate::sealed::Sealed for BoxRoute<B, E> {}
|
|
|
|
impl<B, E> Service<Request<B>> for BoxRoute<B, E>
|
|
where
|
|
E: Into<BoxError>,
|
|
{
|
|
type Response = Response<BoxBody>;
|
|
type Error = E;
|
|
type Future = BoxRouteFuture<B, E>;
|
|
|
|
#[inline]
|
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
#[inline]
|
|
fn call(&mut self, req: Request<B>) -> Self::Future {
|
|
BoxRouteFuture {
|
|
inner: self.0.clone().oneshot(req),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A [`Service`] created from a router by applying a Tower middleware.
|
|
///
|
|
/// Created with [`RoutingDsl::layer`]. See that method for more details.
|
|
pub struct Layered<S> {
|
|
inner: S,
|
|
}
|
|
|
|
impl<S> Layered<S> {
|
|
fn new(inner: S) -> Self {
|
|
Self { inner }
|
|
}
|
|
}
|
|
|
|
impl<S> Clone for Layered<S>
|
|
where
|
|
S: Clone,
|
|
{
|
|
fn clone(&self) -> Self {
|
|
Self::new(self.inner.clone())
|
|
}
|
|
}
|
|
|
|
impl<S> fmt::Debug for Layered<S>
|
|
where
|
|
S: fmt::Debug,
|
|
{
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("Layered")
|
|
.field("inner", &self.inner)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl<S> RoutingDsl for Layered<S> {}
|
|
|
|
impl<S> crate::sealed::Sealed for Layered<S> {}
|
|
|
|
impl<S> Layered<S> {
|
|
/// Create a new [`Layered`] service where errors will be handled using the
|
|
/// given closure.
|
|
///
|
|
/// This is used to convert errors to responses rather than simply
|
|
/// terminating the connection.
|
|
///
|
|
/// That can be done using `handle_error` like so:
|
|
///
|
|
/// ```rust
|
|
/// use axum::prelude::*;
|
|
/// use http::StatusCode;
|
|
/// use tower::{BoxError, timeout::TimeoutLayer};
|
|
/// use std::{convert::Infallible, time::Duration};
|
|
///
|
|
/// async fn handler() { /* ... */ }
|
|
///
|
|
/// // `Timeout` will fail with `BoxError` if the timeout elapses...
|
|
/// let layered_app = route("/", get(handler))
|
|
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
|
|
///
|
|
/// // ...so we should handle that error
|
|
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
|
|
/// if error.is::<tower::timeout::error::Elapsed>() {
|
|
/// Ok::<_, Infallible>((
|
|
/// StatusCode::REQUEST_TIMEOUT,
|
|
/// "request took too long".to_string(),
|
|
/// ))
|
|
/// } else {
|
|
/// Ok::<_, Infallible>((
|
|
/// StatusCode::INTERNAL_SERVER_ERROR,
|
|
/// format!("Unhandled internal error: {}", error),
|
|
/// ))
|
|
/// }
|
|
/// });
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap())
|
|
/// # .serve(with_errors_handled.into_make_service())
|
|
/// # .await
|
|
/// # .unwrap();
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// The closure must return `Result<T, E>` where `T` implements [`IntoResponse`].
|
|
///
|
|
/// You can also return `Err(_)` if you don't wish to handle the error:
|
|
///
|
|
/// ```rust
|
|
/// use axum::prelude::*;
|
|
/// use http::StatusCode;
|
|
/// use tower::{BoxError, timeout::TimeoutLayer};
|
|
/// use std::time::Duration;
|
|
///
|
|
/// async fn handler() { /* ... */ }
|
|
///
|
|
/// let layered_app = route("/", get(handler))
|
|
/// .layer(TimeoutLayer::new(Duration::from_secs(30)));
|
|
///
|
|
/// let with_errors_handled = layered_app.handle_error(|error: BoxError| {
|
|
/// if error.is::<tower::timeout::error::Elapsed>() {
|
|
/// Ok((
|
|
/// StatusCode::REQUEST_TIMEOUT,
|
|
/// "request took too long".to_string(),
|
|
/// ))
|
|
/// } else {
|
|
/// // keep the error as is
|
|
/// Err(error)
|
|
/// }
|
|
/// });
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap())
|
|
/// # .serve(with_errors_handled.into_make_service())
|
|
/// # .await
|
|
/// # .unwrap();
|
|
/// # };
|
|
/// ```
|
|
pub fn handle_error<F, ReqBody, ResBody, Res, E>(
|
|
self,
|
|
f: F,
|
|
) -> crate::service::HandleError<S, F, ReqBody, HandleErrorFromRouter>
|
|
where
|
|
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
|
|
F: FnOnce(S::Error) -> Result<Res, E>,
|
|
Res: IntoResponse,
|
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
|
ResBody::Error: Into<BoxError> + Send + Sync + 'static,
|
|
{
|
|
crate::service::HandleError::new(self.inner, f)
|
|
}
|
|
}
|
|
|
|
impl<S, R> Service<R> for Layered<S>
|
|
where
|
|
S: Service<R>,
|
|
{
|
|
type Response = S::Response;
|
|
type Error = S::Error;
|
|
type Future = S::Future;
|
|
|
|
#[inline]
|
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
self.inner.poll_ready(cx)
|
|
}
|
|
|
|
#[inline]
|
|
fn call(&mut self, req: R) -> Self::Future {
|
|
self.inner.call(req)
|
|
}
|
|
}
|
|
|
|
/// Nest a group of routes (or a [`Service`]) at some path.
|
|
///
|
|
/// This allows you to break your application into smaller pieces and compose
|
|
/// them together. This will strip the matching prefix from the URL so the
|
|
/// nested route will only see the part of URL:
|
|
///
|
|
/// ```
|
|
/// use axum::{routing::nest, prelude::*};
|
|
/// use http::Uri;
|
|
///
|
|
/// async fn users_get(uri: Uri) {
|
|
/// // `users_get` doesn't see the whole URL. `nest` will strip the matching
|
|
/// // `/api` prefix.
|
|
/// assert_eq!(uri.path(), "/users");
|
|
/// }
|
|
///
|
|
/// async fn users_post() {}
|
|
///
|
|
/// async fn careers() {}
|
|
///
|
|
/// let users_api = route("/users", get(users_get).post(users_post));
|
|
///
|
|
/// let app = nest("/api", users_api).route("/careers", get(careers));
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// Take care when using `nest` together with dynamic routes as nesting also
|
|
/// captures from the outer routes:
|
|
///
|
|
/// ```
|
|
/// use axum::{routing::nest, prelude::*};
|
|
///
|
|
/// async fn users_get(params: extract::UrlParamsMap) {
|
|
/// // Both `version` and `id` were captured even though `users_api` only
|
|
/// // explicitly captures `id`.
|
|
/// let version = params.get("version");
|
|
/// let id = params.get("id");
|
|
/// }
|
|
///
|
|
/// let users_api = route("/users/:id", get(users_get));
|
|
///
|
|
/// let app = nest("/:version/api", users_api);
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// `nest` also accepts any [`Service`]. This can for example be used with
|
|
/// [`tower_http::services::ServeDir`] to serve static files from a directory:
|
|
///
|
|
/// ```
|
|
/// use axum::{
|
|
/// routing::nest, service::{get, ServiceExt}, prelude::*,
|
|
/// };
|
|
/// use tower_http::services::ServeDir;
|
|
///
|
|
/// // Serves files inside the `public` directory at `GET /public/*`
|
|
/// let serve_dir_service = ServeDir::new("public");
|
|
///
|
|
/// let app = nest("/public", get(serve_dir_service));
|
|
/// # async {
|
|
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
|
/// # };
|
|
/// ```
|
|
///
|
|
/// If necessary you can use [`RoutingDsl::boxed`] to box a group of routes
|
|
/// making the type easier to name. This is sometimes useful when working with
|
|
/// `nest`.
|
|
pub fn nest<S, B>(description: &str, svc: S) -> Nested<S, EmptyRouter<S::Error>>
|
|
where
|
|
S: Service<Request<B>> + Clone,
|
|
{
|
|
Nested {
|
|
pattern: PathPattern::new(description),
|
|
svc,
|
|
fallback: EmptyRouter::not_found(),
|
|
}
|
|
}
|
|
|
|
/// A [`Service`] that has been nested inside a router at some path.
|
|
///
|
|
/// Created with [`nest`] or [`RoutingDsl::nest`].
|
|
#[derive(Debug, Clone)]
|
|
pub struct Nested<S, F> {
|
|
pattern: PathPattern,
|
|
svc: S,
|
|
fallback: F,
|
|
}
|
|
|
|
impl<S, F> RoutingDsl for Nested<S, F> {}
|
|
|
|
impl<S, F> crate::sealed::Sealed for Nested<S, F> {}
|
|
|
|
impl<S, F, B> Service<Request<B>> for Nested<S, F>
|
|
where
|
|
S: Service<Request<B>, Response = Response<BoxBody>> + Clone,
|
|
F: Service<Request<B>, Response = Response<BoxBody>, Error = S::Error> + Clone,
|
|
{
|
|
type Response = Response<BoxBody>;
|
|
type Error = S::Error;
|
|
type Future = RouteFuture<S, F, B>;
|
|
|
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn call(&mut self, mut req: Request<B>) -> Self::Future {
|
|
if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) {
|
|
let without_prefix = strip_prefix(req.uri(), prefix);
|
|
*req.uri_mut() = without_prefix;
|
|
|
|
insert_url_params(&mut req, captures);
|
|
let fut = self.svc.clone().oneshot(req);
|
|
RouteFuture::a(fut)
|
|
} else {
|
|
let fut = self.fallback.clone().oneshot(req);
|
|
RouteFuture::b(fut)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn strip_prefix(uri: &Uri, prefix: &str) -> Uri {
|
|
let path_and_query = if let Some(path_and_query) = uri.path_and_query() {
|
|
let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) {
|
|
path
|
|
} else {
|
|
path_and_query.path()
|
|
};
|
|
|
|
let new_path = if new_path.starts_with('/') {
|
|
Cow::Borrowed(new_path)
|
|
} else {
|
|
Cow::Owned(format!("/{}", new_path))
|
|
};
|
|
|
|
if let Some(query) = path_and_query.query() {
|
|
Some(
|
|
format!("{}?{}", new_path, query)
|
|
.parse::<http::uri::PathAndQuery>()
|
|
.unwrap(),
|
|
)
|
|
} else {
|
|
Some(new_path.parse().unwrap())
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut parts = http::uri::Parts::default();
|
|
parts.scheme = uri.scheme().cloned();
|
|
parts.authority = uri.authority().cloned();
|
|
parts.path_and_query = path_and_query;
|
|
|
|
Uri::from_parts(parts).unwrap()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_routing() {
|
|
assert_match("/", "/");
|
|
|
|
assert_match("/foo", "/foo");
|
|
assert_match("/foo/", "/foo/");
|
|
refute_match("/foo", "/foo/");
|
|
refute_match("/foo/", "/foo");
|
|
|
|
assert_match("/foo/bar", "/foo/bar");
|
|
refute_match("/foo/bar/", "/foo/bar");
|
|
refute_match("/foo/bar", "/foo/bar/");
|
|
|
|
assert_match("/:value", "/foo");
|
|
assert_match("/users/:id", "/users/1");
|
|
assert_match("/users/:id/action", "/users/42/action");
|
|
refute_match("/users/:id/action", "/users/42");
|
|
refute_match("/users/:id", "/users/42/action");
|
|
}
|
|
|
|
fn assert_match(route_spec: &'static str, path: &'static str) {
|
|
let route = PathPattern::new(route_spec);
|
|
assert!(
|
|
route.full_match(path).is_some(),
|
|
"`{}` doesn't match `{}`",
|
|
path,
|
|
route_spec
|
|
);
|
|
}
|
|
|
|
fn refute_match(route_spec: &'static str, path: &'static str) {
|
|
let route = PathPattern::new(route_spec);
|
|
assert!(
|
|
route.full_match(path).is_none(),
|
|
"`{}` did match `{}` (but shouldn't)",
|
|
path,
|
|
route_spec
|
|
);
|
|
}
|
|
}
|