mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 15:24:54 +00:00
Add extractor_middleware
(#29)
Converts an extractor into a middleware so it can be run for many routes without having to repeat it in the function signature.
This commit is contained in:
parent
0c19fa4d52
commit
1cbd43cfc4
248
src/extract/extractor_middleware.rs
Normal file
248
src/extract/extractor_middleware.rs
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
//! Convert an extractor into a middleware.
|
||||||
|
//!
|
||||||
|
//! See [`extractor_middleware`] for more details.
|
||||||
|
|
||||||
|
use super::{rejection::BodyAlreadyExtracted, BodyAlreadyExtractedExt, FromRequest};
|
||||||
|
use crate::{body::BoxBody, response::IntoResponse};
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures_util::{future::BoxFuture, ready};
|
||||||
|
use http::{Request, Response};
|
||||||
|
use pin_project::pin_project;
|
||||||
|
use std::{
|
||||||
|
fmt,
|
||||||
|
future::Future,
|
||||||
|
marker::PhantomData,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tower::{BoxError, Layer, Service};
|
||||||
|
|
||||||
|
/// Convert an extractor into a middleware.
|
||||||
|
///
|
||||||
|
/// If the extractor succeeds the value will be discarded and the inner service
|
||||||
|
/// will be called. If the extractor fails the rejection will be returned and
|
||||||
|
/// the inner service will _not_ be called.
|
||||||
|
///
|
||||||
|
/// This can be used to perform validation of requests if the validation doesn't
|
||||||
|
/// produce any useful output, and run the extractor for several handlers
|
||||||
|
/// without repeating it in the function signature.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use axum::{extract::extractor_middleware, prelude::*};
|
||||||
|
/// use http::StatusCode;
|
||||||
|
/// use async_trait::async_trait;
|
||||||
|
///
|
||||||
|
/// // An extractor that performs authorization.
|
||||||
|
/// struct RequireAuth;
|
||||||
|
///
|
||||||
|
/// #[async_trait]
|
||||||
|
/// impl<B> extract::FromRequest<B> for RequireAuth
|
||||||
|
/// where
|
||||||
|
/// B: Send,
|
||||||
|
/// {
|
||||||
|
/// type Rejection = StatusCode;
|
||||||
|
///
|
||||||
|
/// async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
/// if let Some(value) = req
|
||||||
|
/// .headers()
|
||||||
|
/// .get(http::header::AUTHORIZATION)
|
||||||
|
/// .and_then(|value| value.to_str().ok())
|
||||||
|
/// {
|
||||||
|
/// if value == "secret" {
|
||||||
|
/// return Ok(Self);
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// Err(StatusCode::UNAUTHORIZED)
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// async fn handler() {
|
||||||
|
/// // If we get here the request has been authorized
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// async fn other_handler() {
|
||||||
|
/// // If we get here the request has been authorized
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let app = route("/", get(handler))
|
||||||
|
/// .route("/foo", post(other_handler))
|
||||||
|
/// // The extractor will run before all routes
|
||||||
|
/// .layer(extractor_middleware::<RequireAuth>());
|
||||||
|
/// # async {
|
||||||
|
/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
|
||||||
|
/// # };
|
||||||
|
/// ```
|
||||||
|
pub fn extractor_middleware<E>() -> ExtractorMiddlewareLayer<E> {
|
||||||
|
ExtractorMiddlewareLayer(PhantomData)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [`Layer`] that applies [`ExtractorMiddleware`] that runs an extractor and
|
||||||
|
/// discards the value.
|
||||||
|
///
|
||||||
|
/// See [`extractor_middleware`] for more details.
|
||||||
|
///
|
||||||
|
/// [`Layer`]: tower::Layer
|
||||||
|
pub struct ExtractorMiddlewareLayer<E>(PhantomData<fn() -> E>);
|
||||||
|
|
||||||
|
impl<E> Clone for ExtractorMiddlewareLayer<E> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self(PhantomData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E> Copy for ExtractorMiddlewareLayer<E> {}
|
||||||
|
|
||||||
|
impl<E> fmt::Debug for ExtractorMiddlewareLayer<E> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("ExtractorMiddleware")
|
||||||
|
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E, S> Layer<S> for ExtractorMiddlewareLayer<E> {
|
||||||
|
type Service = ExtractorMiddleware<S, E>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
ExtractorMiddleware {
|
||||||
|
inner,
|
||||||
|
_extractor: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware that runs an extractor and discards the value.
|
||||||
|
///
|
||||||
|
/// See [`extractor_middleware`] for more details.
|
||||||
|
pub struct ExtractorMiddleware<S, E> {
|
||||||
|
inner: S,
|
||||||
|
_extractor: PhantomData<fn() -> E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, E> Clone for ExtractorMiddleware<S, E>
|
||||||
|
where
|
||||||
|
S: Clone,
|
||||||
|
{
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: self.inner.clone(),
|
||||||
|
_extractor: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, E> Copy for ExtractorMiddleware<S, E> where S: Copy {}
|
||||||
|
|
||||||
|
impl<S, E> fmt::Debug for ExtractorMiddleware<S, E>
|
||||||
|
where
|
||||||
|
S: fmt::Debug,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("ExtractorMiddleware")
|
||||||
|
.field("inner", &self.inner)
|
||||||
|
.field("extractor", &format_args!("{}", std::any::type_name::<E>()))
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
|
||||||
|
where
|
||||||
|
E: FromRequest<ReqBody> + 'static,
|
||||||
|
ReqBody: Send + 'static,
|
||||||
|
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
|
||||||
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||||
|
ResBody::Error: Into<BoxError>,
|
||||||
|
{
|
||||||
|
type Response = Response<BoxBody>;
|
||||||
|
type Error = S::Error;
|
||||||
|
type Future = ExtractorMiddlewareResponseFuture<ReqBody, S, E>;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.inner.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
|
||||||
|
let extract_future = Box::pin(async move {
|
||||||
|
let extracted = E::from_request(&mut req).await;
|
||||||
|
(req, extracted)
|
||||||
|
});
|
||||||
|
|
||||||
|
ExtractorMiddlewareResponseFuture {
|
||||||
|
state: State::Extracting(extract_future),
|
||||||
|
svc: Some(self.inner.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response future for [`ExtractorMiddleware`].
|
||||||
|
#[allow(missing_debug_implementations)]
|
||||||
|
#[pin_project]
|
||||||
|
pub struct ExtractorMiddlewareResponseFuture<ReqBody, S, E>
|
||||||
|
where
|
||||||
|
E: FromRequest<ReqBody>,
|
||||||
|
S: Service<Request<ReqBody>>,
|
||||||
|
{
|
||||||
|
#[pin]
|
||||||
|
state: State<ReqBody, S, E>,
|
||||||
|
svc: Option<S>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pin_project(project = StateProj)]
|
||||||
|
enum State<ReqBody, S, E>
|
||||||
|
where
|
||||||
|
E: FromRequest<ReqBody>,
|
||||||
|
S: Service<Request<ReqBody>>,
|
||||||
|
{
|
||||||
|
Extracting(BoxFuture<'static, (Request<ReqBody>, Result<E, E::Rejection>)>),
|
||||||
|
Call(#[pin] S::Future),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ReqBody, S, E, ResBody> Future for ExtractorMiddlewareResponseFuture<ReqBody, S, E>
|
||||||
|
where
|
||||||
|
E: FromRequest<ReqBody>,
|
||||||
|
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
||||||
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||||
|
ResBody::Error: Into<BoxError>,
|
||||||
|
{
|
||||||
|
type Output = Result<Response<BoxBody>, S::Error>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
loop {
|
||||||
|
let mut this = self.as_mut().project();
|
||||||
|
|
||||||
|
let new_state = match this.state.as_mut().project() {
|
||||||
|
StateProj::Extracting(future) => {
|
||||||
|
let (req, extracted) = ready!(future.as_mut().poll(cx));
|
||||||
|
|
||||||
|
match extracted {
|
||||||
|
Ok(_) => {
|
||||||
|
if req.extensions().get::<BodyAlreadyExtractedExt>().is_some() {
|
||||||
|
let res = BodyAlreadyExtracted.into_response().map(BoxBody::new);
|
||||||
|
return Poll::Ready(Ok(res));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut svc = this.svc.take().expect("future polled after completion");
|
||||||
|
let future = svc.call(req);
|
||||||
|
State::Call(future)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let res = err.into_response().map(BoxBody::new);
|
||||||
|
return Poll::Ready(Ok(res));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StateProj::Call(future) => {
|
||||||
|
return future
|
||||||
|
.poll(cx)
|
||||||
|
.map(|result| result.map(|response| response.map(BoxBody::new)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
this.state.set(new_state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -189,8 +189,12 @@ use std::{
|
|||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub mod extractor_middleware;
|
||||||
pub mod rejection;
|
pub mod rejection;
|
||||||
|
|
||||||
|
#[doc(inline)]
|
||||||
|
pub use self::extractor_middleware::extractor_middleware;
|
||||||
|
|
||||||
/// Types that can be created from requests.
|
/// Types that can be created from requests.
|
||||||
///
|
///
|
||||||
/// See the [module docs](crate::extract) for more details.
|
/// See the [module docs](crate::extract) for more details.
|
||||||
@ -840,12 +844,14 @@ macro_rules! impl_parse_url {
|
|||||||
|
|
||||||
impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||||
|
|
||||||
|
/// Request extension used to indicate that body has been extracted and `Default` has been left in
|
||||||
|
/// its place.
|
||||||
|
struct BodyAlreadyExtractedExt;
|
||||||
|
|
||||||
fn take_body<B>(req: &mut Request<B>) -> Result<B, BodyAlreadyExtracted>
|
fn take_body<B>(req: &mut Request<B>) -> Result<B, BodyAlreadyExtracted>
|
||||||
where
|
where
|
||||||
B: Default,
|
B: Default,
|
||||||
{
|
{
|
||||||
struct BodyAlreadyExtractedExt;
|
|
||||||
|
|
||||||
if req
|
if req
|
||||||
.extensions_mut()
|
.extensions_mut()
|
||||||
.insert(BodyAlreadyExtractedExt)
|
.insert(BodyAlreadyExtractedExt)
|
||||||
|
67
src/tests.rs
67
src/tests.rs
@ -1,6 +1,6 @@
|
|||||||
use crate::{handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, service};
|
use crate::{handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, service};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::{Request, Response, StatusCode};
|
use http::{header::AUTHORIZATION, Request, Response, StatusCode};
|
||||||
use hyper::{Body, Server};
|
use hyper::{Body, Server};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
@ -664,6 +664,71 @@ async fn service_in_bottom() {
|
|||||||
run_in_background(app).await;
|
run_in_background(app).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_extractor_middleware() {
|
||||||
|
struct RequireAuth;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl<B> extract::FromRequest<B> for RequireAuth
|
||||||
|
where
|
||||||
|
B: Send,
|
||||||
|
{
|
||||||
|
type Rejection = StatusCode;
|
||||||
|
|
||||||
|
async fn from_request(req: &mut Request<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
if let Some(auth) = req
|
||||||
|
.headers()
|
||||||
|
.get("authorization")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
{
|
||||||
|
if auth == "secret" {
|
||||||
|
return Ok(Self);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(StatusCode::UNAUTHORIZED)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handler() {}
|
||||||
|
|
||||||
|
let app = route(
|
||||||
|
"/",
|
||||||
|
get(handler.layer(extract::extractor_middleware::<RequireAuth>())),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/take-body-error",
|
||||||
|
post(handler.layer(extract::extractor_middleware::<Bytes>())),
|
||||||
|
);
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/", addr))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}/", addr))
|
||||||
|
.header(AUTHORIZATION, "secret")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://{}/take-body-error", addr))
|
||||||
|
.body("foobar")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
/// Run a `tower::Service` in the background and get a URI for it.
|
/// Run a `tower::Service` in the background and get a URI for it.
|
||||||
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
||||||
where
|
where
|
||||||
|
Loading…
x
Reference in New Issue
Block a user