Fix ServiceExt::handle_error footgun (#120)

As described in
https://github.com/tokio-rs/axum/pull/108#issuecomment-892811637, a
`HandleError` created from `axum::ServiceExt::handle_error` should _not_
implement `RoutingDsl` as that leads to confusing routing behavior.

The technique used here of adding another type parameter to
`HandleError` isn't very clean, I think. But the alternative is
duplicating `HandleError` and having two versions, which I think is less
desirable.
This commit is contained in:
David Pedersen 2021-08-07 16:44:12 +02:00 committed by GitHub
parent b5b9db47db
commit 95d7582d28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 12 deletions

View File

@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Breaking changes ## Breaking changes
- Ensure a `HandleError` service created from `axum::ServiceExt::handle_error`
_does not_ implement `RoutingDsl` as that could lead to confusing routing
behavior. ([#120](https://github.com/tokio-rs/axum/pull/120))
- Remove `QueryStringMissing` as it was no longer being used - Remove `QueryStringMissing` as it was no longer being used
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved - `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133)) to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))

View File

@ -5,7 +5,7 @@ use crate::{
extract::FromRequest, extract::FromRequest,
response::IntoResponse, response::IntoResponse,
routing::{future::RouteFuture, EmptyRouter, MethodFilter}, routing::{future::RouteFuture, EmptyRouter, MethodFilter},
service::HandleError, service::{HandleError, HandleErrorFromRouter},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
@ -371,7 +371,7 @@ impl<S, T> Layered<S, T> {
pub fn handle_error<F, ReqBody, ResBody, Res, E>( pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self, self,
f: F, f: F,
) -> Layered<HandleError<S, F, ReqBody>, T> ) -> Layered<HandleError<S, F, ReqBody, HandleErrorFromRouter>, T>
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>>, S: Service<Request<ReqBody>, Response = Response<ResBody>>,
F: FnOnce(S::Error) -> Result<Res, E>, F: FnOnce(S::Error) -> Result<Res, E>,

View File

@ -6,6 +6,7 @@ use crate::{
buffer::MpscBuffer, buffer::MpscBuffer,
extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo}, extract::connect_info::{Connected, IntoMakeServiceWithConnectInfo},
response::IntoResponse, response::IntoResponse,
service::HandleErrorFromRouter,
util::ByteStr, util::ByteStr,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@ -716,7 +717,7 @@ impl<S> Layered<S> {
pub fn handle_error<F, ReqBody, ResBody, Res, E>( pub fn handle_error<F, ReqBody, ResBody, Res, E>(
self, self,
f: F, f: F,
) -> crate::service::HandleError<S, F, ReqBody> ) -> crate::service::HandleError<S, F, ReqBody, HandleErrorFromRouter>
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone, S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Result<Res, E>, F: FnOnce(S::Error) -> Result<Res, E>,

View File

@ -462,13 +462,13 @@ where
/// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or /// [`handler::Layered::handle_error`](crate::handler::Layered::handle_error) or
/// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error). /// [`routing::Layered::handle_error`](crate::routing::Layered::handle_error).
/// See those methods for more details. /// See those methods for more details.
pub struct HandleError<S, F, B> { pub struct HandleError<S, F, B, T> {
inner: S, inner: S,
f: F, f: F,
_marker: PhantomData<fn() -> B>, _marker: PhantomData<fn() -> (B, T)>,
} }
impl<S, F, B> Clone for HandleError<S, F, B> impl<S, F, B, T> Clone for HandleError<S, F, B, T>
where where
S: Clone, S: Clone,
F: Clone, F: Clone,
@ -478,11 +478,23 @@ where
} }
} }
impl<S, F, B> crate::routing::RoutingDsl for HandleError<S, F, B> {} /// Maker type used for [`HandleError`] to indicate that it should implement
/// [`RoutingDsl`](crate::routing::RoutingDsl).
#[non_exhaustive]
#[derive(Debug)]
pub struct HandleErrorFromRouter;
impl<S, F, B> crate::sealed::Sealed for HandleError<S, F, B> {} /// Maker type used for [`HandleError`] to indicate that it should _not_ implement
/// [`RoutingDsl`](crate::routing::RoutingDsl).
#[non_exhaustive]
#[derive(Debug)]
pub struct HandleErrorFromService;
impl<S, F, B> HandleError<S, F, B> { impl<S, F, B> crate::routing::RoutingDsl for HandleError<S, F, B, HandleErrorFromRouter> {}
impl<S, F, B> crate::sealed::Sealed for HandleError<S, F, B, HandleErrorFromRouter> {}
impl<S, F, B, T> HandleError<S, F, B, T> {
pub(crate) fn new(inner: S, f: F) -> Self { pub(crate) fn new(inner: S, f: F) -> Self {
Self { Self {
inner, inner,
@ -492,7 +504,7 @@ impl<S, F, B> HandleError<S, F, B> {
} }
} }
impl<S, F, B> fmt::Debug for HandleError<S, F, B> impl<S, F, B, T> fmt::Debug for HandleError<S, F, B, T>
where where
S: fmt::Debug, S: fmt::Debug,
{ {
@ -504,7 +516,7 @@ where
} }
} }
impl<S, F, ReqBody, ResBody, Res, E> Service<Request<ReqBody>> for HandleError<S, F, ReqBody> impl<S, F, ReqBody, ResBody, Res, E, T> Service<Request<ReqBody>> for HandleError<S, F, ReqBody, T>
where where
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone, S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
F: FnOnce(S::Error) -> Result<Res, E> + Clone, F: FnOnce(S::Error) -> Result<Res, E> + Clone,
@ -570,7 +582,7 @@ pub trait ServiceExt<ReqBody, ResBody>:
/// It works similarly to [`routing::Layered::handle_error`]. See that for more details. /// It works similarly to [`routing::Layered::handle_error`]. See that for more details.
/// ///
/// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error /// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error
fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody> fn handle_error<F, Res, E>(self, f: F) -> HandleError<Self, F, ReqBody, HandleErrorFromService>
where where
Self: Sized, Self: Sized,
F: FnOnce(Self::Error) -> Result<Res, E>, F: FnOnce(Self::Error) -> Result<Res, E>,
@ -645,3 +657,21 @@ where
future::BoxResponseBodyFuture { future: fut } future::BoxResponseBodyFuture { future: fut }
} }
} }
/// ```compile_fail
/// use crate::{service::ServiceExt, prelude::*};
/// use tower::service_fn;
/// use hyper::Body;
/// use http::{Request, Response, StatusCode};
///
/// let svc = service_fn(|_: Request<Body>| async {
/// Ok::<_, hyper::Error>(Response::new(Body::empty()))
/// })
/// .handle_error::<_, _, hyper::Error>(|_| Ok(StatusCode::INTERNAL_SERVER_ERROR));
///
/// // `.route` should not compile, ie `HandleError` created from any
/// // random service should not implement `RoutingDsl`
/// svc.route::<_, Body>("/", get(|| async {}));
/// ```
#[allow(dead_code)]
fn compile_fail_tests() {}