diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index b81268f9..405876e9 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- Add `middleware::from_fn` for creating middleware from async functions ([#656]) + +[#656]: https://github.com/tokio-rs/axum/pull/656 # 0.1.0 (02. December, 2021) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 046cfe37..a35eb859 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -17,6 +17,10 @@ erased-json = ["serde", "serde_json"] axum = { path = "../axum", version = "0.4" } http = "0.2" mime = "0.3" +pin-project-lite = "0.2" +tower = { version = "0.4", features = ["util"] } +tower-http = { version = "0.2", features = ["util", "map-response-body"] } +tower-layer = "0.3" tower-service = "0.3" # optional dependencies diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index bb16461f..daf0a8d5 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -44,5 +44,6 @@ #![cfg_attr(test, allow(clippy::float_cmp))] pub mod extract; +pub mod middleware; pub mod response; pub mod routing; diff --git a/axum-extra/src/middleware/middleware_fn.rs b/axum-extra/src/middleware/middleware_fn.rs new file mode 100644 index 00000000..2bb9cb8a --- /dev/null +++ b/axum-extra/src/middleware/middleware_fn.rs @@ -0,0 +1,240 @@ +//! Create middleware from async functions. +//! +//! See [`from_fn`] for more details. + +use axum::{ + body::{self, Bytes, HttpBody}, + response::{IntoResponse, Response}, + BoxError, +}; +use http::Request; +use pin_project_lite::pin_project; +use std::{ + any::type_name, + convert::Infallible, + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::{util::BoxCloneService, ServiceBuilder}; +use tower_http::ServiceBuilderExt; +use tower_layer::Layer; +use tower_service::Service; + +/// Create a middleware from an async function. +/// +/// `from_fn` requires the function given to +/// +/// 1. Be an `async fn`. +/// 2. Take [`Request`](http::Request) as the first argument. +/// 3. Take [`Next`](Next) as the second argument. +/// 4. Return something that implements [`IntoResponse`]. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// Router, +/// http::{Request, StatusCode}, +/// routing::get, +/// response::IntoResponse, +/// }; +/// use axum_extra::middleware::{self, Next}; +/// +/// async fn auth(req: Request, next: Next) -> impl IntoResponse { +/// let auth_header = req.headers().get(http::header::AUTHORIZATION); +/// +/// match auth_header { +/// Some(auth_header) if auth_header == "secret" => { +/// Ok(next.run(req).await) +/// } +/// _ => Err(StatusCode::UNAUTHORIZED), +/// } +/// } +/// +/// let app = Router::new() +/// .route("/", get(|| async { /* ... */ })) +/// .route_layer(middleware::from_fn(auth)); +/// # let app: Router = app; +/// ``` +pub fn from_fn(f: F) -> MiddlewareFnLayer { + MiddlewareFnLayer { f } +} + +/// A [`tower::Layer`] from an async function. +/// +/// [`tower::Layer`] is used to apply middleware to [`axum::Router`]s. +/// +/// Created with [`from_fn`]. See that function for more details. +#[derive(Clone, Copy)] +pub struct MiddlewareFnLayer { + f: F, +} + +impl Layer for MiddlewareFnLayer +where + F: Clone, +{ + type Service = MiddlewareFn; + + fn layer(&self, inner: S) -> Self::Service { + MiddlewareFn { + f: self.f.clone(), + inner, + } + } +} + +impl fmt::Debug for MiddlewareFnLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MiddlewareFnLayer") + // Write out the type name, without quoting it as `&type_name::()` would + .field("f", &format_args!("{}", type_name::())) + .finish() + } +} + +/// A middleware created from an async function. +/// +/// Created with [`from_fn`]. See that function for more details. +#[derive(Clone, Copy)] +pub struct MiddlewareFn { + f: F, + inner: S, +} + +impl Service> for MiddlewareFn +where + F: FnMut(Request, Next) -> Fut, + Fut: Future, + Out: IntoResponse, + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send + 'static, + ResBody: HttpBody + Send + 'static, + ResBody::Error: Into, +{ + type Response = Response; + type Error = Infallible; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let not_ready_inner = self.inner.clone(); + let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); + + let inner = ServiceBuilder::new() + .boxed_clone() + .map_response_body(body::boxed) + .service(ready_inner); + let next = Next { inner }; + + ResponseFuture { + inner: (self.f)(req, next), + } + } +} + +impl fmt::Debug for MiddlewareFn +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MiddlewareFnLayer") + .field("f", &format_args!("{}", type_name::())) + .field("inner", &self.inner) + .finish() + } +} + +/// The remainder of a middleware stack, including the handler. +pub struct Next { + inner: BoxCloneService, Response, Infallible>, +} + +impl Next { + /// Execute the remaining middleware stack. + pub async fn run(mut self, req: Request) -> Response { + match self.inner.call(req).await { + Ok(res) => res, + Err(err) => match err {}, + } + } +} + +impl fmt::Debug for Next { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MiddlewareFnLayer") + .field("inner", &self.inner) + .finish() + } +} + +pin_project! { + /// Response future for [`MiddlewareFn`]. + pub struct ResponseFuture { + #[pin] + inner: F, + } +} + +impl Future for ResponseFuture +where + F: Future, + Out: IntoResponse, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .inner + .poll(cx) + .map(IntoResponse::into_response) + .map(Ok) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::{body::Empty, routing::get, Router}; + use http::{HeaderMap, StatusCode}; + use tower::ServiceExt; + + #[tokio::test] + async fn basic() { + async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse { + req.headers_mut() + .insert("x-axum-test", "ok".parse().unwrap()); + + next.run(req).await + } + + async fn handle(headers: HeaderMap) -> String { + (&headers["x-axum-test"]).to_str().unwrap().to_owned() + } + + let app = Router::new() + .route("/", get(handle)) + .layer(from_fn(insert_header)); + + let res = app + .oneshot( + Request::builder() + .uri("/") + .body(body::boxed(Empty::new())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + let body = hyper::body::to_bytes(res).await.unwrap(); + assert_eq!(&body[..], b"ok"); + } +} diff --git a/axum-extra/src/middleware/mod.rs b/axum-extra/src/middleware/mod.rs new file mode 100644 index 00000000..89420678 --- /dev/null +++ b/axum-extra/src/middleware/mod.rs @@ -0,0 +1,5 @@ +//! Additional types for creating middleware. + +pub mod middleware_fn; + +pub use self::middleware_fn::{from_fn, Next}; diff --git a/examples/print-request-response/Cargo.toml b/examples/print-request-response/Cargo.toml index 69781c11..83c4b5b5 100644 --- a/examples/print-request-response/Cargo.toml +++ b/examples/print-request-response/Cargo.toml @@ -6,6 +6,7 @@ publish = false [dependencies] axum = { path = "../../axum" } +axum-extra = { path = "../../axum-extra" } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version="0.3", features = ["env-filter"] } diff --git a/examples/print-request-response/src/main.rs b/examples/print-request-response/src/main.rs index 6ebf4f70..97ce4355 100644 --- a/examples/print-request-response/src/main.rs +++ b/examples/print-request-response/src/main.rs @@ -6,14 +6,13 @@ use axum::{ body::{Body, Bytes}, - error_handling::HandleErrorLayer, http::{Request, StatusCode}, - response::Response, + response::{IntoResponse, Response}, routing::post, Router, }; +use axum_extra::middleware::{self, Next}; use std::net::SocketAddr; -use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError, ServiceBuilder}; #[tokio::main] async fn main() { @@ -28,17 +27,7 @@ async fn main() { let app = Router::new() .route("/", post(|| async move { "Hello from `POST /`" })) - .layer( - ServiceBuilder::new() - .layer(HandleErrorLayer::new(|error| async move { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Unhandled internal error: {}", error), - ) - })) - .layer(AndThenLayer::new(map_response)) - .layer(AsyncFilterLayer::new(map_request)), - ); + .layer(middleware::from_fn(print_request_response)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); tracing::debug!("listening on {}", addr); @@ -48,28 +37,41 @@ async fn main() { .unwrap(); } -async fn map_request(req: Request) -> Result, BoxError> { +async fn print_request_response( + req: Request, + next: Next, +) -> Result { let (parts, body) = req.into_parts(); let bytes = buffer_and_print("request", body).await?; let req = Request::from_parts(parts, Body::from(bytes)); - Ok(req) -} -async fn map_response(res: Response) -> Result, BoxError> { + let res = next.run(req).await; + let (parts, body) = res.into_parts(); let bytes = buffer_and_print("response", body).await?; let res = Response::from_parts(parts, Body::from(bytes)); + Ok(res) } -async fn buffer_and_print(direction: &str, body: B) -> Result +async fn buffer_and_print(direction: &str, body: B) -> Result where B: axum::body::HttpBody, - B::Error: Into, + B::Error: std::fmt::Display, { - let bytes = hyper::body::to_bytes(body).await.map_err(Into::into)?; + let bytes = match hyper::body::to_bytes(body).await { + Ok(bytes) => bytes, + Err(err) => { + return Err(( + StatusCode::BAD_REQUEST, + format!("failed to read {} body: {}", direction, err), + )); + } + }; + if let Ok(body) = std::str::from_utf8(&bytes) { tracing::debug!("{} body = {:?}", direction, body); } + Ok(bytes) }