mirror of
https://github.com/tokio-rs/axum.git
synced 2025-09-28 21:40:55 +00:00
Improve the error message for state type inference failure in FromRequest(Parts) derive macro (#1432)
* Add a dedicated error message for state type inference issues * Generate valid code even if state type can't be inferred * Also error on state type inference for debug_handler
This commit is contained in:
parent
ee0b71a4ac
commit
7cbacd1433
@ -1,8 +1,10 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
attr_parsing::{parse_assignment_attribute, second},
|
attr_parsing::{parse_assignment_attribute, second},
|
||||||
with_position::{Position, WithPosition},
|
with_position::{Position, WithPosition},
|
||||||
};
|
};
|
||||||
use proc_macro2::TokenStream;
|
use proc_macro2::{Span, TokenStream};
|
||||||
use quote::{format_ident, quote, quote_spanned};
|
use quote::{format_ident, quote, quote_spanned};
|
||||||
use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
||||||
|
|
||||||
@ -22,20 +24,38 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
|||||||
// If the function is generic, we can't reliably check its inputs or whether the future it
|
// If the function is generic, we can't reliably check its inputs or whether the future it
|
||||||
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
|
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
|
||||||
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
|
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
|
||||||
|
let mut err = None;
|
||||||
|
|
||||||
if state_ty.is_none() {
|
if state_ty.is_none() {
|
||||||
state_ty = state_type_from_args(&item_fn);
|
let state_types_from_args = state_types_from_args(&item_fn);
|
||||||
|
|
||||||
|
#[allow(clippy::comparison_chain)]
|
||||||
|
if state_types_from_args.len() == 1 {
|
||||||
|
state_ty = state_types_from_args.into_iter().next();
|
||||||
|
} else if state_types_from_args.len() > 1 {
|
||||||
|
err = Some(
|
||||||
|
syn::Error::new(
|
||||||
|
Span::call_site(),
|
||||||
|
"can't infer state type, please add set it explicitly, as in \
|
||||||
|
`#[debug_handler(state = MyStateType)]`",
|
||||||
|
)
|
||||||
|
.into_compile_error(),
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(()));
|
err.unwrap_or_else(|| {
|
||||||
|
let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(()));
|
||||||
|
|
||||||
let check_inputs_impls_from_request =
|
let check_inputs_impls_from_request =
|
||||||
check_inputs_impls_from_request(&item_fn, &body_ty, state_ty);
|
check_inputs_impls_from_request(&item_fn, &body_ty, state_ty);
|
||||||
let check_future_send = check_future_send(&item_fn);
|
let check_future_send = check_future_send(&item_fn);
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
#check_inputs_impls_from_request
|
#check_inputs_impls_from_request
|
||||||
#check_future_send
|
#check_future_send
|
||||||
}
|
}
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
syn::Error::new_spanned(
|
syn::Error::new_spanned(
|
||||||
&item_fn.sig.generics,
|
&item_fn.sig.generics,
|
||||||
@ -433,7 +453,7 @@ fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
|
|||||||
/// This will extract `AppState`.
|
/// This will extract `AppState`.
|
||||||
///
|
///
|
||||||
/// Returns `None` if there are no `State` args or multiple of different types.
|
/// Returns `None` if there are no `State` args or multiple of different types.
|
||||||
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
fn state_types_from_args(item_fn: &ItemFn) -> HashSet<Type> {
|
||||||
let types = item_fn
|
let types = item_fn
|
||||||
.sig
|
.sig
|
||||||
.inputs
|
.inputs
|
||||||
@ -443,7 +463,7 @@ fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
|||||||
FnArg::Typed(pat_type) => Some(pat_type),
|
FnArg::Typed(pat_type) => Some(pat_type),
|
||||||
})
|
})
|
||||||
.map(|pat_type| &*pat_type.ty);
|
.map(|pat_type| &*pat_type.ty);
|
||||||
crate::infer_state_type(types)
|
crate::infer_state_types(types).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -47,6 +47,7 @@ impl fmt::Display for Trait {
|
|||||||
enum State {
|
enum State {
|
||||||
Custom(syn::Type),
|
Custom(syn::Type),
|
||||||
Default(syn::Type),
|
Default(syn::Type),
|
||||||
|
CannotInfer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
@ -58,6 +59,7 @@ impl State {
|
|||||||
match self {
|
match self {
|
||||||
State::Default(inner) => Some(inner.clone()),
|
State::Default(inner) => Some(inner.clone()),
|
||||||
State::Custom(_) => None,
|
State::Custom(_) => None,
|
||||||
|
State::CannotInfer => Some(parse_quote!(S)),
|
||||||
}
|
}
|
||||||
.into_iter()
|
.into_iter()
|
||||||
}
|
}
|
||||||
@ -70,6 +72,7 @@ impl State {
|
|||||||
match self {
|
match self {
|
||||||
State::Default(inner) => iter::once(inner.clone()),
|
State::Default(inner) => iter::once(inner.clone()),
|
||||||
State::Custom(inner) => iter::once(inner.clone()),
|
State::Custom(inner) => iter::once(inner.clone()),
|
||||||
|
State::CannotInfer => iter::once(parse_quote!(S)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,6 +82,9 @@ impl State {
|
|||||||
State::Default(inner) => quote! {
|
State::Default(inner) => quote! {
|
||||||
#inner: ::std::marker::Send + ::std::marker::Sync,
|
#inner: ::std::marker::Send + ::std::marker::Sync,
|
||||||
},
|
},
|
||||||
|
State::CannotInfer => quote! {
|
||||||
|
S: ::std::marker::Send + ::std::marker::Sync,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,6 +94,7 @@ impl ToTokens for State {
|
|||||||
match self {
|
match self {
|
||||||
State::Custom(inner) => inner.to_tokens(tokens),
|
State::Custom(inner) => inner.to_tokens(tokens),
|
||||||
State::Default(inner) => inner.to_tokens(tokens),
|
State::Default(inner) => inner.to_tokens(tokens),
|
||||||
|
State::CannotInfer => quote! { S }.to_tokens(tokens),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,30 +122,60 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
|||||||
|
|
||||||
let state = match state {
|
let state = match state {
|
||||||
Some((_, state)) => State::Custom(state),
|
Some((_, state)) => State::Custom(state),
|
||||||
None => infer_state_type_from_field_types(&fields)
|
None => {
|
||||||
.map(State::Custom)
|
let mut inferred_state_types: HashSet<_> =
|
||||||
.or_else(|| infer_state_type_from_field_attributes(&fields).map(State::Custom))
|
infer_state_type_from_field_types(&fields)
|
||||||
.or_else(|| {
|
.chain(infer_state_type_from_field_attributes(&fields))
|
||||||
let via = via.as_ref().map(|(_, via)| via)?;
|
.collect();
|
||||||
state_from_via(&ident, via).map(State::Custom)
|
|
||||||
})
|
if let Some((_, via)) = &via {
|
||||||
.unwrap_or_else(|| State::Default(syn::parse_quote!(S))),
|
inferred_state_types.extend(state_from_via(&ident, via));
|
||||||
|
}
|
||||||
|
|
||||||
|
match inferred_state_types.len() {
|
||||||
|
0 => State::Default(syn::parse_quote!(S)),
|
||||||
|
1 => State::Custom(inferred_state_types.iter().next().unwrap().to_owned()),
|
||||||
|
_ => State::CannotInfer,
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match (via.map(second), rejection.map(second)) {
|
let trait_impl = match (via.map(second), rejection.map(second)) {
|
||||||
(Some(via), rejection) => impl_struct_by_extracting_all_at_once(
|
(Some(via), rejection) => impl_struct_by_extracting_all_at_once(
|
||||||
ident,
|
ident,
|
||||||
fields,
|
fields,
|
||||||
via,
|
via,
|
||||||
rejection,
|
rejection,
|
||||||
generic_ident,
|
generic_ident,
|
||||||
state,
|
&state,
|
||||||
tr,
|
tr,
|
||||||
),
|
)?,
|
||||||
(None, rejection) => {
|
(None, rejection) => {
|
||||||
error_on_generic_ident(generic_ident, tr)?;
|
error_on_generic_ident(generic_ident, tr)?;
|
||||||
impl_struct_by_extracting_each_field(ident, fields, rejection, state, tr)
|
impl_struct_by_extracting_each_field(ident, fields, rejection, &state, tr)?
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let State::CannotInfer = state {
|
||||||
|
let attr_name = match tr {
|
||||||
|
Trait::FromRequest => "from_request",
|
||||||
|
Trait::FromRequestParts => "from_request_parts",
|
||||||
|
};
|
||||||
|
let compile_error = syn::Error::new(
|
||||||
|
Span::call_site(),
|
||||||
|
format_args!(
|
||||||
|
"can't infer state type, please add \
|
||||||
|
`#[{attr_name}(state = MyStateType)]` attribute",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.into_compile_error();
|
||||||
|
|
||||||
|
Ok(quote! {
|
||||||
|
#trait_impl
|
||||||
|
#compile_error
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(trait_impl)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
syn::Item::Enum(item) => {
|
syn::Item::Enum(item) => {
|
||||||
@ -308,10 +345,22 @@ fn impl_struct_by_extracting_each_field(
|
|||||||
ident: syn::Ident,
|
ident: syn::Ident,
|
||||||
fields: syn::Fields,
|
fields: syn::Fields,
|
||||||
rejection: Option<syn::Path>,
|
rejection: Option<syn::Path>,
|
||||||
state: State,
|
state: &State,
|
||||||
tr: Trait,
|
tr: Trait,
|
||||||
) -> syn::Result<TokenStream> {
|
) -> syn::Result<TokenStream> {
|
||||||
let extract_fields = extract_fields(&fields, &rejection, tr)?;
|
let trait_fn_body = match state {
|
||||||
|
State::CannotInfer => quote! {
|
||||||
|
::std::unimplemented!()
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
let extract_fields = extract_fields(&fields, &rejection, tr)?;
|
||||||
|
quote! {
|
||||||
|
::std::result::Result::Ok(Self {
|
||||||
|
#(#extract_fields)*
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let rejection_ident = if let Some(rejection) = rejection {
|
let rejection_ident = if let Some(rejection) = rejection {
|
||||||
quote!(#rejection)
|
quote!(#rejection)
|
||||||
@ -350,9 +399,7 @@ fn impl_struct_by_extracting_each_field(
|
|||||||
mut req: ::axum::http::Request<B>,
|
mut req: ::axum::http::Request<B>,
|
||||||
state: &#state,
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::std::result::Result::Ok(Self {
|
#trait_fn_body
|
||||||
#(#extract_fields)*
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -369,9 +416,7 @@ fn impl_struct_by_extracting_each_field(
|
|||||||
parts: &mut ::axum::http::request::Parts,
|
parts: &mut ::axum::http::request::Parts,
|
||||||
state: &#state,
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::std::result::Result::Ok(Self {
|
#trait_fn_body
|
||||||
#(#extract_fields)*
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -661,7 +706,7 @@ fn impl_struct_by_extracting_all_at_once(
|
|||||||
via_path: syn::Path,
|
via_path: syn::Path,
|
||||||
rejection: Option<syn::Path>,
|
rejection: Option<syn::Path>,
|
||||||
generic_ident: Option<Ident>,
|
generic_ident: Option<Ident>,
|
||||||
state: State,
|
state: &State,
|
||||||
tr: Trait,
|
tr: Trait,
|
||||||
) -> syn::Result<TokenStream> {
|
) -> syn::Result<TokenStream> {
|
||||||
let fields = match fields {
|
let fields = match fields {
|
||||||
@ -952,15 +997,15 @@ fn impl_enum_by_extracting_all_at_once(
|
|||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// We can infer the state type to be `AppState` because it appears inside a `State`
|
/// We can infer the state type to be `AppState` because it appears inside a `State`
|
||||||
fn infer_state_type_from_field_types(fields: &Fields) -> Option<Type> {
|
fn infer_state_type_from_field_types(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
|
||||||
match fields {
|
match fields {
|
||||||
Fields::Named(fields_named) => {
|
Fields::Named(fields_named) => Box::new(crate::infer_state_types(
|
||||||
crate::infer_state_type(fields_named.named.iter().map(|field| &field.ty))
|
fields_named.named.iter().map(|field| &field.ty),
|
||||||
}
|
)) as Box<dyn Iterator<Item = Type>>,
|
||||||
Fields::Unnamed(fields_unnamed) => {
|
Fields::Unnamed(fields_unnamed) => Box::new(crate::infer_state_types(
|
||||||
crate::infer_state_type(fields_unnamed.unnamed.iter().map(|field| &field.ty))
|
fields_unnamed.unnamed.iter().map(|field| &field.ty),
|
||||||
}
|
)),
|
||||||
Fields::Unit => None,
|
Fields::Unit => Box::new(iter::empty()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -975,43 +1020,29 @@ fn infer_state_type_from_field_types(fields: &Fields) -> Option<Type> {
|
|||||||
///
|
///
|
||||||
/// We can infer the state type to be `AppState` because it has `via(State)` and thus can be
|
/// We can infer the state type to be `AppState` because it has `via(State)` and thus can be
|
||||||
/// extracted with `State<AppState>`
|
/// extracted with `State<AppState>`
|
||||||
fn infer_state_type_from_field_attributes(fields: &Fields) -> Option<Type> {
|
fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
|
||||||
let state_inputs = match fields {
|
match fields {
|
||||||
Fields::Named(fields_named) => {
|
Fields::Named(fields_named) => {
|
||||||
fields_named
|
Box::new(fields_named.named.iter().filter_map(|field| {
|
||||||
.named
|
// TODO(david): its a little wasteful to parse the attributes again here
|
||||||
.iter()
|
// ideally we should parse things once and pass the data down
|
||||||
.filter_map(|field| {
|
let FromRequestFieldAttrs { via } =
|
||||||
// TODO(david): its a little wasteful to parse the attributes again here
|
parse_attrs("from_request", &field.attrs).ok()?;
|
||||||
// ideally we should parse things once and pass the data down
|
let (_, via_path) = via?;
|
||||||
let FromRequestFieldAttrs { via } =
|
path_ident_is_state(&via_path).then(|| field.ty.clone())
|
||||||
parse_attrs("from_request", &field.attrs).ok()?;
|
})) as Box<dyn Iterator<Item = Type>>
|
||||||
let (_, via_path) = via?;
|
|
||||||
path_ident_is_state(&via_path).then(|| &field.ty)
|
|
||||||
})
|
|
||||||
.collect::<HashSet<_>>()
|
|
||||||
}
|
}
|
||||||
Fields::Unnamed(fields_unnamed) => {
|
Fields::Unnamed(fields_unnamed) => {
|
||||||
fields_unnamed
|
Box::new(fields_unnamed.unnamed.iter().filter_map(|field| {
|
||||||
.unnamed
|
// TODO(david): its a little wasteful to parse the attributes again here
|
||||||
.iter()
|
// ideally we should parse things once and pass the data down
|
||||||
.filter_map(|field| {
|
let FromRequestFieldAttrs { via } =
|
||||||
// TODO(david): its a little wasteful to parse the attributes again here
|
parse_attrs("from_request", &field.attrs).ok()?;
|
||||||
// ideally we should parse things once and pass the data down
|
let (_, via_path) = via?;
|
||||||
let FromRequestFieldAttrs { via } =
|
path_ident_is_state(&via_path).then(|| field.ty.clone())
|
||||||
parse_attrs("from_request", &field.attrs).ok()?;
|
}))
|
||||||
let (_, via_path) = via?;
|
|
||||||
path_ident_is_state(&via_path).then(|| &field.ty)
|
|
||||||
})
|
|
||||||
.collect::<HashSet<_>>()
|
|
||||||
}
|
}
|
||||||
Fields::Unit => return None,
|
Fields::Unit => Box::new(iter::empty()),
|
||||||
};
|
|
||||||
|
|
||||||
if state_inputs.len() == 1 {
|
|
||||||
state_inputs.iter().next().map(|&ty| ty.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,8 +43,6 @@
|
|||||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||||
#![cfg_attr(test, allow(clippy::float_cmp))]
|
#![cfg_attr(test, allow(clippy::float_cmp))]
|
||||||
|
|
||||||
use std::collections::HashSet;
|
|
||||||
|
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use quote::{quote, ToTokens};
|
use quote::{quote, ToTokens};
|
||||||
use syn::{parse::Parse, Type};
|
use syn::{parse::Parse, Type};
|
||||||
@ -615,11 +613,11 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_state_type<'a, I>(types: I) -> Option<Type>
|
fn infer_state_types<'a, I>(types: I) -> impl Iterator<Item = Type> + 'a
|
||||||
where
|
where
|
||||||
I: Iterator<Item = &'a Type>,
|
I: Iterator<Item = &'a Type> + 'a,
|
||||||
{
|
{
|
||||||
let state_inputs = types
|
types
|
||||||
.filter_map(|ty| {
|
.filter_map(|ty| {
|
||||||
if let Type::Path(path) = ty {
|
if let Type::Path(path) = ty {
|
||||||
Some(&path.path)
|
Some(&path.path)
|
||||||
@ -650,13 +648,7 @@ where
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<HashSet<_>>();
|
.cloned()
|
||||||
|
|
||||||
if state_inputs.len() == 1 {
|
|
||||||
state_inputs.iter().next().map(|&ty| ty.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -1,23 +1,7 @@
|
|||||||
error[E0277]: the trait bound `AppState: FromRef<S>` is not satisfied
|
error: can't infer state type, please add `#[from_request(state = MyStateType)]` attribute
|
||||||
--> tests/from_request/fail/state_infer_multiple_different_types.rs:6:18
|
--> tests/from_request/fail/state_infer_multiple_different_types.rs:4:10
|
||||||
|
|
|
|
||||||
6 | inner_state: State<AppState>,
|
4 | #[derive(FromRequest)]
|
||||||
| ^^^^^ the trait `FromRef<S>` is not implemented for `AppState`
|
| ^^^^^^^^^^^
|
||||||
|
|
|
|
||||||
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<AppState>`
|
= note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info)
|
||||||
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
|
||||||
|
|
|
||||||
4 | #[derive(FromRequest, AppState: FromRef<S>)]
|
|
||||||
| ++++++++++++++++++++++
|
|
||||||
|
|
||||||
error[E0277]: the trait bound `OtherState: FromRef<S>` is not satisfied
|
|
||||||
--> tests/from_request/fail/state_infer_multiple_different_types.rs:7:18
|
|
||||||
|
|
|
||||||
7 | other_state: State<OtherState>,
|
|
||||||
| ^^^^^ the trait `FromRef<S>` is not implemented for `OtherState`
|
|
||||||
|
|
|
||||||
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<OtherState>`
|
|
||||||
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
|
||||||
|
|
|
||||||
4 | #[derive(FromRequest, OtherState: FromRef<S>)]
|
|
||||||
| ++++++++++++++++++++++++
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user