diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 0af5e8f4..c1b438d9 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -7,7 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **added:** In `debug_handler`, check if `Request` is used as non-final extractor ([#1035]) +- **added:** In `debug_handler`, check if multiple `Path` extractors are used ([#1035]) + +[#1035]: https://github.com/tokio-rs/axum/pull/1035 # 0.2.1 (10. May, 2022) diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index 73740472..96f6d91f 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -24,5 +24,6 @@ axum = { path = "../axum", version = "0.5", features = ["headers"] } axum-extra = { path = "../axum-extra", version = "0.3", features = ["typed-routing"] } rustversion = "1.0" serde = { version = "1.0", features = ["derive"] } +syn = { version = "1.0", features = ["full", "extra-traits"] } tokio = { version = "1.0", features = ["full"] } trybuild = "1.0" diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index e55ed081..3404cda6 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -2,21 +2,24 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type}; -pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> syn::Result { - check_extractor_count(&item_fn)?; +pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { + let check_extractor_count = check_extractor_count(&item_fn); + let check_request_last_extractor = check_request_last_extractor(&item_fn); + let check_path_extractor = check_path_extractor(&item_fn); let check_inputs_impls_from_request = check_inputs_impls_from_request(&item_fn, &attr.body_ty); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); let check_future_send = check_future_send(&item_fn); - let tokens = quote! { + quote! { #item_fn + #check_extractor_count + #check_request_last_extractor + #check_path_extractor #check_inputs_impls_from_request #check_output_impls_into_response #check_future_send - }; - - Ok(tokens) + } } pub(crate) struct Attrs { @@ -45,18 +48,79 @@ impl Parse for Attrs { } } -fn check_extractor_count(item_fn: &ItemFn) -> syn::Result<()> { +fn check_extractor_count(item_fn: &ItemFn) -> Option { let max_extractors = 16; if item_fn.sig.inputs.len() <= max_extractors { - Ok(()) + None } else { - Err(syn::Error::new_spanned( - &item_fn.sig.inputs, - format!( - "Handlers cannot take more than {} arguments. Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", - max_extractors, + let error_message = format!( + "Handlers cannot take more than {} arguments. \ + Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", + max_extractors, + ); + let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error(); + Some(error) + } +} + +fn extractor_idents(item_fn: &ItemFn) -> impl Iterator { + item_fn + .sig + .inputs + .iter() + .enumerate() + .filter_map(|(idx, fn_arg)| match fn_arg { + FnArg::Receiver(_) => None, + FnArg::Typed(pat_type) => { + if let Type::Path(type_path) = &*pat_type.ty { + type_path + .path + .segments + .last() + .map(|segment| (idx, fn_arg, &segment.ident)) + } else { + None + } + } + }) +} + +fn check_request_last_extractor(item_fn: &ItemFn) -> Option { + let request_extractor_ident = + extractor_idents(item_fn).find(|(_, _, ident)| *ident == "Request"); + + if let Some((idx, fn_arg, _)) = request_extractor_ident { + if idx != item_fn.sig.inputs.len() - 1 { + return Some( + syn::Error::new_spanned(fn_arg, "`Request` extractor should always be last") + .to_compile_error(), + ); + } + } + + None +} + +fn check_path_extractor(item_fn: &ItemFn) -> TokenStream { + let path_extractors = extractor_idents(item_fn) + .filter(|(_, _, ident)| *ident == "Path") + .collect::>(); + + if path_extractors.len() > 1 { + path_extractors + .into_iter() + .map(|(_, arg, _)| { + syn::Error::new_spanned( + arg, + "Multiple parameters must be extracted with a tuple \ + `Path<(_, _)>` or a struct `Path`, not by applying \ + multiple `Path<_>` extractors", ) - )) + .to_compile_error() + }) + .collect() + } else { + quote! {} } } diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 22e9a99a..a44a14db 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -407,7 +407,7 @@ where fn expand_attr_with(attr: TokenStream, input: TokenStream, f: F) -> TokenStream where - F: FnOnce(A, I) -> syn::Result, + F: FnOnce(A, I) -> K, A: Parse, I: Parse, K: ToTokens, @@ -415,7 +415,7 @@ where let expand_result = (|| { let attr = syn::parse(attr)?; let input = syn::parse(input)?; - f(attr, input) + Ok(f(attr, input)) })(); expand(expand_result) } diff --git a/axum-macros/tests/debug_handler/fail/multiple_paths.rs b/axum-macros/tests/debug_handler/fail/multiple_paths.rs new file mode 100644 index 00000000..0f4fd90e --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/multiple_paths.rs @@ -0,0 +1,7 @@ +use axum::extract::Path; +use axum_macros::debug_handler; + +#[debug_handler] +async fn handler(_: Path, _: Path) {} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/multiple_paths.stderr b/axum-macros/tests/debug_handler/fail/multiple_paths.stderr new file mode 100644 index 00000000..3eb57c9e --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/multiple_paths.stderr @@ -0,0 +1,11 @@ +error: Multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`, not by applying multiple `Path<_>` extractors + --> tests/debug_handler/fail/multiple_paths.rs:5:18 + | +5 | async fn handler(_: Path, _: Path) {} + | ^^^^^^^^^^^^^^^ + +error: Multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`, not by applying multiple `Path<_>` extractors + --> tests/debug_handler/fail/multiple_paths.rs:5:35 + | +5 | async fn handler(_: Path, _: Path) {} + | ^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/request_not_last.rs b/axum-macros/tests/debug_handler/fail/request_not_last.rs new file mode 100644 index 00000000..153d35ef --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/request_not_last.rs @@ -0,0 +1,7 @@ +use axum::{body::Body, extract::Extension, http::Request}; +use axum_macros::debug_handler; + +#[debug_handler] +async fn handler(_: Request, _: Extension) {} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/request_not_last.stderr b/axum-macros/tests/debug_handler/fail/request_not_last.stderr new file mode 100644 index 00000000..a3482e64 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/request_not_last.stderr @@ -0,0 +1,5 @@ +error: `Request` extractor should always be last + --> tests/debug_handler/fail/request_not_last.rs:5:18 + | +5 | async fn handler(_: Request, _: Extension) {} + | ^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/pass/request_last.rs b/axum-macros/tests/debug_handler/pass/request_last.rs new file mode 100644 index 00000000..bbfb53d2 --- /dev/null +++ b/axum-macros/tests/debug_handler/pass/request_last.rs @@ -0,0 +1,7 @@ +use axum::{body::Body, extract::Extension, http::Request}; +use axum_macros::debug_handler; + +#[debug_handler] +async fn handler(_: Extension, _: Request) {} + +fn main() {}