From 6a078ddb719e9a8fac72b5b69c0fce72fb7994a8 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 2 Aug 2021 22:40:33 +0200 Subject: [PATCH] Fix stripping prefix when nesting at `/` (#91) * Fix stripping prefix when nesting at `/` Fixes https://github.com/tokio-rs/axum/issues/88 * changelog --- CHANGELOG.md | 2 +- src/routing.rs | 10 ++-- src/tests.rs | 101 ++----------------------------------- src/tests/nest.rs | 123 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 101 deletions(-) create mode 100644 src/tests/nest.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index eb7c63c2..1247e3fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -None. +- Fix stripping prefix when nesting services at `/` ([#91](https://github.com/tokio-rs/axum/pull/91)) ## Breaking changes diff --git a/src/routing.rs b/src/routing.rs index a7ed9ab2..21f79d41 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -958,15 +958,17 @@ where fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { let path_and_query = if let Some(path_and_query) = uri.path_and_query() { - let mut new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { + let new_path = if let Some(path) = path_and_query.path().strip_prefix(prefix) { path } else { path_and_query.path() }; - if new_path.is_empty() { - new_path = "/"; - } + let new_path = if new_path.starts_with('/') { + Cow::Borrowed(new_path) + } else { + Cow::Owned(format!("/{}", new_path)) + }; if let Some(query) = path_and_query.query() { Some( diff --git a/src/tests.rs b/src/tests.rs index d476109d..6d8e5973 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,10 +1,10 @@ use crate::{ - extract::RequestParts, handler::on, prelude::*, response::IntoResponse, routing::MethodFilter, - service, + extract::RequestParts, handler::on, prelude::*, response::IntoResponse, routing::nest, + routing::MethodFilter, service, }; use bytes::Bytes; use futures_util::future::Ready; -use http::{header::AUTHORIZATION, Request, Response, StatusCode}; +use http::{header::AUTHORIZATION, Request, Response, StatusCode, Uri}; use hyper::{Body, Server}; use serde::Deserialize; use serde_json::json; @@ -17,6 +17,8 @@ use std::{ use tower::{make::Shared, service_fn, BoxError, Service, ServiceBuilder}; use tower_http::{compression::CompressionLayer, trace::TraceLayer}; +mod nest; + #[tokio::test] async fn hello_world() { async fn root(_: Request) -> &'static str { @@ -521,72 +523,6 @@ async fn layer_on_whole_router() { assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); } -#[tokio::test] -async fn disjunction() { - let api_routes = route( - "/users", - get(|| async { "users#index" }).post(|| async { "users#create" }), - ) - .route( - "/users/:id", - get(|params: extract::UrlParamsMap| async move { - format!( - "{}: users#show ({})", - params.get("version").unwrap(), - params.get("id").unwrap() - ) - }), - ) - .route( - "/games/:id", - get(|params: extract::UrlParamsMap| async move { - format!( - "{}: games#show ({})", - params.get("version").unwrap(), - params.get("id").unwrap() - ) - }), - ); - - let app = route("/", get(|| async { "hi" })).nest("/:version/api", api_routes); - - let addr = run_in_background(app).await; - - let client = reqwest::Client::new(); - - let res = client - .get(format!("http://{}/", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "hi"); - - let res = client - .get(format!("http://{}/v0/api/users", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "users#index"); - - let res = client - .get(format!("http://{}/v0/api/users/123", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "v0: users#show (123)"); - - let res = client - .get(format!("http://{}/v0/api/games/123", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await.unwrap(), "v0: games#show (123)"); -} - #[tokio::test] async fn typed_header() { use extract::TypedHeader; @@ -716,33 +652,6 @@ async fn wrong_method_handler() { assert_eq!(res.status(), StatusCode::NOT_FOUND); } -#[tokio::test] -async fn wrong_method_nest() { - let nested_app = route("/", get(|| async {})); - let app = crate::routing::nest("/", nested_app); - - let addr = run_in_background(app).await; - - let client = reqwest::Client::new(); - - let res = client.get(format!("http://{}", addr)).send().await.unwrap(); - assert_eq!(res.status(), StatusCode::OK); - - let res = client - .post(format!("http://{}", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); - - let res = client - .patch(format!("http://{}/foo", addr)) - .send() - .await - .unwrap(); - assert_eq!(res.status(), StatusCode::NOT_FOUND); -} - #[tokio::test] async fn wrong_method_service() { #[derive(Clone)] diff --git a/src/tests/nest.rs b/src/tests/nest.rs new file mode 100644 index 00000000..d1b0ac37 --- /dev/null +++ b/src/tests/nest.rs @@ -0,0 +1,123 @@ +use super::*; + +#[tokio::test] +async fn nesting_apps() { + let api_routes = route( + "/users", + get(|| async { "users#index" }).post(|| async { "users#create" }), + ) + .route( + "/users/:id", + get(|params: extract::UrlParamsMap| async move { + format!( + "{}: users#show ({})", + params.get("version").unwrap(), + params.get("id").unwrap() + ) + }), + ) + .route( + "/games/:id", + get(|params: extract::UrlParamsMap| async move { + format!( + "{}: games#show ({})", + params.get("version").unwrap(), + params.get("id").unwrap() + ) + }), + ); + + let app = route("/", get(|| async { "hi" })).nest("/:version/api", api_routes); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "hi"); + + let res = client + .get(format!("http://{}/v0/api/users", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "users#index"); + + let res = client + .get(format!("http://{}/v0/api/users/123", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "v0: users#show (123)"); + + let res = client + .get(format!("http://{}/v0/api/games/123", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "v0: games#show (123)"); +} + +#[tokio::test] +async fn wrong_method_nest() { + let nested_app = route("/", get(|| async {})); + let app = crate::routing::nest("/", nested_app); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client.get(format!("http://{}", addr)).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .post(format!("http://{}", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); + + let res = client + .patch(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn nesting_at_root() { + let app = nest("/", get(|uri: Uri| async move { uri.to_string() })); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client.get(format!("http://{}", addr)).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "/"); + + let res = client + .get(format!("http://{}/foo", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "/foo"); + + let res = client + .get(format!("http://{}/foo/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), "/foo/bar"); +}