From 433128102bbc88f14c7fd12f6f77346df5108e14 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 30 May 2021 02:29:41 +0200 Subject: [PATCH] More flexible error types --- src/lib.rs | 172 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 108 insertions(+), 64 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7718ed46..4a92aae9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,7 @@ use http_body::{combinators::BoxBody, Body as _}; use pin_project::pin_project; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ + convert::Infallible, future::Future, marker::PhantomData, pin::Pin, @@ -161,6 +162,12 @@ pub enum Error { ResponseBody(#[source] BoxError), } +impl From for Error { + fn from(err: Infallible) -> Self { + match err {} + } +} + // TODO(david): make this trait sealed #[async_trait] pub trait Handler { @@ -360,7 +367,7 @@ pub struct EmptyRouter(()); impl Service for EmptyRouter { type Response = Response; - type Error = Error; + type Error = Infallible; type Future = future::Ready>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -414,27 +421,29 @@ impl RouteSpec { impl Service> for Route where - H: Service, Response = Response, Error = Error>, - F: Service, Response = Response, Error = Error>, + H: Service, Response = Response>, + H::Error: Into, HB: http_body::Body + Send + Sync + 'static, HB::Error: Into, + + F: Service, Response = Response>, + F::Error: Into, FB: http_body::Body + Send + Sync + 'static, FB::Error: Into, { type Response = Response>; type Error = Error; - // type Future = future::BoxFuture<'static, Result>; type Future = future::Either, BoxResponseBody>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if !self.handler_ready { - ready!(self.handler.poll_ready(cx))?; + ready!(self.handler.poll_ready(cx)).map_err(Into::into)?; self.handler_ready = true; } if !self.fallback_ready { - ready!(self.fallback.poll_ready(cx))?; + ready!(self.fallback.poll_ready(cx)).map_err(Into::into)?; self.fallback_ready = true; } @@ -467,16 +476,17 @@ where #[pin_project] pub struct BoxResponseBody(#[pin] F); -impl Future for BoxResponseBody +impl Future for BoxResponseBody where - F: Future, Error>>, + F: Future, E>>, + E: Into, B: http_body::Body + Send + Sync + 'static, B::Error: Into, { type Output = Result>, Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response: Response = ready!(self.project().0.poll(cx))?; + let response: Response = ready!(self.project().0.poll(cx)).map_err(Into::into)?; let response = response.map(|body| body.map_err(|err| Error::ResponseBody(err.into())).boxed()); Poll::Ready(Ok(response)) @@ -486,6 +496,7 @@ where impl Service for App where R: Service, + R::Error: Into, { type Response = R::Response; type Error = R::Error; @@ -493,11 +504,13 @@ where #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // TODO(david): map error to response self.router.poll_ready(cx) } #[inline] fn call(&mut self, req: T) -> Self::Future { + // TODO(david): map error to response self.router.call(req) } } @@ -505,6 +518,7 @@ where impl Service for RouteBuilder where App: Service, + as Service>::Error: Into, { type Response = as Service>::Response; type Error = as Service>::Error; @@ -512,11 +526,13 @@ where #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // TODO(david): map error to response self.app.poll_ready(cx) } #[inline] fn call(&mut self, req: T) -> Self::Future { + // TODO(david): map error to response self.app.call(req) } } @@ -532,41 +548,84 @@ mod tests { #[tokio::test] async fn basic() { + #[derive(Debug, Deserialize)] + struct Pagination { + page: usize, + per_page: usize, + } + + #[derive(Debug, Deserialize)] + struct UsersCreate { + username: String, + } + let mut app = app() .at("/") - .get(root) + .get(|_: Request| async { + Ok::<_, Error>(Response::new(Body::from("Hello, World!"))) + }) .at("/users") - .get(users_index) - .post(users_create); + .get(|_: Request, pagination: Query| async { + let pagination = pagination.into_inner(); + assert_eq!(pagination.page, 1); + assert_eq!(pagination.per_page, 30); - let req = Request::builder() - .method(Method::POST) - .uri("/users") - .body(Body::from(r#"{ "username": "bob" }"#)) + Ok::<_, Error>(Response::new(Body::from("users#index"))) + }) + .post(|_: Request, payload: Json| async { + let payload = payload.into_inner(); + assert_eq!(payload.username, "bob"); + + Ok::<_, Error>(Response::new(Body::from("users#create"))) + }); + + let res = app + .ready() + .await + .unwrap() + .call( + Request::builder() + .method(Method::GET) + .uri("/") + .body(Body::empty()) + .unwrap(), + ) + .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(body_to_string(res).await, "Hello, World!"); - let res = app.ready().await.unwrap().call(req).await.unwrap(); - let body = body_to_string(res).await; - dbg!(&body); - } + let res = app + .ready() + .await + .unwrap() + .call( + Request::builder() + .method(Method::GET) + .uri("/users?page=1&per_page=30") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(body_to_string(res).await, "users#index"); - #[allow(dead_code)] - // this should just compile - async fn compatible_with_hyper_and_tower_http() { - let app = app() - .at("/") - .get(root) - .at("/users") - .get(users_index) - .post(users_create); - - let app = ServiceBuilder::new() - .layer(TraceLayer::new_for_http()) - .service(app); - - let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); - let server = Server::bind(&addr).serve(Shared::new(app)); - server.await.unwrap(); + let res = app + .ready() + .await + .unwrap() + .call( + Request::builder() + .method(Method::POST) + .uri("/users") + .body(Body::from(r#"{ "username": "bob" }"#)) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(body_to_string(res).await, "users#create"); } async fn body_to_string(res: Response) -> String @@ -578,34 +637,19 @@ mod tests { String::from_utf8(bytes.to_vec()).unwrap() } - async fn root(req: Request) -> Result, Error> { - Ok(Response::new(Body::from("Hello, World!"))) - } + #[allow(dead_code)] + // this should just compile + async fn compatible_with_hyper_and_tower_http() { + let app = app().at("/").get(|_: Request| async { + Ok::<_, Error>(Response::new(Body::from("Hello, World!"))) + }); - async fn users_index( - req: Request, - pagination: Query, - ) -> Result, Error> { - dbg!(pagination.into_inner()); - Ok(Response::new(Body::from("users#index"))) - } + let app = ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .service(app); - #[derive(Debug, Deserialize)] - struct Pagination { - page: usize, - per_page: usize, - } - - async fn users_create( - req: Request, - payload: Json, - ) -> Result, Error> { - dbg!(payload.into_inner()); - Ok(Response::new(Body::from("users#create"))) - } - - #[derive(Debug, Deserialize)] - struct UsersCreate { - username: String, + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let server = Server::bind(&addr).serve(Shared::new(app)); + server.await.unwrap(); } }