axum-extra: implement FromRequest for Either*

This commit is contained in:
David Mládek 2025-04-25 19:13:10 +02:00
parent 58f63bb4d0
commit b7f815f10d

View File

@ -1,6 +1,6 @@
//! `Either*` types for combining extractors or responses into a single type.
//!
//! # As an extractor
//! # As an `FromRequestParts` extractor
//!
//! ```
//! use axum_extra::either::Either3;
@ -54,6 +54,42 @@
//! Note that if all the inner extractors reject the request, the rejection from the last
//! extractor will be returned. For the example above that would be [`BytesRejection`].
//!
//! # As an `FromRequest` extractor
//!
//! In the following example, we can first try to deserialize the payload as JSON, if that fails try
//! to interpret it as a UTF-8 string, and lastly just take the raw bytes.
//!
//! It might be preferable to instead extract `Bytes` directly and then fallibly convert them to
//! `String` and then deserialize the data inside the handler.
//!
//! ```
//! use axum_extra::either::Either3;
//! use axum::{
//! body::Bytes,
//! Json,
//! Router,
//! routing::get,
//! extract::FromRequestParts,
//! };
//!
//! #[derive(serde::Deserialize)]
//! struct Payload {
//! user: String,
//! request_id: u32,
//! }
//!
//! async fn handler(
//! body: Either3<Json<Payload>, String, Bytes>,
//! ) {
//! match body {
//! Either3::E1(json) => { /* ... */ }
//! Either3::E2(string) => { /* ... */ }
//! Either3::E3(bytes) => { /* ... */ }
//! }
//! }
//! #
//! # let _: axum::routing::MethodRouter = axum::routing::get(handler);
//! ```
//! # As a response
//!
//! ```
@ -93,9 +129,10 @@
use std::task::{Context, Poll};
use axum::{
extract::FromRequestParts,
extract::{rejection::BytesRejection, FromRequest, FromRequestParts, Request},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use http::request::Parts;
use tower_layer::Layer;
use tower_service::Service;
@ -226,6 +263,28 @@ pub enum Either8<E1, E2, E3, E4, E5, E6, E7, E8> {
E8(E8),
}
/// Rejection used for [`Either`], [`Either3`], etc.
///
/// Contains one variant for a case when the whole request could not be loaded and one variant
/// containing the rejection of the last variant if all extractors failed..
#[derive(Debug)]
pub enum EitherRejection<E> {
/// Buffering of the request body failed.
Bytes(BytesRejection),
/// All extractors failed. This contains the error returned by the last extractor.
LastRejection(E),
}
impl<E: IntoResponse> IntoResponse for EitherRejection<E> {
fn into_response(self) -> Response {
match self {
EitherRejection::Bytes(rejection) => rejection.into_response(),
EitherRejection::LastRejection(rejection) => rejection.into_response(),
}
}
}
macro_rules! impl_traits_for_either {
(
$either:ident =>
@ -251,6 +310,43 @@ macro_rules! impl_traits_for_either {
}
}
impl<S, $($ident),*, $last> FromRequest<S> for $either<$($ident),*, $last>
where
S: Send + Sync,
$($ident: FromRequest<S>),*,
$last: FromRequest<S>,
$($ident::Rejection: Send),*,
$last::Rejection: IntoResponse + Send,
{
type Rejection = EitherRejection<$last::Rejection>;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let (parts, body) = req.into_parts();
let bytes = Bytes::from_request(Request::from_parts(parts.clone(), body), state)
.await
.map_err(EitherRejection::Bytes)?;
$(
let req = Request::from_parts(
parts.clone(),
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
);
if let Ok(extracted) = $ident::from_request(req, state).await {
return Ok(Self::$ident(extracted));
}
)*
let req = Request::from_parts(
parts.clone(),
axum::body::Body::new(http_body_util::Full::new(bytes.clone())),
);
match $last::from_request(req, state).await {
Ok(extracted) => Ok(Self::$last(extracted)),
Err(error) => Err(EitherRejection::LastRejection(error)),
}
}
}
impl<$($ident),*, $last> IntoResponse for $either<$($ident),*, $last>
where
$($ident: IntoResponse),*,
@ -312,3 +408,67 @@ where
}
}
}
#[cfg(test)]
mod tests {
use std::future::Future;
use axum::body::Body;
use axum::extract::rejection::StringRejection;
use axum::extract::{FromRequest, Request, State};
use bytes::Bytes;
use http_body_util::Full;
use super::*;
struct False;
impl<S> FromRequestParts<S> for False {
type Rejection = ();
fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
std::future::ready(Err(()))
}
}
#[tokio::test]
async fn either_from_request() {
// The body is by design not valid UTF-8.
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));
let either = Either4::<String, String, Request, Bytes>::from_request(request, &())
.await
.unwrap();
assert!(matches!(either, Either4::E3(_)));
}
#[tokio::test]
async fn either_from_request_rejection() {
// The body is by design not valid UTF-8.
let request = Request::new(Body::new(Full::new(Bytes::from_static(&[255]))));
let either = Either::<String, String>::from_request(request, &())
.await
.unwrap_err();
assert!(matches!(
either,
EitherRejection::LastRejection(StringRejection::InvalidUtf8(_))
));
}
#[tokio::test]
async fn either_from_request_parts() {
let (mut parts, _) = Request::new(Body::empty()).into_parts();
let either = Either3::<False, False, State<()>>::from_request_parts(&mut parts, &())
.await
.unwrap();
assert!(matches!(either, Either3::E3(State(()))));
}
}