mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-28 13:30:39 +00:00
Add middleware::from_fn
for creating middleware from async fns (#656)
* Add `middleware::from_fn` for creating middleware from async fns * More trait impls for `Next` * Make `Next::run` consume `self` * Use `.router_layer` in example, since middleware returns early * Actually `Next` probably shouldn't impl `Clone` and `Service` Has implications for backpressure and stuff * Simplify `print-request-response` example * Address review feedback * add changelog link
This commit is contained in:
parent
3841ef44d5
commit
f4716084a7
@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
# Unreleased
|
||||
|
||||
- None.
|
||||
- Add `middleware::from_fn` for creating middleware from async functions ([#656])
|
||||
|
||||
[#656]: https://github.com/tokio-rs/axum/pull/656
|
||||
|
||||
# 0.1.0 (02. December, 2021)
|
||||
|
||||
|
@ -17,6 +17,10 @@ erased-json = ["serde", "serde_json"]
|
||||
axum = { path = "../axum", version = "0.4" }
|
||||
http = "0.2"
|
||||
mime = "0.3"
|
||||
pin-project-lite = "0.2"
|
||||
tower = { version = "0.4", features = ["util"] }
|
||||
tower-http = { version = "0.2", features = ["util", "map-response-body"] }
|
||||
tower-layer = "0.3"
|
||||
tower-service = "0.3"
|
||||
|
||||
# optional dependencies
|
||||
|
@ -44,5 +44,6 @@
|
||||
#![cfg_attr(test, allow(clippy::float_cmp))]
|
||||
|
||||
pub mod extract;
|
||||
pub mod middleware;
|
||||
pub mod response;
|
||||
pub mod routing;
|
||||
|
240
axum-extra/src/middleware/middleware_fn.rs
Normal file
240
axum-extra/src/middleware/middleware_fn.rs
Normal file
@ -0,0 +1,240 @@
|
||||
//! Create middleware from async functions.
|
||||
//!
|
||||
//! See [`from_fn`] for more details.
|
||||
|
||||
use axum::{
|
||||
body::{self, Bytes, HttpBody},
|
||||
response::{IntoResponse, Response},
|
||||
BoxError,
|
||||
};
|
||||
use http::Request;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
any::type_name,
|
||||
convert::Infallible,
|
||||
fmt,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tower::{util::BoxCloneService, ServiceBuilder};
|
||||
use tower_http::ServiceBuilderExt;
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
|
||||
/// Create a middleware from an async function.
|
||||
///
|
||||
/// `from_fn` requires the function given to
|
||||
///
|
||||
/// 1. Be an `async fn`.
|
||||
/// 2. Take [`Request`](http::Request) as the first argument.
|
||||
/// 3. Take [`Next<B>`](Next) as the second argument.
|
||||
/// 4. Return something that implements [`IntoResponse`].
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use axum::{
|
||||
/// Router,
|
||||
/// http::{Request, StatusCode},
|
||||
/// routing::get,
|
||||
/// response::IntoResponse,
|
||||
/// };
|
||||
/// use axum_extra::middleware::{self, Next};
|
||||
///
|
||||
/// async fn auth<B>(req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
/// let auth_header = req.headers().get(http::header::AUTHORIZATION);
|
||||
///
|
||||
/// match auth_header {
|
||||
/// Some(auth_header) if auth_header == "secret" => {
|
||||
/// Ok(next.run(req).await)
|
||||
/// }
|
||||
/// _ => Err(StatusCode::UNAUTHORIZED),
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let app = Router::new()
|
||||
/// .route("/", get(|| async { /* ... */ }))
|
||||
/// .route_layer(middleware::from_fn(auth));
|
||||
/// # let app: Router = app;
|
||||
/// ```
|
||||
pub fn from_fn<F>(f: F) -> MiddlewareFnLayer<F> {
|
||||
MiddlewareFnLayer { f }
|
||||
}
|
||||
|
||||
/// A [`tower::Layer`] from an async function.
|
||||
///
|
||||
/// [`tower::Layer`] is used to apply middleware to [`axum::Router`]s.
|
||||
///
|
||||
/// Created with [`from_fn`]. See that function for more details.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct MiddlewareFnLayer<F> {
|
||||
f: F,
|
||||
}
|
||||
|
||||
impl<S, F> Layer<S> for MiddlewareFnLayer<F>
|
||||
where
|
||||
F: Clone,
|
||||
{
|
||||
type Service = MiddlewareFn<F, S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
MiddlewareFn {
|
||||
f: self.f.clone(),
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> fmt::Debug for MiddlewareFnLayer<F> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MiddlewareFnLayer")
|
||||
// Write out the type name, without quoting it as `&type_name::<F>()` would
|
||||
.field("f", &format_args!("{}", type_name::<F>()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// A middleware created from an async function.
|
||||
///
|
||||
/// Created with [`from_fn`]. See that function for more details.
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct MiddlewareFn<F, S> {
|
||||
f: F,
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<F, Fut, Out, S, ReqBody, ResBody> Service<Request<ReqBody>> for MiddlewareFn<F, S>
|
||||
where
|
||||
F: FnMut(Request<ReqBody>, Next<ReqBody>) -> Fut,
|
||||
Fut: Future<Output = Out>,
|
||||
Out: IntoResponse,
|
||||
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
S::Future: Send + 'static,
|
||||
ResBody: HttpBody<Data = Bytes> + Send + 'static,
|
||||
ResBody::Error: Into<BoxError>,
|
||||
{
|
||||
type Response = Response;
|
||||
type Error = Infallible;
|
||||
type Future = ResponseFuture<Fut>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
let not_ready_inner = self.inner.clone();
|
||||
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
|
||||
|
||||
let inner = ServiceBuilder::new()
|
||||
.boxed_clone()
|
||||
.map_response_body(body::boxed)
|
||||
.service(ready_inner);
|
||||
let next = Next { inner };
|
||||
|
||||
ResponseFuture {
|
||||
inner: (self.f)(req, next),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, S> fmt::Debug for MiddlewareFn<F, S>
|
||||
where
|
||||
S: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MiddlewareFnLayer")
|
||||
.field("f", &format_args!("{}", type_name::<F>()))
|
||||
.field("inner", &self.inner)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// The remainder of a middleware stack, including the handler.
|
||||
pub struct Next<ReqBody> {
|
||||
inner: BoxCloneService<Request<ReqBody>, Response, Infallible>,
|
||||
}
|
||||
|
||||
impl<ReqBody> Next<ReqBody> {
|
||||
/// Execute the remaining middleware stack.
|
||||
pub async fn run(mut self, req: Request<ReqBody>) -> Response {
|
||||
match self.inner.call(req).await {
|
||||
Ok(res) => res,
|
||||
Err(err) => match err {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<ReqBody> fmt::Debug for Next<ReqBody> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MiddlewareFnLayer")
|
||||
.field("inner", &self.inner)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Response future for [`MiddlewareFn`].
|
||||
pub struct ResponseFuture<F> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, Out> Future for ResponseFuture<F>
|
||||
where
|
||||
F: Future<Output = Out>,
|
||||
Out: IntoResponse,
|
||||
{
|
||||
type Output = Result<Response, Infallible>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project()
|
||||
.inner
|
||||
.poll(cx)
|
||||
.map(IntoResponse::into_response)
|
||||
.map(Ok)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::{body::Empty, routing::get, Router};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn basic() {
|
||||
async fn insert_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
|
||||
req.headers_mut()
|
||||
.insert("x-axum-test", "ok".parse().unwrap());
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
async fn handle(headers: HeaderMap) -> String {
|
||||
(&headers["x-axum-test"]).to_str().unwrap().to_owned()
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(handle))
|
||||
.layer(from_fn(insert_header));
|
||||
|
||||
let res = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/")
|
||||
.body(body::boxed(Empty::new()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = hyper::body::to_bytes(res).await.unwrap();
|
||||
assert_eq!(&body[..], b"ok");
|
||||
}
|
||||
}
|
5
axum-extra/src/middleware/mod.rs
Normal file
5
axum-extra/src/middleware/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
//! Additional types for creating middleware.
|
||||
|
||||
pub mod middleware_fn;
|
||||
|
||||
pub use self::middleware_fn::{from_fn, Next};
|
@ -6,6 +6,7 @@ publish = false
|
||||
|
||||
[dependencies]
|
||||
axum = { path = "../../axum" }
|
||||
axum-extra = { path = "../../axum-extra" }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version="0.3", features = ["env-filter"] }
|
||||
|
@ -6,14 +6,13 @@
|
||||
|
||||
use axum::{
|
||||
body::{Body, Bytes},
|
||||
error_handling::HandleErrorLayer,
|
||||
http::{Request, StatusCode},
|
||||
response::Response,
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use axum_extra::middleware::{self, Next};
|
||||
use std::net::SocketAddr;
|
||||
use tower::{filter::AsyncFilterLayer, util::AndThenLayer, BoxError, ServiceBuilder};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@ -28,17 +27,7 @@ async fn main() {
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", post(|| async move { "Hello from `POST /`" }))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|error| async move {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Unhandled internal error: {}", error),
|
||||
)
|
||||
}))
|
||||
.layer(AndThenLayer::new(map_response))
|
||||
.layer(AsyncFilterLayer::new(map_request)),
|
||||
);
|
||||
.layer(middleware::from_fn(print_request_response));
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
|
||||
tracing::debug!("listening on {}", addr);
|
||||
@ -48,28 +37,41 @@ async fn main() {
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn map_request(req: Request<Body>) -> Result<Request<Body>, BoxError> {
|
||||
async fn print_request_response(
|
||||
req: Request<Body>,
|
||||
next: Next<Body>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let (parts, body) = req.into_parts();
|
||||
let bytes = buffer_and_print("request", body).await?;
|
||||
let req = Request::from_parts(parts, Body::from(bytes));
|
||||
Ok(req)
|
||||
}
|
||||
|
||||
async fn map_response(res: Response) -> Result<Response<Body>, BoxError> {
|
||||
let res = next.run(req).await;
|
||||
|
||||
let (parts, body) = res.into_parts();
|
||||
let bytes = buffer_and_print("response", body).await?;
|
||||
let res = Response::from_parts(parts, Body::from(bytes));
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, BoxError>
|
||||
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
|
||||
where
|
||||
B: axum::body::HttpBody<Data = Bytes>,
|
||||
B::Error: Into<BoxError>,
|
||||
B::Error: std::fmt::Display,
|
||||
{
|
||||
let bytes = hyper::body::to_bytes(body).await.map_err(Into::into)?;
|
||||
let bytes = match hyper::body::to_bytes(body).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("failed to read {} body: {}", direction, err),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if let Ok(body) = std::str::from_utf8(&bytes) {
|
||||
tracing::debug!("{} body = {:?}", direction, body);
|
||||
}
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user