Improve the error message for state type inference failure in FromRequest(Parts) derive macro (#1432)

* Add a dedicated error message for state type inference issues

* Generate valid code even if state type can't be inferred

* Also error on state type inference for debug_handler
This commit is contained in:
Jonas Platte 2022-10-09 20:25:05 +02:00 committed by GitHub
parent ee0b71a4ac
commit 7cbacd1433
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 134 additions and 107 deletions

View File

@ -1,8 +1,10 @@
use std::collections::HashSet;
use crate::{ use crate::{
attr_parsing::{parse_assignment_attribute, second}, attr_parsing::{parse_assignment_attribute, second},
with_position::{Position, WithPosition}, with_position::{Position, WithPosition},
}; };
use proc_macro2::TokenStream; use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned}; use quote::{format_ident, quote, quote_spanned};
use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type}; use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type};
@ -22,20 +24,38 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
// If the function is generic, we can't reliably check its inputs or whether the future it // If the function is generic, we can't reliably check its inputs or whether the future it
// returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors.
let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() { let check_inputs_and_future_send = if item_fn.sig.generics.params.is_empty() {
let mut err = None;
if state_ty.is_none() { if state_ty.is_none() {
state_ty = state_type_from_args(&item_fn); let state_types_from_args = state_types_from_args(&item_fn);
#[allow(clippy::comparison_chain)]
if state_types_from_args.len() == 1 {
state_ty = state_types_from_args.into_iter().next();
} else if state_types_from_args.len() > 1 {
err = Some(
syn::Error::new(
Span::call_site(),
"can't infer state type, please add set it explicitly, as in \
`#[debug_handler(state = MyStateType)]`",
)
.into_compile_error(),
);
}
} }
let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); err.unwrap_or_else(|| {
let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(()));
let check_inputs_impls_from_request = let check_inputs_impls_from_request =
check_inputs_impls_from_request(&item_fn, &body_ty, state_ty); check_inputs_impls_from_request(&item_fn, &body_ty, state_ty);
let check_future_send = check_future_send(&item_fn); let check_future_send = check_future_send(&item_fn);
quote! { quote! {
#check_inputs_impls_from_request #check_inputs_impls_from_request
#check_future_send #check_future_send
} }
})
} else { } else {
syn::Error::new_spanned( syn::Error::new_spanned(
&item_fn.sig.generics, &item_fn.sig.generics,
@ -433,7 +453,7 @@ fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
/// This will extract `AppState`. /// This will extract `AppState`.
/// ///
/// Returns `None` if there are no `State` args or multiple of different types. /// Returns `None` if there are no `State` args or multiple of different types.
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> { fn state_types_from_args(item_fn: &ItemFn) -> HashSet<Type> {
let types = item_fn let types = item_fn
.sig .sig
.inputs .inputs
@ -443,7 +463,7 @@ fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
FnArg::Typed(pat_type) => Some(pat_type), FnArg::Typed(pat_type) => Some(pat_type),
}) })
.map(|pat_type| &*pat_type.ty); .map(|pat_type| &*pat_type.ty);
crate::infer_state_type(types) crate::infer_state_types(types).collect()
} }
#[test] #[test]

View File

