diff --git a/src/extract/extractor_middleware.rs b/src/extract/extractor_middleware.rs new file mode 100644 index 00000000..7bb1de8d --- /dev/null +++ b/src/extract/extractor_middleware.rs @@ -0,0 +1,248 @@ +//! Convert an extractor into a middleware. +//! +//! See [`extractor_middleware`] for more details. + +use super::{rejection::BodyAlreadyExtracted, BodyAlreadyExtractedExt, FromRequest}; +use crate::{body::BoxBody, response::IntoResponse}; +use bytes::Bytes; +use futures_util::{future::BoxFuture, ready}; +use http::{Request, Response}; +use pin_project::pin_project; +use std::{ + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; +use tower::{BoxError, Layer, Service}; + +/// Convert an extractor into a middleware. +/// +/// If the extractor succeeds the value will be discarded and the inner service +/// will be called. If the extractor fails the rejection will be returned and +/// the inner service will _not_ be called. +/// +/// This can be used to perform validation of requests if the validation doesn't +/// produce any useful output, and run the extractor for several handlers +/// without repeating it in the function signature. +/// +/// # Example +/// +/// ```rust +/// use axum::{extract::extractor_middleware, prelude::*}; +/// use http::StatusCode; +/// use async_trait::async_trait; +/// +/// // An extractor that performs authorization. +/// struct RequireAuth; +/// +/// #[async_trait] +/// impl extract::FromRequest for RequireAuth +/// where +/// B: Send, +/// { +/// type Rejection = StatusCode; +/// +/// async fn from_request(req: &mut Request) -> Result { +/// if let Some(value) = req +/// .headers() +/// .get(http::header::AUTHORIZATION) +/// .and_then(|value| value.to_str().ok()) +/// { +/// if value == "secret" { +/// return Ok(Self); +/// } +/// } +/// +/// Err(StatusCode::UNAUTHORIZED) +/// } +/// } +/// +/// async fn handler() { +/// // If we get here the request has been authorized +/// } +/// +/// async fn other_handler() { +/// // If we get here the request has been authorized +/// } +/// +/// let app = route("/", get(handler)) +/// .route("/foo", post(other_handler)) +/// // The extractor will run before all routes +/// .layer(extractor_middleware::()); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +pub fn extractor_middleware() -> ExtractorMiddlewareLayer { + ExtractorMiddlewareLayer(PhantomData) +} + +/// [`Layer`] that applies [`ExtractorMiddleware`] that runs an extractor and +/// discards the value. +/// +/// See [`extractor_middleware`] for more details. +/// +/// [`Layer`]: tower::Layer +pub struct ExtractorMiddlewareLayer(PhantomData E>); + +impl Clone for ExtractorMiddlewareLayer { + fn clone(&self) -> Self { + Self(PhantomData) + } +} + +impl Copy for ExtractorMiddlewareLayer {} + +impl fmt::Debug for ExtractorMiddlewareLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtractorMiddleware") + .field("extractor", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Layer for ExtractorMiddlewareLayer { + type Service = ExtractorMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + ExtractorMiddleware { + inner, + _extractor: PhantomData, + } + } +} + +/// Middleware that runs an extractor and discards the value. +/// +/// See [`extractor_middleware`] for more details. +pub struct ExtractorMiddleware { + inner: S, + _extractor: PhantomData E>, +} + +impl Clone for ExtractorMiddleware +where + S: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + _extractor: PhantomData, + } + } +} + +impl Copy for ExtractorMiddleware where S: Copy {} + +impl fmt::Debug for ExtractorMiddleware +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtractorMiddleware") + .field("inner", &self.inner) + .field("extractor", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Service> for ExtractorMiddleware +where + E: FromRequest + 'static, + ReqBody: Send + 'static, + S: Service, Response = Response> + Clone, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, +{ + type Response = Response; + type Error = S::Error; + type Future = ExtractorMiddlewareResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let extract_future = Box::pin(async move { + let extracted = E::from_request(&mut req).await; + (req, extracted) + }); + + ExtractorMiddlewareResponseFuture { + state: State::Extracting(extract_future), + svc: Some(self.inner.clone()), + } + } +} + +/// Response future for [`ExtractorMiddleware`]. +#[allow(missing_debug_implementations)] +#[pin_project] +pub struct ExtractorMiddlewareResponseFuture +where + E: FromRequest, + S: Service>, +{ + #[pin] + state: State, + svc: Option, +} + +#[pin_project(project = StateProj)] +enum State +where + E: FromRequest, + S: Service>, +{ + Extracting(BoxFuture<'static, (Request, Result)>), + Call(#[pin] S::Future), +} + +impl Future for ExtractorMiddlewareResponseFuture +where + E: FromRequest, + S: Service, Response = Response>, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, +{ + type Output = Result, S::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + let mut this = self.as_mut().project(); + + let new_state = match this.state.as_mut().project() { + StateProj::Extracting(future) => { + let (req, extracted) = ready!(future.as_mut().poll(cx)); + + match extracted { + Ok(_) => { + if req.extensions().get::().is_some() { + let res = BodyAlreadyExtracted.into_response().map(BoxBody::new); + return Poll::Ready(Ok(res)); + } + + let mut svc = this.svc.take().expect("future polled after completion"); + let future = svc.call(req); + State::Call(future) + } + Err(err) => { + let res = err.into_response().map(BoxBody::new); + return Poll::Ready(Ok(res)); + } + } + } + StateProj::Call(future) => { + return future + .poll(cx) + .map(|result| result.map(|response| response.map(BoxBody::new))); + } + }; + + this.state.set(new_state); + } + } +} diff --git a/src/extract/mod.rs b/src/extract/mod.rs index e2293672..33e92961 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -189,8 +189,12 @@ use std::{ task::{Context, Poll}, }; +pub mod extractor_middleware; pub mod rejection; +#[doc(inline)] +pub use self::extractor_middleware::extractor_middleware; + /// Types that can be created from requests. /// /// See the [module docs](crate::extract) for more details. @@ -840,12 +844,14 @@ macro_rules! impl_parse_url { impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); +/// Request extension used to indicate that body has been extracted and `Default` has been left in +/// its place. +struct BodyAlreadyExtractedExt; + fn take_body(req: &mut Request) -> Result where B: Default, { - struct BodyAlreadyExtractedExt; - if req .extensions_mut() .insert(BodyAlreadyExtractedExt) diff --git a/src/tests.rs b/src/tests.rs index fc8b9646..c2fec39e 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,6 @@ use crate::{handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, service}; use bytes::Bytes; -use http::{Request, Response, StatusCode}; +use http::{header::AUTHORIZATION, Request, Response, StatusCode}; use hyper::{Body, Server}; use serde::Deserialize; use serde_json::json; @@ -664,6 +664,71 @@ async fn service_in_bottom() { run_in_background(app).await; } +#[tokio::test] +async fn test_extractor_middleware() { + struct RequireAuth; + + #[async_trait::async_trait] + impl extract::FromRequest for RequireAuth + where + B: Send, + { + type Rejection = StatusCode; + + async fn from_request(req: &mut Request) -> Result { + if let Some(auth) = req + .headers() + .get("authorization") + .and_then(|v| v.to_str().ok()) + { + if auth == "secret" { + return Ok(Self); + } + } + + Err(StatusCode::UNAUTHORIZED) + } + } + + async fn handler() {} + + let app = route( + "/", + get(handler.layer(extract::extractor_middleware::())), + ) + .route( + "/take-body-error", + post(handler.layer(extract::extractor_middleware::())), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + + let res = client + .get(format!("http://{}/", addr)) + .header(AUTHORIZATION, "secret") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}/take-body-error", addr)) + .body("foobar") + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); +} + /// Run a `tower::Service` in the background and get a URI for it. async fn run_in_background(svc: S) -> SocketAddr where