mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 15:24:54 +00:00
Clean up RequestParts
API (#167)
In http-body 0.4.3 `BoxBody` implements `Default`. This allows us to clean up the API of `RequestParts` quite a bit.
This commit is contained in:
parent
bc27b09f5c
commit
6b218c7150
@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
|
- Removed `extract::UrlParams` and `extract::UrlParamsMap`. Use `extract::Path` instead
|
||||||
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
|
- `EmptyRouter` now requires the response body to implement `Send + Sync + 'static'` ([#108](https://github.com/tokio-rs/axum/pull/108))
|
||||||
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
- `ServiceExt` has been removed and its methods have been moved to `RoutingDsl` ([#160](https://github.com/tokio-rs/axum/pull/160))
|
||||||
|
- `extractor_middleware` now requires `RequestBody: Default` ([#167](https://github.com/tokio-rs/axum/pull/167))
|
||||||
|
- Convert `RequestAlreadyExtracted` to an enum with each possible error variant ([#167](https://github.com/tokio-rs/axum/pull/167))
|
||||||
- These future types have been moved
|
- These future types have been moved
|
||||||
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
|
- `extract::extractor_middleware::ExtractorMiddlewareResponseFuture` moved
|
||||||
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
|
to `extract::extractor_middleware::future::ResponseFuture` ([#133](https://github.com/tokio-rs/axum/pull/133))
|
||||||
|
@ -22,7 +22,7 @@ bitflags = "1.0"
|
|||||||
bytes = "1.0"
|
bytes = "1.0"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
http = "0.2"
|
http = "0.2"
|
||||||
http-body = "0.4.2"
|
http-body = "0.4.3"
|
||||||
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
|
hyper = { version = "0.14", features = ["server", "tcp", "http1", "stream"] }
|
||||||
pin-project-lite = "0.2.7"
|
pin-project-lite = "0.2.7"
|
||||||
regex = "1.5"
|
regex = "1.5"
|
||||||
|
10
src/error.rs
10
src/error.rs
@ -13,6 +13,16 @@ impl Error {
|
|||||||
inner: error.into(),
|
inner: error.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn downcast<T>(self) -> Result<T, Self>
|
||||||
|
where
|
||||||
|
T: StdError + 'static,
|
||||||
|
{
|
||||||
|
match self.inner.downcast::<T>() {
|
||||||
|
Ok(t) => Ok(*t),
|
||||||
|
Err(err) => Err(*err.downcast().unwrap()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for Error {
|
impl fmt::Display for Error {
|
||||||
|
@ -152,7 +152,7 @@ where
|
|||||||
impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
|
impl<S, E, ReqBody, ResBody> Service<Request<ReqBody>> for ExtractorMiddleware<S, E>
|
||||||
where
|
where
|
||||||
E: FromRequest<ReqBody> + 'static,
|
E: FromRequest<ReqBody> + 'static,
|
||||||
ReqBody: Send + 'static,
|
ReqBody: Default + Send + 'static,
|
||||||
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
|
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone,
|
||||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||||
ResBody::Error: Into<BoxError>,
|
ResBody::Error: Into<BoxError>,
|
||||||
@ -212,6 +212,7 @@ impl<ReqBody, S, E, ResBody> Future for ResponseFuture<ReqBody, S, E>
|
|||||||
where
|
where
|
||||||
E: FromRequest<ReqBody>,
|
E: FromRequest<ReqBody>,
|
||||||
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
|
||||||
|
ReqBody: Default,
|
||||||
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
ResBody: http_body::Body<Data = Bytes> + Send + Sync + 'static,
|
||||||
ResBody::Error: Into<BoxError>,
|
ResBody::Error: Into<BoxError>,
|
||||||
{
|
{
|
||||||
@ -223,12 +224,13 @@ where
|
|||||||
|
|
||||||
let new_state = match this.state.as_mut().project() {
|
let new_state = match this.state.as_mut().project() {
|
||||||
StateProj::Extracting { future } => {
|
StateProj::Extracting { future } => {
|
||||||
let (mut req, extracted) = ready!(future.as_mut().poll(cx));
|
let (req, extracted) = ready!(future.as_mut().poll(cx));
|
||||||
|
|
||||||
match extracted {
|
match extracted {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let mut svc = this.svc.take().expect("future polled after completion");
|
let mut svc = this.svc.take().expect("future polled after completion");
|
||||||
let future = svc.call(req.into_request());
|
let req = req.try_into_request().unwrap_or_default();
|
||||||
|
let future = svc.call(req);
|
||||||
State::Call { future }
|
State::Call { future }
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -244,7 +244,7 @@
|
|||||||
//!
|
//!
|
||||||
//! [`body::Body`]: crate::body::Body
|
//! [`body::Body`]: crate::body::Body
|
||||||
|
|
||||||
use crate::response::IntoResponse;
|
use crate::{response::IntoResponse, Error};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version};
|
||||||
use rejection::*;
|
use rejection::*;
|
||||||
@ -397,32 +397,47 @@ impl<B> RequestParts<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::wrong_self_convention)]
|
// this method uses `Error` since we might make this method public one day and then
|
||||||
pub(crate) fn into_request(&mut self) -> Request<B> {
|
// `Error` is more flexible.
|
||||||
|
pub(crate) fn try_into_request(self) -> Result<Request<B>, Error> {
|
||||||
let Self {
|
let Self {
|
||||||
method,
|
method,
|
||||||
uri,
|
uri,
|
||||||
version,
|
version,
|
||||||
headers,
|
mut headers,
|
||||||
extensions,
|
mut extensions,
|
||||||
body,
|
mut body,
|
||||||
} = self;
|
} = self;
|
||||||
|
|
||||||
let mut req = Request::new(body.take().expect("body already extracted"));
|
let mut req = if let Some(body) = body.take() {
|
||||||
|
Request::new(body)
|
||||||
|
} else {
|
||||||
|
return Err(Error::new(RequestAlreadyExtracted::BodyAlreadyExtracted(
|
||||||
|
BodyAlreadyExtracted,
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
|
||||||
*req.method_mut() = method.clone();
|
*req.method_mut() = method;
|
||||||
*req.uri_mut() = uri.clone();
|
*req.uri_mut() = uri;
|
||||||
*req.version_mut() = *version;
|
*req.version_mut() = version;
|
||||||
|
|
||||||
if let Some(headers) = headers.take() {
|
if let Some(headers) = headers.take() {
|
||||||
*req.headers_mut() = headers;
|
*req.headers_mut() = headers;
|
||||||
|
} else {
|
||||||
|
return Err(Error::new(
|
||||||
|
RequestAlreadyExtracted::HeadersAlreadyExtracted(HeadersAlreadyExtracted),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(extensions) = extensions.take() {
|
if let Some(extensions) = extensions.take() {
|
||||||
*req.extensions_mut() = extensions;
|
*req.extensions_mut() = extensions;
|
||||||
|
} else {
|
||||||
|
return Err(Error::new(
|
||||||
|
RequestAlreadyExtracted::ExtensionsAlreadyExtracted(ExtensionsAlreadyExtracted),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
req
|
Ok(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets a reference the request method.
|
/// Gets a reference the request method.
|
||||||
|
@ -13,14 +13,15 @@ use tower::BoxError;
|
|||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = INTERNAL_SERVER_ERROR]
|
#[status = INTERNAL_SERVER_ERROR]
|
||||||
#[body = "Extensions taken by other extractor"]
|
#[body = "Extensions taken by other extractor"]
|
||||||
/// Rejection used if the method has been taken by another extractor.
|
/// Rejection used if the request extension has been taken by another
|
||||||
|
/// extractor.
|
||||||
pub struct ExtensionsAlreadyExtracted;
|
pub struct ExtensionsAlreadyExtracted;
|
||||||
}
|
}
|
||||||
|
|
||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = INTERNAL_SERVER_ERROR]
|
#[status = INTERNAL_SERVER_ERROR]
|
||||||
#[body = "Headers taken by other extractor"]
|
#[body = "Headers taken by other extractor"]
|
||||||
/// Rejection used if the URI has been taken by another extractor.
|
/// Rejection used if the headers has been taken by another extractor.
|
||||||
pub struct HeadersAlreadyExtracted;
|
pub struct HeadersAlreadyExtracted;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,13 +95,6 @@ define_rejection! {
|
|||||||
pub struct BodyAlreadyExtracted;
|
pub struct BodyAlreadyExtracted;
|
||||||
}
|
}
|
||||||
|
|
||||||
define_rejection! {
|
|
||||||
#[status = INTERNAL_SERVER_ERROR]
|
|
||||||
#[body = "Cannot have two `Request<_>` extractors for a single handler"]
|
|
||||||
/// Rejection type used if you try and extract the request more than once.
|
|
||||||
pub struct RequestAlreadyExtracted;
|
|
||||||
}
|
|
||||||
|
|
||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = BAD_REQUEST]
|
#[status = BAD_REQUEST]
|
||||||
#[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"]
|
#[body = "Form requests must have `Content-Type: x-www-form-urlencoded`"]
|
||||||
@ -272,6 +266,19 @@ composite_rejection! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
composite_rejection! {
|
||||||
|
/// Rejection used for [`Request<_>`].
|
||||||
|
///
|
||||||
|
/// Contains one variant for each way the [`Request<_>`] extractor can fail.
|
||||||
|
///
|
||||||
|
/// [`Request<_>`]: http::Request
|
||||||
|
pub enum RequestAlreadyExtracted {
|
||||||
|
BodyAlreadyExtracted,
|
||||||
|
HeadersAlreadyExtracted,
|
||||||
|
ExtensionsAlreadyExtracted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
|
/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit).
|
||||||
///
|
///
|
||||||
/// Contains one variant for each way the
|
/// Contains one variant for each way the
|
||||||
|
@ -18,21 +18,29 @@ where
|
|||||||
type Rejection = RequestAlreadyExtracted;
|
type Rejection = RequestAlreadyExtracted;
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
let RequestParts {
|
let req = std::mem::replace(
|
||||||
method: _,
|
req,
|
||||||
uri: _,
|
RequestParts {
|
||||||
version: _,
|
method: req.method.clone(),
|
||||||
headers,
|
version: req.version,
|
||||||
extensions,
|
uri: req.uri.clone(),
|
||||||
body,
|
headers: None,
|
||||||
} = req;
|
extensions: None,
|
||||||
|
body: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
let all_parts = extensions.as_ref().zip(body.as_ref()).zip(headers.as_ref());
|
let err = match req.try_into_request() {
|
||||||
|
Ok(req) => return Ok(req),
|
||||||
|
Err(err) => err,
|
||||||
|
};
|
||||||
|
|
||||||
if all_parts.is_some() {
|
match err.downcast::<RequestAlreadyExtracted>() {
|
||||||
Ok(req.into_request())
|
Ok(err) => return Err(err),
|
||||||
} else {
|
Err(err) => unreachable!(
|
||||||
Err(RequestAlreadyExtracted)
|
"Unexpected error type from `try_into_request`: `{:?}`. This is a bug in axum, please file an issue",
|
||||||
|
err,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -251,3 +259,33 @@ where
|
|||||||
Ok(string)
|
Ok(string)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{body::Body, prelude::*, tests::*};
|
||||||
|
use http::StatusCode;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn multiple_request_extractors() {
|
||||||
|
async fn handler(_: Request<Body>, _: Request<Body>) {}
|
||||||
|
|
||||||
|
let app = route("/", post(handler));
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.post(format!("http://{}", addr))
|
||||||
|
.body("hi there")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
assert_eq!(
|
||||||
|
res.text().await.unwrap(),
|
||||||
|
"Cannot have two request body extractors for a single handler"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -605,7 +605,7 @@ async fn wrong_method_service() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Run a `tower::Service` in the background and get a URI for it.
|
/// Run a `tower::Service` in the background and get a URI for it.
|
||||||
async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
pub(crate) async fn run_in_background<S, ResBody>(svc: S) -> SocketAddr
|
||||||
where
|
where
|
||||||
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
|
S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||||
ResBody: http_body::Body + Send + 'static,
|
ResBody: http_body::Body + Send + 'static,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user