@ -47,6 +47,7 @@ impl fmt::Display for Trait {
enum State { enum State {
Custom(syn::Type), Custom(syn::Type),
Default(syn::Type), Default(syn::Type),
CannotInfer,
} }
impl State { impl State {
@ -58,6 +59,7 @@ impl State {
match self { match self {
State::Default(inner) => Some(inner.clone()), State::Default(inner) => Some(inner.clone()),
State::Custom(_) => None, State::Custom(_) => None,
State::CannotInfer => Some(parse_quote!(S)),
} }
.into_iter() .into_iter()
} }
@ -70,6 +72,7 @@ impl State {
match self { match self {
State::Default(inner) => iter::once(inner.clone()), State::Default(inner) => iter::once(inner.clone()),
State::Custom(inner) => iter::once(inner.clone()), State::Custom(inner) => iter::once(inner.clone()),
State::CannotInfer => iter::once(parse_quote!(S)),
} }
} }
@ -79,6 +82,9 @@ impl State {
State::Default(inner) => quote! { State::Default(inner) => quote! {
#inner: ::std::marker::Send + ::std::marker::Sync, #inner: ::std::marker::Send + ::std::marker::Sync,
}, },
State::CannotInfer => quote! {
S: ::std::marker::Send + ::std::marker::Sync,
},
} }
} }
} }
@ -88,6 +94,7 @@ impl ToTokens for State {
match self { match self {
State::Custom(inner) => inner.to_tokens(tokens), State::Custom(inner) => inner.to_tokens(tokens),
State::Default(inner) => inner.to_tokens(tokens), State::Default(inner) => inner.to_tokens(tokens),
State::CannotInfer => quote! { S }.to_tokens(tokens),
} }
} }
} }
@ -115,30 +122,60 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
let state = match state { let state = match state {
Some((_, state)) => State::Custom(state), Some((_, state)) => State::Custom(state),
None => infer_state_type_from_field_types(&fields) None => {
.map(State::Custom) let mut inferred_state_types: HashSet<_> =
.or_else(|| infer_state_type_from_field_attributes(&fields).map(State::Custom)) infer_state_type_from_field_types(&fields)
.or_else(|| { .chain(infer_state_type_from_field_attributes(&fields))
let via = via.as_ref().map(|(_, via)| via)?; .collect();
state_from_via(&ident, via).map(State::Custom)
}) if let Some((_, via)) = &via {
.unwrap_or_else(|| State::Default(syn::parse_quote!(S))), inferred_state_types.extend(state_from_via(&ident, via));
}
match inferred_state_types.len() {
0 => State::Default(syn::parse_quote!(S)),
1 => State::Custom(inferred_state_types.iter().next().unwrap().to_owned()),
_ => State::CannotInfer,
}
}
}; };
match (via.map(second), rejection.map(second)) { let trait_impl = match (via.map(second), rejection.map(second)) {
(Some(via), rejection) => impl_struct_by_extracting_all_at_once( (Some(via), rejection) => impl_struct_by_extracting_all_at_once(
ident, ident,
fields, fields,
via, via,
rejection, rejection,
generic_ident, generic_ident,
state, &state,
tr, tr,
), )?,
(None, rejection) => { (None, rejection) => {
error_on_generic_ident(generic_ident, tr)?; error_on_generic_ident(generic_ident, tr)?;
impl_struct_by_extracting_each_field(ident, fields, rejection, state, tr) impl_struct_by_extracting_each_field(ident, fields, rejection, &state, tr)?
} }
};
if let State::CannotInfer = state {
let attr_name = match tr {
Trait::FromRequest => "from_request",
Trait::FromRequestParts => "from_request_parts",
};
let compile_error = syn::Error::new(
Span::call_site(),
format_args!(
"can't infer state type, please add \
`#[{attr_name}(state = MyStateType)]` attribute",
),
)
.into_compile_error();
Ok(quote! {
#trait_impl
#compile_error
})
} else {
Ok(trait_impl)
} }
} }
syn::Item::Enum(item) => { syn::Item::Enum(item) => {
@ -308,10 +345,22 @@ fn impl_struct_by_extracting_each_field(
ident: syn::Ident, ident: syn::Ident,
fields: syn::Fields, fields: syn::Fields,
rejection: Option<syn::Path>, rejection: Option<syn::Path>,
state: State, state: &State,
tr: Trait, tr: Trait,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let extract_fields = extract_fields(&fields, &rejection, tr)?; let trait_fn_body = match state {
State::CannotInfer => quote! {
::std::unimplemented!()
},
_ => {
let extract_fields = extract_fields(&fields, &rejection, tr)?;
quote! {
::std::result::Result::Ok(Self {
#(#extract_fields)*
})
}
}
};
let rejection_ident = if let Some(rejection) = rejection { let rejection_ident = if let Some(rejection) = rejection {
quote!(#rejection) quote!(#rejection)
@ -350,9 +399,7 @@ fn impl_struct_by_extracting_each_field(
mut req: ::axum::http::Request<B>, mut req: ::axum::http::Request<B>,
state: &#state, state: &#state,
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self { #trait_fn_body
#(#extract_fields)*
})
} }
} }
}, },
@ -369,9 +416,7 @@ fn impl_struct_by_extracting_each_field(
parts: &mut ::axum::http::request::Parts, parts: &mut ::axum::http::request::Parts,
state: &#state, state: &#state,
) -> ::std::result::Result<Self, Self::Rejection> { ) -> ::std::result::Result<Self, Self::Rejection> {
::std::result::Result::Ok(Self { #trait_fn_body
#(#extract_fields)*
})
} }
} }
}, },
@ -661,7 +706,7 @@ fn impl_struct_by_extracting_all_at_once(
via_path: syn::Path, via_path: syn::Path,
rejection: Option<syn::Path>, rejection: Option<syn::Path>,
generic_ident: Option<Ident>, generic_ident: Option<Ident>,
state: State, state: &State,
tr: Trait, tr: Trait,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let fields = match fields { let fields = match fields {
@ -952,15 +997,15 @@ fn impl_enum_by_extracting_all_at_once(
/// ``` /// ```
/// ///
/// We can infer the state type to be `AppState` because it appears inside a `State` /// We can infer the state type to be `AppState` because it appears inside a `State`
fn infer_state_type_from_field_types(fields: &Fields) -> Option<Type> { fn infer_state_type_from_field_types(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
match fields { match fields {
Fields::Named(fields_named) => { Fields::Named(fields_named) => Box::new(crate::infer_state_types(
crate::infer_state_type(fields_named.named.iter().map(|field| &field.ty)) fields_named.named.iter().map(|field| &field.ty),
} )) as Box<dyn Iterator<Item = Type>>,
Fields::Unnamed(fields_unnamed) => { Fields::Unnamed(fields_unnamed) => Box::new(crate::infer_state_types(
crate::infer_state_type(fields_unnamed.unnamed.iter().map(|field| &field.ty)) fields_unnamed.unnamed.iter().map(|field| &field.ty),
} )),
Fields::Unit => None, Fields::Unit => Box::new(iter::empty()),
} }
} }
@ -975,43 +1020,29 @@ fn infer_state_type_from_field_types(fields: &Fields) -> Option<Type> {
/// ///
/// We can infer the state type to be `AppState` because it has `via(State)` and thus can be /// We can infer the state type to be `AppState` because it has `via(State)` and thus can be
/// extracted with `State<AppState>` /// extracted with `State<AppState>`
fn infer_state_type_from_field_attributes(fields: &Fields) -> Option<Type> { fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator<Item = Type> + '_ {
let state_inputs = match fields { match fields {
Fields::Named(fields_named) => { Fields::Named(fields_named) => {
fields_named Box::new(fields_named.named.iter().filter_map(|field| {
.named // TODO(david): its a little wasteful to parse the attributes again here
.iter() // ideally we should parse things once and pass the data down
.filter_map(|field| { let FromRequestFieldAttrs { via } =
// TODO(david): its a little wasteful to parse the attributes again here parse_attrs("from_request", &field.attrs).ok()?;
// ideally we should parse things once and pass the data down let (_, via_path) = via?;
let FromRequestFieldAttrs { via } = path_ident_is_state(&via_path).then(|| field.ty.clone())
parse_attrs("from_request", &field.attrs).ok()?; })) as Box<dyn Iterator<Item = Type>>
let (_, via_path) = via?;
path_ident_is_state(&via_path).then(|| &field.ty)
})
.collect::<HashSet<_>>()
} }
Fields::Unnamed(fields_unnamed) => { Fields::Unnamed(fields_unnamed) => {
fields_unnamed Box::new(fields_unnamed.unnamed.iter().filter_map(|field| {
.unnamed // TODO(david): its a little wasteful to parse the attributes again here
.iter() // ideally we should parse things once and pass the data down
.filter_map(|field| { let FromRequestFieldAttrs { via } =
// TODO(david): its a little wasteful to parse the attributes again here parse_attrs("from_request", &field.attrs).ok()?;
// ideally we should parse things once and pass the data down let (_, via_path) = via?;
let FromRequestFieldAttrs { via } = path_ident_is_state(&via_path).then(|| field.ty.clone())
parse_attrs("from_request", &field.attrs).ok()?; }))
let (_, via_path) = via?;
path_ident_is_state(&via_path).then(|| &field.ty)
})
.collect::<HashSet<_>>()
} }
Fields::Unit => return None, Fields::Unit => Box::new(iter::empty()),
};
if state_inputs.len() == 1 {
state_inputs.iter().next().map(|&ty| ty.clone())
} else {
None
} }
} }

View File

@ -43,8 +43,6 @@
#![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(test, allow(clippy::float_cmp))]
use std::collections::HashSet;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use syn::{parse::Parse, Type}; use syn::{parse::Parse, Type};
@ -615,11 +613,11 @@ where
} }
} }
fn infer_state_type<'a, I>(types: I) -> Option<Type> fn infer_state_types<'a, I>(types: I) -> impl Iterator<Item = Type> + 'a
where where
I: Iterator<Item = &'a Type>, I: Iterator<Item = &'a Type> + 'a,
{ {
let state_inputs = types types
.filter_map(|ty| { .filter_map(|ty| {
if let Type::Path(path) = ty { if let Type::Path(path) = ty {
Some(&path.path) Some(&path.path)
@ -650,13 +648,7 @@ where
None None
} }
}) })
.collect::<HashSet<_>>(); .cloned()
if state_inputs.len() == 1 {
state_inputs.iter().next().map(|&ty| ty.clone())
} else {
None
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,23 +1,7 @@
error[E0277]: the trait bound `AppState: FromRef<S>` is not satisfied error: can't infer state type, please add `#[from_request(state = MyStateType)]` attribute
--> tests/from_request/fail/state_infer_multiple_different_types.rs:6:18 --> tests/from_request/fail/state_infer_multiple_different_types.rs:4:10
| |
6 | inner_state: State<AppState>, 4 | #[derive(FromRequest)]
| ^^^^^ the trait `FromRef<S>` is not implemented for `AppState` | ^^^^^^^^^^^
| |
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<AppState>` = note: this error originates in the derive macro `FromRequest` (in Nightly builds, run with -Z macro-backtrace for more info)
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
4 | #[derive(FromRequest, AppState: FromRef<S>)]
| ++++++++++++++++++++++
error[E0277]: the trait bound `OtherState: FromRef<S>` is not satisfied
--> tests/from_request/fail/state_infer_multiple_different_types.rs:7:18
|
7 | other_state: State<OtherState>,
| ^^^^^ the trait `FromRef<S>` is not implemented for `OtherState`
|
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<OtherState>`
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
4 | #[derive(FromRequest, OtherState: FromRef<S>)]
| ++++++++++++++++++++++++