mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-30 14:31:16 +00:00
Refactor TypedHeader
extractor (#189)
I should use `HeaderMapExt::typed_try_get` rather than implementing it manually.
This commit is contained in:
parent
48afd30491
commit
be7e9e9bc6
@ -341,40 +341,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rejection used for [`TypedHeader`](super::TypedHeader).
|
|
||||||
#[cfg(feature = "headers")]
|
#[cfg(feature = "headers")]
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||||
#[derive(Debug)]
|
pub use super::typed_header::TypedHeaderRejection;
|
||||||
pub struct TypedHeaderRejection {
|
|
||||||
pub(super) name: &'static http::header::HeaderName,
|
|
||||||
pub(super) err: headers::Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "headers")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
|
||||||
impl IntoResponse for TypedHeaderRejection {
|
|
||||||
type Body = Full<Bytes>;
|
|
||||||
type BodyError = Infallible;
|
|
||||||
|
|
||||||
fn into_response(self) -> http::Response<Self::Body> {
|
|
||||||
let mut res = self.to_string().into_response();
|
|
||||||
*res.status_mut() = http::StatusCode::BAD_REQUEST;
|
|
||||||
res
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "headers")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
|
||||||
impl std::fmt::Display for TypedHeaderRejection {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{} ({})", self.err, self.name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "headers")]
|
|
||||||
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
|
||||||
impl std::error::Error for TypedHeaderRejection {
|
|
||||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
|
||||||
Some(&self.err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
use super::{rejection::TypedHeaderRejection, FromRequest, RequestParts};
|
use super::{FromRequest, RequestParts};
|
||||||
|
use crate::response::IntoResponse;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use headers::HeaderMap;
|
use bytes::Bytes;
|
||||||
use std::ops::Deref;
|
use headers::HeaderMapExt;
|
||||||
|
use http_body::Full;
|
||||||
|
use std::{convert::Infallible, ops::Deref};
|
||||||
|
|
||||||
/// Extractor that extracts a typed header value from [`headers`].
|
/// Extractor that extracts a typed header value from [`headers`].
|
||||||
///
|
///
|
||||||
@ -36,19 +39,26 @@ where
|
|||||||
type Rejection = TypedHeaderRejection;
|
type Rejection = TypedHeaderRejection;
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
let empty_headers = HeaderMap::new();
|
let headers = if let Some(headers) = req.headers() {
|
||||||
let header_values = if let Some(headers) = req.headers() {
|
headers
|
||||||
headers.get_all(T::name())
|
|
||||||
} else {
|
} else {
|
||||||
empty_headers.get_all(T::name())
|
return Err(TypedHeaderRejection {
|
||||||
|
name: T::name(),
|
||||||
|
reason: Reason::Missing,
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
T::decode(&mut header_values.iter())
|
match headers.typed_try_get::<T>() {
|
||||||
.map(Self)
|
Ok(Some(value)) => Ok(Self(value)),
|
||||||
.map_err(|err| TypedHeaderRejection {
|
Ok(None) => Err(TypedHeaderRejection {
|
||||||
err,
|
|
||||||
name: T::name(),
|
name: T::name(),
|
||||||
})
|
reason: Reason::Missing,
|
||||||
|
}),
|
||||||
|
Err(err) => Err(TypedHeaderRejection {
|
||||||
|
name: T::name(),
|
||||||
|
reason: Reason::Error(err),
|
||||||
|
}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,3 +69,85 @@ impl<T> Deref for TypedHeader<T> {
|
|||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Rejection used for [`TypedHeader`](super::TypedHeader).
|
||||||
|
#[cfg(feature = "headers")]
|
||||||
|
#[cfg_attr(docsrs, doc(cfg(feature = "headers")))]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct TypedHeaderRejection {
|
||||||
|
name: &'static http::header::HeaderName,
|
||||||
|
reason: Reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum Reason {
|
||||||
|
Missing,
|
||||||
|
Error(headers::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for TypedHeaderRejection {
|
||||||
|
type Body = Full<Bytes>;
|
||||||
|
type BodyError = Infallible;
|
||||||
|
|
||||||
|
fn into_response(self) -> http::Response<Self::Body> {
|
||||||
|
let mut res = self.to_string().into_response();
|
||||||
|
*res.status_mut() = http::StatusCode::BAD_REQUEST;
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TypedHeaderRejection {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match &self.reason {
|
||||||
|
Reason::Missing => {
|
||||||
|
write!(f, "Header of type `{}` was missing", self.name)
|
||||||
|
}
|
||||||
|
Reason::Error(err) => {
|
||||||
|
write!(f, "{} ({})", err, self.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for TypedHeaderRejection {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match &self.reason {
|
||||||
|
Reason::Error(err) => Some(err),
|
||||||
|
Reason::Missing => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{handler::get, response::IntoResponse, route, tests::*};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn typed_header() {
|
||||||
|
async fn handle(
|
||||||
|
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
user_agent.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
let app = route("/", get(handle));
|
||||||
|
|
||||||
|
let addr = run_in_background(app).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
let res = client
|
||||||
|
.get(format!("http://{}", addr))
|
||||||
|
.header("user-agent", "foobar")
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let body = res.text().await.unwrap();
|
||||||
|
assert_eq!(body, "foobar");
|
||||||
|
|
||||||
|
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
||||||
|
let body = res.text().await.unwrap();
|
||||||
|
assert_eq!(body, "Header of type `user-agent` was missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -89,7 +89,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
body::BoxBody,
|
body::BoxBody,
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{future::RouteFuture, EmptyRouter, MethodFilter},
|
routing::{EmptyRouter, MethodFilter},
|
||||||
};
|
};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
|
@ -45,7 +45,7 @@ mod for_services {
|
|||||||
async fn get_handles_head() {
|
async fn get_handles_head() {
|
||||||
let app = route(
|
let app = route(
|
||||||
"/",
|
"/",
|
||||||
get(service_fn(|req: Request<Body>| async move {
|
get(service_fn(|_req: Request<Body>| async move {
|
||||||
let res = Response::builder()
|
let res = Response::builder()
|
||||||
.header("x-some-header", "foobar".parse::<HeaderValue>().unwrap())
|
.header("x-some-header", "foobar".parse::<HeaderValue>().unwrap())
|
||||||
.body(Body::from("you shouldn't see this"))
|
.body(Body::from("you shouldn't see this"))
|
||||||
|
@ -420,35 +420,6 @@ async fn middleware_on_single_route() {
|
|||||||
assert_eq!(body, "Hello, World!");
|
assert_eq!(body, "Hello, World!");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
#[cfg(feature = "header")]
|
|
||||||
async fn typed_header() {
|
|
||||||
use crate::{extract::TypedHeader, response::IntoResponse};
|
|
||||||
|
|
||||||
async fn handle(TypedHeader(user_agent): TypedHeader<headers::UserAgent>) -> impl IntoResponse {
|
|
||||||
user_agent.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
let app = route("/", get(handle));
|
|
||||||
|
|
||||||
let addr = run_in_background(app).await;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let res = client
|
|
||||||
.get(format!("http://{}", addr))
|
|
||||||
.header("user-agent", "foobar")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let body = res.text().await.unwrap();
|
|
||||||
assert_eq!(body, "foobar");
|
|
||||||
|
|
||||||
let res = client.get(format!("http://{}", addr)).send().await.unwrap();
|
|
||||||
let body = res.text().await.unwrap();
|
|
||||||
assert_eq!(body, "invalid HTTP header (user-agent)");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn service_in_bottom() {
|
async fn service_in_bottom() {
|
||||||
async fn handler(_req: Request<hyper::Body>) -> Result<Response<hyper::Body>, hyper::Error> {
|
async fn handler(_req: Request<hyper::Body>) -> Result<Response<hyper::Body>, hyper::Error> {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user