mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 23:34:47 +00:00
Support changing state type in #[debug_handler]
(#1271)
* support setting body type for #[debug_handler] * Use lookahead1 to give better errors and detect duplicate arguments * fix docs link
This commit is contained in:
parent
e7f1c88cd4
commit
568394a28e
@ -18,7 +18,11 @@ proc-macro = true
|
|||||||
heck = "0.4"
|
heck = "0.4"
|
||||||
proc-macro2 = "1.0"
|
proc-macro2 = "1.0"
|
||||||
quote = "1.0"
|
quote = "1.0"
|
||||||
syn = { version = "1.0", features = ["full"] }
|
syn = { version = "1.0", features = [
|
||||||
|
"full",
|
||||||
|
# needed for `Hash` impls
|
||||||
|
"extra-traits",
|
||||||
|
] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
axum = { path = "../axum", version = "0.5", features = ["headers"] }
|
axum = { path = "../axum", version = "0.5", features = ["headers"] }
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use proc_macro2::TokenStream;
|
use proc_macro2::TokenStream;
|
||||||
use quote::{format_ident, quote, quote_spanned};
|
use quote::{format_ident, quote, quote_spanned};
|
||||||
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
||||||
|
|
||||||
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
pub(crate) fn expand(mut attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
||||||
let check_extractor_count = check_extractor_count(&item_fn);
|
let check_extractor_count = check_extractor_count(&item_fn);
|
||||||
let check_request_last_extractor = check_request_last_extractor(&item_fn);
|
let check_request_last_extractor = check_request_last_extractor(&item_fn);
|
||||||
let check_path_extractor = check_path_extractor(&item_fn);
|
let check_path_extractor = check_path_extractor(&item_fn);
|
||||||
@ -12,8 +14,14 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
|||||||
// If the function is generic, we can't reliably check its inputs or whether the future it
|
// If the function is generic, we can't reliably check its inputs or whether the future it
|
||||||
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
|
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
|
||||||
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
|
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
|
||||||
|
if attr.state_ty.is_none() {
|
||||||
|
attr.state_ty = state_type_from_args(&item_fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
let state_ty = attr.state_ty.unwrap_or_else(|| syn::parse_quote!(()));
|
||||||
|
|
||||||
let check_inputs_impls_from_request =
|
let check_inputs_impls_from_request =
|
||||||
check_inputs_impls_from_request(&item_fn, &attr.body_ty);
|
check_inputs_impls_from_request(&item_fn, &attr.body_ty, state_ty);
|
||||||
let check_future_send = check_future_send(&item_fn);
|
let check_future_send = check_future_send(&item_fn);
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
@ -39,21 +47,46 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod kw {
|
||||||
|
syn::custom_keyword!(body);
|
||||||
|
syn::custom_keyword!(state);
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) struct Attrs {
|
pub(crate) struct Attrs {
|
||||||
body_ty: Type,
|
body_ty: Type,
|
||||||
|
state_ty: Option<Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Parse for Attrs {
|
impl Parse for Attrs {
|
||||||
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
|
||||||
let mut body_ty = None;
|
let mut body_ty = None;
|
||||||
|
let mut state_ty = None;
|
||||||
|
|
||||||
while !input.is_empty() {
|
while !input.is_empty() {
|
||||||
let ident = input.parse::<syn::Ident>()?;
|
let lh = input.lookahead1();
|
||||||
if ident == "body" {
|
|
||||||
|
if lh.peek(kw::body) {
|
||||||
|
let kw = input.parse::<kw::body>()?;
|
||||||
|
if body_ty.is_some() {
|
||||||
|
return Err(syn::Error::new_spanned(
|
||||||
|
kw,
|
||||||
|
"`body` specified more than once",
|
||||||
|
));
|
||||||
|
}
|
||||||
input.parse::<Token![=]>()?;
|
input.parse::<Token![=]>()?;
|
||||||
body_ty = Some(input.parse()?);
|
body_ty = Some(input.parse()?);
|
||||||
|
} else if lh.peek(kw::state) {
|
||||||
|
let kw = input.parse::<kw::state>()?;
|
||||||
|
if state_ty.is_some() {
|
||||||
|
return Err(syn::Error::new_spanned(
|
||||||
|
kw,
|
||||||
|
"`state` specified more than once",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
input.parse::<Token![=]>()?;
|
||||||
|
state_ty = Some(input.parse()?);
|
||||||
} else {
|
} else {
|
||||||
return Err(syn::Error::new_spanned(ident, "unknown argument"));
|
return Err(lh.error());
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = input.parse::<Token![,]>();
|
let _ = input.parse::<Token![,]>();
|
||||||
@ -61,7 +94,7 @@ impl Parse for Attrs {
|
|||||||
|
|
||||||
let body_ty = body_ty.unwrap_or_else(|| syn::parse_quote!(axum::body::Body));
|
let body_ty = body_ty.unwrap_or_else(|| syn::parse_quote!(axum::body::Body));
|
||||||
|
|
||||||
Ok(Self { body_ty })
|
Ok(Self { body_ty, state_ty })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,7 +200,11 @@ fn check_multiple_body_extractors(item_fn: &ItemFn) -> TokenStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStream {
|
fn check_inputs_impls_from_request(
|
||||||
|
item_fn: &ItemFn,
|
||||||
|
body_ty: &Type,
|
||||||
|
state_ty: Type,
|
||||||
|
) -> TokenStream {
|
||||||
item_fn
|
item_fn
|
||||||
.sig
|
.sig
|
||||||
.inputs
|
.inputs
|
||||||
@ -203,7 +240,7 @@ fn check_inputs_impls_from_request(item_fn: &ItemFn, body_ty: &Type) -> TokenStr
|
|||||||
#[allow(warnings)]
|
#[allow(warnings)]
|
||||||
fn #name()
|
fn #name()
|
||||||
where
|
where
|
||||||
#ty: ::axum::extract::FromRequest<(), #body_ty> + Send,
|
#ty: ::axum::extract::FromRequest<#state_ty, #body_ty> + Send,
|
||||||
{}
|
{}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -371,6 +408,68 @@ fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Given a signature like
|
||||||
|
///
|
||||||
|
/// ```skip
|
||||||
|
/// #[debug_handler]
|
||||||
|
/// async fn handler(
|
||||||
|
/// _: axum::extract::State<AppState>,
|
||||||
|
/// _: State<AppState>,
|
||||||
|
/// ) {}
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// This will extract `AppState`.
|
||||||
|
///
|
||||||
|
/// Returns `None` if there are no `State` args or multiple of different types.
|
||||||
|
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
||||||
|
let state_inputs = item_fn
|
||||||
|
.sig
|
||||||
|
.inputs
|
||||||
|
.iter()
|
||||||
|
.filter_map(|input| match input {
|
||||||
|
FnArg::Receiver(_) => None,
|
||||||
|
FnArg::Typed(pat_type) => Some(pat_type),
|
||||||
|
})
|
||||||
|
.map(|pat_type| &pat_type.ty)
|
||||||
|
.filter_map(|ty| {
|
||||||
|
if let Type::Path(path) = &**ty {
|
||||||
|
Some(&path.path)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter_map(|path| {
|
||||||
|
if let Some(last_segment) = path.segments.last() {
|
||||||
|
if last_segment.ident != "State" {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
match &last_segment.arguments {
|
||||||
|
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
|
||||||
|
Some(args.args.first().unwrap())
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter_map(|generic_arg| {
|
||||||
|
if let syn::GenericArgument::Type(ty) = generic_arg {
|
||||||
|
Some(ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
if state_inputs.len() == 1 {
|
||||||
|
state_inputs.iter().next().map(|&ty| ty.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ui() {
|
fn ui() {
|
||||||
#[rustversion::stable]
|
#[rustversion::stable]
|
||||||
|
@ -513,12 +513,60 @@ pub fn derive_from_request(item: TokenStream) -> TokenStream {
|
|||||||
/// async fn handler(request: Request<BoxBody>) {}
|
/// async fn handler(request: Request<BoxBody>) {}
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
|
/// # Changing state type
|
||||||
|
///
|
||||||
|
/// By default `#[debug_handler]` assumes your state type is `()` unless your handler has a
|
||||||
|
/// [`axum::extract::State`] argument:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use axum::extract::State;
|
||||||
|
/// # use axum_macros::debug_handler;
|
||||||
|
///
|
||||||
|
/// #[debug_handler]
|
||||||
|
/// async fn handler(
|
||||||
|
/// // this makes `#[debug_handler]` use `AppState`
|
||||||
|
/// State(state): State<AppState>,
|
||||||
|
/// ) {}
|
||||||
|
///
|
||||||
|
/// #[derive(Clone)]
|
||||||
|
/// struct AppState {}
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// If your handler takes multiple [`axum::extract::State`] arguments or you need to otherwise
|
||||||
|
/// customize the state type you can set it with `#[debug_handler(state = ...)]`:
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use axum::extract::{State, FromRef};
|
||||||
|
/// # use axum_macros::debug_handler;
|
||||||
|
///
|
||||||
|
/// #[debug_handler(state = AppState)]
|
||||||
|
/// async fn handler(
|
||||||
|
/// State(app_state): State<AppState>,
|
||||||
|
/// State(inner_state): State<InnerState>,
|
||||||
|
/// ) {}
|
||||||
|
///
|
||||||
|
/// #[derive(Clone)]
|
||||||
|
/// struct AppState {
|
||||||
|
/// inner: InnerState,
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// #[derive(Clone)]
|
||||||
|
/// struct InnerState {}
|
||||||
|
///
|
||||||
|
/// impl FromRef<AppState> for InnerState {
|
||||||
|
/// fn from_ref(state: &AppState) -> Self {
|
||||||
|
/// state.inner.clone()
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
/// # Performance
|
/// # Performance
|
||||||
///
|
///
|
||||||
/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
|
/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`)
|
||||||
///
|
///
|
||||||
/// [`axum`]: https://docs.rs/axum/latest
|
/// [`axum`]: https://docs.rs/axum/latest
|
||||||
/// [`Handler`]: https://docs.rs/axum/latest/axum/handler/trait.Handler.html
|
/// [`Handler`]: https://docs.rs/axum/latest/axum/handler/trait.Handler.html
|
||||||
|
/// [`axum::extract::State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
|
||||||
/// [`debug_handler`]: macro@debug_handler
|
/// [`debug_handler`]: macro@debug_handler
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||||
|
9
axum-macros/tests/debug_handler/fail/duplicate_args.rs
Normal file
9
axum-macros/tests/debug_handler/fail/duplicate_args.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
use axum_macros::debug_handler;
|
||||||
|
|
||||||
|
#[debug_handler(body = BoxBody, body = BoxBody)]
|
||||||
|
async fn handler() {}
|
||||||
|
|
||||||
|
#[debug_handler(state = (), state = ())]
|
||||||
|
async fn handler_2() {}
|
||||||
|
|
||||||
|
fn main() {}
|
11
axum-macros/tests/debug_handler/fail/duplicate_args.stderr
Normal file
11
axum-macros/tests/debug_handler/fail/duplicate_args.stderr
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
error: `body` specified more than once
|
||||||
|
--> tests/debug_handler/fail/duplicate_args.rs:3:33
|
||||||
|
|
|
||||||
|
3 | #[debug_handler(body = BoxBody, body = BoxBody)]
|
||||||
|
| ^^^^
|
||||||
|
|
||||||
|
error: `state` specified more than once
|
||||||
|
--> tests/debug_handler/fail/duplicate_args.rs:6:29
|
||||||
|
|
|
||||||
|
6 | #[debug_handler(state = (), state = ())]
|
||||||
|
| ^^^^^
|
@ -1,4 +1,4 @@
|
|||||||
error: unknown argument
|
error: expected `body` or `state`
|
||||||
--> tests/debug_handler/fail/invalid_attrs.rs:3:17
|
--> tests/debug_handler/fail/invalid_attrs.rs:3:17
|
||||||
|
|
|
|
||||||
3 | #[debug_handler(foo)]
|
3 | #[debug_handler(foo)]
|
||||||
|
31
axum-macros/tests/debug_handler/pass/infer_state.rs
Normal file
31
axum-macros/tests/debug_handler/pass/infer_state.rs
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
use axum_macros::debug_handler;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
#[debug_handler]
|
||||||
|
async fn handler(_: State<AppState>) {}
|
||||||
|
|
||||||
|
#[debug_handler]
|
||||||
|
async fn handler_2(_: axum::extract::State<AppState>) {}
|
||||||
|
|
||||||
|
#[debug_handler]
|
||||||
|
async fn handler_3(
|
||||||
|
_: axum::extract::State<AppState>,
|
||||||
|
_: axum::extract::State<AppState>,
|
||||||
|
) {}
|
||||||
|
|
||||||
|
#[debug_handler]
|
||||||
|
async fn handler_4(
|
||||||
|
_: State<AppState>,
|
||||||
|
_: State<AppState>,
|
||||||
|
) {}
|
||||||
|
|
||||||
|
#[debug_handler]
|
||||||
|
async fn handler_5(
|
||||||
|
_: axum::extract::State<AppState>,
|
||||||
|
_: State<AppState>,
|
||||||
|
) {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState;
|
||||||
|
|
||||||
|
fn main() {}
|
27
axum-macros/tests/debug_handler/pass/set_state.rs
Normal file
27
axum-macros/tests/debug_handler/pass/set_state.rs
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
use axum_macros::debug_handler;
|
||||||
|
use axum::extract::{FromRef, FromRequest, RequestParts};
|
||||||
|
use axum::async_trait;
|
||||||
|
|
||||||
|
#[debug_handler(state = AppState)]
|
||||||
|
async fn handler(_: A) {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState;
|
||||||
|
|
||||||
|
struct A;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S, B> FromRequest<S, B> for A
|
||||||
|
where
|
||||||
|
B: Send,
|
||||||
|
S: Send + Sync,
|
||||||
|
AppState: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = ();
|
||||||
|
|
||||||
|
async fn from_request(_req: &mut RequestParts<S, B>) -> Result<Self, Self::Rejection> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
10
axum-macros/tests/debug_handler/pass/state_and_body.rs
Normal file
10
axum-macros/tests/debug_handler/pass/state_and_body.rs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
use axum_macros::debug_handler;
|
||||||
|
use axum::{body::BoxBody, extract::State, http::Request};
|
||||||
|
|
||||||
|
#[debug_handler(state = AppState, body = BoxBody)]
|
||||||
|
async fn handler(_: State<AppState>, _: Request<BoxBody>) {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState;
|
||||||
|
|
||||||
|
fn main() {}
|
Loading…
x
Reference in New Issue
Block a user