mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-04 00:04:48 +00:00
Parameterize ContentLengthLimit
This commit is contained in:
parent
09f76f3c87
commit
90c3e5ba74
@ -23,7 +23,7 @@ use tower_http::{
|
|||||||
};
|
};
|
||||||
use tower_web::{
|
use tower_web::{
|
||||||
body::{Body, BoxBody},
|
body::{Body, BoxBody},
|
||||||
extract::{BytesMaxLength, Extension, UrlParams},
|
extract::{ContentLengthLimit, Extension, UrlParams},
|
||||||
prelude::*,
|
prelude::*,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::BoxRoute,
|
routing::BoxRoute,
|
||||||
@ -88,10 +88,10 @@ async fn kv_get(
|
|||||||
async fn kv_set(
|
async fn kv_set(
|
||||||
_req: Request<Body>,
|
_req: Request<Body>,
|
||||||
UrlParams((key,)): UrlParams<(String,)>,
|
UrlParams((key,)): UrlParams<(String,)>,
|
||||||
BytesMaxLength(value): BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
|
ContentLengthLimit(bytes): ContentLengthLimit<Bytes, { 1024 * 5_000 }>, // ~5mb
|
||||||
Extension(state): Extension<SharedState>,
|
Extension(state): Extension<SharedState>,
|
||||||
) {
|
) {
|
||||||
state.write().unwrap().db.insert(key, value);
|
state.write().unwrap().db.insert(key, bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list_keys(_req: Request<Body>, Extension(state): Extension<SharedState>) -> String {
|
async fn list_keys(_req: Request<Body>, Extension(state): Extension<SharedState>) -> String {
|
||||||
|
@ -388,14 +388,14 @@ impl FromRequest for Body {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extractor that will buffer request bodies up to a certain size.
|
/// Extractor that will reject requests with a body larger than some size.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```rust,no_run
|
/// ```rust,no_run
|
||||||
/// use tower_web::prelude::*;
|
/// use tower_web::prelude::*;
|
||||||
///
|
///
|
||||||
/// async fn handler(req: Request<Body>, body: extract::BytesMaxLength<1024>) {
|
/// async fn handler(req: Request<Body>, body: extract::ContentLengthLimit<String, 1024>) {
|
||||||
/// // ...
|
/// // ...
|
||||||
/// }
|
/// }
|
||||||
///
|
///
|
||||||
@ -404,15 +404,17 @@ impl FromRequest for Body {
|
|||||||
///
|
///
|
||||||
/// This requires the request to have a `Content-Length` header.
|
/// This requires the request to have a `Content-Length` header.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct BytesMaxLength<const N: u64>(pub Bytes);
|
pub struct ContentLengthLimit<T, const N: u64>(pub T);
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<const N: u64> FromRequest for BytesMaxLength<N> {
|
impl<T, const N: u64> FromRequest for ContentLengthLimit<T, N>
|
||||||
|
where
|
||||||
|
T: FromRequest,
|
||||||
|
{
|
||||||
type Rejection = Response<Body>;
|
type Rejection = Response<Body>;
|
||||||
|
|
||||||
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Self::Rejection> {
|
||||||
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
let content_length = req.headers().get(http::header::CONTENT_LENGTH).cloned();
|
||||||
let body = take_body(req).map_err(|reject| reject.into_response())?;
|
|
||||||
|
|
||||||
let content_length =
|
let content_length =
|
||||||
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
|
content_length.and_then(|value| value.to_str().ok()?.parse::<u64>().ok());
|
||||||
@ -425,11 +427,11 @@ impl<const N: u64> FromRequest for BytesMaxLength<N> {
|
|||||||
return Err(LengthRequired.into_response());
|
return Err(LengthRequired.into_response());
|
||||||
};
|
};
|
||||||
|
|
||||||
let bytes = hyper::body::to_bytes(body)
|
let value = T::from_request(req)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| FailedToBufferBody::from_err(e).into_response())?;
|
.map_err(IntoResponse::into_response)?;
|
||||||
|
|
||||||
Ok(BytesMaxLength(bytes))
|
Ok(Self(value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,16 +103,16 @@ define_rejection! {
|
|||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = PAYLOAD_TOO_LARGE]
|
#[status = PAYLOAD_TOO_LARGE]
|
||||||
#[body = "Request payload is too large"]
|
#[body = "Request payload is too large"]
|
||||||
/// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the
|
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
|
||||||
/// request body is too large.
|
/// the request body is too large.
|
||||||
pub struct PayloadTooLarge;
|
pub struct PayloadTooLarge;
|
||||||
}
|
}
|
||||||
|
|
||||||
define_rejection! {
|
define_rejection! {
|
||||||
#[status = LENGTH_REQUIRED]
|
#[status = LENGTH_REQUIRED]
|
||||||
#[body = "Content length header is required"]
|
#[body = "Content length header is required"]
|
||||||
/// Rejection type for [`BytesMaxLength`](super::BytesMaxLength) if the
|
/// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if
|
||||||
/// request is missing the `Content-Length` header or it is invalid.
|
/// the request is missing the `Content-Length` header or it is invalid.
|
||||||
pub struct LengthRequired;
|
pub struct LengthRequired;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
//! Handler future types.
|
//! Handler future types.
|
||||||
|
|
||||||
|
use crate::body::BoxBody;
|
||||||
use http::Response;
|
use http::Response;
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use crate::body::BoxBody;
|
|
||||||
|
|
||||||
opaque_future! {
|
opaque_future! {
|
||||||
/// The response future for [`IntoService`](super::IntoService).
|
/// The response future for [`IntoService`](super::IntoService).
|
||||||
|
@ -50,8 +50,8 @@ use crate::{
|
|||||||
service::HandleError,
|
service::HandleError,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures_util::future::Either;
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use futures_util::future::Either;
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use crate::{handler::on, prelude::*, routing::MethodFilter, service};
|
use crate::{handler::on, prelude::*, routing::MethodFilter, service};
|
||||||
|
use bytes::Bytes;
|
||||||
use http::{Request, Response, StatusCode};
|
use http::{Request, Response, StatusCode};
|
||||||
use hyper::{Body, Server};
|
use hyper::{Body, Server};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -137,7 +138,7 @@ async fn body_with_length_limit() {
|
|||||||
let app = route(
|
let app = route(
|
||||||
"/",
|
"/",
|
||||||
post(
|
post(
|
||||||
|req: Request<Body>, _body: extract::BytesMaxLength<LIMIT>| async move {
|
|req: Request<Body>, _body: extract::ContentLengthLimit<Bytes, LIMIT>| async move {
|
||||||
dbg!(&req);
|
dbg!(&req);
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user