From 8faed8120f476d4227259f2e4c8d8c1c5ae725f2 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 22 Jul 2021 15:00:33 +0200 Subject: [PATCH] Docs improvements (#37) --- examples/key_value_store.rs | 13 +- examples/static_file_server.rs | 4 +- examples/websocket.rs | 4 +- src/extract/mod.rs | 22 ++- src/handler/mod.rs | 79 +------- src/lib.rs | 326 ++++++++++++++++----------------- src/routing.rs | 54 ++++-- src/service/future.rs | 13 +- src/service/mod.rs | 26 ++- src/tests.rs | 11 +- src/ws/mod.rs | 2 +- 11 files changed, 263 insertions(+), 291 deletions(-) diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index f10bd7fb..9d83b910 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -19,6 +19,7 @@ use http::StatusCode; use std::{ borrow::Cow, collections::HashMap, + convert::Infallible, net::SocketAddr, sync::{Arc, RwLock}, time::Duration, @@ -152,20 +153,20 @@ where } } -fn handle_error(error: BoxError) -> impl IntoResponse { +fn handle_error(error: BoxError) -> Result { if error.is::() { - return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")); + return Ok((StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out"))); } if error.is::() { - return ( + return Ok(( StatusCode::SERVICE_UNAVAILABLE, Cow::from("service is overloaded, try again later"), - ); + )); } - ( + Ok(( StatusCode::INTERNAL_SERVER_ERROR, Cow::from(format!("Unhandled internal error: {}", error)), - ) + )) } diff --git a/examples/static_file_server.rs b/examples/static_file_server.rs index 55ee29e6..fe8af5e4 100644 --- a/examples/static_file_server.rs +++ b/examples/static_file_server.rs @@ -10,10 +10,10 @@ async fn main() { let app = axum::routing::nest( "/static", axum::service::get(ServeDir::new(".").handle_error(|error: std::io::Error| { - ( + Ok::<_, std::convert::Infallible>(( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled interal error: {}", error), - ) + )) })), ) .layer(TraceLayer::new_for_http()); diff --git a/examples/websocket.rs b/examples/websocket.rs index ddc24ed3..59474958 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -33,10 +33,10 @@ async fn main() { ServeDir::new("examples/websocket") .append_index_html_on_directories(true) .handle_error(|error: std::io::Error| { - ( + Ok::<_, std::convert::Infallible>(( StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled interal error: {}", error), - ) + )) }), ), ) diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 8f25d812..a828892b 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -796,12 +796,22 @@ where type Rejection = RequestAlreadyExtracted; async fn from_request(req: &mut RequestParts) -> Result { - let all_parts = req - .method() - .zip(req.uri()) - .zip(req.headers()) - .zip(req.extensions()) - .zip(req.body()); + let RequestParts { + method, + uri, + version, + headers, + extensions, + body, + } = req; + + let all_parts = method + .as_ref() + .zip(version.as_ref()) + .zip(uri.as_ref()) + .zip(extensions.as_ref()) + .zip(body.as_ref()) + .zip(headers.as_ref()); if all_parts.is_some() { Ok(req.into_request()) diff --git a/src/handler/mod.rs b/src/handler/mod.rs index d97afaf6..6ae1f7c2 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -1,42 +1,4 @@ //! Async functions that can be used to handle requests. -//! -//! # What is a handler? -//! -//! In axum a "handler" is an async function that accepts zero or more -//! ["extractors"](crate::extract) as arguments and returns something that -//! implements [`IntoResponse`]. -//! -//! # Example -//! -//! Some examples of handlers: -//! -//! ```rust -//! use axum::prelude::*; -//! use bytes::Bytes; -//! use http::StatusCode; -//! -//! // Handler that immediately returns an empty `200 OK` response. -//! async fn unit_handler() {} -//! -//! // Handler that immediately returns an empty `200 Ok` response with a plain -//! /// text body. -//! async fn string_handler() -> String { -//! "Hello, World!".to_string() -//! } -//! -//! // Handler that buffers the request body and returns it if it is valid UTF-8 -//! async fn buffer_body(body: Bytes) -> Result { -//! if let Ok(string) = String::from_utf8(body.to_vec()) { -//! Ok(string) -//! } else { -//! Err(StatusCode::BAD_REQUEST) -//! } -//! } -//! ``` -//! -//! For more details on generating responses see the -//! [`response`](crate::response) module and for more details on extractors see -//! the [`extract`](crate::extract) module. use crate::{ body::{box_body, BoxBody}, @@ -402,49 +364,16 @@ impl Layered { /// This is used to convert errors to responses rather than simply /// terminating the connection. /// - /// `handle_error` can be used like so: + /// It works similarly to [`routing::Layered::handle_error`]. See that for more details. /// - /// ```rust - /// use axum::prelude::*; - /// use http::StatusCode; - /// use tower::{BoxError, timeout::TimeoutLayer}; - /// use std::time::Duration; - /// - /// async fn handler() { /* ... */ } - /// - /// // `Timeout` will fail with `BoxError` if the timeout elapses... - /// let layered_handler = handler - /// .layer(TimeoutLayer::new(Duration::from_secs(30))); - /// - /// // ...so we should handle that error - /// let layered_handler = layered_handler.handle_error(|error: BoxError| { - /// if error.is::() { - /// ( - /// StatusCode::REQUEST_TIMEOUT, - /// "request took too long".to_string(), - /// ) - /// } else { - /// ( - /// StatusCode::INTERNAL_SERVER_ERROR, - /// format!("Unhandled internal error: {}", error), - /// ) - /// } - /// }); - /// - /// let app = route("/", get(layered_handler)); - /// # async { - /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); - /// # }; - /// ``` - /// - /// The closure can return any type that implements [`IntoResponse`]. - pub fn handle_error( + /// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error + pub fn handle_error( self, f: F, ) -> Layered, T> where S: Service, Response = Response>, - F: FnOnce(S::Error) -> Res, + F: FnOnce(S::Error) -> Result, Res: IntoResponse, { let svc = HandleError::new(self.svc, f); diff --git a/src/lib.rs b/src/lib.rs index f4b8129f..2f981b33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,31 @@ //! axum is a web application framework that focuses on ergonomics and modularity. //! -//! ## Goals +//! # Table of contents +//! +//! - [Goals](#goals) +//! - [Compatibility](#compatibility) +//! - [Handlers](#handlers) +//! - [Routing](#routing) +//! - [Extractors](#extractors) +//! - [Building responses](#building-responses) +//! - [Applying middleware](#applying-middleware) +//! - [To individual handlers](#to-individual-handlers) +//! - [To groups of routes](#to-groups-of-routes) +//! - [Error handling](#error-handling) +//! - [Sharing state with handlers](#sharing-state-with-handlers) +//! - [Routing to any `Service`](#routing-to-any-service) +//! - [Nesting applications](#nesting-applications) +//! - [Feature flags](#feature-flags) +//! +//! # Goals //! //! - Ease of use. Building web apps in Rust should be as easy as `async fn //! handle(Request) -> Response`. //! - Solid foundation. axum is built on top of [tower] and [hyper] and makes it //! easy to plug in any middleware from the [tower] and [tower-http] ecosystem. -//! - Focus on routing, extracting data from requests, and generating responses. +//! This improves compatibility since axum doesn't have its own custom +//! middleware system. +//! - Focus on routing, extracting data from requests, and building responses. //! tower middleware can handle the rest. //! - Macro free core. Macro frameworks have their place but axum focuses //! on providing a core that is macro free. @@ -39,6 +58,41 @@ //! } //! ``` //! +//! # Handlers +//! +//! In axum a "handler" is an async function that accepts zero or more +//! ["extractors"](#extractors) as arguments and returns something that +//! can be converted [into a response](#building-responses). +//! +//! Handlers is where you custom domain logic lives and axum applications are +//! built by routing between handlers. +//! +//! Some examples of handlers: +//! +//! ```rust +//! use axum::prelude::*; +//! use bytes::Bytes; +//! use http::StatusCode; +//! +//! // Handler that immediately returns an empty `200 OK` response. +//! async fn unit_handler() {} +//! +//! // Handler that immediately returns an empty `200 Ok` response with a plain +//! // text body. +//! async fn string_handler() -> String { +//! "Hello, World!".to_string() +//! } +//! +//! // Handler that buffers the request body and returns it if it is valid UTF-8 +//! async fn buffer_body(body: Bytes) -> Result { +//! if let Ok(string) = String::from_utf8(body.to_vec()) { +//! Ok(string) +//! } else { +//! Err(StatusCode::BAD_REQUEST) +//! } +//! } +//! ``` +//! //! # Routing //! //! Routing between handlers looks like this: @@ -65,92 +119,12 @@ //! # }; //! ``` //! -//! Routes can also be dynamic like `/users/:id`. See ["Extracting data from -//! requests"](#extracting-data-from-requests) for more details on that. +//! Routes can also be dynamic like `/users/:id`. //! -//! # Responses +//! # Extractors //! -//! Anything that implements [`IntoResponse`](response::IntoResponse) can be -//! returned from a handler: -//! -//! ```rust,no_run -//! use axum::{body::Body, response::{Html, Json}, prelude::*}; -//! use http::{StatusCode, Response, Uri}; -//! use serde_json::{Value, json}; -//! -//! // We've already seen returning &'static str -//! async fn plain_text() -> &'static str { -//! "foo" -//! } -//! -//! // String works too and will get a text/plain content-type -//! async fn plain_text_string(uri: Uri) -> String { -//! format!("Hi from {}", uri.path()) -//! } -//! -//! // Bytes will get a `application/octet-stream` content-type -//! async fn bytes() -> Vec { -//! vec![1, 2, 3, 4] -//! } -//! -//! // `()` gives an empty response -//! async fn empty() {} -//! -//! // `StatusCode` gives an empty response with that status code -//! async fn empty_with_status() -> StatusCode { -//! StatusCode::NOT_FOUND -//! } -//! -//! // A tuple of `StatusCode` and something that implements `IntoResponse` can -//! // be used to override the status code -//! async fn with_status() -> (StatusCode, &'static str) { -//! (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong") -//! } -//! -//! // `Html` gives a content-type of `text/html` -//! async fn html() -> Html<&'static str> { -//! Html("

Hello, World!

") -//! } -//! -//! // `Json` gives a content-type of `application/json` and works with any type -//! // that implements `serde::Serialize` -//! async fn json() -> Json { -//! Json(json!({ "data": 42 })) -//! } -//! -//! // `Result` where `T` and `E` implement `IntoResponse` is useful for -//! // returning errors -//! async fn result() -> Result<&'static str, StatusCode> { -//! Ok("all good") -//! } -//! -//! // `Response` gives full control -//! async fn response() -> Response { -//! Response::builder().body(Body::empty()).unwrap() -//! } -//! -//! let app = route("/plain_text", get(plain_text)) -//! .route("/plain_text_string", get(plain_text_string)) -//! .route("/bytes", get(bytes)) -//! .route("/empty", get(empty)) -//! .route("/empty_with_status", get(empty_with_status)) -//! .route("/with_status", get(with_status)) -//! .route("/html", get(html)) -//! .route("/json", get(json)) -//! .route("/result", get(result)) -//! .route("/response", get(response)); -//! # async { -//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! See the [`response`] module for more details. -//! -//! # Extracting data from requests -//! -//! A handler function is an async function take takes any number of -//! "extractors" as arguments. An extractor is a type that implements -//! [`FromRequest`](crate::extract::FromRequest). +//! An extractor is a type that implements [`FromRequest`]. Extractors is how +//! you pick apart the incoming request to get the parts your handler needs. //! //! For example, [`extract::Json`] is an extractor that consumes the request //! body and deserializes it as JSON into some target type: @@ -256,13 +230,90 @@ //! See the [`extract`] module for more details. //! //! [`Uuid`]: https://docs.rs/uuid/latest/uuid/ +//! [`FromRequest`]: crate::extract::FromRequest +//! +//! # Building responses +//! +//! Anything that implements [`IntoResponse`](response::IntoResponse) can be +//! returned from a handler: +//! +//! ```rust,no_run +//! use axum::{body::Body, response::{Html, Json}, prelude::*}; +//! use http::{StatusCode, Response, Uri}; +//! use serde_json::{Value, json}; +//! +//! // We've already seen returning &'static str +//! async fn plain_text() -> &'static str { +//! "foo" +//! } +//! +//! // String works too and will get a text/plain content-type +//! async fn plain_text_string(uri: Uri) -> String { +//! format!("Hi from {}", uri.path()) +//! } +//! +//! // Bytes will get a `application/octet-stream` content-type +//! async fn bytes() -> Vec { +//! vec![1, 2, 3, 4] +//! } +//! +//! // `()` gives an empty response +//! async fn empty() {} +//! +//! // `StatusCode` gives an empty response with that status code +//! async fn empty_with_status() -> StatusCode { +//! StatusCode::NOT_FOUND +//! } +//! +//! // A tuple of `StatusCode` and something that implements `IntoResponse` can +//! // be used to override the status code +//! async fn with_status() -> (StatusCode, &'static str) { +//! (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong") +//! } +//! +//! // `Html` gives a content-type of `text/html` +//! async fn html() -> Html<&'static str> { +//! Html("

