diff --git a/Cargo.toml b/Cargo.toml index 308762c2..f8e13fc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,6 @@ tower = { version = "0.4", features = ["util"] } [dev-dependencies] tokio = { version = "1.6.1", features = ["macros", "rt"] } serde = { version = "1.0", features = ["derive"] } -tower = { version = "0.4", features = ["util", "make"] } +tower = { version = "0.4", features = ["util", "make", "timeout"] } tower-http = { version = "0.1", features = ["trace"] } hyper = { version = "0.14", features = ["full"] } diff --git a/src/lib.rs b/src/lib.rs index 4a92aae9..ea5a2cf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{BoxError, Service, ServiceExt}; +use tower::{BoxError, Layer, Service, ServiceExt}; pub use hyper::body::Body; @@ -83,6 +83,14 @@ impl RouteAt { self.add_route(handler_fn, Method::GET) } + pub fn get_service(self, service: S) -> RouteBuilder> + where + S: Service, Response = Response> + Clone, + S::Error: Into, + { + self.add_route_service(service, Method::GET) + } + pub fn post(self, handler_fn: F) -> RouteBuilder, R>> where F: Handler, @@ -90,16 +98,25 @@ impl RouteAt { self.add_route(handler_fn, Method::POST) } + pub fn post_service(self, service: S) -> RouteBuilder> + where + S: Service, Response = Response> + Clone, + S::Error: Into, + { + self.add_route_service(service, Method::POST) + } + fn add_route(self, handler: H, method: Method) -> RouteBuilder, R>> where H: Handler, { + self.add_route_service(HandlerSvc::new(handler), method) + } + + fn add_route_service(self, service: S, method: Method) -> RouteBuilder> { let new_app = App { router: Route { - handler: HandlerSvc { - handler, - _input: PhantomData, - }, + service, route_spec: RouteSpec { method, spec: self.route_spec.clone(), @@ -135,12 +152,28 @@ impl RouteBuilder { self.app.at_bytes(self.route_spec).get(handler_fn) } + pub fn get_service(self, service: S) -> RouteBuilder> + where + S: Service, Response = Response> + Clone, + S::Error: Into, + { + self.app.at_bytes(self.route_spec).get_service(service) + } + pub fn post(self, handler_fn: F) -> RouteBuilder, R>> where F: Handler, { self.app.at_bytes(self.route_spec).post(handler_fn) } + + pub fn post_service(self, service: S) -> RouteBuilder> + where + S: Service, Response = Response> + Clone, + S::Error: Into, + { + self.app.at_bytes(self.route_spec).post_service(service) + } } #[derive(Debug, thiserror::Error)] @@ -160,6 +193,18 @@ pub enum Error { #[error("failed generating the response body")] ResponseBody(#[source] BoxError), + + #[error("handler service returned an error")] + Service(#[source] BoxError), +} + +impl From for Error { + fn from(err: BoxError) -> Self { + match err.downcast::() { + Ok(err) => *err, + Err(err) => Error::Service(err), + } + } } impl From for Error { @@ -170,10 +215,17 @@ impl From for Error { // TODO(david): make this trait sealed #[async_trait] -pub trait Handler { +pub trait Handler: Sized { type ResponseBody; async fn call(self, req: Request) -> Result, Error>; + + fn layer(self, layer: L) -> Layered + where + L: Layer>, + { + Layered::new(layer.layer(HandlerSvc::new(self))) + } } #[async_trait] @@ -237,11 +289,60 @@ macro_rules! impl_handler { impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); +pub struct Layered { + svc: S, + _input: PhantomData T>, +} + +impl Clone for Layered +where + S: Clone, +{ + fn clone(&self) -> Self { + Self::new(self.svc.clone()) + } +} + +#[async_trait] +impl Handler for Layered +where + S: Service, Response = Response> + Send, + S::Error: Into, + S::Future: Send, +{ + type ResponseBody = B; + + async fn call(self, req: Request) -> Result, Error> { + self.svc + .oneshot(req) + .await + .map_err(|err| Error::from(err.into())) + } +} + +impl Layered { + fn new(svc: S) -> Self { + Self { + svc, + _input: PhantomData, + } + } +} + pub struct HandlerSvc { handler: H, _input: PhantomData T>, } +impl HandlerSvc { + fn new(handler: H) -> Self { + Self { + handler, + _input: PhantomData, + } + } +} + impl Clone for HandlerSvc where H: Clone, @@ -382,7 +483,7 @@ impl Service for EmptyRouter { } pub struct Route { - handler: H, + service: H, route_spec: RouteSpec, fallback: F, handler_ready: bool, @@ -396,7 +497,7 @@ where { fn clone(&self) -> Self { Self { - handler: self.handler.clone(), + service: self.service.clone(), fallback: self.fallback.clone(), route_spec: self.route_spec.clone(), // important to reset readiness when cloning @@ -438,7 +539,7 @@ where fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if !self.handler_ready { - ready!(self.handler.poll_ready(cx)).map_err(Into::into)?; + ready!(self.service.poll_ready(cx)).map_err(Into::into)?; self.handler_ready = true; } @@ -460,7 +561,7 @@ where "handler not ready. Did you forget to call `poll_ready`?" ); self.handler_ready = false; - future::Either::Left(BoxResponseBody(self.handler.call(req))) + future::Either::Left(BoxResponseBody(self.service.call(req))) } else { assert!( self.fallback_ready, @@ -542,8 +643,11 @@ mod tests { #![allow(warnings)] use super::*; use hyper::Server; + use std::time::Duration; use std::{fmt, net::SocketAddr}; - use tower::{make::Shared, ServiceBuilder}; + use tower::{ + layer::util::Identity, make::Shared, service_fn, timeout::TimeoutLayer, ServiceBuilder, + }; use tower_http::trace::TraceLayer; #[tokio::test] @@ -559,11 +663,13 @@ mod tests { username: String, } + async fn root(_: Request) -> Result, Error> { + Ok(Response::new(Body::from("Hello, World!"))) + } + let mut app = app() .at("/") - .get(|_: Request| async { - Ok::<_, Error>(Response::new(Body::from("Hello, World!"))) - }) + .get(root.layer(TimeoutLayer::new(Duration::from_secs(30)))) .at("/users") .get(|_: Request, pagination: Query| async { let pagination = pagination.into_inner(); @@ -577,7 +683,9 @@ mod tests { assert_eq!(payload.username, "bob"); Ok::<_, Error>(Response::new(Body::from("users#create"))) - }); + }) + .at("/service") + .get_service(service_fn(root)); let res = app .ready()