diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 5a327d5f..aaa628e6 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **added:** Add type safe routing. See `axum_extra::routing::typed` for more details ([#756]) - **breaking:** `CachedRejection` has been removed ([#699]) - **breaking:** ` as FromRequest>::Rejection` is now `T::Rejection`. ([#699]) - **breaking:** `middleware::from_fn` has been moved into the main axum crate ([#719]) @@ -14,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#666]: https://github.com/tokio-rs/axum/pull/666 [#699]: https://github.com/tokio-rs/axum/pull/699 [#719]: https://github.com/tokio-rs/axum/pull/719 +[#756]: https://github.com/tokio-rs/axum/pull/756 # 0.1.2 (13. January, 2021) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 7da51d18..a2a14cca 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -11,7 +11,9 @@ repository = "https://github.com/tokio-rs/axum" version = "0.1.2" [features] -erased-json = ["serde", "serde_json"] +default = [] +erased-json = ["serde_json", "serde"] +typed-routing = ["axum-macros", "serde", "percent-encoding"] [dependencies] axum = { path = "../axum", version = "0.4" } @@ -25,11 +27,14 @@ tower-layer = "0.3" tower-service = "0.3" # optional dependencies -serde = { version = "1.0.130", optional = true } +axum-macros = { path = "../axum-macros", version = "0.1", optional = true } +serde = { version = "1.0", optional = true } serde_json = { version = "1.0.71", optional = true } +percent-encoding = { version = "2.1", optional = true } [dev-dependencies] hyper = "0.14" +serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.14", features = ["full"] } tower = { version = "0.4", features = ["util"] } diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index bb16461f..144c137c 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -40,9 +40,24 @@ #![deny(unreachable_pub, private_in_public)] #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] #![forbid(unsafe_code)] -#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(test, allow(clippy::float_cmp))] pub mod extract; pub mod response; pub mod routing; + +#[cfg(feature = "typed-routing")] +#[doc(hidden)] +pub mod __private { + //! _not_ public API + + use percent_encoding::{AsciiSet, CONTROLS}; + + pub use percent_encoding::utf8_percent_encode; + + // from https://github.com/servo/rust-url/blob/master/url/src/parser.rs + const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`'); + const PATH: &AsciiSet = &FRAGMENT.add(b'#').add(b'?').add(b'{').add(b'}'); + pub const PATH_SEGMENT: &AsciiSet = &PATH.add(b'/').add(b'%'); +} diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 112b67f4..884cc484 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -1,11 +1,20 @@ //! Additional types for defining routes. -use axum::{body::Body, Router}; +use axum::{body::Body, handler::Handler, Router}; mod resource; +#[cfg(feature = "typed-routing")] +mod typed; + pub use self::resource::Resource; +#[cfg(feature = "typed-routing")] +pub use axum_macros::TypedPath; + +#[cfg(feature = "typed-routing")] +pub use self::typed::{FirstElementIs, TypedPath}; + /// Extension trait that adds additional methods to [`Router`]. pub trait RouterExt: sealed::Sealed { /// Add the routes from `T`'s [`HasRoutes::routes`] to this router. @@ -32,6 +41,110 @@ pub trait RouterExt: sealed::Sealed { fn with(self, routes: T) -> Self where T: HasRoutes; + + /// Add a typed `GET` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_get(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `DELETE` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_delete(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `HEAD` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_head(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `OPTIONS` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_options(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `PATCH` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_patch(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `POST` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_post(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `PUT` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_put(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; + + /// Add a typed `TRACE` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_trace(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath; } impl RouterExt for Router @@ -44,6 +157,86 @@ where { self.merge(routes.routes()) } + + #[cfg(feature = "typed-routing")] + fn typed_get(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::get(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_delete(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::delete(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_head(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::head(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_options(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::options(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_patch(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::patch(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_post(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::post(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_put(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::put(handler)) + } + + #[cfg(feature = "typed-routing")] + fn typed_trace(self, handler: H) -> Self + where + H: Handler, + T: FirstElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::trace(handler)) + } } /// Trait for things that can provide routes. diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs new file mode 100644 index 00000000..5c855453 --- /dev/null +++ b/axum-extra/src/routing/typed.rs @@ -0,0 +1,196 @@ +use super::sealed::Sealed; + +/// A type safe path. +/// +/// This is used to statically connect a path to its corresponding handler using +/// [`RouterExt::typed_get`], [`RouterExt::typed_post`], etc. +/// +/// # Example +/// +/// ```rust +/// use serde::Deserialize; +/// use axum::{Router, extract::Json}; +/// use axum_extra::routing::{ +/// TypedPath, +/// RouterExt, // for `Router::typed_*` +/// }; +/// +/// // A type safe route with `/users/:id` as its associated path. +/// #[derive(TypedPath, Deserialize)] +/// #[typed_path("/users/:id")] +/// struct UsersMember { +/// id: u32, +/// } +/// +/// // A regular handler function that takes `UsersMember` as the first argument +/// // and thus creates a typed connection between this handler and the `/users/:id` path. +/// // +/// // The `TypedPath` must be the first argument to the function. +/// async fn users_show( +/// UsersMember { id }: UsersMember, +/// ) { +/// // ... +/// } +/// +/// let app = Router::new() +/// // Add our typed route to the router. +/// // +/// // The path will be inferred to `/users/:id` since `users_show`'s +/// // first argument is `UsersMember` which implements `TypedPath` +/// .typed_get(users_show) +/// .typed_post(users_create) +/// .typed_delete(users_destroy); +/// +/// #[derive(TypedPath)] +/// #[typed_path("/users")] +/// struct UsersCollection; +/// +/// #[derive(Deserialize)] +/// struct UsersCreatePayload { /* ... */ } +/// +/// async fn users_create( +/// _: UsersCollection, +/// // Our handlers can accept other extractors. +/// Json(payload): Json, +/// ) { +/// // ... +/// } +/// +/// async fn users_destroy(_: UsersCollection) { /* ... */ } +/// +/// # +/// # let app: Router = app; +/// ``` +/// +/// # Using `#[derive(TypedPath)]` +/// +/// While `TypedPath` can be implemented manually, it's _highly_ recommended to derive it: +/// +/// ``` +/// use serde::Deserialize; +/// use axum_extra::routing::TypedPath; +/// +/// #[derive(TypedPath, Deserialize)] +/// #[typed_path("/users/:id")] +/// struct UsersMember { +/// id: u32, +/// } +/// ``` +/// +/// The macro expands to: +/// +/// - A `TypedPath` implementation. +/// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`], +/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must +/// also implement [`serde::Deserialize`], unless it's a unit struct. +/// - A [`Display`] implementation that interpolates the captures. This can be used to, among other +/// things, create links to known paths and have them verified statically. Note that the +/// [`Display`] implementation for each field must return something that's compatible with its +/// [`Deserialize`] implementation. +/// +/// Additionally the macro will verify the captures in the path matches the fields of the struct. +/// For example this fails to compile since the struct doesn't have a `team_id` field: +/// +/// ```compile_fail +/// use serde::Deserialize; +/// use axum_extra::routing::TypedPath; +/// +/// #[derive(TypedPath, Deserialize)] +/// #[typed_path("/users/:id/teams/:team_id")] +/// struct UsersMember { +/// id: u32, +/// } +/// ``` +/// +/// Unit and tuple structs are also supported: +/// +/// ``` +/// use serde::Deserialize; +/// use axum_extra::routing::TypedPath; +/// +/// #[derive(TypedPath)] +/// #[typed_path("/users")] +/// struct UsersCollection; +/// +/// #[derive(TypedPath, Deserialize)] +/// #[typed_path("/users/:id")] +/// struct UsersMember(u32); +/// ``` +/// +/// ## Percent encoding +/// +/// The generated [`Display`] implementation will automatically percent-encode the arguments: +/// +/// ``` +/// use serde::Deserialize; +/// use axum_extra::routing::TypedPath; +/// +/// #[derive(TypedPath, Deserialize)] +/// #[typed_path("/users/:id")] +/// struct UsersMember { +/// id: String, +/// } +/// +/// assert_eq!( +/// UsersMember { +/// id: "foo bar".to_string(), +/// }.to_string(), +/// "/users/foo%20bar", +/// ); +/// ``` +/// +/// [`FromRequest`]: axum::extract::FromRequest +/// [`RouterExt::typed_get`]: super::RouterExt::typed_get +/// [`RouterExt::typed_post`]: super::RouterExt::typed_post +/// [`Path`]: axum::extract::Path +/// [`Display`]: std::fmt::Display +/// [`Deserialize`]: serde::Deserialize +pub trait TypedPath: std::fmt::Display { + /// The path with optional captures such as `/users/:id`. + const PATH: &'static str; +} + +/// Utility trait used with [`RouterExt`] to ensure the first element of a tuple type is a +/// given type. +/// +/// If you see it in type errors its most likely because the first argument to your handler doesn't +/// implement [`TypedPath`]. +/// +/// You normally shouldn't have to use this trait directly. +/// +/// It is sealed such that it cannot be implemented outside this crate. +/// +/// [`RouterExt`]: super::RouterExt +pub trait FirstElementIs

: Sealed {} + +macro_rules! impl_first_element_is { + ( $($ty:ident),* $(,)? ) => { + impl FirstElementIs

for (P, $($ty,)*) + where + P: TypedPath + {} + + impl Sealed for (P, $($ty,)*) + where + P: TypedPath + {} + }; +} + +impl_first_element_is!(); +impl_first_element_is!(T1); +impl_first_element_is!(T1, T2); +impl_first_element_is!(T1, T2, T3); +impl_first_element_is!(T1, T2, T3, T4); +impl_first_element_is!(T1, T2, T3, T4, T5); +impl_first_element_is!(T1, T2, T3, T4, T5, T6); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); +impl_first_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 22638d5c..da3338d0 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- Add `#[derive(TypedPath)]` for use with axum-extra's new "type safe" routing API ([#756]) # 0.1.0 (31. January, 2022) - Initial release. + +[#756]: https://github.com/tokio-rs/axum/pull/756 diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index 383fbced..4a7e02bd 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -21,6 +21,7 @@ syn = { version = "1.0", features = ["full"] } [dev-dependencies] axum = { path = "../axum", version = "0.4", features = ["headers"] } +axum-extra = { path = "../axum-extra", version = "0.1", features = ["typed-routing"] } rustversion = "1.0" serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index ae826426..22e9a99a 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -49,6 +49,7 @@ use syn::parse::Parse; mod debug_handler; mod from_request; +mod typed_path; /// Derive an implementation of [`FromRequest`]. /// @@ -385,6 +386,16 @@ pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream { return expand_attr_with(_attr, input, debug_handler::expand); } +/// Derive an implementation of [`axum_extra::routing::TypedPath`]. +/// +/// See that trait for more details. +/// +/// [`axum_extra::routing::TypedPath`]: https://docs.rs/axum-extra/latest/axum_extra/routing/trait.TypedPath.html +#[proc_macro_derive(TypedPath, attributes(typed_path))] +pub fn derive_typed_path(input: TokenStream) -> TokenStream { + expand_with(input, typed_path::expand) +} + fn expand_with(input: TokenStream, f: F) -> TokenStream where F: FnOnce(I) -> syn::Result, diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs new file mode 100644 index 00000000..d7b1414f --- /dev/null +++ b/axum-macros/src/typed_path.rs @@ -0,0 +1,324 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; +use syn::{ItemStruct, LitStr}; + +pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { + let ItemStruct { + attrs, + ident, + generics, + fields, + .. + } = &item_struct; + + if !generics.params.is_empty() || generics.where_clause.is_some() { + return Err(syn::Error::new_spanned( + generics, + "`#[derive(TypedPath)]` doesn't support generics", + )); + } + + let Attrs { path } = parse_attrs(attrs)?; + + match fields { + syn::Fields::Named(_) => { + let segments = parse_path(&path)?; + Ok(expand_named_fields(ident, path, &segments)) + } + syn::Fields::Unnamed(fields) => { + let segments = parse_path(&path)?; + expand_unnamed_fields(fields, ident, path, &segments) + } + syn::Fields::Unit => Ok(expand_unit_fields(ident, path)?), + } +} + +struct Attrs { + path: LitStr, +} + +fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result { + let mut path = None; + + for attr in attrs { + if attr.path.is_ident("typed_path") { + if path.is_some() { + return Err(syn::Error::new_spanned( + attr, + "`typed_path` specified more than once", + )); + } else { + path = Some(attr.parse_args()?); + } + } + } + + Ok(Attrs { + path: path.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "missing `#[typed_path(\"...\")]` attribute", + ) + })?, + }) +} + +fn expand_named_fields(ident: &syn::Ident, path: LitStr, segments: &[Segment]) -> TokenStream { + let format_str = format_str_from_path(segments); + let captures = captures_from_path(segments); + + let typed_path_impl = quote_spanned! {path.span()=> + #[automatically_derived] + impl ::axum_extra::routing::TypedPath for #ident { + const PATH: &'static str = #path; + } + }; + + let display_impl = quote_spanned! {path.span()=> + #[automatically_derived] + impl ::std::fmt::Display for #ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + let Self { #(#captures,)* } = self; + write!( + f, + #format_str, + #(#captures = ::axum_extra::__private::utf8_percent_encode(&#captures.to_string(), ::axum_extra::__private::PATH_SEGMENT)),* + ) + } + } + }; + + let from_request_impl = quote! { + #[::axum::async_trait] + #[automatically_derived] + impl ::axum::extract::FromRequest for #ident + where + B: Send, + { + type Rejection = <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection; + + async fn from_request(req: &mut ::axum::extract::RequestParts) -> Result { + ::axum::extract::Path::from_request(req).await.map(|path| path.0) + } + } + }; + + quote! { + #typed_path_impl + #display_impl + #from_request_impl + } +} + +fn expand_unnamed_fields( + fields: &syn::FieldsUnnamed, + ident: &syn::Ident, + path: LitStr, + segments: &[Segment], +) -> syn::Result { + let num_captures = segments + .iter() + .filter(|segment| match segment { + Segment::Capture(_, _) => true, + Segment::Static(_) => false, + }) + .count(); + let num_fields = fields.unnamed.len(); + if num_fields != num_captures { + return Err(syn::Error::new_spanned( + fields, + format!( + "Mismatch in number of captures and fields. Path has {} but struct has {}", + simple_pluralize(num_captures, "capture"), + simple_pluralize(num_fields, "field"), + ), + )); + } + + let destructure_self = segments + .iter() + .filter_map(|segment| match segment { + Segment::Capture(capture, _) => Some(capture), + Segment::Static(_) => None, + }) + .enumerate() + .map(|(idx, capture)| { + let idx = syn::Index { + index: idx as _, + span: Span::call_site(), + }; + let capture = format_ident!("{}", capture, span = path.span()); + quote_spanned! {path.span()=> + #idx: #capture, + } + }); + + let format_str = format_str_from_path(segments); + let captures = captures_from_path(segments); + + let typed_path_impl = quote_spanned! {path.span()=> + #[automatically_derived] + impl ::axum_extra::routing::TypedPath for #ident { + const PATH: &'static str = #path; + } + }; + + let display_impl = quote_spanned! {path.span()=> + #[automatically_derived] + impl ::std::fmt::Display for #ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + let Self { #(#destructure_self)* } = self; + write!( + f, + #format_str, + #(#captures = ::axum_extra::__private::utf8_percent_encode(&#captures.to_string(), ::axum_extra::__private::PATH_SEGMENT)),* + ) + } + } + }; + + let from_request_impl = quote! { + #[::axum::async_trait] + #[automatically_derived] + impl ::axum::extract::FromRequest for #ident + where + B: Send, + { + type Rejection = <::axum::extract::Path as ::axum::extract::FromRequest>::Rejection; + + async fn from_request(req: &mut ::axum::extract::RequestParts) -> Result { + ::axum::extract::Path::from_request(req).await.map(|path| path.0) + } + } + }; + + Ok(quote! { + #typed_path_impl + #display_impl + #from_request_impl + }) +} + +fn simple_pluralize(count: usize, word: &str) -> String { + if count == 1 { + format!("{} {}", count, word) + } else { + format!("{} {}s", count, word) + } +} + +fn expand_unit_fields(ident: &syn::Ident, path: LitStr) -> syn::Result { + for segment in parse_path(&path)? { + match segment { + Segment::Capture(_, span) => { + return Err(syn::Error::new( + span, + "Typed paths for unit structs cannot contain captures", + )); + } + Segment::Static(_) => {} + } + } + + let typed_path_impl = quote_spanned! {path.span()=> + #[automatically_derived] + impl ::axum_extra::routing::TypedPath for #ident { + const PATH: &'static str = #path; + } + }; + + let display_impl = quote_spanned! {path.span()=> + #[automatically_derived] + impl ::std::fmt::Display for #ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, #path) + } + } + }; + + let from_request_impl = quote! { + #[::axum::async_trait] + #[automatically_derived] + impl ::axum::extract::FromRequest for #ident + where + B: Send, + { + type Rejection = ::axum::http::StatusCode; + + async fn from_request(req: &mut ::axum::extract::RequestParts) -> Result { + if req.uri().path() == ::PATH { + Ok(Self) + } else { + Err(::axum::http::StatusCode::NOT_FOUND) + } + } + } + }; + + Ok(quote! { + #typed_path_impl + #display_impl + #from_request_impl + }) +} + +fn format_str_from_path(segments: &[Segment]) -> String { + segments + .iter() + .map(|segment| match segment { + Segment::Capture(capture, _) => format!("{{{}}}", capture), + Segment::Static(segment) => segment.to_owned(), + }) + .collect::>() + .join("/") +} + +fn captures_from_path(segments: &[Segment]) -> Vec { + segments + .iter() + .filter_map(|segment| match segment { + Segment::Capture(capture, span) => Some(format_ident!("{}", capture, span = *span)), + Segment::Static(_) => None, + }) + .collect::>() +} + +fn parse_path(path: &LitStr) -> syn::Result> { + path.value() + .split('/') + .map(|segment| { + if segment.contains('*') { + return Err(syn::Error::new_spanned( + path, + "`typed_path` cannot contain wildcards", + )); + } + + if let Some(capture) = segment.strip_prefix(':') { + Ok(Segment::Capture(capture.to_owned(), path.span())) + } else { + Ok(Segment::Static(segment.to_owned())) + } + }) + .collect() +} + +enum Segment { + Capture(String, Span), + Static(String), +} + +#[test] +fn ui() { + #[rustversion::stable] + fn go() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/typed_path/fail/*.rs"); + t.pass("tests/typed_path/pass/*.rs"); + } + + #[rustversion::not(stable)] + fn go() {} + + go(); +} diff --git a/axum-macros/tests/typed_path/fail/missing_capture.rs b/axum-macros/tests/typed_path/fail/missing_capture.rs new file mode 100644 index 00000000..8ecf7d45 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/missing_capture.rs @@ -0,0 +1,10 @@ +use axum_macros::TypedPath; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/users")] +struct MyPath { + id: u32, +} + +fn main() {} diff --git a/axum-macros/tests/typed_path/fail/missing_capture.stderr b/axum-macros/tests/typed_path/fail/missing_capture.stderr new file mode 100644 index 00000000..85865a52 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/missing_capture.stderr @@ -0,0 +1,14 @@ +error[E0027]: pattern does not mention field `id` + --> tests/typed_path/fail/missing_capture.rs:5:14 + | +5 | #[typed_path("/users")] + | ^^^^^^^^ missing field `id` + | +help: include the missing field in the pattern + | +5 | #[typed_path("/users" { id })] + | ++++++ +help: if you don't care about this missing field, you can explicitly ignore it + | +5 | #[typed_path("/users" { .. })] + | ++++++ diff --git a/axum-macros/tests/typed_path/fail/missing_field.rs b/axum-macros/tests/typed_path/fail/missing_field.rs new file mode 100644 index 00000000..2e211769 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/missing_field.rs @@ -0,0 +1,9 @@ +use axum_macros::TypedPath; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/users/:id")] +struct MyPath {} + +fn main() { +} diff --git a/axum-macros/tests/typed_path/fail/missing_field.stderr b/axum-macros/tests/typed_path/fail/missing_field.stderr new file mode 100644 index 00000000..faf2d4b6 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/missing_field.stderr @@ -0,0 +1,5 @@ +error[E0026]: struct `MyPath` does not have a field named `id` + --> tests/typed_path/fail/missing_field.rs:5:14 + | +5 | #[typed_path("/users/:id")] + | ^^^^^^^^^^^^ struct `MyPath` does not have this field diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.rs b/axum-macros/tests/typed_path/fail/not_deserialize.rs new file mode 100644 index 00000000..b5691866 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/not_deserialize.rs @@ -0,0 +1,9 @@ +use axum_macros::TypedPath; + +#[derive(TypedPath)] +#[typed_path("/users/:id")] +struct MyPath { + id: u32, +} + +fn main() {} diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr new file mode 100644 index 00000000..3c6f6258 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -0,0 +1,9 @@ +error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is not satisfied + --> tests/typed_path/fail/not_deserialize.rs:3:10 + | +3 | #[derive(TypedPath)] + | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath` + | + = note: required because of the requirements on the impl of `serde::de::DeserializeOwned` for `MyPath` + = note: required because of the requirements on the impl of `FromRequest` for `axum::extract::Path` + = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/typed_path/fail/unit_with_capture.rs b/axum-macros/tests/typed_path/fail/unit_with_capture.rs new file mode 100644 index 00000000..49979cf7 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/unit_with_capture.rs @@ -0,0 +1,8 @@ +use axum_macros::TypedPath; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/users/:id")] +struct MyPath; + +fn main() {} diff --git a/axum-macros/tests/typed_path/fail/unit_with_capture.stderr b/axum-macros/tests/typed_path/fail/unit_with_capture.stderr new file mode 100644 index 00000000..d290308c --- /dev/null +++ b/axum-macros/tests/typed_path/fail/unit_with_capture.stderr @@ -0,0 +1,5 @@ +error: Typed paths for unit structs cannot contain captures + --> tests/typed_path/fail/unit_with_capture.rs:5:14 + | +5 | #[typed_path("/users/:id")] + | ^^^^^^^^^^^^ diff --git a/axum-macros/tests/typed_path/fail/wildcard.rs b/axum-macros/tests/typed_path/fail/wildcard.rs new file mode 100644 index 00000000..f9d64555 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/wildcard.rs @@ -0,0 +1,7 @@ +use axum_extra::routing::TypedPath; + +#[derive(TypedPath)] +#[typed_path("/users/*rest")] +struct MyPath; + +fn main() {} diff --git a/axum-macros/tests/typed_path/fail/wildcard.stderr b/axum-macros/tests/typed_path/fail/wildcard.stderr new file mode 100644 index 00000000..bf897f40 --- /dev/null +++ b/axum-macros/tests/typed_path/fail/wildcard.stderr @@ -0,0 +1,5 @@ +error: `typed_path` cannot contain wildcards + --> tests/typed_path/fail/wildcard.rs:4:14 + | +4 | #[typed_path("/users/*rest")] + | ^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/typed_path/pass/named_fields_struct.rs b/axum-macros/tests/typed_path/pass/named_fields_struct.rs new file mode 100644 index 00000000..6942bd33 --- /dev/null +++ b/axum-macros/tests/typed_path/pass/named_fields_struct.rs @@ -0,0 +1,25 @@ +use axum_extra::routing::TypedPath; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/users/:user_id/teams/:team_id")] +struct MyPath { + user_id: u32, + team_id: u32, +} + +fn main() { + axum::Router::::new().route("/", axum::routing::get(|_: MyPath| async {})); + + assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); + assert_eq!( + format!( + "{}", + MyPath { + user_id: 1, + team_id: 2 + } + ), + "/users/1/teams/2" + ); +} diff --git a/axum-macros/tests/typed_path/pass/tuple_struct.rs b/axum-macros/tests/typed_path/pass/tuple_struct.rs new file mode 100644 index 00000000..a0b2e609 --- /dev/null +++ b/axum-macros/tests/typed_path/pass/tuple_struct.rs @@ -0,0 +1,13 @@ +use axum_extra::routing::TypedPath; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/users/:user_id/teams/:team_id")] +struct MyPath(u32, u32); + +fn main() { + axum::Router::::new().route("/", axum::routing::get(|_: MyPath| async {})); + + assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); + assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2"); +} diff --git a/axum-macros/tests/typed_path/pass/unit_struct.rs b/axum-macros/tests/typed_path/pass/unit_struct.rs new file mode 100644 index 00000000..9b6a0f6e --- /dev/null +++ b/axum-macros/tests/typed_path/pass/unit_struct.rs @@ -0,0 +1,13 @@ +use axum_extra::routing::TypedPath; + +#[derive(TypedPath)] +#[typed_path("/users")] +struct MyPath; + +fn main() { + axum::Router::::new() + .route("/", axum::routing::get(|_: MyPath| async {})); + + assert_eq!(MyPath::PATH, "/users"); + assert_eq!(format!("{}", MyPath), "/users"); +} diff --git a/axum-macros/tests/typed_path/pass/url_encoding.rs b/axum-macros/tests/typed_path/pass/url_encoding.rs new file mode 100644 index 00000000..db1c3700 --- /dev/null +++ b/axum-macros/tests/typed_path/pass/url_encoding.rs @@ -0,0 +1,32 @@ +use axum_extra::routing::TypedPath; +use serde::Deserialize; + +#[derive(TypedPath, Deserialize)] +#[typed_path("/:param")] +struct Named { + param: String, +} + +#[derive(TypedPath, Deserialize)] +#[typed_path("/:param")] +struct Unnamed(String); + +fn main() { + assert_eq!( + format!( + "{}", + Named { + param: "a b".to_string() + } + ), + "/a%20b" + ); + + assert_eq!( + format!( + "{}", + Unnamed("a b".to_string()), + ), + "/a%20b" + ); +} diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 7d0d2660..bf2f5663 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -60,7 +60,6 @@ impl RouteId { } /// The router type for composing handlers and services. -#[derive(Debug)] pub struct Router { routes: HashMap>, node: Node, @@ -88,6 +87,17 @@ where } } +impl fmt::Debug for Router { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Router") + .field("routes", &self.routes) + .field("node", &self.node) + .field("fallback", &self.fallback) + .field("nested_at_root", &self.nested_at_root) + .finish() + } +} + pub(crate) const NEST_TAIL_PARAM: &str = "axum_nest"; const NEST_TAIL_PARAM_CAPTURE: &str = "/*axum_nest";