diff --git a/src/extract/mod.rs b/src/extract/mod.rs index faa4ddbb..635bfd89 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -244,7 +244,7 @@ //! //! [`body::Body`]: crate::body::Body -use crate::{response::IntoResponse, routing::OriginalUri}; +use crate::response::IntoResponse; use async_trait::async_trait; use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use rejection::*; @@ -378,7 +378,7 @@ impl RequestParts { let ( http::request::Parts { method, - mut uri, + uri, version, headers, extensions, @@ -387,10 +387,6 @@ impl RequestParts { body, ) = req.into_parts(); - if let Some(original_uri) = extensions.get::() { - uri = original_uri.0.clone(); - }; - RequestParts { method, uri, diff --git a/src/routing/mod.rs b/src/routing/mod.rs index c4c16209..3d6d4e90 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -459,7 +459,7 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(captures) = self.pattern.full_match(req.uri().path()) { + if let Some(captures) = self.pattern.full_match(&req) { insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) @@ -606,8 +606,8 @@ impl PathPattern { })) } - pub(crate) fn full_match(&self, path: &str) -> Option { - self.do_match(path).and_then(|match_| { + pub(crate) fn full_match(&self, req: &Request) -> Option { + self.do_match(req).and_then(|match_| { if match_.full_match { Some(match_.captures) } else { @@ -616,12 +616,18 @@ impl PathPattern { }) } - pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> { - self.do_match(path) + pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request) -> Option<(&'a str, Captures)> { + self.do_match(req) .map(|match_| (match_.matched, match_.captures)) } - fn do_match<'a>(&self, path: &'a str) -> Option> { + fn do_match<'a, B>(&self, req: &'a Request) -> Option> { + let path = if let Some(nested_uri) = req.extensions().get::() { + nested_uri.0.path() + } else { + req.uri().path() + }; + self.0.full_path_regex.captures(path).map(|captures| { let matched = captures.get(0).unwrap(); let full_match = matched.as_str() == path; @@ -864,16 +870,15 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - if req.extensions().get::().is_none() { - let original_uri = OriginalUri(req.uri().clone()); - req.extensions_mut().insert(original_uri); - } + let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) { + let uri = if let Some(nested_uri) = req.extensions().get::() { + &nested_uri.0 + } else { + req.uri() + }; - let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { - let without_prefix = strip_prefix(req.uri(), prefix); - req.extensions_mut() - .insert(NestedUri(without_prefix.clone())); - *req.uri_mut() = without_prefix; + let without_prefix = strip_prefix(uri, prefix); + req.extensions_mut().insert(NestedUri(without_prefix)); insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); @@ -887,11 +892,6 @@ where } } -/// `Nested` changes the incoming requests URI. This will be saved as an -/// extension so extractors can still access the original URI. -#[derive(Clone)] -pub(crate) struct OriginalUri(pub(crate) Uri); - fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { let path_and_query = if let Some(path_and_query) = uri.path_and_query() { let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { @@ -953,8 +953,9 @@ mod tests { fn assert_match(route_spec: &'static str, path: &'static str) { let route = PathPattern::new(route_spec); + let req = Request::builder().uri(path).body(()).unwrap(); assert!( - route.full_match(path).is_some(), + route.full_match(&req).is_some(), "`{}` doesn't match `{}`", path, route_spec @@ -963,8 +964,9 @@ mod tests { fn refute_match(route_spec: &'static str, path: &'static str) { let route = PathPattern::new(route_spec); + let req = Request::builder().uri(path).body(()).unwrap(); assert!( - route.full_match(path).is_none(), + route.full_match(&req).is_none(), "`{}` did match `{}` (but shouldn't)", path, route_spec diff --git a/src/tests/nest.rs b/src/tests/nest.rs index bf2dea24..bde77a42 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -1,4 +1,5 @@ use super::*; +use crate::body::box_body; use std::collections::HashMap; #[tokio::test] @@ -149,6 +150,7 @@ async fn nested_url_extractor() { .send() .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); let res = client @@ -156,6 +158,7 @@ async fn nested_url_extractor() { .send() .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/foo/bar/qux"); } @@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() { .send() .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/baz"); } + +#[tokio::test] +async fn nested_service_sees_original_uri() { + let app = nest( + "/foo", + nest( + "/bar", + route( + "/baz", + service_fn(|req: Request| async move { + let body = box_body(Body::from(req.uri().to_string())); + Ok::<_, Infallible>(Response::new(body)) + }), + ), + ), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); +}