From bc27b09f5ca54a6d4dfae9faaccc5b22d154bc28 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 8 Aug 2021 17:27:23 +0200 Subject: [PATCH] Make sure nested services still see full URI (#166) They'd previously see the nested URI as we mutated the request. Now we always route based on the nested URI (if present) without mutating the request. Also meant we could get rid of `OriginalUri` which is nice. --- src/extract/mod.rs | 8 ++------ src/routing/mod.rs | 46 ++++++++++++++++++++++++---------------------- src/tests/nest.rs | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 28 deletions(-) diff --git a/src/extract/mod.rs b/src/extract/mod.rs index faa4ddbb..635bfd89 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -244,7 +244,7 @@ //! //! [`body::Body`]: crate::body::Body -use crate::{response::IntoResponse, routing::OriginalUri}; +use crate::response::IntoResponse; use async_trait::async_trait; use http::{header, Extensions, HeaderMap, Method, Request, Uri, Version}; use rejection::*; @@ -378,7 +378,7 @@ impl RequestParts { let ( http::request::Parts { method, - mut uri, + uri, version, headers, extensions, @@ -387,10 +387,6 @@ impl RequestParts { body, ) = req.into_parts(); - if let Some(original_uri) = extensions.get::() { - uri = original_uri.0.clone(); - }; - RequestParts { method, uri, diff --git a/src/routing/mod.rs b/src/routing/mod.rs index c4c16209..3d6d4e90 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -459,7 +459,7 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - if let Some(captures) = self.pattern.full_match(req.uri().path()) { + if let Some(captures) = self.pattern.full_match(&req) { insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); RouteFuture::a(fut) @@ -606,8 +606,8 @@ impl PathPattern { })) } - pub(crate) fn full_match(&self, path: &str) -> Option { - self.do_match(path).and_then(|match_| { + pub(crate) fn full_match(&self, req: &Request) -> Option { + self.do_match(req).and_then(|match_| { if match_.full_match { Some(match_.captures) } else { @@ -616,12 +616,18 @@ impl PathPattern { }) } - pub(crate) fn prefix_match<'a>(&self, path: &'a str) -> Option<(&'a str, Captures)> { - self.do_match(path) + pub(crate) fn prefix_match<'a, B>(&self, req: &'a Request) -> Option<(&'a str, Captures)> { + self.do_match(req) .map(|match_| (match_.matched, match_.captures)) } - fn do_match<'a>(&self, path: &'a str) -> Option> { + fn do_match<'a, B>(&self, req: &'a Request) -> Option> { + let path = if let Some(nested_uri) = req.extensions().get::() { + nested_uri.0.path() + } else { + req.uri().path() + }; + self.0.full_path_regex.captures(path).map(|captures| { let matched = captures.get(0).unwrap(); let full_match = matched.as_str() == path; @@ -864,16 +870,15 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - if req.extensions().get::().is_none() { - let original_uri = OriginalUri(req.uri().clone()); - req.extensions_mut().insert(original_uri); - } + let f = if let Some((prefix, captures)) = self.pattern.prefix_match(&req) { + let uri = if let Some(nested_uri) = req.extensions().get::() { + &nested_uri.0 + } else { + req.uri() + }; - let f = if let Some((prefix, captures)) = self.pattern.prefix_match(req.uri().path()) { - let without_prefix = strip_prefix(req.uri(), prefix); - req.extensions_mut() - .insert(NestedUri(without_prefix.clone())); - *req.uri_mut() = without_prefix; + let without_prefix = strip_prefix(uri, prefix); + req.extensions_mut().insert(NestedUri(without_prefix)); insert_url_params(&mut req, captures); let fut = self.svc.clone().oneshot(req); @@ -887,11 +892,6 @@ where } } -/// `Nested` changes the incoming requests URI. This will be saved as an -/// extension so extractors can still access the original URI. -#[derive(Clone)] -pub(crate) struct OriginalUri(pub(crate) Uri); - fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { let path_and_query = if let Some(path_and_query) = uri.path_and_query() { let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { @@ -953,8 +953,9 @@ mod tests { fn assert_match(route_spec: &'static str, path: &'static str) { let route = PathPattern::new(route_spec); + let req = Request::builder().uri(path).body(()).unwrap(); assert!( - route.full_match(path).is_some(), + route.full_match(&req).is_some(), "`{}` doesn't match `{}`", path, route_spec @@ -963,8 +964,9 @@ mod tests { fn refute_match(route_spec: &'static str, path: &'static str) { let route = PathPattern::new(route_spec); + let req = Request::builder().uri(path).body(()).unwrap(); assert!( - route.full_match(path).is_none(), + route.full_match(&req).is_none(), "`{}` did match `{}` (but shouldn't)", path, route_spec diff --git a/src/tests/nest.rs b/src/tests/nest.rs index bf2dea24..bde77a42 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -1,4 +1,5 @@ use super::*; +use crate::body::box_body; use std::collections::HashMap; #[tokio::test] @@ -149,6 +150,7 @@ async fn nested_url_extractor() { .send() .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); let res = client @@ -156,6 +158,7 @@ async fn nested_url_extractor() { .send() .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/foo/bar/qux"); } @@ -181,5 +184,35 @@ async fn nested_url_nested_extractor() { .send() .await .unwrap(); + assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await.unwrap(), "/baz"); } + +#[tokio::test] +async fn nested_service_sees_original_uri() { + let app = nest( + "/foo", + nest( + "/bar", + route( + "/baz", + service_fn(|req: Request| async move { + let body = box_body(Body::from(req.uri().to_string())); + Ok::<_, Infallible>(Response::new(body)) + }), + ), + ), + ); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "/foo/bar/baz"); +}