diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs
index 4460854b..f76ab510 100644
--- a/examples/key_value_store.rs
+++ b/examples/key_value_store.rs
@@ -23,7 +23,7 @@ use tower_http::{
};
use tower_web::{
body::{Body, BoxBody},
- extract::{BytesMaxLength, Extension, UrlParams},
+ extract::{ContentLengthLimit, Extension, UrlParams},
prelude::*,
response::IntoResponse,
routing::BoxRoute,
@@ -88,10 +88,10 @@ async fn kv_get(
async fn kv_set(
_req: Request
,
UrlParams((key,)): UrlParams<(String,)>,
- BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
+ ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb
Extension(state): Extension,
) {
- state.write().unwrap().db.insert(key, value);
+ state.write().unwrap().db.insert(key, bytes);
}
async fn list_keys(_req: Request, Extension(state): Extension) -> String {
diff --git a/src/extract/mod.rs b/src/extract/mod.rs
index 6ba954e6..bebb5257 100644
--- a/src/extract/mod.rs
+++ b/src/extract/mod.rs
@@ -388,14 +388,14 @@ impl FromRequest for Body {
}
}
-/// Extractor that will buffer request bodies up to a certain size.
+/// Extractor that will reject requests with a body larger than some size.
///
/// # Example
///
/// ```rust,no_run
/// use tower_web::prelude::*;
///
-/// async fn handler(req: Request, body: extract::BytesMaxLength<1024>) {
+/// async fn handler(req: Request, body: extract::ContentLengthLimit) {
/// // ...
/// }
///
@@ -404,15 +404,17 @@ impl FromRequest for Body {
///
/// This requires the request to have a `Content-Length` header.
#[derive(Debug, Clone)]
-pub struct BytesMaxLength(pub Bytes);
+pub struct ContentLengthLimit(pub T);
#[async_trait]
-impl FromRequest for BytesMaxLength {
+impl FromRequest for ContentLengthLimit
+where
+ T: FromRequest,
+{
type Rejection = Response;
async fn from_request(req: &mut Request) -> Result {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
- let body = take_body(req).map_err(|reject| reject.into_response())?;
let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::().ok());
@@ -425,11 +427,11 @@ impl FromRequest for BytesMaxLength {
return Err(LengthRequired.into_response());
};
- let bytes = hyper::body::to_bytes(body)
+ let value = T::from_request(req)
.await
- .map_err(|e| FailedToBufferBody::from_err(e).into_response())?;
+ .map_err(IntoResponse::into_response)?;
- Ok(BytesMaxLength(bytes))
+ Ok(Self(value))
}
}
diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs
index f1b13ca3..af388d77 100644
--- a/src/extract/rejection.rs
+++ b/src/extract/rejection.rs
@@ -103,16 +103,16 @@ define_rejection! {
define_rejection! {
#[status = PAYLOAD_TOO_LARGE]
#[body = "Request payload is too large"]
- /// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the
- /// request body is too large.
+ /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
+ /// the request body is too large.
pub struct PayloadTooLarge;
}
define_rejection! {
#[status = LENGTH_REQUIRED]
#[body = "Content length header is required"]
- /// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the
- /// request is missing the `Content-Length` header or it is invalid.
+ /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
+ /// the request is missing the `Content-Length` header or it is invalid.
pub struct LengthRequired;
}
diff --git a/src/handler/future.rs b/src/handler/future.rs
index 0556c539..f593fde1 100644
--- a/src/handler/future.rs
+++ b/src/handler/future.rs
@@ -1,8 +1,8 @@
//! Handler future types.
+use crate::body::BoxBody;
use http::Response;
use std::convert::Infallible;
-use crate::body::BoxBody;
opaque_future! {
/// The response future for [`IntoService`](super::IntoService).
diff --git a/src/handler/mod.rs b/src/handler/mod.rs
index 9be1468a..32d5b57c 100644
--- a/src/handler/mod.rs
+++ b/src/handler/mod.rs
@@ -50,8 +50,8 @@ use crate::{
service::HandleError,
};
use async_trait::async_trait;
-use futures_util::future::Either;
use bytes::Bytes;
+use futures_util::future::Either;
use http::{Request, Response};
use std::{
convert::Infallible,
diff --git a/src/tests.rs b/src/tests.rs
index 49c13beb..cb776da1 100644
--- a/src/tests.rs
+++ b/src/tests.rs
@@ -1,4 +1,5 @@
use crate::{handler::on, prelude::*, routing::MethodFilter, service};
+use bytes::Bytes;
use http::{Request, Response, StatusCode};
use hyper::{Body, Server};
use serde::Deserialize;
@@ -137,7 +138,7 @@ async fn body_with_length_limit() {
let app = route(
"/",
post(
- |req: Request, _body: extract::BytesMaxLength| async move {
+ |req: Request, _body: extract::ContentLengthLimit| async move {
dbg!(&req);
},
),