Hello, World!

") +//! } +//! +//! // `Json` gives a content-type of `application/json` and works with any type +//! // that implements `serde::Serialize` +//! async fn json() -> Json { +//! Json(json!({ "data": 42 })) +//! } +//! +//! // `Result` where `T` and `E` implement `IntoResponse` is useful for +//! // returning errors +//! async fn result() -> Result<&'static str, StatusCode> { +//! Ok("all good") +//! } +//! +//! // `Response` gives full control +//! async fn response() -> Response { +//! Response::builder().body(Body::empty()).unwrap() +//! } +//! +//! let app = route("/plain_text", get(plain_text)) +//! .route("/plain_text_string", get(plain_text_string)) +//! .route("/bytes", get(bytes)) +//! .route("/empty", get(empty)) +//! .route("/empty_with_status", get(empty_with_status)) +//! .route("/with_status", get(with_status)) +//! .route("/html", get(html)) +//! .route("/json", get(json)) +//! .route("/result", get(result)) +//! .route("/response", get(response)); +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` //! //! # Applying middleware //! //! axum is designed to take full advantage of the tower and tower-http //! ecosystem of middleware: //! -//! ## Applying middleware to individual handlers +//! ## To individual handlers //! //! A middleware can be applied to a single handler like so: //! @@ -281,7 +332,7 @@ //! # }; //! ``` //! -//! ## Applying middleware to groups of routes +//! ## To groups of routes //! //! Middleware can also be applied to a group of routes like so: //! @@ -320,7 +371,7 @@ //! use tower::{ //! BoxError, timeout::{TimeoutLayer, error::Elapsed}, //! }; -//! use std::{borrow::Cow, time::Duration}; +//! use std::{borrow::Cow, time::Duration, convert::Infallible}; //! use http::StatusCode; //! //! let app = route( @@ -332,16 +383,19 @@ //! // Check if the actual error type is `Elapsed` which //! // `Timeout` returns //! if error.is::() { -//! return (StatusCode::REQUEST_TIMEOUT, "Request took too long".into()); +//! return Ok::<_, Infallible>(( +//! StatusCode::REQUEST_TIMEOUT, +//! "Request took too long".into(), +//! )); //! } //! //! // If we encounter some error we don't handle return a generic //! // error -//! return ( +//! return Ok::<_, Infallible>(( //! StatusCode::INTERNAL_SERVER_ERROR, //! // `Cow` lets us return either `&str` or `String` //! Cow::from(format!("Unhandled internal error: {}", error)), -//! ); +//! )); //! })), //! ); //! @@ -352,30 +406,10 @@ //! ``` //! //! The closure passed to [`handle_error`](handler::Layered::handle_error) must -//! return something that implements [`IntoResponse`](response::IntoResponse). +//! return `Result` where `T` implements +//! [`IntoResponse`](response::IntoResponse). //! -//! [`handle_error`](routing::Layered::handle_error) is also available on a -//! group of routes with middleware applied: -//! -//! ```rust,no_run -//! use axum::prelude::*; -//! use tower::{BoxError, timeout::TimeoutLayer}; -//! use std::time::Duration; -//! -//! let app = route("/", get(handle)) -//! .route("/foo", post(other_handle)) -//! .layer(TimeoutLayer::new(Duration::from_secs(30))) -//! .handle_error(|error: BoxError| { -//! // ... -//! }); -//! -//! async fn handle() {} -//! -//! async fn other_handle() {} -//! # async { -//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` +//! See [`routing::Layered::handle_error`] fo more details. //! //! ## Applying multiple middleware //! @@ -383,14 +417,9 @@ //! //! ```rust,no_run //! use axum::prelude::*; -//! use tower::{ -//! ServiceBuilder, BoxError, -//! load_shed::error::Overloaded, -//! timeout::error::Elapsed, -//! }; +//! use tower::ServiceBuilder; //! use tower_http::compression::CompressionLayer; //! use std::{borrow::Cow, time::Duration}; -//! use http::StatusCode; //! //! let middleware_stack = ServiceBuilder::new() //! // Return an error after 30 seconds @@ -404,27 +433,7 @@ //! .into_inner(); //! //! let app = route("/", get(|_: Request| async { /* ... */ })) -//! .layer(middleware_stack) -//! .handle_error(|error: BoxError| { -//! if error.is::() { -//! return ( -//! StatusCode::SERVICE_UNAVAILABLE, -//! "Try again later".into(), -//! ); -//! } -//! -//! if error.is::() { -//! return ( -//! StatusCode::REQUEST_TIMEOUT, -//! "Request took too long".into(), -//! ); -//! }; -//! -//! return ( -//! StatusCode::INTERNAL_SERVER_ERROR, -//! Cow::from(format!("Unhandled internal error: {}", error)), -//! ); -//! }); +//! .layer(middleware_stack); //! # async { //! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; @@ -466,10 +475,7 @@ //! axum also supports routing to general [`Service`]s: //! //! ```rust,no_run -//! use axum::{ -//! // `ServiceExt` adds `handle_error` to any `Service` -//! service::{self, ServiceExt}, prelude::*, -//! }; +//! use axum::{service, prelude::*}; //! use tower_http::services::ServeFile; //! use http::Response; //! use std::convert::Infallible; @@ -485,10 +491,7 @@ //! ).route( //! // GET `/static/Cargo.toml` goes to a service from tower-http //! "/static/Cargo.toml", -//! service::get( -//! ServeFile::new("Cargo.toml") -//! .handle_error(|error: std::io::Error| { /* ... */ }) -//! ) +//! service::get(ServeFile::new("Cargo.toml")) //! ); //! # async { //! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); @@ -518,26 +521,12 @@ //! # }; //! ``` //! -//! [`nest`](routing::nest) can also be used to serve static files from a directory: +//! # Examples //! -//! ```rust,no_run -//! use axum::{prelude::*, service::ServiceExt, routing::nest}; -//! use tower_http::services::ServeDir; -//! use http::Response; -//! use tower::{service_fn, BoxError}; +//! The axum repo contains [a number of examples][examples] that show how to put all the +//! pieces togehter. //! -//! let app = nest( -//! "/images", -//! ServeDir::new("public/images").handle_error(|error: std::io::Error| { -//! // ... -//! }) -//! ); -//! # async { -//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -//! # }; -//! ``` -//! -//! # Features +//! # Feature flags //! //! axum uses a set of [feature flags] to reduce the amount of compiled and //! optional dependencies. @@ -555,6 +544,7 @@ //! [feature flags]: https://doc.rust-lang.org/cargo/reference/features.html#the-features-section //! [`IntoResponse`]: crate::response::IntoResponse //! [`Timeout`]: tower::timeout::Timeout +//! [examples]: https://github.com/davidpdrsn/axum/tree/main/examples #![doc(html_root_url = "https://docs.rs/tower-http/0.1.0")] #![warn( diff --git a/src/routing.rs b/src/routing.rs index 8d0e9e98..8871fb95 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -28,7 +28,7 @@ use tower::{ }; use tower_http::map_response_body::MapResponseBodyLayer; -/// A filter that matches one or more HTTP method. +/// A filter that matches one or more HTTP methods. #[derive(Debug, Copy, Clone)] pub enum MethodFilter { /// Match any method. @@ -632,7 +632,7 @@ impl Layered { /// use axum::prelude::*; /// use http::StatusCode; /// use tower::{BoxError, timeout::TimeoutLayer}; - /// use std::time::Duration; + /// use std::{convert::Infallible, time::Duration}; /// /// async fn handler() { /* ... */ } /// @@ -643,18 +643,17 @@ impl Layered { /// // ...so we should handle that error /// let with_errors_handled = layered_app.handle_error(|error: BoxError| { /// if error.is::() { - /// ( + /// Ok::<_, Infallible>(( /// StatusCode::REQUEST_TIMEOUT, /// "request took too long".to_string(), - /// ) + /// )) /// } else { - /// ( + /// Ok::<_, Infallible>(( /// StatusCode::INTERNAL_SERVER_ERROR, /// format!("Unhandled internal error: {}", error), - /// ) + /// )) /// } /// }); - /// /// # async { /// # hyper::Server::bind(&"".parse().unwrap()) /// # .serve(with_errors_handled.into_make_service()) @@ -663,14 +662,46 @@ impl Layered { /// # }; /// ``` /// - /// The closure can return any type that implements [`IntoResponse`]. - pub fn handle_error( + /// The closure must return `Result` where `T` implements [`IntoResponse`]. + /// + /// You can also return `Err(_)` if you don't wish to handle the error: + /// + /// ```rust + /// use axum::prelude::*; + /// use http::StatusCode; + /// use tower::{BoxError, timeout::TimeoutLayer}; + /// use std::time::Duration; + /// + /// async fn handler() { /* ... */ } + /// + /// let layered_app = route("/", get(handler)) + /// .layer(TimeoutLayer::new(Duration::from_secs(30))); + /// + /// let with_errors_handled = layered_app.handle_error(|error: BoxError| { + /// if error.is::() { + /// Ok(( + /// StatusCode::REQUEST_TIMEOUT, + /// "request took too long".to_string(), + /// )) + /// } else { + /// // keep the error as is + /// Err(error) + /// } + /// }); + /// # async { + /// # hyper::Server::bind(&"".parse().unwrap()) + /// # .serve(with_errors_handled.into_make_service()) + /// # .await + /// # .unwrap(); + /// # }; + /// ``` + pub fn handle_error( self, f: F, ) -> crate::service::HandleError where S: Service, Response = Response> + Clone, - F: FnOnce(S::Error) -> Res, + F: FnOnce(S::Error) -> Result, Res: IntoResponse, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync + 'static, @@ -757,8 +788,7 @@ where /// use tower_http::services::ServeDir; /// /// // Serves files inside the `public` directory at `GET /public/*` -/// let serve_dir_service = ServeDir::new("public") -/// .handle_error(|error: std::io::Error| { /* ... */ }); +/// let serve_dir_service = ServeDir::new("public"); /// /// let app = nest("/public", get(serve_dir_service)); /// # async { diff --git a/src/service/future.rs b/src/service/future.rs index 8bc4bbb4..7e6cd421 100644 --- a/src/service/future.rs +++ b/src/service/future.rs @@ -9,7 +9,6 @@ use futures_util::ready; use http::Response; use pin_project::pin_project; use std::{ - convert::Infallible, future::Future, pin::Pin, task::{Context, Poll}, @@ -25,15 +24,15 @@ pub struct HandleErrorFuture { pub(super) f: Option, } -impl Future for HandleErrorFuture +impl Future for HandleErrorFuture where Fut: Future, E>>, - F: FnOnce(E) -> Res, + F: FnOnce(E) -> Result, Res: IntoResponse, B: http_body::Body + Send + Sync + 'static, B::Error: Into + Send + Sync + 'static, { - type Output = Result, Infallible>; + type Output = Result, E2>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -42,8 +41,10 @@ where Ok(res) => Ok(res.map(box_body)).into(), Err(err) => { let f = this.f.take().unwrap(); - let res = f(err).into_response(); - Ok(res.map(box_body)).into() + match f(err) { + Ok(res) => Ok(res.into_response().map(box_body)).into(), + Err(err) => Err(err).into(), + } } } } diff --git a/src/service/mod.rs b/src/service/mod.rs index ae1587b3..6a614bee 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -507,16 +507,16 @@ where } } -impl Service> for HandleError +impl Service> for HandleError where S: Service, Response = Response> + Clone, - F: FnOnce(S::Error) -> Res + Clone, + F: FnOnce(S::Error) -> Result + Clone, Res: IntoResponse, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync + 'static, { type Response = Response; - type Error = Infallible; + type Error = E; type Future = future::HandleErrorFuture>, F>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { @@ -538,15 +538,16 @@ pub trait ServiceExt: /// Handle errors from a service. /// /// `handle_error` takes a closure that will map errors from the service - /// into responses. The closure's return type must implement - /// [`IntoResponse`]. + /// into responses. The closure's return type must be `Result` where + /// `T` implements [`IntoIntoResponse`](crate::response::IntoResponse). /// /// # Example /// /// ```rust,no_run /// use axum::{service::{self, ServiceExt}, prelude::*}; - /// use http::Response; + /// use http::{Response, StatusCode}; /// use tower::{service_fn, BoxError}; + /// use std::convert::Infallible; /// /// // A service that might fail with `std::io::Error` /// let service = service_fn(|_: Request| async { @@ -557,7 +558,10 @@ pub trait ServiceExt: /// let app = route( /// "/", /// service.handle_error(|error: std::io::Error| { - /// // Handle error by returning something that implements `IntoResponse` + /// Ok::<_, Infallible>(( + /// StatusCode::INTERNAL_SERVER_ERROR, + /// error.to_string(), + /// )) /// }), /// ); /// # @@ -565,10 +569,14 @@ pub trait ServiceExt: /// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` - fn handle_error(self, f: F) -> HandleError + /// + /// It works similarly to [`routing::Layered::handle_error`]. See that for more details. + /// + /// [`routing::Layered::handle_error`]: crate::routing::Layered::handle_error + fn handle_error(self, f: F) -> HandleError where Self: Sized, - F: FnOnce(Self::Error) -> Res, + F: FnOnce(Self::Error) -> Result, Res: IntoResponse, ResBody: http_body::Body + Send + Sync + 'static, ResBody::Error: Into + Send + Sync + 'static, diff --git a/src/tests.rs b/src/tests.rs index 7b5ad308..d0b3565c 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -8,6 +8,7 @@ use hyper::{Body, Server}; use serde::Deserialize; use serde_json::json; use std::{ + convert::Infallible, net::{SocketAddr, TcpListener}, time::Duration, }; @@ -340,7 +341,7 @@ async fn service_handlers() { service_fn(|req: Request| async move { Ok::<_, BoxError>(Response::new(req.into_body())) }) - .handle_error(|_error: BoxError| StatusCode::INTERNAL_SERVER_ERROR), + .handle_error(|_error: BoxError| Ok(StatusCode::INTERNAL_SERVER_ERROR)), ), ) .route( @@ -348,7 +349,7 @@ async fn service_handlers() { service::on( MethodFilter::Get, ServeFile::new("Cargo.toml").handle_error(|error: std::io::Error| { - (StatusCode::INTERNAL_SERVER_ERROR, error.to_string()) + Ok::<_, Infallible>((StatusCode::INTERNAL_SERVER_ERROR, error.to_string())) }), ), ); @@ -485,7 +486,9 @@ async fn handling_errors_from_layered_single_routes() { .layer(TraceLayer::new_for_http()) .into_inner(), ) - .handle_error(|_error: BoxError| StatusCode::INTERNAL_SERVER_ERROR)), + .handle_error(|_error: BoxError| { + Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR) + })), ); let addr = run_in_background(app).await; @@ -508,7 +511,7 @@ async fn layer_on_whole_router() { .timeout(Duration::from_millis(100)) .into_inner(), ) - .handle_error(|_err: BoxError| StatusCode::INTERNAL_SERVER_ERROR); + .handle_error(|_err: BoxError| Ok::<_, Infallible>(StatusCode::INTERNAL_SERVER_ERROR)); let addr = run_in_background(app).await; diff --git a/src/ws/mod.rs b/src/ws/mod.rs index fe00149b..cc2ac234 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -56,7 +56,7 @@ where crate::service::get::<_, B>(svc) } -/// [`Service`] that ugprades connections to websockets and spawns a task to +/// [`Service`] that upgrades connections to websockets and spawns a task to /// handle the stream. /// /// Created with [`ws`].