Parameterize ContentLengthLimit

This commit is contained in:
David Pedersen 2021-06-09 08:14:20 +02:00
parent 09f76f3c87
commit 90c3e5ba74
6 changed files with 21 additions and 18 deletions

View File

@ -23,7 +23,7 @@ use tower_http::{
}; };
use tower_web::{ use tower_web::{
body::{Body, BoxBody}, body::{Body, BoxBody},
extract::{BytesMaxLength, Extension, UrlParams}, extract::{ContentLengthLimit, Extension, UrlParams},
prelude::*, prelude::*,
response::IntoResponse, response::IntoResponse,
routing::BoxRoute, routing::BoxRoute,
@ -88,10 +88,10 @@ async fn kv_get(
async fn kv_set( async fn kv_set(
_req: Request<Body>, _req: Request<Body>,
UrlParams((key,)): UrlParams<(String,)>, UrlParams((key,)): UrlParams<(String,)>,
BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
Extension(state): Extension<SharedState>, Extension(state): Extension<SharedState>,
) { ) {
state.write().unwrap().db.insert(key, value); state.write().unwrap().db.insert(key, bytes);
} }
async fn list_keys(_req: Request<Body>, Extension(state): Extension<SharedState>) -> String { async fn list_keys(_req: Request<Body>, Extension(state): Extension<SharedState>) -> String {

View File

@ -388,14 +388,14 @@ impl FromRequest for Body {
} }
} }
/// Extractor that will buffer request bodies up to a certain size. /// Extractor that will reject requests with a body larger than some size.
/// ///
/// # Example /// # Example
/// ///
/// ```rust,no_run /// ```rust,no_run
/// use tower_web::prelude::*; /// use tower_web::prelude::*;
/// ///
/// async fn handler(req: Request<Body>, body: extract::BytesMaxLength<1024>) { /// async fn handler(req: Request<Body>, body: extract::ContentLengthLimit<String, 1024>) {
/// // ... /// // ...
/// } /// }
/// ///
@ -404,15 +404,17 @@ impl FromRequest for Body {
/// ///
/// This requires the request to have a `Content-Length` header. /// This requires the request to have a `Content-Length` header.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BytesMaxLength<const N: u64>(pub Bytes); pub struct ContentLengthLimit<T, const N: u64>(pub T);
#[async_trait] #[async_trait]
impl<const N: u64> FromRequest for BytesMaxLength<N> { impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N>
where
T: FromRequest,
{
type Rejection = Response<Body>; type Rejection = Response<Body>;
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned(); let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
let body = take_body(req).map_err(|reject| reject.into_response())?;
let content_length = let content_length =
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok()); content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
@ -425,11 +427,11 @@ impl<const N: u64> FromRequest for BytesMaxLength<N> {
return Err(LengthRequired.into_response()); return Err(LengthRequired.into_response());
}; };
let bytes = hyper::body::to_bytes(body) let value = T::from_request(req)
.await .await
.map_err(|e| FailedToBufferBody::from_err(e).into_response())?; .map_err(IntoResponse::into_response)?;
Ok(BytesMaxLength(bytes)) Ok(Self(value))
} }
} }

View File

@ -103,16 +103,16 @@ define_rejection! {
define_rejection! { define_rejection! {
#[status = PAYLOAD_TOO_LARGE] #[status = PAYLOAD_TOO_LARGE]
#[body = "Request payload is too large"] #[body = "Request payload is too large"]
/// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// request body is too large. /// the request body is too large.
pub struct PayloadTooLarge; pub struct PayloadTooLarge;
} }
define_rejection! { define_rejection! {
#[status = LENGTH_REQUIRED] #[status = LENGTH_REQUIRED]
#[body = "Content length header is required"] #[body = "Content length header is required"]
/// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
/// request is missing the `Content-Length` header or it is invalid. /// the request is missing the `Content-Length` header or it is invalid.
pub struct LengthRequired; pub struct LengthRequired;
} }

View File

@ -1,8 +1,8 @@
//! Handler future types. //! Handler future types.
use crate::body::BoxBody;
use http::Response; use http::Response;
use std::convert::Infallible; use std::convert::Infallible;
use crate::body::BoxBody;
opaque_future! { opaque_future! {
/// The response future for [`IntoService`](super::IntoService). /// The response future for [`IntoService`](super::IntoService).

View File

@ -50,8 +50,8 @@ use crate::{
service::HandleError, service::HandleError,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::future::Either;
use bytes::Bytes; use bytes::Bytes;
use futures_util::future::Either;
use http::{Request, Response}; use http::{Request, Response};
use std::{ use std::{
convert::Infallible, convert::Infallible,

View File

@ -1,4 +1,5 @@
use crate::{handler::on, prelude::*, routing::MethodFilter, service}; use crate::{handler::on, prelude::*, routing::MethodFilter, service};
use bytes::Bytes;
use http::{Request, Response, StatusCode}; use http::{Request, Response, StatusCode};
use hyper::{Body, Server}; use hyper::{Body, Server};
use serde::Deserialize; use serde::Deserialize;
@ -137,7 +138,7 @@ async fn body_with_length_limit() {
let app = route( let app = route(
"/", "/",
post( post(
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move { |req: Request<Body>, _body: extract::ContentLengthLimit<Bytes, LIMIT>| async move {
dbg!(&req); dbg!(&req);
}, },
), ),