More flexible error types

This commit is contained in:
David Pedersen 2021-05-30 02:29:41 +02:00
parent a04c98dd42
commit 433128102b

View File

@ -35,6 +35,7 @@ use http_body::{combinators::BoxBody, Body as _};
use pin_project::pin_project; use pin_project::pin_project;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{ use std::{
convert::Infallible,
future::Future, future::Future,
marker::PhantomData, marker::PhantomData,
pin::Pin, pin::Pin,
@ -161,6 +162,12 @@ pub enum Error {
ResponseBody(#[source] BoxError), ResponseBody(#[source] BoxError),
} }
impl From<Infallible> for Error {
fn from(err: Infallible) -> Self {
match err {}
}
}
// TODO(david): make this trait sealed // TODO(david): make this trait sealed
#[async_trait] #[async_trait]
pub trait Handler<Out> { pub trait Handler<Out> {
@ -360,7 +367,7 @@ pub struct EmptyRouter(());
impl<R> Service<R> for EmptyRouter { impl<R> Service<R> for EmptyRouter {
type Response = Response<Body>; type Response = Response<Body>;
type Error = Error; type Error = Infallible;
type Future = future::Ready<Result<Self::Response, Self::Error>>; type Future = future::Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -414,27 +421,29 @@ impl RouteSpec {
impl<H, F, HB, FB> Service<Request<Body>> for Route<H, F> impl<H, F, HB, FB> Service<Request<Body>> for Route<H, F>
where where
H: Service<Request<Body>, Response = Response<HB>, Error = Error>, H: Service<Request<Body>, Response = Response<HB>>,
F: Service<Request<Body>, Response = Response<FB>, Error = Error>, H::Error: Into<Error>,
HB: http_body::Body + Send + Sync + 'static, HB: http_body::Body + Send + Sync + 'static,
HB::Error: Into<BoxError>, HB::Error: Into<BoxError>,
F: Service<Request<Body>, Response = Response<FB>>,
F::Error: Into<Error>,
FB: http_body::Body<Data = HB::Data> + Send + Sync + 'static, FB: http_body::Body<Data = HB::Data> + Send + Sync + 'static,
FB::Error: Into<BoxError>, FB::Error: Into<BoxError>,
{ {
type Response = Response<BoxBody<HB::Data, Error>>; type Response = Response<BoxBody<HB::Data, Error>>;
type Error = Error; type Error = Error;
// type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>; type Future = future::Either<BoxResponseBody<H::Future>, BoxResponseBody<F::Future>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop { loop {
if !self.handler_ready { if !self.handler_ready {
ready!(self.handler.poll_ready(cx))?; ready!(self.handler.poll_ready(cx)).map_err(Into::into)?;
self.handler_ready = true; self.handler_ready = true;
} }
if !self.fallback_ready { if !self.fallback_ready {
ready!(self.fallback.poll_ready(cx))?; ready!(self.fallback.poll_ready(cx)).map_err(Into::into)?;
self.fallback_ready = true; self.fallback_ready = true;
} }
@ -467,16 +476,17 @@ where
#[pin_project] #[pin_project]
pub struct BoxResponseBody<F>(#[pin] F); pub struct BoxResponseBody<F>(#[pin] F);
impl<F, B> Future for BoxResponseBody<F> impl<F, B, E> Future for BoxResponseBody<F>
where where
F: Future<Output = Result<Response<B>, Error>>, F: Future<Output = Result<Response<B>, E>>,
E: Into<Error>,
B: http_body::Body + Send + Sync + 'static, B: http_body::Body + Send + Sync + 'static,
B::Error: Into<BoxError>, B::Error: Into<BoxError>,
{ {
type Output = Result<Response<BoxBody<B::Data, Error>>, Error>; type Output = Result<Response<BoxBody<B::Data, Error>>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let response: Response<B> = ready!(self.project().0.poll(cx))?; let response: Response<B> = ready!(self.project().0.poll(cx)).map_err(Into::into)?;
let response = let response =
response.map(|body| body.map_err(|err| Error::ResponseBody(err.into())).boxed()); response.map(|body| body.map_err(|err| Error::ResponseBody(err.into())).boxed());
Poll::Ready(Ok(response)) Poll::Ready(Ok(response))
@ -486,6 +496,7 @@ where
impl<R, T> Service<T> for App<R> impl<R, T> Service<T> for App<R>
where where
R: Service<T>, R: Service<T>,
R::Error: Into<Error>,
{ {
type Response = R::Response; type Response = R::Response;
type Error = R::Error; type Error = R::Error;
@ -493,11 +504,13 @@ where
#[inline] #[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// TODO(david): map error to response
self.router.poll_ready(cx) self.router.poll_ready(cx)
} }
#[inline] #[inline]
fn call(&mut self, req: T) -> Self::Future { fn call(&mut self, req: T) -> Self::Future {
// TODO(david): map error to response
self.router.call(req) self.router.call(req)
} }
} }
@ -505,6 +518,7 @@ where
impl<R, T> Service<T> for RouteBuilder<R> impl<R, T> Service<T> for RouteBuilder<R>
where where
App<R>: Service<T>, App<R>: Service<T>,
<App<R> as Service<T>>::Error: Into<Error>,
{ {
type Response = <App<R> as Service<T>>::Response; type Response = <App<R> as Service<T>>::Response;
type Error = <App<R> as Service<T>>::Error; type Error = <App<R> as Service<T>>::Error;
@ -512,11 +526,13 @@ where
#[inline] #[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// TODO(david): map error to response
self.app.poll_ready(cx) self.app.poll_ready(cx)
} }
#[inline] #[inline]
fn call(&mut self, req: T) -> Self::Future { fn call(&mut self, req: T) -> Self::Future {
// TODO(david): map error to response
self.app.call(req) self.app.call(req)
} }
} }
@ -532,41 +548,84 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn basic() { async fn basic() {
#[derive(Debug, Deserialize)]
struct Pagination {
page: usize,
per_page: usize,
}
#[derive(Debug, Deserialize)]
struct UsersCreate {
username: String,
}
let mut app = app() let mut app = app()
.at("/") .at("/")
.get(root) .get(|_: Request<Body>| async {
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
})
.at("/users") .at("/users")
.get(users_index) .get(|_: Request<Body>, pagination: Query<Pagination>| async {
.post(users_create); let pagination = pagination.into_inner();
assert_eq!(pagination.page, 1);
assert_eq!(pagination.per_page, 30);
let req = Request::builder() Ok::<_, Error>(Response::new(Body::from("users#index")))
})
.post(|_: Request<Body>, payload: Json<UsersCreate>| async {
let payload = payload.into_inner();
assert_eq!(payload.username, "bob");
Ok::<_, Error>(Response::new(Body::from("users#create")))
});
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.uri("/")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "Hello, World!");
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::GET)
.uri("/users?page=1&per_page=30")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(body_to_string(res).await, "users#index");
let res = app
.ready()
.await
.unwrap()
.call(
Request::builder()
.method(Method::POST) .method(Method::POST)
.uri("/users") .uri("/users")
.body(Body::from(r#"{ "username": "bob" }"#)) .body(Body::from(r#"{ "username": "bob" }"#))
.unwrap(),
)
.await
.unwrap(); .unwrap();
assert_eq!(res.status(), StatusCode::OK);
let res = app.ready().await.unwrap().call(req).await.unwrap(); assert_eq!(body_to_string(res).await, "users#create");
let body = body_to_string(res).await;
dbg!(&body);
}
#[allow(dead_code)]
// this should just compile
async fn compatible_with_hyper_and_tower_http() {
let app = app()
.at("/")
.get(root)
.at("/users")
.get(users_index)
.post(users_create);
let app = ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.service(app);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let server = Server::bind(&addr).serve(Shared::new(app));
server.await.unwrap();
} }
async fn body_to_string<B>(res: Response<B>) -> String async fn body_to_string<B>(res: Response<B>) -> String
@ -578,34 +637,19 @@ mod tests {
String::from_utf8(bytes.to_vec()).unwrap() String::from_utf8(bytes.to_vec()).unwrap()
} }
async fn root(req: Request<Body>) -> Result<Response<Body>, Error> { #[allow(dead_code)]
Ok(Response::new(Body::from("Hello, World!"))) // this should just compile
} async fn compatible_with_hyper_and_tower_http() {
let app = app().at("/").get(|_: Request<Body>| async {
Ok::<_, Error>(Response::new(Body::from("Hello, World!")))
});
async fn users_index( let app = ServiceBuilder::new()
req: Request<Body>, .layer(TraceLayer::new_for_http())
pagination: Query<Pagination>, .service(app);
) -> Result<Response<Body>, Error> {
dbg!(pagination.into_inner());
Ok(Response::new(Body::from("users#index")))
}
#[derive(Debug, Deserialize)] let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
struct Pagination { let server = Server::bind(&addr).serve(Shared::new(app));
page: usize, server.await.unwrap();
per_page: usize,
}
async fn users_create(
req: Request<Body>,
payload: Json<UsersCreate>,
) -> Result<Response<Body>, Error> {
dbg!(payload.into_inner());
Ok(Response::new(Body::from("users#create")))
}
#[derive(Debug, Deserialize)]
struct UsersCreate {
username: String,
} }
} }