mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-27 04:50:31 +00:00
axum-extra: implement FromRequest
for Either*
This commit is contained in:
parent
58f63bb4d0
commit
b7f815f10d
@ -1,6 +1,6 @@
|
||||
//! `Either*` types for combining extractors or responses into a single type.
|
||||
//!
|
||||
//! # As an extractor
|
||||
//! # As an `FromRequestParts` extractor
|
||||
//!
|
||||
//! ```
|
||||
//! use axum_extra::either::Either3;
|
||||
@ -54,6 +54,42 @@
|
||||
//! Note that if all the inner extractors reject the request, the rejection from the last
|
||||
//! extractor will be returned. For the example above that would be [`BytesRejection`].
|
||||
//!
|
||||
//! # As an `FromRequest` extractor
|
||||
//!
|
||||
//! In the following example, we can first try to deserialize the payload as JSON, if that fails try
|
||||
//! to interpret it as a UTF-8 string, and lastly just take the raw bytes.
|
||||
//!
|
||||
//! It might be preferable to instead extract `Bytes` directly and then fallibly convert them to
|
||||
//! `String` and then deserialize the data inside the handler.
|
||||
//!
|
||||
//! ```
|
||||
//! use axum_extra::either::Either3;
|
||||
//! use axum::{
|
||||
//! body::Bytes,
|
||||
//! Json,
|
||||
//! Router,
|
||||
//! routing::get,
|
||||
//! extract::FromRequestParts,
|
||||
//! };
|
||||
//!
|
||||
//! #[derive(serde::Deserialize)]
|
||||
//! struct Payload {
|
||||
//! user: String,
|
||||
//! request_id: u32,
|
||||
//! }
|
||||
//!
|
||||
//! async fn handler(
|
||||
//! body: Either3<Json<Payload>, String, Bytes>,
|
||||
//! ) {
|
||||
//! match body {
|
||||
//! Either3::E1(json) => { /* ... */ }
|
||||
//! Either3::E2(string) => { /* ... */ }
|
||||
//! Either3::E3(bytes) => { /* ... */ }
|
||||
//! }
|
||||
//! }
|
||||
//! #
|
||||
//! # let _: axum::routing::MethodRouter = axum::routing::get(handler);
|
||||
//! ```
|
||||
//! # As a response
|
||||
//!
|
||||
//! ```
|
||||
@ -93,9 +129,10 @@
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
extract::{rejection::BytesRejection, FromRequest, FromRequestParts, Request},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use http::request::Parts;
|
||||
use tower_layer::Layer;
|
||||
use tower_service::Service;
|
||||
@ -226,6 +263,28 @@ pub enum Either8<E1, E2, E3, E4, E5, E6, E7, E8> {
|
||||
E8(E8),
|
||||
}
|
||||
|
||||
/// Rejection used for [`Either`], [`Either3`], etc.
|
||||
///
|
||||
/// Contains one variant for a case when the whole request could not be loaded and one variant
|
||||
/// containing the rejection of the last variant if all extractors failed..
|
||||
#[derive(Debug)]
|
||||
pub enum EitherRejection<E> {
|
||||
/// Buffering of the request body failed.
|
||||
Bytes(BytesRejection),
|
||||
|
||||
/// All extractors failed. This contains the error returned by the last extractor.
|
||||
LastRejection(E),
|
||||
}
|
||||
|
||||
impl<E: IntoResponse> IntoResponse for EitherRejection<E> {
|
||||
fn into_response(self) -> Response {
|
||||
match self {
|
||||
EitherRejection::Bytes(rejection) => rejection.into_response(),
|
||||
EitherRejection::LastRejection(rejection) => rejection.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_traits_for_either {
|
||||
(
|
||||
$either:ident =>
|
||||
@ -251,6 +310,43 @@ macro_rules! impl_traits_for_either {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, $($ident),*, $last> FromRequest<S> for $either<$($ident),*, $last>
|
||||
where
|
||||
S: Send + Sync,
|
||||
$($ident: FromRequest<S>),*,
|
||||
$last: FromRequest<S>,
|
||||
$($ident::Rejection: Send),*,
|
||||
$last::Rejection: IntoResponse + Send,
|
||||
{
|
||||
type Rejection = EitherRejection<$last::Rejection>;
|
||||
|
||||
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let (parts, body) = req.into_parts();
|
||||
let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state)
|
||||
.await
|
||||
.map_err(EitherRejection::Bytes)?;
|
||||
|
||||
$(
|
||||
let req = Request::from_parts(
|
||||
parts.clone(),
|
||||
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
|
||||
);
|
||||
if let Ok(extracted) = $ident::from_request(req, state).await {
|
||||
return Ok(Self::$ident(extracted));
|
||||
}
|
||||
)*
|
||||
|
||||
let req = Request::from_parts(
|
||||
parts.clone(),
|
||||
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
|
||||
);
|
||||
match $last::from_request(req, state).await {
|
||||
Ok(extracted) => Ok(Self::$last(extracted)),
|
||||
Err(error) => Err(EitherRejection::LastRejection(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last>
|
||||
where
|
||||
$($ident: IntoResponse),*,
|
||||
@ -312,3 +408,67 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::future::Future;
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::extract::rejection::StringRejection;
|
||||
use axum::extract::{FromRequest, Request, State};
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct False;
|
||||
|
||||
impl<S> FromRequestParts<S> for False {
|
||||
type Rejection = ();
|
||||
|
||||
fn from_request_parts(
|
||||
_parts: &mut Parts,
|
||||
_state: &S,
|
||||
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
|
||||
std::future::ready(Err(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn either_from_request() {
|
||||
// The body is by design not valid UTF-8.
|
||||
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));
|
||||
|
||||
let either = Either4::<String, String, Request, Bytes>::from_request(request, &())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(either, Either4::E3(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn either_from_request_rejection() {
|
||||
// The body is by design not valid UTF-8.
|
||||
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));
|
||||
|
||||
let either = Either::<String, String>::from_request(request, &())
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(
|
||||
either,
|
||||
EitherRejection::LastRejection(StringRejection::InvalidUtf8(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn either_from_request_parts() {
|
||||
let (mut parts, _) = Request::new(Body::empty()).into_parts();
|
||||
|
||||
let either = Either3::<False, False, State<()>>::from_request_parts(&mut parts, &())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(either, Either3::E3(State(()))));
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user