mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 23:34:47 +00:00
A bit of clean up
This commit is contained in:
parent
07294378b3
commit
33f2e5f661
257
src/lib.rs
257
src/lib.rs
@ -29,7 +29,7 @@ Tests
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::future;
|
use futures_util::{future, ready};
|
||||||
use http::{Method, Request, Response, StatusCode};
|
use http::{Method, Request, Response, StatusCode};
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
@ -47,26 +47,31 @@ pub fn app() -> App<EmptyRouter> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct App<R> {
|
pub struct App<R> {
|
||||||
router: R,
|
router: R,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> App<R> {
|
impl<R> App<R> {
|
||||||
pub fn at(self, route_spec: &str) -> RouteBuilder<R> {
|
pub fn at(self, route_spec: &str) -> RouteAt<R> {
|
||||||
RouteBuilder {
|
self.at_bytes(Bytes::copy_from_slice(route_spec.as_bytes()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn at_bytes(self, route_spec: Bytes) -> RouteAt<R> {
|
||||||
|
RouteAt {
|
||||||
app: self,
|
app: self,
|
||||||
route_spec: Bytes::copy_from_slice(route_spec.as_bytes()),
|
route_spec,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RouteBuilder<R> {
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RouteAt<R> {
|
||||||
app: App<R>,
|
app: App<R>,
|
||||||
route_spec: Bytes,
|
route_spec: Bytes,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> RouteBuilder<R> {
|
impl<R> RouteAt<R> {
|
||||||
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||||
where
|
where
|
||||||
F: Handler<T>,
|
F: Handler<T>,
|
||||||
@ -81,14 +86,6 @@ impl<R> RouteBuilder<R> {
|
|||||||
self.add_route(handler_fn, Method::POST)
|
self.add_route(handler_fn, Method::POST)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn at(self, route_spec: &str) -> Self {
|
|
||||||
self.app.at(route_spec)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_service(self) -> App<R> {
|
|
||||||
self.app
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
|
fn add_route<H, T>(self, handler: H, method: Method) -> RouteBuilder<Route<HandlerSvc<H, T>, R>>
|
||||||
where
|
where
|
||||||
H: Handler<T>,
|
H: Handler<T>,
|
||||||
@ -104,6 +101,8 @@ impl<R> RouteBuilder<R> {
|
|||||||
spec: self.route_spec.clone(),
|
spec: self.route_spec.clone(),
|
||||||
},
|
},
|
||||||
fallback: self.app.router,
|
fallback: self.app.router,
|
||||||
|
handler_ready: false,
|
||||||
|
fallback_ready: false,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -114,9 +113,47 @@ impl<R> RouteBuilder<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct RouteBuilder<R> {
|
||||||
|
app: App<R>,
|
||||||
|
route_spec: Bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> RouteBuilder<R> {
|
||||||
|
pub fn at(self, route_spec: &str) -> RouteAt<R> {
|
||||||
|
self.app.at(route_spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||||
|
where
|
||||||
|
F: Handler<T>,
|
||||||
|
{
|
||||||
|
self.app.at_bytes(self.route_spec).get(handler_fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post<F, T>(self, handler_fn: F) -> RouteBuilder<Route<HandlerSvc<F, T>, R>>
|
||||||
|
where
|
||||||
|
F: Handler<T>,
|
||||||
|
{
|
||||||
|
self.app.at_bytes(self.route_spec).post(handler_fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub enum Error {}
|
pub enum Error {
|
||||||
|
#[error("failed to deserialize the request body")]
|
||||||
|
DeserializeRequestBody(#[source] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("failed to consume the body")]
|
||||||
|
ConsumeBody(#[source] hyper::Error),
|
||||||
|
|
||||||
|
#[error("URI contained no query string")]
|
||||||
|
QueryStringMissing,
|
||||||
|
|
||||||
|
#[error("failed to deserialize query string")]
|
||||||
|
DeserializeQueryString(#[from] serde_urlencoded::de::Error),
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Handler<Out> {
|
pub trait Handler<Out> {
|
||||||
@ -136,38 +173,50 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
macro_rules! impl_handler {
|
||||||
|
( $head:ident $(,)? ) => {
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
impl<F, Fut, T1> Handler<(T1,)> for F
|
impl<F, Fut, $head> Handler<($head,)> for F
|
||||||
where
|
where
|
||||||
F: Fn(Request<Body>, T1) -> Fut + Send + Sync,
|
F: Fn(Request<Body>, $head) -> Fut + Send + Sync,
|
||||||
Fut: Future<Output = Result<Response<Body>, Error>> + Send,
|
Fut: Future<Output = Result<Response<Body>, Error>> + Send,
|
||||||
T1: FromRequest + Send,
|
$head: FromRequest + Send,
|
||||||
{
|
{
|
||||||
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
|
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
|
||||||
let T1 = T1::from_request(&mut req).await;
|
let $head = $head::from_request(&mut req).await?;
|
||||||
let res = self(req, T1).await?;
|
let res = self(req, $head).await?;
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
( $head:ident, $($tail:ident),* $(,)? ) => {
|
||||||
|
#[async_trait]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
impl<F, Fut, $head, $($tail,)*> Handler<($head, $($tail,)*)> for F
|
||||||
|
where
|
||||||
|
F: Fn(Request<Body>, $head, $($tail,)*) -> Fut + Send + Sync,
|
||||||
|
Fut: Future<Output = Result<Response<Body>, Error>> + Send,
|
||||||
|
$head: FromRequest + Send,
|
||||||
|
$( $tail: FromRequest + Send, )*
|
||||||
|
{
|
||||||
|
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
|
||||||
|
let $head = $head::from_request(&mut req).await?;
|
||||||
|
$(
|
||||||
|
let $tail = $tail::from_request(&mut req).await?;
|
||||||
|
)*
|
||||||
|
let res = self(req, $head, $($tail,)*).await?;
|
||||||
Ok(res)
|
Ok(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
impl_handler!($($tail,)*);
|
||||||
#[allow(non_snake_case)]
|
};
|
||||||
impl<F, Fut, T1, T2> Handler<(T1, T2)> for F
|
|
||||||
where
|
|
||||||
F: Fn(Request<Body>, T1, T2) -> Fut + Send + Sync,
|
|
||||||
Fut: Future<Output = Result<Response<Body>, Error>> + Send,
|
|
||||||
T1: FromRequest + Send,
|
|
||||||
T2: FromRequest + Send,
|
|
||||||
{
|
|
||||||
async fn call(self, mut req: Request<Body>) -> Result<Response<Body>, Error> {
|
|
||||||
let T1 = T1::from_request(&mut req).await;
|
|
||||||
let T2 = T2::from_request(&mut req).await;
|
|
||||||
let res = self(req, T1, T2).await?;
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl_handler!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
|
||||||
|
|
||||||
pub struct HandlerSvc<H, T> {
|
pub struct HandlerSvc<H, T> {
|
||||||
handler: H,
|
handler: H,
|
||||||
_input: PhantomData<fn() -> T>,
|
_input: PhantomData<fn() -> T>,
|
||||||
@ -206,83 +255,71 @@ where
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait FromRequest: Sized {
|
pub trait FromRequest: Sized {
|
||||||
async fn from_request(req: &mut Request<Body>) -> Self;
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Query<T>(Result<T, QueryError>);
|
#[async_trait]
|
||||||
|
impl<T> FromRequest for Option<T>
|
||||||
|
where
|
||||||
|
T: FromRequest,
|
||||||
|
{
|
||||||
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||||
|
Ok(T::from_request(req).await.ok())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Query<T>(T);
|
||||||
|
|
||||||
impl<T> Query<T> {
|
impl<T> Query<T> {
|
||||||
pub fn into_inner(self) -> Result<T, QueryError> {
|
pub fn into_inner(self) -> T {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum QueryError {
|
|
||||||
#[error("URI contained no query string")]
|
|
||||||
Missing,
|
|
||||||
#[error("failed to deserialize query string")]
|
|
||||||
Deserialize(#[from] serde_urlencoded::de::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T> FromRequest for Query<T>
|
impl<T> FromRequest for Query<T>
|
||||||
where
|
where
|
||||||
T: DeserializeOwned,
|
T: DeserializeOwned,
|
||||||
{
|
{
|
||||||
async fn from_request(req: &mut Request<Body>) -> Self {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||||
let result = (|| {
|
let query = req.uri().query().ok_or(Error::QueryStringMissing)?;
|
||||||
let query = req.uri().query().ok_or(QueryError::Missing)?;
|
|
||||||
let value = serde_urlencoded::from_str(query)?;
|
let value = serde_urlencoded::from_str(query)?;
|
||||||
Ok(value)
|
Ok(Query(value))
|
||||||
})();
|
|
||||||
Query(result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Json<T>(Result<T, JsonError>);
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct Json<T>(T);
|
||||||
|
|
||||||
impl<T> Json<T> {
|
impl<T> Json<T> {
|
||||||
pub fn into_inner(self) -> Result<T, JsonError> {
|
pub fn into_inner(self) -> T {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
#[non_exhaustive]
|
|
||||||
pub enum JsonError {
|
|
||||||
#[error("failed to consume the body")]
|
|
||||||
ConsumeBody(#[from] hyper::Error),
|
|
||||||
#[error("failed to deserialize the body")]
|
|
||||||
Deserialize(#[from] serde_json::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T> FromRequest for Json<T>
|
impl<T> FromRequest for Json<T>
|
||||||
where
|
where
|
||||||
T: DeserializeOwned,
|
T: DeserializeOwned,
|
||||||
{
|
{
|
||||||
async fn from_request(req: &mut Request<Body>) -> Self {
|
async fn from_request(req: &mut Request<Body>) -> Result<Self, Error> {
|
||||||
// TODO(david): require the body to have `content-type: application/json`
|
// TODO(david): require the body to have `content-type: application/json`
|
||||||
|
|
||||||
let body = std::mem::take(req.body_mut());
|
let body = std::mem::take(req.body_mut());
|
||||||
|
|
||||||
let result = async move {
|
let bytes = hyper::body::to_bytes(body)
|
||||||
let bytes = hyper::body::to_bytes(body).await?;
|
.await
|
||||||
let value = serde_json::from_slice(&bytes)?;
|
.map_err(Error::ConsumeBody)?;
|
||||||
Ok(value)
|
let value = serde_json::from_slice(&bytes).map_err(Error::DeserializeRequestBody)?;
|
||||||
}
|
Ok(Json(value))
|
||||||
.await;
|
|
||||||
|
|
||||||
Json(result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub struct EmptyRouter(());
|
pub struct EmptyRouter(());
|
||||||
|
|
||||||
impl Service<Request<Body>> for EmptyRouter {
|
impl<R> Service<R> for EmptyRouter {
|
||||||
type Response = Response<Body>;
|
type Response = Response<Body>;
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
type Future = future::Ready<Result<Self::Response, Self::Error>>;
|
type Future = future::Ready<Result<Self::Response, Self::Error>>;
|
||||||
@ -291,7 +328,7 @@ impl Service<Request<Body>> for EmptyRouter {
|
|||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, _req: Request<Body>) -> Self::Future {
|
fn call(&mut self, _req: R) -> Self::Future {
|
||||||
let mut res = Response::new(Body::empty());
|
let mut res = Response::new(Body::empty());
|
||||||
*res.status_mut() = StatusCode::NOT_FOUND;
|
*res.status_mut() = StatusCode::NOT_FOUND;
|
||||||
future::ready(Ok(res))
|
future::ready(Ok(res))
|
||||||
@ -303,6 +340,8 @@ pub struct Route<H, F> {
|
|||||||
handler: H,
|
handler: H,
|
||||||
route_spec: RouteSpec,
|
route_spec: RouteSpec,
|
||||||
fallback: F,
|
fallback: F,
|
||||||
|
handler_ready: bool,
|
||||||
|
fallback_ready: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -320,51 +359,76 @@ impl RouteSpec {
|
|||||||
|
|
||||||
impl<H, F> Service<Request<Body>> for Route<H, F>
|
impl<H, F> Service<Request<Body>> for Route<H, F>
|
||||||
where
|
where
|
||||||
H: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone + Send + 'static,
|
H: Service<Request<Body>, Response = Response<Body>, Error = Error>,
|
||||||
H::Future: Send,
|
F: Service<Request<Body>, Response = Response<Body>, Error = Error>,
|
||||||
F: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone + Send + 'static,
|
|
||||||
F::Future: Send,
|
|
||||||
{
|
{
|
||||||
type Response = Response<Body>;
|
type Response = Response<Body>;
|
||||||
type Error = Error;
|
type Error = Error;
|
||||||
type Future = future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
type Future = future::Either<H::Future, F::Future>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
if !self.handler_ready {
|
||||||
|
ready!(self.handler.poll_ready(cx))?;
|
||||||
|
self.handler_ready = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.fallback_ready {
|
||||||
|
ready!(self.fallback.poll_ready(cx))?;
|
||||||
|
self.fallback_ready = true;
|
||||||
|
}
|
||||||
|
|
||||||
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
// TODO(david): do we need to drive readiness in `call`?
|
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||||
if self.route_spec.matches(&req) {
|
if self.route_spec.matches(&req) {
|
||||||
let handler_clone = self.handler.clone();
|
self.handler_ready = false;
|
||||||
let mut handler = std::mem::replace(&mut self.handler, handler_clone);
|
future::Either::Left(self.handler.call(req))
|
||||||
Box::pin(async move { handler.ready().await?.call(req).await })
|
|
||||||
} else {
|
} else {
|
||||||
let fallback_clone = self.fallback.clone();
|
self.fallback_ready = false;
|
||||||
let mut fallback = std::mem::replace(&mut self.fallback, fallback_clone);
|
future::Either::Right(self.fallback.call(req))
|
||||||
Box::pin(async move { fallback.ready().await?.call(req).await })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Service<Request<Body>> for App<R>
|
impl<R, T> Service<T> for App<R>
|
||||||
where
|
where
|
||||||
R: Service<Request<Body>, Response = Response<Body>, Error = Error> + Clone,
|
R: Service<T>,
|
||||||
{
|
{
|
||||||
type Response = Response<Body>;
|
type Response = R::Response;
|
||||||
type Error = Error;
|
type Error = R::Error;
|
||||||
type Future = R::Future;
|
type Future = R::Future;
|
||||||
|
|
||||||
// TODO(david): handle backpressure
|
#[inline]
|
||||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
self.router.poll_ready(cx)
|
self.router.poll_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
#[inline]
|
||||||
|
fn call(&mut self, req: T) -> Self::Future {
|
||||||
self.router.call(req)
|
self.router.call(req)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<R, T> Service<T> for RouteBuilder<R>
|
||||||
|
where
|
||||||
|
App<R>: Service<T>,
|
||||||
|
{
|
||||||
|
type Response = <App<R> as Service<T>>::Response;
|
||||||
|
type Error = <App<R> as Service<T>>::Error;
|
||||||
|
type Future = <App<R> as Service<T>>::Future;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.app.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn call(&mut self, req: T) -> Self::Future {
|
||||||
|
self.app.call(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
#![allow(warnings)]
|
#![allow(warnings)]
|
||||||
@ -377,8 +441,7 @@ mod tests {
|
|||||||
.get(root)
|
.get(root)
|
||||||
.at("/users")
|
.at("/users")
|
||||||
.get(users_index)
|
.get(users_index)
|
||||||
.post(users_create)
|
.post(users_create);
|
||||||
.into_service();
|
|
||||||
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user