diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 4460854b..f76ab510 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -23,7 +23,7 @@ use tower_http::{ }; use tower_web::{ body::{Body, BoxBody}, - extract::{BytesMaxLength, Extension, UrlParams}, + extract::{ContentLengthLimit, Extension, UrlParams}, prelude::*, response::IntoResponse, routing::BoxRoute, @@ -88,10 +88,10 @@ async fn kv_get( async fn kv_set( _req: Request, UrlParams((key,)): UrlParams<(String,)>, - BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb + ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb Extension(state): Extension, ) { - state.write().unwrap().db.insert(key, value); + state.write().unwrap().db.insert(key, bytes); } async fn list_keys(_req: Request, Extension(state): Extension) -> String { diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 6ba954e6..bebb5257 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -388,14 +388,14 @@ impl FromRequest for Body { } } -/// Extractor that will buffer request bodies up to a certain size. +/// Extractor that will reject requests with a body larger than some size. /// /// # Example /// /// ```rust,no_run /// use tower_web::prelude::*; /// -/// async fn handler(req: Request, body: extract::BytesMaxLength<1024>) { +/// async fn handler(req: Request, body: extract::ContentLengthLimit) { /// // ... /// } /// @@ -404,15 +404,17 @@ impl FromRequest for Body { /// /// This requires the request to have a `Content-Length` header. #[derive(Debug, Clone)] -pub struct BytesMaxLength(pub Bytes); +pub struct ContentLengthLimit(pub T); #[async_trait] -impl FromRequest for BytesMaxLength { +impl FromRequest for ContentLengthLimit +where + T: FromRequest, +{ type Rejection = Response; async fn from_request(req: &mut Request) -> Result { let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); - let body = take_body(req).map_err(|reject| reject.into_response())?; let content_length = content_length.and_then(|value| value.to_str().ok()?.parse::().ok()); @@ -425,11 +427,11 @@ impl FromRequest for BytesMaxLength { return Err(LengthRequired.into_response()); }; - let bytes = hyper::body::to_bytes(body) + let value = T::from_request(req) .await - .map_err(|e| FailedToBufferBody::from_err(e).into_response())?; + .map_err(IntoResponse::into_response)?; - Ok(BytesMaxLength(bytes)) + Ok(Self(value)) } } diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index f1b13ca3..af388d77 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -103,16 +103,16 @@ define_rejection! { define_rejection! { #[status = PAYLOAD_TOO_LARGE] #[body = "Request payload is too large"] - /// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the - /// request body is too large. + /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if + /// the request body is too large. pub struct PayloadTooLarge; } define_rejection! { #[status = LENGTH_REQUIRED] #[body = "Content length header is required"] - /// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the - /// request is missing the `Content-Length` header or it is invalid. + /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if + /// the request is missing the `Content-Length` header or it is invalid. pub struct LengthRequired; } diff --git a/src/handler/future.rs b/src/handler/future.rs index 0556c539..f593fde1 100644 --- a/src/handler/future.rs +++ b/src/handler/future.rs @@ -1,8 +1,8 @@ //! Handler future types. +use crate::body::BoxBody; use http::Response; use std::convert::Infallible; -use crate::body::BoxBody; opaque_future! { /// The response future for [`IntoService`](super::IntoService). diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 9be1468a..32d5b57c 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -50,8 +50,8 @@ use crate::{ service::HandleError, }; use async_trait::async_trait; -use futures_util::future::Either; use bytes::Bytes; +use futures_util::future::Either; use http::{Request, Response}; use std::{ convert::Infallible, diff --git a/src/tests.rs b/src/tests.rs index 49c13beb..cb776da1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,5 @@ use crate::{handler::on, prelude::*, routing::MethodFilter, service}; +use bytes::Bytes; use http::{Request, Response, StatusCode}; use hyper::{Body, Server}; use serde::Deserialize; @@ -137,7 +138,7 @@ async fn body_with_length_limit() { let app = route( "/", post( - |req: Request, _body: extract::BytesMaxLength| async move { + |req: Request, _body: extract::ContentLengthLimit| async move { dbg!(&req); }, ),