From f690e742758877525b5db76bfab6608a6a31f287 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 1 Jun 2021 11:23:56 +0200 Subject: [PATCH] Support nesting services with error handling --- Cargo.toml | 11 ++- examples/hello_world.rs | 23 +++-- examples/key_value_store.rs | 11 +-- src/body.rs | 64 ++++++-------- src/lib.rs | 164 +++++++++++++++++++++++++++++++----- src/response.rs | 2 + src/routing.rs | 23 ++--- src/tests.rs | 86 ++++++++++++++++++- 8 files changed, 288 insertions(+), 96 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a1376f2a..d417747c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,15 @@ reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread"] } tower = { version = "0.4", features = ["util", "make", "timeout"] } -tower-http = { version = "0.1", features = ["trace", "compression", "add-extension"] } tracing = "0.1" tracing-subscriber = "0.2" + +[dev-dependencies.tower-http] +version = "0.1" +features = [ + "trace", + "compression", + "compression-full", + "add-extension", + "fs", +] diff --git a/examples/hello_world.rs b/examples/hello_world.rs index b6b57b19..848ae949 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -1,9 +1,8 @@ -use http::Request; +use http::{Request, StatusCode}; use hyper::Server; use std::net::SocketAddr; -use tower::{make::Shared, ServiceBuilder}; -use tower_http::trace::TraceLayer; -use tower_web::{body::Body, response::Html}; +use tower::make::Shared; +use tower_web::{body::Body, extract, response::Html}; #[tokio::main] async fn main() { @@ -13,14 +12,11 @@ async fn main() { let app = tower_web::app() .at("/") .get(handler) + .at("/greet/:name") + .get(greet) // convert it into a `Service` .into_service(); - // add some middleware - let app = ServiceBuilder::new() - .layer(TraceLayer::new_for_http()) - .service(app); - // run it with hyper let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -31,3 +27,12 @@ async fn main() { async fn handler(_req: Request) -> Html<&'static str> { Html("

Hello, World!

