A bit of clean up

This commit is contained in:
David Pedersen 2021-05-30 00:52:04 +02:00
parent 07294378b3
commit 33f2e5f661

View File

@ -29,7 +29,7 @@ Tests
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures_util::future; use futures_util::{future, ready};
use http::{Method, Request, Response, StatusCode}; use http::{Method, Request, Response, StatusCode};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{ use std::{
@ -47,26 +47,31 @@ pub fn app() -> App<EmptyRouter> {
} }
} }
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct App<R> { pub struct App<R> {
router: R, router: R,
} }
impl<R> App<R> { impl<R> App<R> {
pub fn at(self, route_spec: &str) -> RouteBuilder<R> { pub fn at(self, route_spec: &str) -> RouteAt<R> {
RouteBuilder { self.at_bytes(Bytes::copy_from_slice(route_spec.as_bytes()))
}
fn at_bytes(self, route_spec: Bytes) -> RouteAt<R> {
RouteAt {
app: self, app: self,
route_spec: Bytes::copy_from_slice(route_spec.as_bytes()), route_spec,
} }
} }
} }
pub struct RouteBuilder<R> { #[derive(Debug, Clone)]
pub struct RouteAt<R> {
app: App<R>, app: App<R>,
route_spec: Bytes, route_spec: Bytes,
} }
impl<R> RouteBuilder<R> { impl<R> RouteAt<R> {
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>> pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
where where
F: Handler<T>, F: Handler<T>,
@ -81,14 +86,6 @@ impl<R> RouteBuilder<R> {
self.add_route(handler_fn, Method::POST) self.add_route(handler_fn, Method::POST)
} }
pub fn at(self, route_spec: &str) -> Self {
self.app.at(route_spec)
}
pub fn into_service(self) -> App<R> {
self.app
}
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>> fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
where where
H: Handler<T>, H: Handler<T>,
@ -104,6 +101,8 @@ impl<R> RouteBuilder<R> {
spec: self.route_spec.clone(), spec: self.route_spec.clone(),
}, },
fallback: self.app.router, fallback: self.app.router,
handler_ready: false,
fallback_ready: false,
}, },
}; };
@ -114,9 +113,47 @@ impl<R> RouteBuilder<R> {
} }
} }
#[derive(Clone)]
pub struct RouteBuilder<R> {
app: App<R>,
route_spec: Bytes,
}
impl<R> RouteBuilder<R> {
pub fn at(self, route_spec: &str) -> RouteAt<R> {
self.app.at(route_spec)
}
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
where
F: Handler<T>,
{
self.app.at_bytes(self.route_spec).get(handler_fn)
}
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
where
F: Handler<T>,
{
self.app.at_bytes(self.route_spec).post(handler_fn)
}
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[non_exhaustive] #[non_exhaustive]
pub enum Error {} pub enum Error {
#[error("failed to deserialize the request body")]
DeserializeRequestBody(#[source] serde_json::Error),
#[error("failed to consume the body")]
ConsumeBody(#[source] hyper::Error),
#[error("URI contained no query string")]
QueryStringMissing,
#[error("failed to deserialize query string")]
DeserializeQueryString(#[from] serde_urlencoded::de::Error),
}
#[async_trait] #[async_trait]
pub trait Handler<Out> { pub trait Handler<Out> {
@ -136,38 +173,50 @@ where
} }
} }
macro_rules! impl_handler {
( $head:ident $(,)? ) => {
#[async_trait] #[async_trait]
#[allow(non_snake_case)] #[allow(non_snake_case)]
impl<F, Fut, T1> Handler<(T1,)> for F impl<F, Fut, $head> Handler<($head,)> for F
where where
F: Fn(Request<Body>, T1) -> Fut + Send + Sync, F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
Fut: Future<Output = Result<Response<Body>, Error>> + Send, Fut: Future<Output = Result<Response<Body>, Error>> + Send,
T1: FromRequest + Send, $head: FromRequest + Send,
{ {
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> { async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
let T1 = T1::from_request(&mut req).await; let $head = $head::from_request(&mut req).await?;
let res = self(req, T1).await?; let res = self(req, $head).await?;
Ok(res)
}
}
};
( $head:ident, $($tail:ident),* $(,)? ) => {
#[async_trait]
#[allow(non_snake_case)]
impl<F, Fut, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F
where
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
Fut: Future<Output = Result<Response<Body>, Error>> + Send,
$head: FromRequest + Send,
$( $tail: FromRequest + Send, )*
{
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
let $head = $head::from_request(&mut req).await?;
$(
let $tail = $tail::from_request(&mut req).await?;
)*
let res = self(req, $head, $($tail,)*).await?;
Ok(res) Ok(res)
} }
} }
#[async_trait] impl_handler!($($tail,)*);
#[allow(non_snake_case)] };
impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
where
F: Fn(Request<Body>, T1, T2) -> Fut + Send + Sync,
Fut: Future<Output = Result<Response<Body>, Error>> + Send,
T1: FromRequest + Send,
T2: FromRequest + Send,
{
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
let T1 = T1::from_request(&mut req).await;
let T2 = T2::from_request(&mut req).await;
let res = self(req, T1, T2).await?;
Ok(res)
}
} }
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
pub struct HandlerSvc<H, T> { pub struct HandlerSvc<H, T> {
handler: H, handler: H,
_input: PhantomData<fn() -> T>, _input: PhantomData<fn() -> T>,
@ -206,83 +255,71 @@ where
#[async_trait] #[async_trait]
pub trait FromRequest: Sized { pub trait FromRequest: Sized {
async fn from_request(req: &mut Request<Body>) -> Self; async fn from_request(req: &mut Request<Body>) -> Result<Self, Error>;
} }
pub struct Query<T>(Result<T, QueryError>); #[async_trait]
impl<T> FromRequest for Option<T>
where
T: FromRequest,
{
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
Ok(T::from_request(req).await.ok())
}
}
#[derive(Debug, Clone, Copy)]
pub struct Query<T>(T);
impl<T> Query<T> { impl<T> Query<T> {
pub fn into_inner(self) -> Result<T, QueryError> { pub fn into_inner(self) -> T {
self.0 self.0
} }
} }
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum QueryError {
#[error("URI contained no query string")]
Missing,
#[error("failed to deserialize query string")]
Deserialize(#[from] serde_urlencoded::de::Error),
}
#[async_trait] #[async_trait]
impl<T> FromRequest for Query<T> impl<T> FromRequest for Query<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
async fn from_request(req: &mut Request<Body>) -> Self { async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
let result = (|| { let query = req.uri().query().ok_or(Error::QueryStringMissing)?;
let query = req.uri().query().ok_or(QueryError::Missing)?;
let value = serde_urlencoded::from_str(query)?; let value = serde_urlencoded::from_str(query)?;
Ok(value) Ok(Query(value))
})();
Query(result)
} }
} }
pub struct Json<T>(Result<T, JsonError>); #[derive(Debug, Clone, Copy)]
pub struct Json<T>(T);
impl<T> Json<T> { impl<T> Json<T> {
pub fn into_inner(self) -> Result<T, JsonError> { pub fn into_inner(self) -> T {
self.0 self.0
} }
} }
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum JsonError {
#[error("failed to consume the body")]
ConsumeBody(#[from] hyper::Error),
#[error("failed to deserialize the body")]
Deserialize(#[from] serde_json::Error),
}
#[async_trait] #[async_trait]
impl<T> FromRequest for Json<T> impl<T> FromRequest for Json<T>
where where
T: DeserializeOwned, T: DeserializeOwned,
{ {
async fn from_request(req: &mut Request<Body>) -> Self { async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
// TODO(david): require the body to have `content-type: application/json` // TODO(david): require the body to have `content-type: application/json`
let body = std::mem::take(req.body_mut()); let body = std::mem::take(req.body_mut());
let result = async move { let bytes = hyper::body::to_bytes(body)
let bytes = hyper::body::to_bytes(body).await?; .await
let value = serde_json::from_slice(&bytes)?; .map_err(Error::ConsumeBody)?;
Ok(value) let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
} Ok(Json(value))
.await;
Json(result)
} }
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub struct EmptyRouter(()); pub struct EmptyRouter(());
impl Service<Request<Body>> for EmptyRouter { impl<R> Service<R> for EmptyRouter {
type Response = Response<Body>; type Response = Response<Body>;
type Error = Error; type Error = Error;
type Future = future::Ready<Result<Self::Response, Self::Error>>; type Future = future::Ready<Result<Self::Response, Self::Error>>;
@ -291,7 +328,7 @@ impl Service<Request<Body>> for EmptyRouter {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, _req: Request<Body>) -> Self::Future { fn call(&mut self, _req: R) -> Self::Future {
let mut res = Response::new(Body::empty()); let mut res = Response::new(Body::empty());
*res.status_mut() = StatusCode::NOT_FOUND; *res.status_mut() = StatusCode::NOT_FOUND;
future::ready(Ok(res)) future::ready(Ok(res))
@ -303,6 +340,8 @@ pub struct Route<H, F> {
handler: H, handler: H,
route_spec: RouteSpec, route_spec: RouteSpec,
fallback: F, fallback: F,
handler_ready: bool,
fallback_ready: bool,
} }
#[derive(Clone)] #[derive(Clone)]
@ -320,51 +359,76 @@ impl RouteSpec {
impl<H, F> Service<Request<Body>> for Route<H, F> impl<H, F> Service<Request<Body>> for Route<H, F>
where where
H: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone + Send + 'static, H: Service<Request<Body>, Response = Response<Body>, Error = Error>,
H::Future: Send, F: Service<Request<Body>, Response = Response<Body>, Error = Error>,
F: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone + Send + 'static,
F::Future: Send,
{ {
type Response = Response<Body>; type Response = Response<Body>;
type Error = Error; type Error = Error;
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = future::Either<H::Future, F::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if !self.handler_ready {
ready!(self.handler.poll_ready(cx))?;
self.handler_ready = true;
}
if !self.fallback_ready {
ready!(self.fallback.poll_ready(cx))?;
self.fallback_ready = true;
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// TODO(david): do we need to drive readiness in `call`?
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: Request<Body>) -> Self::Future { fn call(&mut self, req: Request<Body>) -> Self::Future {
if self.route_spec.matches(&req) { if self.route_spec.matches(&req) {
let handler_clone = self.handler.clone(); self.handler_ready = false;
let mut handler = std::mem::replace(&mut self.handler, handler_clone); future::Either::Left(self.handler.call(req))
Box::pin(async move { handler.ready().await?.call(req).await })
} else { } else {
let fallback_clone = self.fallback.clone(); self.fallback_ready = false;
let mut fallback = std::mem::replace(&mut self.fallback, fallback_clone); future::Either::Right(self.fallback.call(req))
Box::pin(async move { fallback.ready().await?.call(req).await })
} }
} }
} }
impl<R> Service<Request<Body>> for App<R> impl<R, T> Service<T> for App<R>
where where
R: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone, R: Service<T>,
{ {
type Response = Response<Body>; type Response = R::Response;
type Error = Error; type Error = R::Error;
type Future = R::Future; type Future = R::Future;
// TODO(david): handle backpressure #[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.router.poll_ready(cx) self.router.poll_ready(cx)
} }
fn call(&mut self, req: Request<Body>) -> Self::Future { #[inline]
fn call(&mut self, req: T) -> Self::Future {
self.router.call(req) self.router.call(req)
} }
} }
impl<R, T> Service<T> for RouteBuilder<R>
where
App<R>: Service<T>,
{
type Response = <App<R> as Service<T>>::Response;
type Error = <App<R> as Service<T>>::Error;
type Future = <App<R> as Service<T>>::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.app.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: T) -> Self::Future {
self.app.call(req)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#![allow(warnings)] #![allow(warnings)]
@ -377,8 +441,7 @@ mod tests {
.get(root) .get(root)
.at("/users") .at("/users")
.get(users_index) .get(users_index)
.post(users_create) .post(users_create);
.into_service();
let req = Request::builder() let req = Request::builder()
.method(Method::POST) .method(Method::POST)