diff --git a/src/extract/mod.rs b/src/extract/mod.rs index a828892b..4367568f 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -172,6 +172,77 @@ //! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); //! # }; //! ``` +//! +//! # Request body extractors +//! +//! Most of the time your request body type will be [`body::Body`] (a re-export +//! of [`hyper::Body`]), which is directly supported by all extractors. +//! +//! However if you're applying a tower middleware that changes the response you +//! might have to apply a different body type to some extractors: +//! +//! ```rust +//! use std::{ +//! task::{Context, Poll}, +//! pin::Pin, +//! }; +//! use tower_http::map_request_body::MapRequestBodyLayer; +//! use axum::prelude::*; +//! +//! struct MyBody(B); +//! +//! impl http_body::Body for MyBody +//! where +//! B: http_body::Body + Unpin, +//! { +//! type Data = B::Data; +//! type Error = B::Error; +//! +//! fn poll_data( +//! mut self: Pin<&mut Self>, +//! cx: &mut Context<'_>, +//! ) -> Poll>> { +//! Pin::new(&mut self.0).poll_data(cx) +//! } +//! +//! fn poll_trailers( +//! mut self: Pin<&mut Self>, +//! cx: &mut Context<'_>, +//! ) -> Poll, Self::Error>> { +//! Pin::new(&mut self.0).poll_trailers(cx) +//! } +//! } +//! +//! let app = +//! // `String` works directly with any body type +//! route( +//! "/string", +//! get(|_: String| async {}) +//! ) +//! .route( +//! "/body", +//! // `extract::Body` defaults to `axum::body::Body` +//! // but can be customized +//! get(|_: extract::Body>| async {}) +//! ) +//! .route( +//! "/body-stream", +//! // same for `extract::BodyStream` +//! get(|_: extract::BodyStream>| async {}), +//! ) +//! .route( +//! // and `Request<_>` +//! "/request", +//! get(|_: Request>| async {}) +//! ) +//! // middleware that changes the request body type +//! .layer(MapRequestBodyLayer::new(MyBody)); +//! # async { +//! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! [`body::Body`]: crate::body::Body use crate::{response::IntoResponse, util::ByteStr}; use async_trait::async_trait; @@ -788,6 +859,39 @@ where } } +/// Extractor that extracts the request body. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::prelude::*; +/// use futures::StreamExt; +/// +/// async fn handler(extract::Body(body): extract::Body) { +/// // ... +/// } +/// +/// let app = route("/users", get(handler)); +/// # async { +/// # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Debug, Default, Clone)] +pub struct Body(pub B); + +#[async_trait] +impl FromRequest for Body +where + B: Send, +{ + type Rejection = BodyAlreadyExtracted; + + async fn from_request(req: &mut RequestParts) -> Result { + let body = take_body(req)?; + Ok(Self(body)) + } +} + #[async_trait] impl FromRequest for Request where diff --git a/src/lib.rs b/src/lib.rs index 06f83a60..6e8be474 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -487,7 +487,7 @@ //! "/", //! service::any(service_fn(|_: Request| async { //! let res = Response::new(Body::from("Hi from `GET /`")); -//! Ok::<_, Infallible>(res) +//! Ok(res) //! })) //! ).route( //! // GET `/static/Cargo.toml` goes to a service from tower-http diff --git a/src/tests.rs b/src/tests.rs index d0b3565c..9d8fb2b5 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -612,53 +612,6 @@ async fn typed_header() { assert_eq!(body, "invalid HTTP header (user-agent)"); } -#[tokio::test] -async fn different_request_body_types() { - use http_body::{Empty, Full}; - use std::convert::Infallible; - use tower_http::map_request_body::MapRequestBodyLayer; - - async fn handler(body: String) -> String { - body - } - - async fn svc_handler(req: Request) -> Result, Infallible> - where - B: http_body::Body, - B::Error: std::fmt::Debug, - { - let body = hyper::body::to_bytes(req.into_body()).await.unwrap(); - Ok(Response::new(Body::from(body))) - } - - let app = route("/", service::get(service_fn(svc_handler))) - .route( - "/foo", - get(handler.layer(MapRequestBodyLayer::new(|_| Full::::from("foo")))), - ) - .layer(MapRequestBodyLayer::new(|_| Empty::::new())); - - let addr = run_in_background(app).await; - - let client = reqwest::Client::new(); - - let res = client - .get(format!("http://{}/", addr)) - .send() - .await - .unwrap(); - let body = res.text().await.unwrap(); - assert_eq!(body, ""); - - let res = client - .get(format!("http://{}/foo", addr)) - .send() - .await - .unwrap(); - let body = res.text().await.unwrap(); - assert_eq!(body, "foo"); -} - #[tokio::test] async fn service_in_bottom() { async fn handler(_req: Request) -> Result, hyper::Error> {