diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d3c1187..47702ddf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Breaking changes +- Ensure a `HandleError` service created from `axum::ServiceExt::handle_error` + _does not_ implement `RoutingDsl` as that could lead to confusing routing + behavior. ([#120](https://github.com/tokio-rs/axum/pull/120)) - Remove `QueryStringMissing` as it was no longer being used - `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133)) diff --git a/src/handler/mod.rs b/src/handler/mod.rs index d38edb89..f723df6f 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -5,7 +5,7 @@ use crate::{ extract::FromRequest, response::IntoResponse, routing::{future::RouteFuture, EmptyRouter, MethodFilter}, - service::HandleError, + service::{HandleError, HandleErrorFromRouter}, }; use async_trait::async_trait; use bytes::Bytes; @@ -371,7 +371,7 @@ impl Layered { pub fn handle_error( self, f: F, - ) -> Layered, T> + ) -> Layered, T> where S: Service, Response = Response>, F: FnOnce(S::Error) -> Result, diff --git a/src/routing.rs b/src/routing.rs index 30f9ccf7..72b6c2e8 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -6,6 +6,7 @@ use crate::{ buffer::MpscBuffer, extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, response::IntoResponse, + service::HandleErrorFromRouter, util::ByteStr, }; use async_trait::async_trait; @@ -716,7 +717,7 @@ impl Layered { pub fn handle_error( self, f: F, - ) -> crate::service::HandleError + ) -> crate::service::HandleError where S: Service, Response = Response> + Clone, F: FnOnce(S::Error) -> Result, diff --git a/src/service/mod.rs b/src/service/mod.rs index 3d6508b6..1012db85 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -462,13 +462,13 @@ where /// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or /// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error). /// See those methods for more details. -pub struct HandleError { +pub struct HandleError { inner: S, f: F, - _marker: PhantomData B>, + _marker: PhantomData (B, T)>, } -impl Clone for HandleError +impl Clone for HandleError where S: Clone, F: Clone, @@ -478,11 +478,23 @@ where } } -impl crate::routing::RoutingDsl for HandleError {} +/// Maker type used for [`HandleError`] to indicate that it should implement +/// [`RoutingDsl`](crate::routing::RoutingDsl). +#[non_exhaustive] +#[derive(Debug)] +pub struct HandleErrorFromRouter; -impl crate::sealed::Sealed for HandleError {} +/// Maker type used for [`HandleError`] to indicate that it should _not_ implement +/// [`RoutingDsl`](crate::routing::RoutingDsl). +#[non_exhaustive] +#[derive(Debug)] +pub struct HandleErrorFromService; -impl HandleError { +impl crate::routing::RoutingDsl for HandleError {} + +impl crate::sealed::Sealed for HandleError {} + +impl HandleError { pub(crate) fn new(inner: S, f: F) -> Self { Self { inner, @@ -492,7 +504,7 @@ impl HandleError { } } -impl fmt::Debug for HandleError +impl fmt::Debug for HandleError where S: fmt::Debug, { @@ -504,7 +516,7 @@ where } } -impl Service> for HandleError +impl Service> for HandleError where S: Service, Response = Response> + Clone, F: FnOnce(S::Error) -> Result + Clone, @@ -570,7 +582,7 @@ pub trait ServiceExt: /// It works similarly to [`routing::Layered::handle_error`]. See that for more details. /// /// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error - fn handle_error(self, f: F) -> HandleError + fn handle_error(self, f: F) -> HandleError where Self: Sized, F: FnOnce(Self::Error) -> Result, @@ -645,3 +657,21 @@ where future::BoxResponseBodyFuture { future: fut } } } + +/// ```compile_fail +/// use crate::{service::ServiceExt, prelude::*}; +/// use tower::service_fn; +/// use hyper::Body; +/// use http::{Request, Response, StatusCode}; +/// +/// let svc = service_fn(|_: Request| async { +/// Ok::<_, hyper::Error>(Response::new(Body::empty())) +/// }) +/// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR)); +/// +/// // `.route` should not compile, ie `HandleError` created from any +/// // random service should not implement `RoutingDsl` +/// svc.route::<_, Body>("/", get(|| async {})); +/// ``` +#[allow(dead_code)] +fn compile_fail_tests() {}