diff --git a/src/handler/future.rs b/src/handler/future.rs index 999b96a9..68f139c8 100644 --- a/src/handler/future.rs +++ b/src/handler/future.rs @@ -2,10 +2,7 @@ use crate::body::{box_body, BoxBody}; use crate::util::{Either, EitherProj}; -use futures_util::{ - future::{BoxFuture, Map}, - ready, -}; +use futures_util::{future::BoxFuture, ready}; use http::{Method, Request, Response}; use http_body::Empty; use pin_project_lite::pin_project; @@ -67,8 +64,5 @@ where opaque_future! { /// The response future for [`IntoService`](super::IntoService). pub type IntoServiceFuture = - Map< - BoxFuture<'static, Response>, - fn(Response) -> Result, Infallible>, - >; + BoxFuture<'static, Result, Infallible>>; } diff --git a/src/handler/into_service.rs b/src/handler/into_service.rs index ffc0fde0..c5f6bcfc 100644 --- a/src/handler/into_service.rs +++ b/src/handler/into_service.rs @@ -1,5 +1,9 @@ use super::Handler; -use crate::body::BoxBody; +use crate::{ + body::{box_body, BoxBody}, + extract::{FromRequest, RequestParts}, + response::IntoResponse, +}; use http::{Request, Response}; use std::{ convert::Infallible, @@ -12,12 +16,12 @@ use tower::Service; /// An adapter that makes a [`Handler`] into a [`Service`]. /// /// Created with [`Handler::into_service`]. -pub struct IntoService { +pub struct IntoService { handler: H, - _marker: PhantomData (B, T)>, + _marker: PhantomData T>, } -impl IntoService { +impl IntoService { pub(super) fn new(handler: H) -> Self { Self { handler, @@ -26,7 +30,7 @@ impl IntoService { } } -impl fmt::Debug for IntoService { +impl fmt::Debug for IntoService { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("IntoService") .field(&format_args!("...")) @@ -34,7 +38,7 @@ impl fmt::Debug for IntoService { } } -impl Clone for IntoService +impl Clone for IntoService where H: Clone, { @@ -46,9 +50,11 @@ where } } -impl Service> for IntoService +impl Service> for IntoService where - H: Handler + Clone + Send + 'static, + H: Handler, + T: FromRequest + Send, + T::Rejection: Send, B: Send + 'static, { type Response = Response; @@ -63,10 +69,16 @@ where } fn call(&mut self, req: Request) -> Self::Future { - use futures_util::future::FutureExt; - let handler = self.handler.clone(); - let future = Handler::call(handler, req).map(Ok::<_, Infallible> as _); + let future = Box::pin(async move { + let mut req = RequestParts::new(req); + let input = T::from_request(&mut req).await; + let res = match input { + Ok(input) => Handler::call(handler, input).await, + Err(rejection) => rejection.into_response().map(box_body), + }; + Ok::<_, Infallible>(res) + }); super::future::IntoServiceFuture { future } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 8fb0630c..a64e7c29 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -2,7 +2,7 @@ use crate::{ body::{box_body, BoxBody}, - extract::FromRequest, + extract::{FromRequest, RequestParts}, response::IntoResponse, routing::{EmptyRouter, MethodFilter}, service::HandleError, @@ -44,9 +44,9 @@ pub use self::into_service::IntoService; /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn any(handler: H) -> OnMethod +pub fn any(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::all(), handler) } @@ -54,9 +54,9 @@ where /// Route `CONNECT` requests to the given handler. /// /// See [`get`] for an example. -pub fn connect(handler: H) -> OnMethod +pub fn connect(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::CONNECT, handler) } @@ -64,9 +64,9 @@ where /// Route `DELETE` requests to the given handler. /// /// See [`get`] for an example. -pub fn delete(handler: H) -> OnMethod +pub fn delete(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::DELETE, handler) } @@ -93,9 +93,9 @@ where /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. -pub fn get(handler: H) -> OnMethod +pub fn get(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::GET | MethodFilter::HEAD, handler) } @@ -103,9 +103,9 @@ where /// Route `HEAD` requests to the given handler. /// /// See [`get`] for an example. -pub fn head(handler: H) -> OnMethod +pub fn head(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::HEAD, handler) } @@ -113,9 +113,9 @@ where /// Route `OPTIONS` requests to the given handler. /// /// See [`get`] for an example. -pub fn options(handler: H) -> OnMethod +pub fn options(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::OPTIONS, handler) } @@ -123,9 +123,9 @@ where /// Route `PATCH` requests to the given handler. /// /// See [`get`] for an example. -pub fn patch(handler: H) -> OnMethod +pub fn patch(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::PATCH, handler) } @@ -133,9 +133,9 @@ where /// Route `POST` requests to the given handler. /// /// See [`get`] for an example. -pub fn post(handler: H) -> OnMethod +pub fn post(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::POST, handler) } @@ -143,9 +143,9 @@ where /// Route `PUT` requests to the given handler. /// /// See [`get`] for an example. -pub fn put(handler: H) -> OnMethod +pub fn put(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::PUT, handler) } @@ -153,9 +153,9 @@ where /// Route `TRACE` requests to the given handler. /// /// See [`get`] for an example. -pub fn trace(handler: H) -> OnMethod +pub fn trace(handler: H) -> OnMethod where - H: Handler, + H: Handler, { on(MethodFilter::TRACE, handler) } @@ -179,9 +179,9 @@ where /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -pub fn on(method: MethodFilter, handler: H) -> OnMethod +pub fn on(method: MethodFilter, handler: H) -> OnMethod where - H: Handler, + H: Handler, { OnMethod { method, @@ -206,14 +206,14 @@ pub(crate) mod sealed { /// /// See the [module docs](crate::handler) for more details. #[async_trait] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sync + Sized + 'static { // This seals the trait. We cannot use the regular "sealed super trait" // approach due to coherence. #[doc(hidden)] type Sealed: sealed::HiddentTrait; - /// Call the handler with the given request. - async fn call(self, req: Request) -> Response; + /// Call the handler. + async fn call(self, input: T) -> Response; /// Apply a [`tower::Layer`] to the handler. /// @@ -251,7 +251,7 @@ pub trait Handler: Clone + Send + Sized + 'static { /// errors. See [`Layered::handle_error`] for more details. fn layer(self, layer: L) -> Layered where - L: Layer>, + L: Layer>, { Layered::new(layer.layer(any(self))) } @@ -280,22 +280,21 @@ pub trait Handler: Clone + Send + Sized + 'static { /// # Ok::<_, hyper::Error>(()) /// # }; /// ``` - fn into_service(self) -> IntoService { + fn into_service(self) -> IntoService { IntoService::new(self) } } #[async_trait] -impl Handler for F +impl Handler<()> for F where F: FnOnce() -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, Res: IntoResponse, - B: Send + 'static, { type Sealed = sealed::Hidden; - async fn call(self, _req: Request) -> Response { + async fn call(self, _: ()) -> Response { self().await.into_response().map(box_body) } } @@ -307,35 +306,18 @@ macro_rules! impl_handler { ( $head:ident, $($tail:ident),* $(,)? ) => { #[async_trait] #[allow(non_snake_case)] - impl Handler for F + impl Handler<($head, $($tail,)*)> for F where F: FnOnce($head, $($tail,)*) -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, - B: Send + 'static, Res: IntoResponse, - B: Send + 'static, - $head: FromRequest + Send, - $( $tail: FromRequest + Send,)* + $head: Send + 'static, + $( $tail: Send + 'static ),* { type Sealed = sealed::Hidden; - async fn call(self, req: Request) -> Response { - let mut req = crate::extract::RequestParts::new(req); - - let $head = match $head::from_request(&mut req).await { - Ok(value) => value, - Err(rejection) => return rejection.into_response().map(box_body), - }; - - $( - let $tail = match $tail::from_request(&mut req).await { - Ok(value) => value, - Err(rejection) => return rejection.into_response().map(box_body), - }; - )* - + async fn call(self, ($head, $($tail,)*): ($head, $($tail,)*)) -> Response { let res = self($head, $($tail,)*).await; - res.into_response().map(crate::body::box_body) } } @@ -373,9 +355,9 @@ where } #[async_trait] -impl Handler for Layered +impl Handler<(Request,)> for Layered where - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + Sync + 'static, S::Error: IntoResponse, S::Future: Send, T: 'static, @@ -385,7 +367,7 @@ where { type Sealed = sealed::Hidden; - async fn call(self, req: Request) -> Response { + async fn call(self, (req,): (Request,)) -> Response { match self .svc .oneshot(req) @@ -431,14 +413,14 @@ impl Layered { /// A handler [`Service`] that accepts requests based on a [`MethodFilter`] and /// allows chaining additional handlers. -pub struct OnMethod { +pub struct OnMethod { pub(crate) method: MethodFilter, pub(crate) handler: H, pub(crate) fallback: F, - pub(crate) _marker: PhantomData (B, T)>, + pub(crate) _marker: PhantomData T>, } -impl fmt::Debug for OnMethod +impl fmt::Debug for OnMethod where T: fmt::Debug, F: fmt::Debug, @@ -452,7 +434,7 @@ where } } -impl Clone for OnMethod +impl Clone for OnMethod where H: Clone, F: Clone, @@ -467,21 +449,21 @@ where } } -impl Copy for OnMethod +impl Copy for OnMethod where H: Copy, F: Copy, { } -impl OnMethod { +impl OnMethod { /// Chain an additional handler that will accept all requests regardless of /// its HTTP method. /// /// See [`OnMethod::get`] for an example. - pub fn any(self, handler: H2) -> OnMethod + pub fn any(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::all(), handler) } @@ -489,9 +471,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `CONNECT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn connect(self, handler: H2) -> OnMethod + pub fn connect(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::CONNECT, handler) } @@ -499,9 +481,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `DELETE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn delete(self, handler: H2) -> OnMethod + pub fn delete(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::DELETE, handler) } @@ -528,9 +510,9 @@ impl OnMethod { /// Note that `get` routes will also be called for `HEAD` requests but will have /// the response body removed. Make sure to add explicit `HEAD` routes /// afterwards. - pub fn get(self, handler: H2) -> OnMethod + pub fn get(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::GET | MethodFilter::HEAD, handler) } @@ -538,9 +520,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `HEAD` requests. /// /// See [`OnMethod::get`] for an example. - pub fn head(self, handler: H2) -> OnMethod + pub fn head(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::HEAD, handler) } @@ -548,9 +530,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `OPTIONS` requests. /// /// See [`OnMethod::get`] for an example. - pub fn options(self, handler: H2) -> OnMethod + pub fn options(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::OPTIONS, handler) } @@ -558,9 +540,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `PATCH` requests. /// /// See [`OnMethod::get`] for an example. - pub fn patch(self, handler: H2) -> OnMethod + pub fn patch(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::PATCH, handler) } @@ -568,9 +550,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `POST` requests. /// /// See [`OnMethod::get`] for an example. - pub fn post(self, handler: H2) -> OnMethod + pub fn post(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::POST, handler) } @@ -578,9 +560,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `PUT` requests. /// /// See [`OnMethod::get`] for an example. - pub fn put(self, handler: H2) -> OnMethod + pub fn put(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::PUT, handler) } @@ -588,9 +570,9 @@ impl OnMethod { /// Chain an additional handler that will only accept `TRACE` requests. /// /// See [`OnMethod::get`] for an example. - pub fn trace(self, handler: H2) -> OnMethod + pub fn trace(self, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { self.on(MethodFilter::TRACE, handler) } @@ -618,9 +600,9 @@ impl OnMethod { /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - pub fn on(self, method: MethodFilter, handler: H2) -> OnMethod + pub fn on(self, method: MethodFilter, handler: H2) -> OnMethod where - H2: Handler, + H2: Handler, { OnMethod { method, @@ -631,10 +613,12 @@ impl OnMethod { } } -impl Service> for OnMethod +impl Service> for OnMethod where - H: Handler, - F: Service, Response = Response, Error = Infallible> + Clone, + H: Handler, + T: FromRequest + Send, + T::Rejection: Send, + F: Service, Response = Response, Error = Infallible> + Clone + Send, B: Send + 'static, { type Response = Response; @@ -649,7 +633,15 @@ where let req_method = req.method().clone(); let fut = if self.method.matches(req.method()) { - let fut = Handler::call(self.handler.clone(), req); + let handler = self.handler.clone(); + let fut = Box::pin(async move { + let mut req = RequestParts::new(req); + let input = T::from_request(&mut req).await; + match input { + Ok(input) => Handler::call(handler, input).await, + Err(rejection) => rejection.into_response().map(box_body), + } + }) as _; Either::A { inner: fut } } else { let fut = self.fallback.clone().oneshot(req); diff --git a/src/tests/handle_error.rs b/src/tests/handle_error.rs index cc93fe25..92a1bdb6 100644 --- a/src/tests/handle_error.rs +++ b/src/tests/handle_error.rs @@ -191,9 +191,9 @@ fn service_handle_on_router_still_impls_routing_dsl() { #[test] fn layered() { let app = Router::new() - .route("/echo", get::<_, Body, _>(unit)) + .route("/echo", get(unit)) .layer(timeout()) - .handle_error(handle_error::); + .handle_error::(handle_error::); check_make_svc::<_, _, _, Infallible>(app.into_make_service()); } @@ -201,9 +201,9 @@ fn layered() { #[tokio::test] // async because of `.boxed()` async fn layered_boxed() { let app = Router::new() - .route("/echo", get::<_, Body, _>(unit)) + .route("/echo", get(unit)) .layer(timeout()) - .boxed() + .boxed::() .handle_error(handle_error::); check_make_svc::<_, _, _, Infallible>(app.into_make_service());