Add middleware::from_fn for creating middleware from async fns (#656)

* Add `middleware::from_fn` for creating middleware from async fns

* More trait impls for `Next`

* Make `Next::run` consume `self`

* Use `.router_layer` in example, since middleware returns early

* Actually `Next` probably shouldn't impl `Clone` and `Service`

Has implications for backpressure and stuff

* Simplify `print-request-response` example

* Address review feedback

* add changelog link
This commit is contained in:
David Pedersen 2021-12-27 14:01:26 +01:00 committed by GitHub
parent 3841ef44d5
commit f4716084a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 277 additions and 22 deletions

View File

@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased
- None.
- Add `middleware::from_fn` for creating middleware from async functions ([#656])
[#656]: https://github.com/tokio-rs/axum/pull/656
# 0.1.0 (02. December, 2021)

View File

@ -17,6 +17,10 @@ erased-json = ["serde", "serde_json"]
axum = { path = "../axum", version = "0.4" }
http = "0.2"
mime = "0.3"
pin-project-lite = "0.2"
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.2", features = ["util", "map-response-body"] }
tower-layer = "0.3"
tower-service = "0.3"
# optional dependencies

View File

@ -44,5 +44,6 @@
#![cfg_attr(test, allow(clippy::float_cmp))]
pub mod extract;
pub mod middleware;
pub mod response;
pub mod routing;

View File

@ -0,0 +1,240 @@
//! Create middleware from async functions.
//!
//! See [`from_fn`] for more details.
use axum::{
body::{self, Bytes, HttpBody},
response::{IntoResponse, Response},
BoxError,
};
use http::Request;
use pin_project_lite::pin_project;
use std::{
any::type_name,
convert::Infallible,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{util::BoxCloneService, ServiceBuilder};
use tower_http::ServiceBuilderExt;
use tower_layer::Layer;
use tower_service::Service;
/// Create a middleware from an async function.
///
/// `from_fn` requires the function given to
///
/// 1. Be an `async fn`.
/// 2. Take [`Request`](http::Request) as the first argument.
/// 3. Take [`Next<B>`](Next) as the second argument.
/// 4. Return something that implements [`IntoResponse`].
///
/// # Example
///
/// ```rust
/// use axum::{
/// Router,
/// http::{Request, StatusCode},
/// routing::get,
/// response::IntoResponse,
/// };
/// use axum_extra::middleware::{self, Next};
///
/// async fn auth<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
/// let auth_header = req.headers().get(http::header::AUTHORIZATION);
///
/// match auth_header {
/// Some(auth_header) if auth_header == "secret" => {
/// Ok(next.run(req).await)
/// }
/// _ => Err(StatusCode::UNAUTHORIZED),
/// }
/// }
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(middleware::from_fn(auth));
/// # let app: Router = app;
/// ```
pub fn from_fn<F>(f: F) -> MiddlewareFnLayer<F> {
MiddlewareFnLayer { f }
}
/// A [`tower::Layer`] from an async function.
///
/// [`tower::Layer`] is used to apply middleware to [`axum::Router`]s.
///
/// Created with [`from_fn`]. See that function for more details.
#[derive(Clone, Copy)]
pub struct MiddlewareFnLayer<F> {
f: F,
}
impl<S, F> Layer<S> for MiddlewareFnLayer<F>
where
F: Clone,
{
type Service = MiddlewareFn<F, S>;
fn layer(&self, inner: S) -> Self::Service {
MiddlewareFn {
f: self.f.clone(),
inner,
}
}
}
impl<F> fmt::Debug for MiddlewareFnLayer<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MiddlewareFnLayer")
// Write out the type name, without quoting it as `&type_name::<F>()` would
.field("f", &format_args!("{}", type_name::<F>()))
.finish()
}
}
/// A middleware created from an async function.
///
/// Created with [`from_fn`]. See that function for more details.
#[derive(Clone, Copy)]
pub struct MiddlewareFn<F, S> {
f: F,
inner: S,
}
impl<F, Fut, Out, S, ReqBody, ResBody> Service<Request<ReqBody>> for MiddlewareFn<F, S>
where
F: FnMut(Request<ReqBody>, Next<ReqBody>) -> Fut,
Fut: Future<Output = Out>,
Out: IntoResponse,
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
ResBody::Error: Into<BoxError>,
{
type Response = Response;
type Error = Infallible;
type Future = ResponseFuture<Fut>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let not_ready_inner = self.inner.clone();
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let inner = ServiceBuilder::new()
.boxed_clone()
.map_response_body(body::boxed)
.service(ready_inner);
let next = Next { inner };
ResponseFuture {
inner: (self.f)(req, next),
}
}
}
impl<F, S> fmt::Debug for MiddlewareFn<F, S>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MiddlewareFnLayer")
.field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner)
.finish()
}
}
/// The remainder of a middleware stack, including the handler.
pub struct Next<ReqBody> {
inner: BoxCloneService<Request<ReqBody>, Response, Infallible>,
}
impl<ReqBody> Next<ReqBody> {
/// Execute the remaining middleware stack.
pub async fn run(mut self, req: Request<ReqBody>) -> Response {
match self.inner.call(req).await {
Ok(res) => res,
Err(err) => match err {},
}
}
}
impl<ReqBody> fmt::Debug for Next<ReqBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MiddlewareFnLayer")
.field("inner", &self.inner)
.finish()
}
}
pin_project! {
/// Response future for [`MiddlewareFn`].
pub struct ResponseFuture<F> {
#[pin]
inner: F,
}
}
impl<F, Out> Future for ResponseFuture<F>
where
F: Future<Output = Out>,
Out: IntoResponse,
{
type Output = Result<Response, Infallible>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project()
.inner
.poll(cx)
.map(IntoResponse::into_response)
.map(Ok)
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Empty, routing::get, Router};
use http::{HeaderMap, StatusCode};
use tower::ServiceExt;
#[tokio::test]
async fn basic() {
async fn insert_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
req.headers_mut()
.insert("x-axum-test", "ok".parse().unwrap());
next.run(req).await
}
async fn handle(headers: HeaderMap) -> String {
(&headers["x-axum-test"]).to_str().unwrap().to_owned()
}
let app = Router::new()
.route("/", get(handle))
.layer(from_fn(insert_header));
let res = app
.oneshot(
Request::builder()
.uri("/")
.body(body::boxed(Empty::new()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = hyper::body::to_bytes(res).await.unwrap();
assert_eq!(&body[..], b"ok");
}
}

View File

@ -0,0 +1,5 @@
//! Additional types for creating middleware.
pub mod middleware_fn;
pub use self::middleware_fn::{from_fn, Next};

View File

@ -6,6 +6,7 @@ publish = false
[dependencies]
axum = { path = "../../axum" }
axum-extra = { path = "../../axum-extra" }
tokio = { version = "1.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version="0.3", features = ["env-filter"] }

View File

@ -6,14 +6,13 @@
use axum::{
body::{Body, Bytes},
error_handling::HandleErrorLayer,
http::{Request, StatusCode},
response::Response,
response::{IntoResponse, Response},
routing::post,
Router,
};
use axum_extra::middleware::{self, Next};
use std::net::SocketAddr;
use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError, ServiceBuilder};
#[tokio::main]
async fn main() {
@ -28,17 +27,7 @@ async fn main() {
let app = Router::new()
.route("/", post(|| async move { "Hello from `POST /`" }))
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
)
}))
.layer(AndThenLayer::new(map_response))
.layer(AsyncFilterLayer::new(map_request)),
);
.layer(middleware::from_fn(print_request_response));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
@ -48,28 +37,41 @@ async fn main() {
.unwrap();
}
async fn map_request(req: Request<Body>) -> Result<Request<Body>, BoxError> {
async fn print_request_response(
req: Request<Body>,
next: Next<Body>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let (parts, body) = req.into_parts();
let bytes = buffer_and_print("request", body).await?;
let req = Request::from_parts(parts, Body::from(bytes));
Ok(req)
}
async fn map_response(res: Response) -> Result<Response<Body>, BoxError> {
let res = next.run(req).await;
let (parts, body) = res.into_parts();
let bytes = buffer_and_print("response", body).await?;
let res = Response::from_parts(parts, Body::from(bytes));
Ok(res)
}
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, BoxError>
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
where
B: axum::body::HttpBody<Data = Bytes>,
B::Error: Into<BoxError>,
B::Error: std::fmt::Display,
{
let bytes = hyper::body::to_bytes(body).await.map_err(Into::into)?;
let bytes = match hyper::body::to_bytes(body).await {
Ok(bytes) => bytes,
Err(err) => {
return Err((
StatusCode::BAD_REQUEST,
format!("failed to read {} body: {}", direction, err),
));
}
};
if let Ok(body) = std::str::from_utf8(&bytes) {
tracing::debug!("{} body = {:?}", direction, body);
}
Ok(bytes)
}