From f32d325e55fe002fdfcad8fd32475c2e3222c225 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Thu, 22 Jul 2021 13:23:50 +0200 Subject: [PATCH] Make extractors easier to write (#36) Previously extractors worked directly on `Request` which meant you had to do weird tricks like `mem::take(req.headers_mut())` to get owned parts of the request. This changes that instead to use a new `RequestParts` type that have methods to "take" each part of the request. Without having to do weird tricks. Also removed the need to have `B: Default` for body extractors. --- examples/key_value_store.rs | 38 ++- examples/versioning.rs | 8 +- src/body.rs | 84 +------ src/extract/extractor_middleware.rs | 28 ++- src/extract/mod.rs | 362 ++++++++++++++++++++++------ src/extract/multipart.rs | 8 +- src/extract/rejection.rs | 54 +++++ src/handler/mod.rs | 18 +- src/routing.rs | 11 +- src/service/future.rs | 9 +- src/service/mod.rs | 4 +- src/tests.rs | 10 +- 12 files changed, 441 insertions(+), 193 deletions(-) diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 8a457979..f10bd7fb 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -7,7 +7,8 @@ //! ``` use axum::{ - extract::{ContentLengthLimit, Extension, UrlParams}, + async_trait, + extract::{extractor_middleware, ContentLengthLimit, Extension, RequestParts, UrlParams}, prelude::*, response::IntoResponse, routing::BoxRoute, @@ -24,8 +25,7 @@ use std::{ }; use tower::{BoxError, ServiceBuilder}; use tower_http::{ - add_extension::AddExtensionLayer, auth::RequireAuthorizationLayer, - compression::CompressionLayer, trace::TraceLayer, + add_extension::AddExtensionLayer, compression::CompressionLayer, trace::TraceLayer, }; #[tokio::main] @@ -118,10 +118,40 @@ fn admin_routes() -> BoxRoute { route("/keys", delete(delete_all_keys)) .route("/key/:key", delete(remove_key)) // Require beare auth for all admin routes - .layer(RequireAuthorizationLayer::bearer("secret-token")) + .layer(extractor_middleware::()) .boxed() } +/// An extractor that performs authorization. +// TODO: when https://github.com/hyperium/http-body/pull/46 is merged we can use +// `tower_http::auth::RequireAuthorization` instead +struct RequireAuth; + +#[async_trait] +impl extract::FromRequest for RequireAuth +where + B: Send, +{ + type Rejection = StatusCode; + + async fn from_request(req: &mut RequestParts) -> Result { + let auth_header = req + .headers() + .and_then(|headers| headers.get(http::header::AUTHORIZATION)) + .and_then(|value| value.to_str().ok()); + + if let Some(value) = auth_header { + if let Some(token) = value.strip_prefix("Bearer ") { + if token == "secret-token" { + return Ok(Self); + } + } + } + + Err(StatusCode::UNAUTHORIZED) + } +} + fn handle_error(error: BoxError) -> impl IntoResponse { if error.is::() { return (StatusCode::REQUEST_TIMEOUT, Cow::from("request timed out")); diff --git a/examples/versioning.rs b/examples/versioning.rs index 88447006..61fd5051 100644 --- a/examples/versioning.rs +++ b/examples/versioning.rs @@ -1,5 +1,9 @@ use axum::response::IntoResponse; -use axum::{async_trait, extract::FromRequest, prelude::*}; +use axum::{ + async_trait, + extract::{FromRequest, RequestParts}, + prelude::*, +}; use http::Response; use http::StatusCode; use std::net::SocketAddr; @@ -36,7 +40,7 @@ where { type Rejection = Response; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = extract::UrlParamsMap::from_request(req) .await .map_err(IntoResponse::into_response)?; diff --git a/src/body.rs b/src/body.rs index a1952b3b..667d544e 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,13 +1,8 @@ //! HTTP body utilities. use bytes::Bytes; -use http_body::{Empty, Full}; -use std::{ - error::Error as StdError, - fmt, - pin::Pin, - task::{Context, Poll}, -}; +use http_body::Body as _; +use std::{error::Error as StdError, fmt}; use tower::BoxError; pub use hyper::body::Body; @@ -16,75 +11,18 @@ pub use hyper::body::Body; /// /// This is used in axum as the response body type for applications. Its necessary to unify /// multiple response bodies types into one. -pub struct BoxBody { - // when we've gotten rid of `BoxStdError` we should be able to change the error type to - // `BoxError` - inner: Pin + Send + Sync + 'static>>, -} +pub type BoxBody = http_body::combinators::BoxBody; -impl BoxBody { - /// Create a new `BoxBody`. - pub fn new(body: B) -> Self - where - B: http_body::Body + Send + Sync + 'static, - B::Error: Into, - { - Self { - inner: Box::pin(body.map_err(|error| BoxStdError(error.into()))), - } - } - - pub(crate) fn empty() -> Self { - Self::new(Empty::new()) - } -} - -impl Default for BoxBody { - fn default() -> Self { - BoxBody::empty() - } -} - -impl fmt::Debug for BoxBody { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("BoxBody").finish() - } -} - -impl http_body::Body for BoxBody { - type Data = Bytes; - type Error = BoxStdError; - - fn poll_data( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - self.inner.as_mut().poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.inner.as_mut().poll_trailers(cx) - } - - fn is_end_stream(&self) -> bool { - self.inner.is_end_stream() - } - - fn size_hint(&self) -> http_body::SizeHint { - self.inner.size_hint() - } -} - -impl From for BoxBody +pub(crate) fn box_body(body: B) -> BoxBody where - B: Into, + B: http_body::Body + Send + Sync + 'static, + B::Error: Into, { - fn from(s: B) -> Self { - BoxBody::new(Full::from(s.into())) - } + body.map_err(|err| BoxStdError(err.into())).boxed() +} + +pub(crate) fn empty() -> BoxBody { + box_body(http_body::Empty::new()) } /// A boxed error trait object that implements [`std::error::Error`]. diff --git a/src/extract/extractor_middleware.rs b/src/extract/extractor_middleware.rs index d075f253..a8fa6301 100644 --- a/src/extract/extractor_middleware.rs +++ b/src/extract/extractor_middleware.rs @@ -2,7 +2,7 @@ //! //! See [`extractor_middleware`] for more details. -use super::FromRequest; +use super::{FromRequest, RequestParts}; use crate::{body::BoxBody, response::IntoResponse}; use bytes::Bytes; use futures_util::{future::BoxFuture, ready}; @@ -34,7 +34,7 @@ use tower::{BoxError, Layer, Service}; /// # Example /// /// ```rust -/// use axum::{extract::extractor_middleware, prelude::*}; +/// use axum::{extract::{extractor_middleware, RequestParts}, prelude::*}; /// use http::StatusCode; /// use async_trait::async_trait; /// @@ -48,12 +48,13 @@ use tower::{BoxError, Layer, Service}; /// { /// type Rejection = StatusCode; /// -/// async fn from_request(req: &mut Request) -> Result { -/// if let Some(value) = req +/// async fn from_request(req: &mut RequestParts) -> Result { +/// let auth_header = req /// .headers() -/// .get(http::header::AUTHORIZATION) -/// .and_then(|value| value.to_str().ok()) -/// { +/// .and_then(|headers| headers.get(http::header::AUTHORIZATION)) +/// .and_then(|value| value.to_str().ok()); +/// +/// if let Some(value) = auth_header { /// if value == "secret" { /// return Ok(Self); /// } @@ -169,8 +170,9 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let extract_future = Box::pin(async move { + let mut req = super::RequestParts::new(req); let extracted = E::from_request(&mut req).await; (req, extracted) }); @@ -201,7 +203,7 @@ where E: FromRequest, S: Service>, { - Extracting(BoxFuture<'static, (Request, Result)>), + Extracting(BoxFuture<'static, (RequestParts, Result)>), Call(#[pin] S::Future), } @@ -220,16 +222,16 @@ where let new_state = match this.state.as_mut().project() { StateProj::Extracting(future) => { - let (req, extracted) = ready!(future.as_mut().poll(cx)); + let (mut req, extracted) = ready!(future.as_mut().poll(cx)); match extracted { Ok(_) => { let mut svc = this.svc.take().expect("future polled after completion"); - let future = svc.call(req); + let future = svc.call(req.into_request()); State::Call(future) } Err(err) => { - let res = err.into_response().map(BoxBody::new); + let res = err.into_response().map(crate::body::box_body); return Poll::Ready(Ok(res)); } } @@ -237,7 +239,7 @@ where StateProj::Call(future) => { return future .poll(cx) - .map(|result| result.map(|response| response.map(BoxBody::new))); + .map(|result| result.map(|response| response.map(crate::body::box_body))); } }; diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 1cbc8989..8f25d812 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -34,7 +34,7 @@ //! You can also define your own extractors by implementing [`FromRequest`]: //! //! ```rust,no_run -//! use axum::{async_trait, extract::FromRequest, prelude::*}; +//! use axum::{async_trait, extract::{FromRequest, RequestParts}, prelude::*}; //! use http::{StatusCode, header::{HeaderValue, USER_AGENT}}; //! //! struct ExtractUserAgent(HeaderValue); @@ -46,8 +46,10 @@ //! { //! type Rejection = (StatusCode, &'static str); //! -//! async fn from_request(req: &mut Request) -> Result { -//! if let Some(user_agent) = req.headers().get(USER_AGENT) { +//! async fn from_request(req: &mut RequestParts) -> Result { +//! let user_agent = req.headers().and_then(|headers| headers.get(USER_AGENT)); +//! +//! if let Some(user_agent) = user_agent { //! Ok(ExtractUserAgent(user_agent.clone())) //! } else { //! Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing")) @@ -175,13 +177,12 @@ use crate::{response::IntoResponse, util::ByteStr}; use async_trait::async_trait; use bytes::{Buf, Bytes}; use futures_util::stream::Stream; -use http::{header, HeaderMap, Method, Request, Uri, Version}; +use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use rejection::*; use serde::de::DeserializeOwned; use std::{ collections::HashMap, convert::Infallible, - mem, pin::Pin, str::FromStr, task::{Context, Poll}, @@ -212,7 +213,195 @@ pub trait FromRequest: Sized { type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: &mut Request) -> Result; + async fn from_request(req: &mut RequestParts) -> Result; +} + +/// The type used with [`FromRequest`] to extract data from requests. +/// +/// Has several convenience methods for getting owned parts of the request. +#[derive(Debug)] +pub struct RequestParts { + method: Option, + uri: Option, + version: Option, + headers: Option, + extensions: Option, + body: Option, +} + +impl RequestParts { + pub(crate) fn new(req: Request) -> Self { + let ( + http::request::Parts { + method, + uri, + version, + headers, + extensions, + .. + }, + body, + ) = req.into_parts(); + + RequestParts { + method: Some(method), + uri: Some(uri), + version: Some(version), + headers: Some(headers), + extensions: Some(extensions), + body: Some(body), + } + } + + #[allow(clippy::wrong_self_convention)] + pub(crate) fn into_request(&mut self) -> Request { + let Self { + method, + uri, + version, + headers, + extensions, + body, + } = self; + + let mut req = Request::new(body.take().expect("body already extracted")); + + if let Some(method) = method.take() { + *req.method_mut() = method; + } + + if let Some(uri) = uri.take() { + *req.uri_mut() = uri; + } + + if let Some(version) = version.take() { + *req.version_mut() = version; + } + + if let Some(headers) = headers.take() { + *req.headers_mut() = headers; + } + + if let Some(extensions) = extensions.take() { + *req.extensions_mut() = extensions; + } + + req + } + + /// Gets a reference to the request method. + /// + /// Returns `None` if the method has been taken by another extractor. + pub fn method(&self) -> Option<&Method> { + self.method.as_ref() + } + + /// Gets a mutable reference to the request method. + /// + /// Returns `None` if the method has been taken by another extractor. + pub fn method_mut(&mut self) -> Option<&mut Method> { + self.method.as_mut() + } + + /// Takes the method out of the request, leaving a `None` in its place. + pub fn take_method(&mut self) -> Option { + self.method.take() + } + + /// Gets a reference to the request URI. + /// + /// Returns `None` if the URI has been taken by another extractor. + pub fn uri(&self) -> Option<&Uri> { + self.uri.as_ref() + } + + /// Gets a mutable reference to the request URI. + /// + /// Returns `None` if the URI has been taken by another extractor. + pub fn uri_mut(&mut self) -> Option<&mut Uri> { + self.uri.as_mut() + } + + /// Takes the URI out of the request, leaving a `None` in its place. + pub fn take_uri(&mut self) -> Option { + self.uri.take() + } + + /// Gets a reference to the request HTTP version. + /// + /// Returns `None` if the HTTP version has been taken by another extractor. + pub fn version(&self) -> Option { + self.version + } + + /// Gets a mutable reference to the request HTTP version. + /// + /// Returns `None` if the HTTP version has been taken by another extractor. + pub fn version_mut(&mut self) -> Option<&mut Version> { + self.version.as_mut() + } + + /// Takes the HTTP version out of the request, leaving a `None` in its place. + pub fn take_version(&mut self) -> Option { + self.version.take() + } + + /// Gets a reference to the request headers. + /// + /// Returns `None` if the headers has been taken by another extractor. + pub fn headers(&self) -> Option<&HeaderMap> { + self.headers.as_ref() + } + + /// Gets a mutable reference to the request headers. + /// + /// Returns `None` if the headers has been taken by another extractor. + pub fn headers_mut(&mut self) -> Option<&mut HeaderMap> { + self.headers.as_mut() + } + + /// Takes the headers out of the request, leaving a `None` in its place. + pub fn take_headers(&mut self) -> Option { + self.headers.take() + } + + /// Gets a reference to the request extensions. + /// + /// Returns `None` if the extensions has been taken by another extractor. + pub fn extensions(&self) -> Option<&Extensions> { + self.extensions.as_ref() + } + + /// Gets a mutable reference to the request extensions. + /// + /// Returns `None` if the extensions has been taken by another extractor. + pub fn extensions_mut(&mut self) -> Option<&mut Extensions> { + self.extensions.as_mut() + } + + /// Takes the extensions out of the request, leaving a `None` in its place. + pub fn take_extensions(&mut self) -> Option { + self.extensions.take() + } + + /// Gets a reference to the request body. + /// + /// Returns `None` if the body has been taken by another extractor. + pub fn body(&self) -> Option<&B> { + self.body.as_ref() + } + + /// Gets a mutable reference to the request body. + /// + /// Returns `None` if the body has been taken by another extractor. + pub fn body_mut(&mut self) -> Option<&mut B> { + self.body.as_mut() + } + + /// Takes the body out of the request, leaving a `None` in its place. + pub fn take_body(&mut self) -> Option { + self.body.take() + } } #[async_trait] @@ -223,7 +412,7 @@ where { type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result, Self::Rejection> { + async fn from_request(req: &mut RequestParts) -> Result, Self::Rejection> { Ok(T::from_request(req).await.ok()) } } @@ -236,7 +425,7 @@ where { type Rejection = Infallible; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { Ok(T::from_request(req).await) } } @@ -284,8 +473,12 @@ where { type Rejection = QueryRejection; - async fn from_request(req: &mut Request) -> Result { - let query = req.uri().query().ok_or(QueryStringMissing)?; + async fn from_request(req: &mut RequestParts) -> Result { + let query = req + .uri() + .ok_or(UriAlreadyExtracted)? + .query() + .ok_or(QueryStringMissing)?; let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::new::)?; Ok(Query(value)) @@ -329,20 +522,24 @@ pub struct Form(pub T); impl FromRequest for Form where T: DeserializeOwned, - B: http_body::Body + Default + Send, + B: http_body::Body + Send, B::Data: Send, B::Error: Into, { type Rejection = FormRejection; #[allow(warnings)] - async fn from_request(req: &mut Request) -> Result { - if !has_content_type(&req, "application/x-www-form-urlencoded") { + async fn from_request(req: &mut RequestParts) -> Result { + if !has_content_type(&req, "application/x-www-form-urlencoded")? { Err(InvalidFormContentType)?; } - if req.method() == Method::GET { - let query = req.uri().query().ok_or(QueryStringMissing)?; + if req.method().ok_or(MethodAlreadyExtracted)? == Method::GET { + let query = req + .uri() + .ok_or(UriAlreadyExtracted)? + .query() + .ok_or(QueryStringMissing)?; let value = serde_urlencoded::from_str(query) .map_err(FailedToDeserializeQueryString::new::)?; Ok(Form(value)) @@ -398,16 +595,16 @@ pub struct Json(pub T); impl FromRequest for Json where T: DeserializeOwned, - B: http_body::Body + Default + Send, + B: http_body::Body + Send, B::Data: Send, B::Error: Into, { type Rejection = JsonRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { use bytes::Buf; - if has_content_type(req, "application/json") { + if has_content_type(req, "application/json")? { let body = take_body(req)?; let buf = hyper::body::aggregate(body) @@ -423,20 +620,27 @@ where } } -fn has_content_type(req: &Request, expected_content_type: &str) -> bool { - let content_type = if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) { +fn has_content_type( + req: &RequestParts, + expected_content_type: &str, +) -> Result { + let content_type = if let Some(content_type) = req + .headers() + .ok_or(HeadersAlreadyExtracted)? + .get(header::CONTENT_TYPE) + { content_type } else { - return false; + return Ok(false); }; let content_type = if let Ok(content_type) = content_type.to_str() { content_type } else { - return false; + return Ok(false); }; - content_type.starts_with(expected_content_type) + Ok(content_type.starts_with(expected_content_type)) } /// Extractor that gets a value from request extensions. @@ -480,11 +684,12 @@ where T: Clone + Send + Sync + 'static, B: Send, { - type Rejection = MissingExtension; + type Rejection = ExtensionRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let value = req .extensions() + .ok_or(ExtensionsAlreadyExtracted)? .get::() .ok_or(MissingExtension) .map(|x| x.clone())?; @@ -496,13 +701,13 @@ where #[async_trait] impl FromRequest for Bytes where - B: http_body::Body + Default + Send, + B: http_body::Body + Send, B::Data: Send, B::Error: Into, { type Rejection = BytesRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = hyper::body::to_bytes(body) @@ -516,13 +721,13 @@ where #[async_trait] impl FromRequest for String where - B: http_body::Body + Default + Send, + B: http_body::Body + Send, B::Data: Send, B::Error: Into, { type Rejection = StringRejection; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let bytes = hyper::body::to_bytes(body) @@ -572,11 +777,11 @@ where #[async_trait] impl FromRequest for BodyStream where - B: http_body::Body + Default + Unpin + Send, + B: http_body::Body + Unpin + Send, { type Rejection = BodyAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let body = take_body(req)?; let stream = BodyStream(body); Ok(stream) @@ -586,21 +791,22 @@ where #[async_trait] impl FromRequest for Request where - B: Default + Send, + B: Send, { type Rejection = RequestAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { - struct RequestAlreadyExtractedExt; + async fn from_request(req: &mut RequestParts) -> Result { + let all_parts = req + .method() + .zip(req.uri()) + .zip(req.headers()) + .zip(req.extensions()) + .zip(req.body()); - if req - .extensions_mut() - .insert(RequestAlreadyExtractedExt) - .is_some() - { - Err(RequestAlreadyExtracted) + if all_parts.is_some() { + Ok(req.into_request()) } else { - Ok(mem::take(req)) + Err(RequestAlreadyExtracted) } } } @@ -610,10 +816,10 @@ impl FromRequest for Method where B: Send, { - type Rejection = Infallible; + type Rejection = MethodAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { - Ok(req.method().clone()) + async fn from_request(req: &mut RequestParts) -> Result { + req.take_method().ok_or(MethodAlreadyExtracted) } } @@ -622,10 +828,10 @@ impl FromRequest for Uri where B: Send, { - type Rejection = Infallible; + type Rejection = UriAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { - Ok(req.uri().clone()) + async fn from_request(req: &mut RequestParts) -> Result { + req.take_uri().ok_or(UriAlreadyExtracted) } } @@ -634,10 +840,10 @@ impl FromRequest for Version where B: Send, { - type Rejection = Infallible; + type Rejection = VersionAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { - Ok(req.version()) + async fn from_request(req: &mut RequestParts) -> Result { + req.take_version().ok_or(VersionAlreadyExtracted) } } @@ -646,10 +852,10 @@ impl FromRequest for HeaderMap where B: Send, { - type Rejection = Infallible; + type Rejection = HeadersAlreadyExtracted; - async fn from_request(req: &mut Request) -> Result { - Ok(mem::take(req.headers_mut())) + async fn from_request(req: &mut RequestParts) -> Result { + req.take_headers().ok_or(HeadersAlreadyExtracted) } } @@ -682,8 +888,13 @@ where { type Rejection = ContentLengthLimitRejection; - async fn from_request(req: &mut Request) -> Result { - let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); + async fn from_request(req: &mut RequestParts) -> Result { + let content_length = req + .headers() + .ok_or(ContentLengthLimitRejection::HeadersAlreadyExtracted( + HeadersAlreadyExtracted, + ))? + .get(http::header::CONTENT_LENGTH); let content_length = content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); @@ -752,10 +963,10 @@ where { type Rejection = MissingRouteParams; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(params) = req .extensions_mut() - .get_mut::>() + .and_then(|ext| ext.get_mut::>()) { if let Some(params) = params { Ok(Self(params.0.iter().cloned().collect())) @@ -810,10 +1021,12 @@ macro_rules! impl_parse_url { type Rejection = UrlParamsRejection; #[allow(non_snake_case)] - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let params = if let Some(params) = req .extensions_mut() - .get_mut::>() + .and_then(|ext| { + ext.get_mut::>() + }) { if let Some(params) = params { params.0.clone() @@ -852,23 +1065,8 @@ macro_rules! impl_parse_url { impl_parse_url!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); -/// Request extension used to indicate that body has been extracted and `Default` has been left in -/// its place. -struct BodyAlreadyExtractedExt; - -fn take_body(req: &mut Request) -> Result -where - B: Default, -{ - if req - .extensions_mut() - .insert(BodyAlreadyExtractedExt) - .is_some() - { - Err(BodyAlreadyExtracted) - } else { - Ok(mem::take(req.body_mut())) - } +fn take_body(req: &mut RequestParts) -> Result { + req.take_body().ok_or(BodyAlreadyExtracted) } /// Extractor that extracts a typed header value from [`headers`]. @@ -903,10 +1101,16 @@ where T: headers::Header, B: Send, { - type Rejection = rejection::TypedHeaderRejection; + type Rejection = TypedHeaderRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let empty_headers = HeaderMap::new(); + let header_values = if let Some(headers) = req.headers() { + headers.get_all(T::name()) + } else { + empty_headers.get_all(T::name()) + }; - async fn from_request(req: &mut Request) -> Result { - let header_values = req.headers().get_all(T::name()); T::decode(&mut header_values.iter()) .map(Self) .map_err(|err| rejection::TypedHeaderRejection { diff --git a/src/extract/multipart.rs b/src/extract/multipart.rs index b0c39bc7..acae08b8 100644 --- a/src/extract/multipart.rs +++ b/src/extract/multipart.rs @@ -2,7 +2,7 @@ //! //! See [`Multipart`] for more details. -use super::{rejection::*, BodyStream, FromRequest}; +use super::{rejection::*, BodyStream, FromRequest, RequestParts}; use async_trait::async_trait; use bytes::Bytes; use futures_util::stream::Stream; @@ -53,9 +53,10 @@ where { type Rejection = MultipartRejection; - async fn from_request(req: &mut http::Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { let stream = BodyStream::from_request(req).await?; - let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; + let headers = req.headers().ok_or(HeadersAlreadyExtracted)?; + let boundary = parse_boundary(headers).ok_or(InvalidBoundary)?; let multipart = multer::Multipart::new(stream, boundary); Ok(Self { inner: multipart }) } @@ -175,6 +176,7 @@ composite_rejection! { pub enum MultipartRejection { BodyAlreadyExtracted, InvalidBoundary, + HeadersAlreadyExtracted, } } diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 6107f2c8..2668616c 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -4,6 +4,41 @@ use super::IntoResponse; use crate::body::Body; use tower::BoxError; +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Version taken by other extractor"] + /// Rejection used if the HTTP version has been taken by another extractor. + pub struct VersionAlreadyExtracted; +} + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "URI taken by other extractor"] + /// Rejection used if the URI has been taken by another extractor. + pub struct UriAlreadyExtracted; +} + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Method taken by other extractor"] + /// Rejection used if the method has been taken by another extractor. + pub struct MethodAlreadyExtracted; +} + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Extensions taken by other extractor"] + /// Rejection used if the method has been taken by another extractor. + pub struct ExtensionsAlreadyExtracted; +} + +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "Headers taken by other extractor"] + /// Rejection used if the URI has been taken by another extractor. + pub struct HeadersAlreadyExtracted; +} + define_rejection! { #[status = BAD_REQUEST] #[body = "Query string was invalid or missing"] @@ -160,6 +195,7 @@ composite_rejection! { /// Contains one variant for each way the [`Query`](super::Query) extractor /// can fail. pub enum QueryRejection { + UriAlreadyExtracted, QueryStringMissing, FailedToDeserializeQueryString, } @@ -176,6 +212,9 @@ composite_rejection! { FailedToDeserializeQueryString, FailedToBufferBody, BodyAlreadyExtracted, + UriAlreadyExtracted, + HeadersAlreadyExtracted, + MethodAlreadyExtracted, } } @@ -188,6 +227,18 @@ composite_rejection! { InvalidJsonBody, MissingJsonContentType, BodyAlreadyExtracted, + HeadersAlreadyExtracted, + } +} + +composite_rejection! { + /// Rejection used for [`Extension`](super::Extension). + /// + /// Contains one variant for each way the [`Extension`](super::Extension) extractor + /// can fail. + pub enum ExtensionRejection { + MissingExtension, + ExtensionsAlreadyExtracted, } } @@ -236,6 +287,8 @@ pub enum ContentLengthLimitRejection { #[allow(missing_docs)] LengthRequired(LengthRequired), #[allow(missing_docs)] + HeadersAlreadyExtracted(HeadersAlreadyExtracted), + #[allow(missing_docs)] Inner(T), } @@ -247,6 +300,7 @@ where match self { Self::PayloadTooLarge(inner) => inner.into_response(), Self::LengthRequired(inner) => inner.into_response(), + Self::HeadersAlreadyExtracted(inner) => inner.into_response(), Self::Inner(inner) => inner.into_response(), } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 8c3f84aa..d97afaf6 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -39,7 +39,7 @@ //! the [`extract`](crate::extract) module. use crate::{ - body::BoxBody, + body::{box_body, BoxBody}, extract::FromRequest, response::IntoResponse, routing::{EmptyRouter, MethodFilter, RouteFuture}, @@ -289,7 +289,7 @@ where type Sealed = sealed::Hidden; async fn call(self, _req: Request) -> Response { - self().await.into_response().map(BoxBody::new) + self().await.into_response().map(box_body) } } @@ -310,22 +310,24 @@ macro_rules! impl_handler { { type Sealed = sealed::Hidden; - async fn call(self, mut req: Request) -> Response { + async fn call(self, req: Request) -> Response { + let mut req = crate::extract::RequestParts::new(req); + let $head = match $head::from_request(&mut req).await { Ok(value) => value, - Err(rejection) => return rejection.into_response().map(BoxBody::new), + Err(rejection) => return rejection.into_response().map(crate::body::box_body), }; $( let $tail = match $tail::from_request(&mut req).await { Ok(value) => value, - Err(rejection) => return rejection.into_response().map(BoxBody::new), + Err(rejection) => return rejection.into_response().map(crate::body::box_body), }; )* let res = self($head, $($tail,)*).await; - res.into_response().map(BoxBody::new) + res.into_response().map(crate::body::box_body) } } @@ -380,8 +382,8 @@ where .await .map_err(IntoResponse::into_response) { - Ok(res) => res.map(BoxBody::new), - Err(res) => res.map(BoxBody::new), + Ok(res) => res.map(box_body), + Err(res) => res.map(box_body), } } } diff --git a/src/routing.rs b/src/routing.rs index eb589f61..8d0e9e98 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -1,6 +1,11 @@ //! Routing between [`Service`]s. -use crate::{body::BoxBody, buffer::MpscBuffer, response::IntoResponse, util::ByteStr}; +use crate::{ + body::{box_body, BoxBody}, + buffer::MpscBuffer, + response::IntoResponse, + util::ByteStr, +}; use async_trait::async_trait; use bytes::Bytes; use futures_util::future; @@ -165,7 +170,7 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { .layer_fn(BoxRoute) .layer_fn(MpscBuffer::new) .layer(BoxService::layer()) - .layer(MapResponseBodyLayer::new(BoxBody::new)) + .layer(MapResponseBodyLayer::new(box_body)) .service(self) } @@ -399,7 +404,7 @@ impl Service> for EmptyRouter { } fn call(&mut self, _req: Request) -> Self::Future { - let mut res = Response::new(BoxBody::empty()); + let mut res = Response::new(crate::body::empty()); *res.status_mut() = StatusCode::NOT_FOUND; EmptyRouterFuture(future::ok(res)) } diff --git a/src/service/future.rs b/src/service/future.rs index 2190666e..8bc4bbb4 100644 --- a/src/service/future.rs +++ b/src/service/future.rs @@ -1,6 +1,9 @@ //! [`Service`](tower::Service) future types. -use crate::{body::BoxBody, response::IntoResponse}; +use crate::{ + body::{box_body, BoxBody}, + response::IntoResponse, +}; use bytes::Bytes; use futures_util::ready; use http::Response; @@ -36,11 +39,11 @@ where let this = self.project(); match ready!(this.inner.poll(cx)) { - Ok(res) => Ok(res.map(BoxBody::new)).into(), + Ok(res) => Ok(res.map(box_body)).into(), Err(err) => { let f = this.f.take().unwrap(); let res = f(err).into_response(); - Ok(res.map(BoxBody::new)).into() + Ok(res.map(box_body)).into() } } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 7d1f0bf7..ae1587b3 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -87,7 +87,7 @@ //! [load shed]: tower::load_shed use crate::{ - body::BoxBody, + body::{box_body, BoxBody}, response::IntoResponse, routing::{EmptyRouter, MethodFilter, RouteFuture}, }; @@ -656,7 +656,7 @@ where fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = ready!(self.project().0.poll(cx))?; - let res = res.map(BoxBody::new); + let res = res.map(box_body); Poll::Ready(Ok(res)) } } diff --git a/src/tests.rs b/src/tests.rs index a9ba386c..7b5ad308 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,7 @@ -use crate::{handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, service}; +use crate::{ + extract::RequestParts, handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, + service, +}; use bytes::Bytes; use http::{header::AUTHORIZATION, Request, Response, StatusCode}; use hyper::{Body, Server}; @@ -105,7 +108,7 @@ async fn consume_body_to_json_requires_json_content_type() { let app = route( "/", - post(|_: Request, input: extract::Json| async { input.0.foo }), + post(|input: extract::Json| async { input.0.foo }), ); let addr = run_in_background(app).await; @@ -675,9 +678,10 @@ async fn test_extractor_middleware() { { type Rejection = StatusCode; - async fn from_request(req: &mut Request) -> Result { + async fn from_request(req: &mut RequestParts) -> Result { if let Some(auth) = req .headers() + .expect("headers already extracted") .get("authorization") .and_then(|v| v.to_str().ok()) {