From 02a035fb140dd336d3bb7482f733143beef48867 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 25 Oct 2021 23:38:29 +0200 Subject: [PATCH] Add `MatchedPath` extractor (#412) Fixes #386 --- CHANGELOG.md | 3 ++ Cargo.toml | 1 + src/extract/matched_path.rs | 86 +++++++++++++++++++++++++++++++++++++ src/extract/mod.rs | 2 + src/extract/rejection.rs | 17 ++++++++ src/routing/mod.rs | 11 +++-- src/tests/mod.rs | 22 ++++++++++ 7 files changed, 139 insertions(+), 3 deletions(-) create mode 100644 src/extract/matched_path.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 67fd703a..6c22c799 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -136,6 +136,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 without trailing a slash. - **breaking:** `EmptyRouter` has been renamed to `MethodNotAllowed` as its only used in method routers and not in path routers (`Router`) +- **added:** Add `extract::MatchedPath` for accessing path in router that + matched request ([#412]) [#339]: https://github.com/tokio-rs/axum/pull/339 [#286]: https://github.com/tokio-rs/axum/pull/286 @@ -147,6 +149,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#404]: https://github.com/tokio-rs/axum/pull/404 [#405]: https://github.com/tokio-rs/axum/pull/405 [#408]: https://github.com/tokio-rs/axum/pull/408 +[#412]: https://github.com/tokio-rs/axum/pull/412 # 0.2.8 (07. October, 2021) diff --git a/Cargo.toml b/Cargo.toml index 6a6aec92..ecd331b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.6.1", features = ["macros", "rt", "rt-multi-thread", "net"] } tokio-stream = "0.1" +tracing = "0.1" uuid = { version = "0.8", features = ["serde", "v4"] } [dev-dependencies.tower] diff --git a/src/extract/matched_path.rs b/src/extract/matched_path.rs new file mode 100644 index 00000000..80947612 --- /dev/null +++ b/src/extract/matched_path.rs @@ -0,0 +1,86 @@ +use super::{rejection::*, FromRequest, RequestParts}; +use async_trait::async_trait; +use std::sync::Arc; + +/// Access the path in the router that matches the request. +/// +/// ``` +/// use axum::{ +/// Router, +/// extract::MatchedPath, +/// routing::get, +/// }; +/// +/// let app = Router::new().route( +/// "/users/:id", +/// get(|path: MatchedPath| async move { +/// let path = path.as_str(); +/// // `path` will be "/users/:id" +/// }) +/// ); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// `MatchedPath` can also be accessed from middleware via request extensions. +/// This is useful for example with [`Trace`](tower_http::trace::Trace) to +/// create a span that contains the matched path: +/// +/// ``` +/// use axum::{ +/// Router, +/// extract::MatchedPath, +/// http::Request, +/// routing::get, +/// }; +/// use tower_http::trace::TraceLayer; +/// +/// let app = Router::new() +/// .route("/users/:id", get(|| async { /* ... */ })) +/// .layer( +/// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { +/// let path = if let Some(path) = req.extensions().get::() { +/// path.as_str() +/// } else { +/// req.uri().path() +/// }; +/// tracing::info_span!("http-request", %path) +/// }), +/// ); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Clone, Debug)] +pub struct MatchedPath(pub(crate) Arc); + +impl MatchedPath { + /// Returns a `str` representation of the path. + pub fn as_str(&self) -> &str { + &*self.0 + } +} + +#[async_trait] +impl FromRequest for MatchedPath +where + B: Send, +{ + type Rejection = MatchedPathRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + let extensions = + req.extensions() + .ok_or(MatchedPathRejection::ExtensionsAlreadyExtracted( + ExtensionsAlreadyExtracted, + ))?; + + let matched_path = extensions + .get::() + .ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))? + .clone(); + + Ok(matched_path) + } +} diff --git a/src/extract/mod.rs b/src/extract/mod.rs index 0ab95e09..97869d3a 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -172,6 +172,7 @@ pub mod ws; mod content_length_limit; mod extension; mod form; +mod matched_path; mod path; mod query; mod raw_query; @@ -186,6 +187,7 @@ pub use self::{ extension::Extension, extractor_middleware::extractor_middleware, form::Form, + matched_path::MatchedPath, path::Path, query::Query, raw_query::RawQuery, diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index 856146b3..a65520b8 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -273,6 +273,23 @@ composite_rejection! { } } +define_rejection! { + #[status = INTERNAL_SERVER_ERROR] + #[body = "No matched path found"] + /// Rejection if no matched path could be found. + /// + /// See [`MatchedPath`](super::MatchedPath) for more details. + pub struct MatchedPathMissing; +} + +composite_rejection! { + /// Rejection used for [`MatchedPath`](super::MatchedPath). + pub enum MatchedPathRejection { + ExtensionsAlreadyExtracted, + MatchedPathMissing, + } +} + /// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit). /// /// Contains one variant for each way the diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 3a6403ad..a4ae77d4 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -738,6 +738,11 @@ where let id = *match_.value; req.extensions_mut().insert(id); + if let Some(matched_path) = self.node.paths.get(&id) { + req.extensions_mut() + .insert(crate::extract::MatchedPath(matched_path.clone())); + } + let params = match_ .params .iter() @@ -1059,7 +1064,7 @@ impl Service> for Route { #[derive(Clone, Default)] struct Node { inner: matchit::Node, - paths: Vec<(Arc, RouteId)>, + paths: HashMap>, } impl Node { @@ -1070,12 +1075,12 @@ impl Node { ) -> Result<(), matchit::InsertError> { let path = path.into(); self.inner.insert(&path, val)?; - self.paths.push((path.into(), val)); + self.paths.insert(val, path.into()); Ok(()) } fn merge(&mut self, other: Node) -> Result<(), matchit::InsertError> { - for (path, id) in other.paths { + for (id, path) in other.paths { self.insert(&*path, id)?; } Ok(()) diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 4df09fe0..b7513cf6 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,6 +1,7 @@ #![allow(clippy::blacklisted_name)] use crate::error_handling::HandleErrorLayer; +use crate::extract::MatchedPath; use crate::BoxError; use crate::{ extract::{self, Path}, @@ -27,6 +28,7 @@ use std::{ }; use tower::{service_fn, timeout::TimeoutLayer, ServiceBuilder}; use tower_http::auth::RequireAuthorizationLayer; +use tower_http::trace::TraceLayer; use tower_service::Service; pub(crate) use helpers::*; @@ -618,6 +620,26 @@ async fn with_and_without_trailing_slash() { assert_eq!(res.text().await, "without tsr"); } +#[tokio::test] +async fn access_matched_path() { + let app = Router::new() + .route( + "/:key", + get(|path: MatchedPath| async move { path.as_str().to_string() }), + ) + .layer( + TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { + let path = req.extensions().get::().unwrap().as_str(); + tracing::info_span!("http-request", %path) + }), + ); + + let client = TestClient::new(app); + + let res = client.get("/foo").send().await; + assert_eq!(res.text().await, "/:key"); +} + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} pub(crate) fn assert_unpin() {}