diff --git a/README.md b/README.md index b00fa977..6c00ca4d 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is *not* https://github.com/carllerche/tower-web even though the name is the same. Its just a prototype of a minimal HTTP framework I've been toying -with. +with. Will probably change the name to something else. # What is this? diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index bbd2a0d4..9a2edd36 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -1,9 +1,6 @@ -#![allow(warnings)] - use bytes::Bytes; -use http::{Request, Response, StatusCode}; +use http::{Request, StatusCode}; use hyper::Server; -use serde::Deserialize; use std::{ collections::HashMap, net::SocketAddr, @@ -14,12 +11,7 @@ use tower::{make::Shared, ServiceBuilder}; use tower_http::{ add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer, }; -use tower_web::{ - body::Body, - extract, - response::{self, IntoResponse}, - Error, -}; +use tower_web::{body::Body, extract}; #[tokio::main] async fn main() { @@ -59,7 +51,11 @@ async fn get( _req: Request, params: extract::UrlParams<(String,)>, state: extract::Extension, -) -> Result { + // Anything that implements `IntoResponse` can be used a response + // + // Handlers cannot return errors. Everything will be converted + // into a response. `BoxError` becomes `500 Internal server error` +) -> Result { let state = state.into_inner(); let db = &state.lock().unwrap().db; @@ -68,7 +64,7 @@ async fn get( if let Some(value) = db.get(&key) { Ok(value.clone()) } else { - Err(NotFound) + Err(StatusCode::NOT_FOUND) } } @@ -77,6 +73,8 @@ async fn set( params: extract::UrlParams<(String,)>, value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb state: extract::Extension, + // `()` also implements `IntoResponse` so we can use that to return + // an empty response ) { let state = state.into_inner(); let db = &mut state.lock().unwrap().db; @@ -84,16 +82,5 @@ async fn set( let key = params.into_inner(); let value = value.into_inner(); - db.insert(key.to_string(), value); -} - -struct NotFound; - -impl IntoResponse for NotFound { - fn into_response(self) -> Response { - Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap() - } + db.insert(key, value); } diff --git a/src/body.rs b/src/body.rs index 0075fdfc..f1a30085 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,4 +1,5 @@ use bytes::Buf; +use futures_util::ready; use http_body::{Body as _, Empty}; use std::{ fmt, @@ -8,6 +9,8 @@ use std::{ pub use hyper::body::Body; +use crate::BoxStdError; + /// A boxed [`Body`] trait object. pub struct BoxBody { inner: Pin + Send + Sync + 'static>>, @@ -42,25 +45,34 @@ impl fmt::Debug for BoxBody { } } +// when we've gotten rid of `BoxStdError` then we can remove this impl http_body::Body for BoxBody where D: Buf, + E: Into, { type Data = D; - type Error = E; + type Error = BoxStdError; fn poll_data( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - self.inner.as_mut().poll_data(cx) + match ready!(self.inner.as_mut().poll_data(cx)) { + Some(Ok(chunk)) => Some(Ok(chunk)).into(), + Some(Err(err)) => Some(Err(BoxStdError(err.into()))).into(), + None => None.into(), + } } fn poll_trailers( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>> { - self.inner.as_mut().poll_trailers(cx) + match ready!(self.inner.as_mut().poll_trailers(cx)) { + Ok(trailers) => Ok(trailers).into(), + Err(err) => Err(BoxStdError(err.into())).into(), + } } fn is_end_stream(&self) -> bool { @@ -71,3 +83,14 @@ where self.inner.size_hint() } } + +impl From for BoxBody { + fn from(s: String) -> Self { + let body = hyper::Body::from(s); + let body = body.map_err(Into::::into); + + BoxBody { + inner: Box::pin(body), + } + } +} diff --git a/src/error.rs b/src/error.rs deleted file mode 100644 index 98e3f0ca..00000000 --- a/src/error.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::convert::Infallible; - -use http::{Response, StatusCode}; -use tower::BoxError; - -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum Error { - #[error("failed to deserialize the request body")] - DeserializeRequestBody(#[source] serde_json::Error), - - #[error("failed to serialize the response body")] - SerializeResponseBody(#[source] serde_json::Error), - - #[error("failed to consume the body")] - ConsumeRequestBody(#[source] hyper::Error), - - #[error("URI contained no query string")] - QueryStringMissing, - - #[error("failed to deserialize query string")] - DeserializeQueryString(#[source] serde_urlencoded::de::Error), - - #[error("failed generating the response body")] - ResponseBody(#[source] BoxError), - - #[error("some dynamic error happened")] - Dynamic(#[source] BoxError), - - #[error("request extension of type `{type_name}` was not set")] - MissingExtension { type_name: &'static str }, - - #[error("`Content-Length` header is missing but was required")] - LengthRequired, - - #[error("response body was too large")] - PayloadTooLarge, - - #[error("response failed with status {0}")] - Status(StatusCode), - - #[error("invalid URL param. Expected something of type `{type_name}`")] - InvalidUrlParam { type_name: &'static str }, - - #[error("unknown URL param `{0}`")] - UnknownUrlParam(String), - - #[error("response body didn't contain valid UTF-8")] - InvalidUtf8, -} - -impl From for Error { - fn from(err: BoxError) -> Self { - match err.downcast::() { - Ok(err) => *err, - Err(err) => Error::Dynamic(err), - } - } -} - -impl From for Error { - fn from(err: Infallible) -> Self { - match err {} - } -} - -pub(crate) fn handle_error(error: Error) -> Result, Error> -where - B: Default, -{ - fn make_response(status: StatusCode) -> Result, Error> - where - B: Default, - { - let mut res = Response::new(B::default()); - *res.status_mut() = status; - Ok(res) - } - - match error { - Error::DeserializeRequestBody(_) - | Error::QueryStringMissing - | Error::DeserializeQueryString(_) - | Error::InvalidUrlParam { .. } - | Error::InvalidUtf8 => make_response(StatusCode::BAD_REQUEST), - - Error::Status(status) => make_response(status), - - Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED), - Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE), - - Error::MissingExtension { .. } - | Error::SerializeResponseBody(_) - | Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR), - - Error::Dynamic(err) => match err.downcast::() { - Ok(err) => Err(*err), - Err(err) => Err(Error::Dynamic(err)), - }, - - err @ Error::ConsumeRequestBody(_) => Err(err), - err @ Error::ResponseBody(_) => Err(err), - } -} diff --git a/src/extract.rs b/src/extract.rs index 9f0aa718..fd70e254 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,7 +1,6 @@ use crate::{ body::Body, response::{BoxIntoResponse, IntoResponse}, - Error, }; use async_trait::async_trait; use bytes::Bytes; @@ -315,21 +314,15 @@ define_rejection! { pub struct UrlParamsMap(HashMap); impl UrlParamsMap { - pub fn get(&self, key: &str) -> Result<&str, Error> { - if let Some(value) = self.0.get(key) { - Ok(value) - } else { - Err(Error::UnknownUrlParam(key.to_string())) - } + pub fn get(&self, key: &str) -> Option<&str> { + self.0.get(key).map(|s| &**s) } - pub fn get_typed(&self, key: &str) -> Result + pub fn get_typed(&self, key: &str) -> Option where T: FromStr, { - self.get(key)?.parse().map_err(|_| Error::InvalidUrlParam { - type_name: std::any::type_name::(), - }) + self.get(key)?.parse().ok() } } diff --git a/src/handler.rs b/src/handler.rs index e26a303e..0f36a260 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,4 +1,4 @@ -use crate::{body::Body, error::Error, extract::FromRequest, response::IntoResponse}; +use crate::{body::Body, extract::FromRequest, response::IntoResponse}; use async_trait::async_trait; use futures_util::future; use http::{Request, Response}; @@ -8,7 +8,7 @@ use std::{ marker::PhantomData, task::{Context, Poll}, }; -use tower::{BoxError, Layer, Service, ServiceExt}; +use tower::{Layer, Service, ServiceExt}; mod sealed { pub trait HiddentTrait {} @@ -30,6 +30,8 @@ pub trait Handler: Sized { fn layer(self, layer: L) -> Layered where L: Layer>, + >>::Service: Service>, + <>>::Service as Service>>::Error: IntoResponse, { Layered::new(layer.layer(HandlerSvc::new(self))) } diff --git a/src/lib.rs b/src/lib.rs index 8f815236..3fb265ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use futures_util::ready; use http::Response; use pin_project::pin_project; use std::{ + convert::Infallible, future::Future, pin::Pin, task::{Context, Poll}, @@ -19,13 +20,9 @@ pub mod handler; pub mod response; pub mod routing; -mod error; - #[cfg(test)] mod tests; -pub use self::error::Error; - pub fn app() -> App { App { service_tree: AlwaysNotFound(()), @@ -52,7 +49,6 @@ impl App { pub struct IntoService { app: App, - poll_ready_error: Option, } impl Clone for IntoService @@ -62,71 +58,67 @@ where fn clone(&self) -> Self { Self { app: self.app.clone(), - poll_ready_error: None, } } } impl Service for IntoService where - R: Service>, - R::Error: Into, + R: Service, Error = Infallible>, B: Default, { type Response = Response; - type Error = Error; - type Future = HandleErrorFuture; + type Error = Infallible; + type Future = HandleErrorFuture; - #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Err(err) = ready!(self.app.service_tree.poll_ready(cx)).map_err(Into::into) { - self.poll_ready_error = Some(err); + match ready!(self.app.service_tree.poll_ready(cx)) { + Ok(_) => Poll::Ready(Ok(())), + Err(err) => match err {}, } - - Poll::Ready(Ok(())) } fn call(&mut self, req: T) -> Self::Future { - if let Some(poll_ready_error) = self.poll_ready_error.take() { - match error::handle_error::(poll_ready_error) { - Ok(res) => { - return HandleErrorFuture(Kind::Response(Some(res))); - } - Err(err) => { - return HandleErrorFuture(Kind::Error(Some(err))); - } - } - } - HandleErrorFuture(Kind::Future(self.app.service_tree.call(req))) + HandleErrorFuture(self.app.service_tree.call(req)) } } #[pin_project] -pub struct HandleErrorFuture(#[pin] Kind); +pub struct HandleErrorFuture(#[pin] F); -#[pin_project(project = KindProj)] -enum Kind { - Response(Option>), - Error(Option), - Future(#[pin] F), -} - -impl Future for HandleErrorFuture +impl Future for HandleErrorFuture where - F: Future, E>>, - E: Into, + F: Future, Infallible>>, B: Default, { - type Output = Result, Error>; + type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project().0.project() { - KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())), - KindProj::Error(err) => Poll::Ready(Err(err.take().unwrap())), - KindProj::Future(fut) => match ready!(fut.poll(cx)) { - Ok(res) => Poll::Ready(Ok(res)), - Err(err) => Poll::Ready(error::handle_error(err.into())), - }, + self.project().0.poll(cx) + } +} + +pub(crate) trait ResultExt { + fn unwrap_infallible(self) -> T; +} + +impl ResultExt for Result { + fn unwrap_infallible(self) -> T { + match self { + Ok(value) => value, + Err(err) => match err {}, } } } + +// work around for `BoxError` not implementing `std::error::Error` +// +// This is currently required since tower-http's Compression middleware's body type's +// error only implements error when the inner error type does: +// https://github.com/tower-rs/tower-http/blob/master/tower-http/src/lib.rs#L310 +// +// Fixing that is a breaking change to tower-http so we should wait a bit, but should +// totally fix it at some point. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct BoxStdError(#[source] tower::BoxError); diff --git a/src/response.rs b/src/response.rs index c1580a05..5935d70d 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,6 +1,6 @@ use crate::Body; use bytes::Bytes; -use http::{header, HeaderValue, Response, StatusCode}; +use http::{HeaderMap, HeaderValue, Response, StatusCode, header}; use serde::Serialize; use std::convert::Infallible; use tower::{util::Either, BoxError}; @@ -176,9 +176,45 @@ impl IntoResponse for BoxIntoResponse { impl IntoResponse for BoxError { fn into_response(self) -> Response { + // TODO(david): test for know error types like std::io::Error + // or common errors types from tower and map those more appropriately + Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(self.to_string())) .unwrap() } } + +impl IntoResponse for StatusCode +where + B: Default, +{ + fn into_response(self) -> Response { + Response::builder().status(self).body(B::default()).unwrap() + } +} + +impl IntoResponse for (StatusCode, T) +where + T: Into, +{ + fn into_response(self) -> Response { + Response::builder() + .status(self.0) + .body(self.1.into()) + .unwrap() + } +} + +impl IntoResponse for (StatusCode, HeaderMap, T) +where + T: Into, +{ + fn into_response(self) -> Response { + let mut res = Response::new(self.2.into()); + *res.status_mut() = self.0; + *res.headers_mut() = self.1; + res + } +} diff --git a/src/routing.rs b/src/routing.rs index 8a0abdab..77364959 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -1,8 +1,7 @@ use crate::{ body::{Body, BoxBody}, - error::Error, handler::{Handler, HandlerSvc}, - App, IntoService, + App, IntoService, ResultExt, }; use bytes::Bytes; use futures_util::{future, ready}; @@ -164,17 +163,18 @@ impl RouteBuilder { } pub fn into_service(self) -> IntoService { - IntoService { - app: self.app, - poll_ready_error: None, - } + IntoService { app: self.app } } + // TODO(david): Add `layer` method here that applies a `tower::Layer` inside the service tree + // that way we get to map errors + pub fn boxed(self) -> RouteBuilder> where - R: Service, Response = Response, Error = Error> + Send + 'static, + R: Service, Response = Response, Error = Infallible> + Send + 'static, R::Future: Send, - B: Default + 'static, + // TODO(david): do we still need default here + B: Default + From + 'static, { let svc = ServiceBuilder::new() .layer(BufferLayer::new(1024)) @@ -182,7 +182,10 @@ impl RouteBuilder { .service(self.app.service_tree); let app = App { - service_tree: BoxServiceTree { inner: svc }, + service_tree: BoxServiceTree { + inner: svc, + poll_ready_error: None, + }, }; RouteBuilder { @@ -270,29 +273,27 @@ impl RouteSpec { impl Service> for Or where - H: Service, Response = Response>, - H::Error: Into, + H: Service, Response = Response, Error = Infallible>, HB: http_body::Body + Send + Sync + 'static, HB::Error: Into, - F: Service, Response = Response>, - F::Error: Into, + F: Service, Response = Response, Error = Infallible>, FB: http_body::Body + Send + Sync + 'static, FB::Error: Into, { - type Response = Response>; - type Error = Error; + type Response = Response>; + type Error = Infallible; type Future = future::Either, BoxResponseBody>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if !self.handler_ready { - ready!(self.service.poll_ready(cx)).map_err(Into::into)?; + ready!(self.service.poll_ready(cx)).unwrap_infallible(); self.handler_ready = true; } if !self.fallback_ready { - ready!(self.fallback.poll_ready(cx)).map_err(Into::into)?; + ready!(self.fallback.poll_ready(cx)).unwrap_infallible(); self.fallback_ready = true; } @@ -333,20 +334,18 @@ pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>); #[pin_project] pub struct BoxResponseBody(#[pin] F); -impl Future for BoxResponseBody +impl Future for BoxResponseBody where - F: Future, E>>, - E: Into, + F: Future, Infallible>>, B: http_body::Body + Send + Sync + 'static, B::Error: Into, { - type Output = Result>, Error>; + type Output = Result>, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let response: Response = ready!(self.project().0.poll(cx)).map_err(Into::into)?; + let response: Response = ready!(self.project().0.poll(cx)).unwrap_infallible(); let response = response.map(|body| { - // TODO(david): attempt to downcast this into `Error` - let body = body.map_err(|err| Error::ResponseBody(err.into())); + let body = body.map_err(Into::into); BoxBody::new(body) }); Poll::Ready(Ok(response)) @@ -354,13 +353,15 @@ where } pub struct BoxServiceTree { - inner: Buffer, Response, Error>, Request>, + inner: Buffer, Response, Infallible>, Request>, + poll_ready_error: Option, } impl Clone for BoxServiceTree { fn clone(&self) -> Self { Self { inner: self.inner.clone(), + poll_ready_error: None, } } } @@ -373,21 +374,36 @@ impl fmt::Debug for BoxServiceTree { impl Service> for BoxServiceTree where - B: 'static, + B: From + 'static, { type Response = Response; - type Error = Error; + type Error = Infallible; type Future = BoxServiceTreeResponseFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Error::from) + // TODO(david): downcast this into one of the cases in `tower::buffer::error` + // and convert the error into a response. `ServiceError` should never be able to happen + // since all inner services use `Infallible` as the error type. + match ready!(self.inner.poll_ready(cx)) { + Ok(_) => Poll::Ready(Ok(())), + Err(err) => { + self.poll_ready_error = Some(err); + Poll::Ready(Ok(())) + } + } } #[inline] fn call(&mut self, req: Request) -> Self::Future { + if let Some(err) = self.poll_ready_error.take() { + return BoxServiceTreeResponseFuture { + kind: Kind::Response(Some(handle_buffer_error(err))), + }; + } + BoxServiceTreeResponseFuture { - inner: self.inner.call(req), + kind: Kind::Future(self.inner.call(req)), } } } @@ -395,24 +411,71 @@ where #[pin_project] pub struct BoxServiceTreeResponseFuture { #[pin] - inner: InnerFuture, + kind: Kind, +} + +#[pin_project(project = KindProj)] +enum Kind { + Response(Option>), + Future(#[pin] InnerFuture), } type InnerFuture = tower::buffer::future::ResponseFuture< - Pin, Error>> + Send + 'static>>, + Pin, Infallible>> + Send + 'static>>, >; -impl Future for BoxServiceTreeResponseFuture { - type Output = Result, Error>; +impl Future for BoxServiceTreeResponseFuture +where + B: From, +{ + type Output = Result, Infallible>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project() - .inner - .poll(cx) - .map_err(Error::from) + match self.project().kind.project() { + KindProj::Response(res) => Poll::Ready(Ok(res.take().unwrap())), + KindProj::Future(future) => match ready!(future.poll(cx)) { + Ok(res) => Poll::Ready(Ok(res)), + Err(err) => Poll::Ready(Ok(handle_buffer_error(err))), + }, + } } } +fn handle_buffer_error(error: BoxError) -> Response +where + B: From, +{ + use tower::buffer::error::{Closed, ServiceError}; + + let error = match error.downcast::() { + Ok(closed) => { + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(B::from(closed.to_string())) + .unwrap(); + } + Err(e) => e, + }; + + let error = match error.downcast::() { + Ok(service_error) => { + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(B::from(format!("Service error: {}. This is a bug in tower-web. All inner services should be infallible. Please file an issue", service_error))) + .unwrap(); + } + Err(e) => e, + }; + + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(B::from(format!( + "Uncountered an unknown error: {}. This should never happen. Please file an issue", + error + ))) + .unwrap() +} + #[cfg(test)] mod tests { #[allow(unused_imports)] diff --git a/src/tests.rs b/src/tests.rs index 280b0393..d332256d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -273,6 +273,8 @@ async fn boxing() { assert_eq!(res.text().await.unwrap(), "hi from POST"); } +// TODO(david): tests for adding middleware to single services + /// Run a `tower::Service` in the background and get a URI for it. pub async fn run_in_background(svc: S) -> SocketAddr where