") } + +async fn greet(_req: Request, params: extract::UrlParamsMap) -> Result { + if let Some(name) = params.get("name") { + Ok(format!("Hello {}!", name)) + } else { + // if the route matches "name" will be present + Err(StatusCode::INTERNAL_SERVER_ERROR) + } +} diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 9a2edd36..98c9c51d 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -11,7 +11,7 @@ use tower::{make::Shared, ServiceBuilder}; use tower_http::{ add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer, }; -use tower_web::{body::Body, extract}; +use tower_web::{body::Body, extract, handler::Handler}; #[tokio::main] async fn main() { @@ -20,7 +20,7 @@ async fn main() { // build our application with some routes let app = tower_web::app() .at("/:key") - .get(get) + .get(get.layer(CompressionLayer::new())) .post(set) // convert it into a `Service` .into_service(); @@ -29,7 +29,6 @@ async fn main() { let app = ServiceBuilder::new() .timeout(Duration::from_secs(10)) .layer(TraceLayer::new_for_http()) - .layer(CompressionLayer::new()) .layer(AddExtensionLayer::new(SharedState::default())) .service(app); @@ -51,10 +50,6 @@ async fn get( _req: Request, params: extract::UrlParams<(String,)>, state: extract::Extension, - // Anything that implements `IntoResponse` can be used a response - // - // Handlers cannot return errors. Everything will be converted - // into a response. `BoxError` becomes `500 Internal server error` ) -> Result { let state = state.into_inner(); let db = &state.lock().unwrap().db; @@ -73,8 +68,6 @@ async fn set( params: extract::UrlParams<(String,)>, value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb state: extract::Extension, - // `()` also implements `IntoResponse` so we can use that to return - // an empty response ) { let state = state.into_inner(); let db = &mut state.lock().unwrap().db; diff --git a/src/body.rs b/src/body.rs index f1a30085..b03f7c1c 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,78 +1,64 @@ -use bytes::Buf; -use futures_util::ready; -use http_body::{Body as _, Empty}; +use bytes::Bytes; +use http_body::{Empty, Full}; use std::{ fmt, pin::Pin, task::{Context, Poll}, }; +use tower::BoxError; pub use hyper::body::Body; use crate::BoxStdError; /// A boxed [`Body`] trait object. -pub struct BoxBody { - inner: Pin + Send + Sync + 'static>>, +pub struct BoxBody { + // when we've gotten rid of `BoxStdError` we should be able to change the error type to + // `BoxError` + inner: Pin + Send + Sync + 'static>>, } -impl BoxBody { +impl BoxBody { /// Create a new `BoxBody`. pub fn new(body: B) -> Self where - B: http_body::Body + Send + Sync + 'static, - D: Buf, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, { Self { - inner: Box::pin(body), + inner: Box::pin(body.map_err(|error| BoxStdError(error.into()))), } } } -// TODO: upstream this to http-body? -impl Default for BoxBody -where - D: bytes::Buf + 'static, -{ +impl Default for BoxBody { fn default() -> Self { - BoxBody::new(Empty::::new().map_err(|err| match err {})) + BoxBody::new(Empty::::new()) } } -impl fmt::Debug for BoxBody { +impl fmt::Debug for BoxBody { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BoxBody").finish() } } -// when we've gotten rid of `BoxStdError` then we can remove this -impl http_body::Body for BoxBody -where - D: Buf, - E: Into, -{ - type Data = D; +impl http_body::Body for BoxBody { + type Data = Bytes; type Error = BoxStdError; fn poll_data( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - match ready!(self.inner.as_mut().poll_data(cx)) { - Some(Ok(chunk)) => Some(Ok(chunk)).into(), - Some(Err(err)) => Some(Err(BoxStdError(err.into()))).into(), - None => None.into(), - } + self.inner.as_mut().poll_data(cx) } fn poll_trailers( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { - match ready!(self.inner.as_mut().poll_trailers(cx)) { - Ok(trailers) => Ok(trailers).into(), - Err(err) => Err(BoxStdError(err.into())).into(), - } + self.inner.as_mut().poll_trailers(cx) } fn is_end_stream(&self) -> bool { @@ -84,13 +70,11 @@ where } } -impl From for BoxBody { - fn from(s: String) -> Self { - let body = hyper::Body::from(s); - let body = body.map_err(Into::::into); - - BoxBody { - inner: Box::pin(body), - } +impl From for BoxBody +where + B: Into, +{ + fn from(s: B) -> Self { + BoxBody::new(Full::from(s.into())) } } diff --git a/src/lib.rs b/src/lib.rs index 3fb265ea..5658d43f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,17 +2,19 @@ use self::{ body::Body, routing::{AlwaysNotFound, RouteAt}, }; +use body::BoxBody; use bytes::Bytes; use futures_util::ready; -use http::Response; +use http::{Request, Response}; use pin_project::pin_project; use std::{ convert::Infallible, + fmt, future::Future, pin::Pin, task::{Context, Poll}, }; -use tower::Service; +use tower::{BoxError, Service}; pub mod body; pub mod extract; @@ -69,7 +71,7 @@ where { type Response = Response; type Error = Infallible; - type Future = HandleErrorFuture; + type Future = R::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match ready!(self.app.service_tree.poll_ready(cx)) { @@ -79,22 +81,7 @@ where } fn call(&mut self, req: T) -> Self::Future { - HandleErrorFuture(self.app.service_tree.call(req)) - } -} - -#[pin_project] -pub struct HandleErrorFuture(#[pin] F); - -impl Future for HandleErrorFuture -where - F: Future, Infallible>>, - B: Default, -{ - type Output = Result, Infallible>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().0.poll(cx) + self.app.service_tree.call(req) } } @@ -121,4 +108,141 @@ impl ResultExt for Result { // totally fix it at some point. #[derive(Debug, thiserror::Error)] #[error("{0}")] -pub struct BoxStdError(#[source] tower::BoxError); +pub struct BoxStdError(#[source] pub(crate) tower::BoxError); + +pub trait ServiceExt: Service, Response = Response> { + fn handle_error(self, f: F) -> HandleError + where + Self: Sized, + F: FnOnce(Self::Error) -> Response, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, + NewBody: http_body::Body + Send + Sync + 'static, + NewBody::Error: Into + Send + Sync + 'static, + { + HandleError { + inner: self, + f, + poll_ready_error: None, + } + } +} + +impl ServiceExt for S where S: Service, Response = Response> {} + +pub struct HandleError { + inner: S, + f: F, + poll_ready_error: Option, +} + +impl fmt::Debug for HandleError +where + S: fmt::Debug, + E: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HandleError") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .field("poll_ready_error", &self.poll_ready_error) + .finish() + } +} + +impl Clone for HandleError +where + S: Clone, + F: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + f: self.f.clone(), + poll_ready_error: None, + } + } +} + +impl Service> for HandleError +where + S: Service, Response = Response>, + F: FnOnce(S::Error) -> Response + Clone, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, + NewBody: http_body::Body + Send + Sync + 'static, + NewBody::Error: Into + Send + Sync + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = HandleErrorFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match ready!(self.inner.poll_ready(cx)) { + Ok(_) => Poll::Ready(Ok(())), + Err(err) => { + self.poll_ready_error = Some(err); + Poll::Ready(Ok(())) + } + } + } + + fn call(&mut self, req: Request) -> Self::Future { + if let Some(err) = self.poll_ready_error.take() { + return HandleErrorFuture { + f: Some(self.f.clone()), + kind: Kind::Error(Some(err)), + }; + } + + HandleErrorFuture { + f: Some(self.f.clone()), + kind: Kind::Future(self.inner.call(req)), + } + } +} + +#[pin_project] +pub struct HandleErrorFuture { + #[pin] + kind: Kind, + f: Option, +} + +#[pin_project(project = KindProj)] +enum Kind { + Future(#[pin] Fut), + Error(Option), +} + +impl Future for HandleErrorFuture +where + Fut: Future, E>>, + F: FnOnce(E) -> Response, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into + Send + Sync + 'static, + NewBody: http_body::Body + Send + Sync + 'static, + NewBody::Error: Into + Send + Sync + 'static, +{ + type Output = Result, Infallible>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + match this.kind.project() { + KindProj::Future(future) => match ready!(future.poll(cx)) { + Ok(res) => Ok(res.map(BoxBody::new)).into(), + Err(err) => { + let f = this.f.take().unwrap(); + let res = f(err); + Ok(res.map(BoxBody::new)).into() + } + }, + KindProj::Error(err) => { + let f = this.f.take().unwrap(); + let res = f(err.take().unwrap()); + Ok(res.map(BoxBody::new)).into() + } + } + } +} diff --git a/src/response.rs b/src/response.rs index 377a3364..16e9beb5 100644 --- a/src/response.rs +++ b/src/response.rs @@ -8,6 +8,8 @@ use tower::util::Either; pub trait IntoResponse { fn into_response(self) -> Response; + // TODO(david): remove this an return return `Response` instead. That is what this method + // does anyway. fn boxed(self) -> BoxIntoResponse where Self: Sized + 'static, diff --git a/src/routing.rs b/src/routing.rs index f59296ad..68d11427 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -55,8 +55,7 @@ impl RouteAt { pub fn get_service(self, service: S) -> RouteBuilder> where - S: Service, Response = Response> + Clone, - S::Error: Into, + S: Service, Response = Response, Error = Infallible> + Clone, { self.add_route_service(service, Method::GET) } @@ -70,8 +69,7 @@ impl RouteAt { pub fn post_service(self, service: S) -> RouteBuilder> where - S: Service, Response = Response> + Clone, - S::Error: Into, + S: Service, Response = Response, Error = Infallible> + Clone, { self.add_route_service(service, Method::POST) } @@ -141,8 +139,7 @@ impl RouteBuilder { pub fn get_service(self, service: S) -> RouteBuilder> where - S: Service, Response = Response> + Clone, - S::Error: Into, + S: Service, Response = Response, Error = Infallible> + Clone, { self.app.at_bytes(self.route_spec).get_service(service) } @@ -156,8 +153,7 @@ impl RouteBuilder { pub fn post_service(self, service: S) -> RouteBuilder> where - S: Service, Response = Response> + Clone, - S::Error: Into, + S: Service, Response = Response, Error = Infallible> + Clone, { self.app.at_bytes(self.route_spec).post_service(service) } @@ -173,7 +169,6 @@ impl RouteBuilder { where R: Service, Response = Response, Error = Infallible> + Send + 'static, R::Future: Send, - // TODO(david): do we still need default here B: From + 'static, { let svc = ServiceBuilder::new() @@ -274,14 +269,14 @@ impl RouteSpec { impl Service> for Or where H: Service, Response = Response, Error = Infallible>, - HB: http_body::Body + Send + Sync + 'static, + HB: http_body::Body + Send + Sync + 'static, HB::Error: Into, F: Service, Response = Response, Error = Infallible>, - FB: http_body::Body + Send + Sync + 'static, + FB: http_body::Body + Send + Sync + 'static, FB::Error: Into, { - type Response = Response>; + type Response = Response; type Error = Infallible; type Future = future::Either, BoxResponseBody>; @@ -337,10 +332,10 @@ pub struct BoxResponseBody(#[pin] F); impl Future for BoxResponseBody where F: Future, Infallible>>, - B: http_body::Body + Send + Sync + 'static, + B: http_body::Body + Send + Sync + 'static, B::Error: Into, { - type Output = Result>, Infallible>; + type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let response: Response = ready!(self.project().0.poll(cx)).unwrap_infallible(); diff --git a/src/tests.rs b/src/tests.rs index fa7e1bfe..92977bd1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,4 @@ -use crate::{app, extract}; +use crate::{app, extract, handler::Handler}; use http::{Request, Response, StatusCode}; use hyper::{Body, Server}; use serde::Deserialize; @@ -273,9 +273,89 @@ async fn boxing() { assert_eq!(res.text().await.unwrap(), "hi from POST"); } -// TODO(david): tests for adding middleware to single services +#[tokio::test] +async fn service_handlers() { + use crate::{body::BoxBody, ServiceExt as _}; + use std::convert::Infallible; + use tower::service_fn; + use tower_http::services::ServeFile; -// TODO(david): tests for nesting services + let app = app() + .at("/echo") + .post_service(service_fn(|req: Request| async move { + Ok::<_, Infallible>(Response::new(req.into_body())) + })) + // calling boxed isn't necessary here but done so + // we're sure it compiles + .boxed() + .at("/static/Cargo.toml") + .get_service( + ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| { + // `ServeFile` internally maps some errors to `404` so we don't have + // to handle those here + let body = BoxBody::from(error.to_string()); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(body) + .unwrap() + }), + ) + // calling boxed isn't necessary here but done so + // we're sure it compiles + .boxed() + .into_service(); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .post(format!("http://{}/echo", addr)) + .body("foobar") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "foobar"); + + let res = client + .get(format!("http://{}/static/Cargo.toml", addr)) + .body("foobar") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert!(res.text().await.unwrap().contains("edition =")); +} + +#[tokio::test] +async fn middleware_on_single_route() { + use tower::ServiceBuilder; + use tower_http::{compression::CompressionLayer, trace::TraceLayer}; + + async fn handle(_: Request) -> &'static str { + "Hello, World!" + } + + let app = app() + .at("/") + .get( + handle.layer( + ServiceBuilder::new() + .layer(TraceLayer::new_for_http()) + .layer(CompressionLayer::new()) + .into_inner(), + ), + ) + .into_service(); + + let addr = run_in_background(app).await; + + let res = reqwest::get(format!("http://{}", addr)).await.unwrap(); + let body = res.text().await.unwrap(); + + assert_eq!(body, "Hello, World!"); +} /// Run a `tower::Service` in the background and get a URI for it. pub async fn run_in_background(svc: S) -> SocketAddr