diff --git a/axum-macros/src/attr_parsing.rs b/axum-macros/src/attr_parsing.rs index 1f2f16f0..1ee76de9 100644 --- a/axum-macros/src/attr_parsing.rs +++ b/axum-macros/src/attr_parsing.rs @@ -1,5 +1,8 @@ use quote::ToTokens; -use syn::parse::{Parse, ParseStream}; +use syn::{ + parse::{Parse, ParseStream}, + Token, +}; pub(crate) fn parse_parenthesized_attribute( input: ParseStream, @@ -26,6 +29,29 @@ where Ok(()) } +pub(crate) fn parse_assignment_attribute( + input: ParseStream, + out: &mut Option<(K, T)>, +) -> syn::Result<()> +where + K: Parse + ToTokens, + T: Parse, +{ + let kw = input.parse()?; + input.parse::()?; + let inner = input.parse()?; + + if out.is_some() { + let kw_name = std::any::type_name::().split("::").last().unwrap(); + let msg = format!("`{}` specified more than once", kw_name); + return Err(syn::Error::new_spanned(kw, msg)); + } + + *out = Some((kw, inner)); + + Ok(()) +} + pub(crate) trait Combine: Sized { fn combine(self, other: Self) -> syn::Result; } diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 8d69129a..4ed0a5a2 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,10 +1,21 @@ -use crate::with_position::{Position, WithPosition}; +use crate::{ + attr_parsing::{parse_assignment_attribute, second}, + with_position::{Position, WithPosition}, +}; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use std::collections::HashSet; -use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type}; +use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type}; + +pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { + let Attrs { body_ty, state_ty } = attr; + + let body_ty = body_ty + .map(second) + .unwrap_or_else(|| parse_quote!(axum::body::Body)); + + let mut state_ty = state_ty.map(second); -pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream { let check_extractor_count = check_extractor_count(&item_fn); let check_path_extractor = check_path_extractor(&item_fn); let check_output_impls_into_response = check_output_impls_into_response(&item_fn); @@ -12,14 +23,14 @@ pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream { // If the function is generic, we can't reliably check its inputs or whether the future it // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { - if attr.state_ty.is_none() { - attr.state_ty = state_type_from_args(&item_fn); + if state_ty.is_none() { + state_ty = state_type_from_args(&item_fn); } - let state_ty = attr.state_ty.unwrap_or_else(|| syn::parse_quote!(())); + let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); let check_inputs_impls_from_request = - check_inputs_impls_from_request(&item_fn, &attr.body_ty, state_ty); + check_inputs_impls_from_request(&item_fn, &body_ty, state_ty); let check_future_send = check_future_send(&item_fn); quote! { @@ -49,8 +60,8 @@ mod kw { } pub(crate) struct Attrs { - body_ty: Type, - state_ty: Option, + body_ty: Option<(kw::body, Type)>, + state_ty: Option<(kw::state, Type)>, } impl Parse for Attrs { @@ -60,27 +71,10 @@ impl Parse for Attrs { while !input.is_empty() { let lh = input.lookahead1(); - if lh.peek(kw::body) { - let kw = input.parse::()?; - if body_ty.is_some() { - return Err(syn::Error::new_spanned( - kw, - "`body` specified more than once", - )); - } - input.parse::()?; - body_ty = Some(input.parse()?); + parse_assignment_attribute(input, &mut body_ty)?; } else if lh.peek(kw::state) { - let kw = input.parse::()?; - if state_ty.is_some() { - return Err(syn::Error::new_spanned( - kw, - "`state` specified more than once", - )); - } - input.parse::()?; - state_ty = Some(input.parse()?); + parse_assignment_attribute(input, &mut state_ty)?; } else { return Err(lh.error()); } @@ -88,8 +82,6 @@ impl Parse for Attrs { let _ = input.parse::(); } - let body_ty = body_ty.unwrap_or_else(|| syn::parse_quote!(axum::body::Body)); - Ok(Self { body_ty, state_ty }) } } diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index ede7a581..6a50886f 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -2,6 +2,8 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, ItemStruct, LitStr, Token}; +use crate::attr_parsing::{combine_attribute, parse_parenthesized_attribute, second, Combine}; + pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { let ItemStruct { attrs, @@ -18,7 +20,16 @@ pub(crate) fn expand(item_struct: ItemStruct) -> syn::Result { )); } - let Attrs { path, rejection } = parse_attrs(attrs)?; + let Attrs { path, rejection } = crate::attr_parsing::parse_attrs("typed_path", attrs)?; + + let path = path.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "Missing path: `#[typed_path(\"/foo/bar\")]`", + ) + })?; + + let rejection = rejection.map(second); match fields { syn::Fields::Named(_) => { @@ -37,52 +48,49 @@ mod kw { syn::custom_keyword!(rejection); } +#[derive(Default)] struct Attrs { - path: LitStr, - rejection: Option, + path: Option, + rejection: Option<(kw::rejection, syn::Path)>, } impl Parse for Attrs { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let path = input.parse()?; + let mut path = None; + let mut rejection = None; - let rejection = if input.is_empty() { - None - } else { - let _: Token![,] = input.parse()?; - let _: kw::rejection = input.parse()?; + while !input.is_empty() { + let lh = input.lookahead1(); + if lh.peek(LitStr) { + path = Some(input.parse()?); + } else if lh.peek(kw::rejection) { + parse_parenthesized_attribute(input, &mut rejection)?; + } else { + return Err(lh.error()); + } - let content; - syn::parenthesized!(content in input); - Some(content.parse()?) - }; + let _ = input.parse::(); + } Ok(Self { path, rejection }) } } -fn parse_attrs(attrs: &[syn::Attribute]) -> syn::Result { - let mut out = None; - - for attr in attrs { - if attr.path.is_ident("typed_path") { - if out.is_some() { +impl Combine for Attrs { + fn combine(mut self, other: Self) -> syn::Result { + let Self { path, rejection } = other; + if let Some(path) = path { + if self.path.is_some() { return Err(syn::Error::new_spanned( - attr, - "`typed_path` specified more than once", + path, + "path specified more than once", )); - } else { - out = Some(attr.parse_args()?); } + self.path = Some(path); } + combine_attribute(&mut self.rejection, rejection)?; + Ok(self) } - - out.ok_or_else(|| { - syn::Error::new( - Span::call_site(), - "missing `#[typed_path(\"...\")]` attribute", - ) - }) } fn expand_named_fields(