diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index 3612d211..eeb70970 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -1,6 +1,14 @@ use std::{convert::Infallible, fmt}; -use crate::{body::HttpBody, handler::Handler, routing::Route, Router}; +use http::Request; +use tower::Service; + +use crate::{ + body::HttpBody, + handler::Handler, + routing::{future::RouteFuture, Route}, + Router, +}; pub(crate) struct BoxedIntoRoute(Box>); @@ -30,6 +38,14 @@ where into_route: |router, state| Route::new(router.with_state(state)), })) } + + pub(crate) fn call_with_state( + self, + request: Request, + state: S, + ) -> RouteFuture { + self.0.call_with_state(request, state) + } } impl BoxedIntoRoute { @@ -39,7 +55,7 @@ impl BoxedIntoRoute { B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { BoxedIntoRoute(Box::new(Map { @@ -69,6 +85,8 @@ pub(crate) trait ErasedIntoRoute: Send { fn clone_box(&self) -> Box>; fn into_route(self: Box, state: S) -> Route; + + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture; } pub(crate) struct MakeErasedHandler { @@ -89,6 +107,14 @@ where fn into_route(self: Box, state: S) -> Route { (self.into_route)(self.handler, state) } + + fn call_with_state( + self: Box, + request: Request, + state: S, + ) -> RouteFuture { + todo!() + } } impl Clone for MakeErasedHandler @@ -110,8 +136,8 @@ pub(crate) struct MakeErasedRouter { impl ErasedIntoRoute for MakeErasedRouter where - S: Clone + Send + 'static, - B: 'static, + S: Clone + Send + Sync + 'static, + B: HttpBody + Send + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) @@ -120,6 +146,14 @@ where fn into_route(self: Box, state: S) -> Route { (self.into_route)(self.router, state) } + + fn call_with_state( + mut self: Box, + request: Request, + state: S, + ) -> RouteFuture { + self.router.call_with_state(request, state) + } } impl Clone for MakeErasedRouter @@ -144,7 +178,7 @@ where S: 'static, B: 'static, E: 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { fn clone_box(&self) -> Box> { @@ -157,6 +191,11 @@ where fn into_route(self: Box, state: S) -> Route { (self.layer)(self.inner.into_route(state)) } + + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture { + let route = (self.layer)(self.inner.into_route(state)); + route.call(request) + } } pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 7861c403..7f08f3fd 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -407,6 +407,102 @@ where pub fn with_state(self, state: S) -> RouterService { RouterService::new(self, state) } + + pub(crate) fn call_with_state( + &mut self, + mut req: Request, + state: S, + ) -> RouteFuture { + #[cfg(feature = "original-uri")] + { + use crate::extract::OriginalUri; + + if req.extensions().get::().is_none() { + let original_uri = OriginalUri(req.uri().clone()); + req.extensions_mut().insert(original_uri); + } + } + + let path = req.uri().path().to_owned(); + + match self.node.at(&path) { + Ok(match_) => { + match &self.fallback { + Fallback::Default(_) => {} + Fallback::Service(fallback) => { + req.extensions_mut() + .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); + } + Fallback::BoxedHandler(fallback) => { + req.extensions_mut().insert(SuperFallback(SyncWrapper::new( + fallback.clone().into_route(state.clone()), + ))); + } + } + + self.call_route(match_, req, state) + } + Err( + MatchError::NotFound + | MatchError::ExtraTrailingSlash + | MatchError::MissingTrailingSlash, + ) => { + match &mut self.fallback { + Fallback::Default(fallback) => { + if let Some(super_fallback) = + req.extensions_mut().remove::>() + { + let mut super_fallback = super_fallback.0.into_inner(); + super_fallback.call(req) + } else { + fallback.call(req) + } + } + Fallback::Service(fallback) => fallback.call(req), + Fallback::BoxedHandler(handler) => { + todo!() + // handler.clone().into_route(state).call(req) + } + } + } + } + } + + // TODO(david): fix duplication + #[inline] + fn call_route( + &self, + match_: matchit::Match<&RouteId>, + mut req: Request, + state: S, + ) -> RouteFuture { + let id = *match_.value; + + #[cfg(feature = "matched-path")] + crate::extract::matched_path::set_matched_path_for_request( + id, + &self.node.route_id_to_path, + req.extensions_mut(), + ); + + url_params::insert_url_params(req.extensions_mut(), match_.params); + + let endpont = self + .routes + .get(&id) + .expect("no route for id. This is a bug in axum. Please file an issue") + .clone(); + + match endpont { + Endpoint::MethodRouter(method_router) => { + // method_router.call(req) + todo!() + } + Endpoint::Route(mut route) => route.call(req), + // TODO(david): optimize? + Endpoint::NestedRouter(router) => router.call_with_state(req, state), + } + } } impl Router<(), B> @@ -460,38 +556,6 @@ where ) -> IntoMakeServiceWithConnectInfo, C> { IntoMakeServiceWithConnectInfo::new(self.into_service()) } - - // TODO(david): fix duplication - #[inline] - fn call_route( - &self, - match_: matchit::Match<&RouteId>, - mut req: Request, - ) -> RouteFuture { - let id = *match_.value; - - #[cfg(feature = "matched-path")] - crate::extract::matched_path::set_matched_path_for_request( - id, - &self.node.route_id_to_path, - req.extensions_mut(), - ); - - url_params::insert_url_params(req.extensions_mut(), match_.params); - - let endpont = self - .routes - .get(&id) - .expect("no route for id. This is a bug in axum. Please file an issue") - .clone(); - - match endpont { - Endpoint::MethodRouter(mut method_router) => method_router.call(req), - Endpoint::Route(mut route) => route.call(req), - // TODO(david): optimize? - Endpoint::NestedRouter(router) => router.into_route(()).call(req), - } - } } // TODO(david): fix duplication @@ -509,58 +573,8 @@ where } #[inline] - fn call(&mut self, mut req: Request) -> Self::Future { - #[cfg(feature = "original-uri")] - { - use crate::extract::OriginalUri; - - if req.extensions().get::().is_none() { - let original_uri = OriginalUri(req.uri().clone()); - req.extensions_mut().insert(original_uri); - } - } - - let path = req.uri().path().to_owned(); - - match self.node.at(&path) { - Ok(match_) => { - match &self.fallback { - Fallback::Default(_) => {} - Fallback::Service(fallback) => { - req.extensions_mut() - .insert(SuperFallback(SyncWrapper::new(fallback.clone()))); - } - Fallback::BoxedHandler(fallback) => { - req.extensions_mut().insert(SuperFallback(SyncWrapper::new( - fallback.clone().into_route(()), - ))); - } - } - - self.call_route(match_, req) - } - Err( - MatchError::NotFound - | MatchError::ExtraTrailingSlash - | MatchError::MissingTrailingSlash, - ) => { - // - match &mut self.fallback { - Fallback::Default(fallback) => { - if let Some(super_fallback) = - req.extensions_mut().remove::>() - { - let mut super_fallback = super_fallback.0.into_inner(); - super_fallback.call(req) - } else { - fallback.call(req) - } - } - Fallback::Service(fallback) => fallback.call(req), - Fallback::BoxedHandler(handler) => handler.clone().into_route(()).call(req), - } - } - } + fn call(&mut self, req: Request) -> Self::Future { + self.call_with_state(req, ()) } } @@ -639,7 +653,7 @@ where B: 'static, E: 'static, F: FnOnce(Route) -> Route + Clone + Send + 'static, - B2: 'static, + B2: HttpBody + 'static, E2: 'static, { match self {