From fdd889525d1f9ddaa42413130a9481c10fb8d8b2 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 9 Nov 2021 21:37:24 +0100 Subject: [PATCH] Implement `FromRequest` for `http::request::Parts` (#489) Its a convenient way to extract everything from a request except the body. Also makes sense for axum to provide this since other crates can't. --- axum/CHANGELOG.md | 4 ++ axum/src/extract/rejection.rs | 10 ++++ axum/src/extract/request_parts.rs | 79 ++++++++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 1 deletion(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index cc9fdd2b..f3ea2008 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,9 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- Implement `FromRequest` for [`http::request::Parts`] so it can be used an + extractor ([#489]) - Implement `IntoResponse` for `http::response::Parts` ([#490]) +[#489]: https://github.com/tokio-rs/axum/pull/489 [#490]: https://github.com/tokio-rs/axum/pull/490 +[`http::request::Parts`]: https://docs.rs/http/latest/http/request/struct.Parts.html # 0.3.2 (08. November, 2021) diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index 92316389..657a2ca7 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -272,6 +272,16 @@ composite_rejection! { } } +composite_rejection! { + /// Rejection used for [`http::request::Parts`]. + /// + /// Contains one variant for each way the [`http::request::Parts`] extractor can fail. + pub enum RequestPartsAlreadyExtracted { + HeadersAlreadyExtracted, + ExtensionsAlreadyExtracted, + } +} + define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "No matched path found"] diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 0a84dee2..560dc0b7 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -323,10 +323,49 @@ where } } +#[async_trait] +impl FromRequest for http::request::Parts +where + B: Send, +{ + type Rejection = RequestPartsAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let method = unwrap_infallible(Method::from_request(req).await); + let uri = unwrap_infallible(Uri::from_request(req).await); + let version = unwrap_infallible(Version::from_request(req).await); + let headers = HeaderMap::from_request(req).await?; + let extensions = Extensions::from_request(req).await?; + + let mut temp_request = Request::new(()); + *temp_request.method_mut() = method; + *temp_request.uri_mut() = uri; + *temp_request.version_mut() = version; + *temp_request.headers_mut() = headers; + *temp_request.extensions_mut() = extensions; + + let (parts, _) = temp_request.into_parts(); + + Ok(parts) + } +} + +fn unwrap_infallible(result: Result) -> T { + match result { + Ok(value) => value, + Err(err) => match err {}, + } +} + #[cfg(test)] mod tests { use super::*; - use crate::{body::Body, routing::post, test_helpers::*, Router}; + use crate::{ + body::Body, + routing::{get, post}, + test_helpers::*, + AddExtensionLayer, Router, + }; use http::StatusCode; #[tokio::test] @@ -344,4 +383,42 @@ mod tests { "Cannot have two request body extractors for a single handler" ); } + + #[tokio::test] + async fn extract_request_parts() { + #[derive(Clone)] + struct Ext; + + async fn handler(parts: http::request::Parts) { + assert_eq!(parts.method, Method::GET); + assert_eq!(parts.uri, "/"); + assert_eq!(parts.version, http::Version::HTTP_11); + assert_eq!(parts.headers["x-foo"], "123"); + parts.extensions.get::().unwrap(); + } + + let client = TestClient::new( + Router::new() + .route("/", get(handler)) + .layer(AddExtensionLayer::new(Ext)), + ); + + let res = client.get("/").header("x-foo", "123").send().await; + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn extract_request_parts_doesnt_consume_the_body() { + #[derive(Clone)] + struct Ext; + + async fn handler(_parts: http::request::Parts, body: String) { + assert_eq!(body, "foo"); + } + + let client = TestClient::new(Router::new().route("/", get(handler))); + + let res = client.get("/").body("foo").send().await; + assert_eq!(res.status(), StatusCode::OK); + } }