mirror of
https://github.com/tokio-rs/axum.git
synced 2025-10-02 07:20:38 +00:00
Support State
with #[derive(FromRequest[Parts])]
(#1391)
* Support `State` with `#[derive(FromRequest[Parts])]` Fixes https://github.com/tokio-rs/axum/issues/1314 This makes it possible to extract things via `State` in `#[derive(FromRequet)]`: ```rust struct Foo { state: State<AppState>, } ``` The state can also be inferred in a lot of cases so you only need to write: ```rust struct Foo { // since we're using `State<AppState>` we know the state has to be // `AppState` state: State<AppState>, } ``` Same for ```rust struct Foo { #[from_request(via(State))] state: AppState, } ``` And ```rust struct AppState {} ``` I think I've covered all the edge cases but there are (unsurprisingly) a few. * make sure things can be combined with other extractors * main functions in ui tests don't need to be async * Add test for multiple identicaly state types * Add failing test for multiple states
This commit is contained in:
parent
e3a17c1249
commit
c3f3db79ec
@ -26,7 +26,7 @@ syn = { version = "1.0", features = [
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] }
|
axum = { path = "../axum", version = "0.6.0-rc.2", features = ["headers"] }
|
||||||
axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing"] }
|
axum-extra = { path = "../axum-extra", version = "0.4.0-rc.1", features = ["typed-routing", "cookie-private"] }
|
||||||
rustversion = "1.0"
|
rustversion = "1.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
@ -4,7 +4,6 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use proc_macro2::TokenStream;
|
use proc_macro2::TokenStream;
|
||||||
use quote::{format_ident, quote, quote_spanned};
|
use quote::{format_ident, quote, quote_spanned};
|
||||||
use std::collections::HashSet;
|
|
||||||
use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
use syn::{parse::Parse, parse_quote, spanned::Spanned, FnArg, ItemFn, Token, Type};
|
||||||
|
|
||||||
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream {
|
||||||
@ -435,7 +434,7 @@ fn self_receiver(item_fn: &ItemFn) -> Option<TokenStream> {
|
|||||||
///
|
///
|
||||||
/// Returns `None` if there are no `State` args or multiple of different types.
|
/// Returns `None` if there are no `State` args or multiple of different types.
|
||||||
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
||||||
let state_inputs = item_fn
|
let types = item_fn
|
||||||
.sig
|
.sig
|
||||||
.inputs
|
.inputs
|
||||||
.iter()
|
.iter()
|
||||||
@ -443,44 +442,8 @@ fn state_type_from_args(item_fn: &ItemFn) -> Option<Type> {
|
|||||||
FnArg::Receiver(_) => None,
|
FnArg::Receiver(_) => None,
|
||||||
FnArg::Typed(pat_type) => Some(pat_type),
|
FnArg::Typed(pat_type) => Some(pat_type),
|
||||||
})
|
})
|
||||||
.map(|pat_type| &pat_type.ty)
|
.map(|pat_type| &*pat_type.ty);
|
||||||
.filter_map(|ty| {
|
crate::infer_state_type(types)
|
||||||
if let Type::Path(path) = &**ty {
|
|
||||||
Some(&path.path)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter_map(|path| {
|
|
||||||
if let Some(last_segment) = path.segments.last() {
|
|
||||||
if last_segment.ident != "State" {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
match &last_segment.arguments {
|
|
||||||
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
|
|
||||||
Some(args.args.first().unwrap())
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter_map(|generic_arg| {
|
|
||||||
if let syn::GenericArgument::Type(ty) = generic_arg {
|
|
||||||
Some(ty)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
if state_inputs.len() == 1 {
|
|
||||||
state_inputs.iter().next().map(|&ty| ty.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -4,9 +4,11 @@ use crate::{
|
|||||||
from_request::attr::FromRequestFieldAttrs,
|
from_request::attr::FromRequestFieldAttrs,
|
||||||
};
|
};
|
||||||
use proc_macro2::{Span, TokenStream};
|
use proc_macro2::{Span, TokenStream};
|
||||||
use quote::{quote, quote_spanned};
|
use quote::{quote, quote_spanned, ToTokens};
|
||||||
use std::fmt;
|
use std::{collections::HashSet, fmt, iter};
|
||||||
use syn::{punctuated::Punctuated, spanned::Spanned, Ident, Token};
|
use syn::{
|
||||||
|
parse_quote, punctuated::Punctuated, spanned::Spanned, Fields, Ident, Path, Token, Type,
|
||||||
|
};
|
||||||
|
|
||||||
mod attr;
|
mod attr;
|
||||||
|
|
||||||
@ -16,6 +18,22 @@ pub(crate) enum Trait {
|
|||||||
FromRequestParts,
|
FromRequestParts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Trait {
|
||||||
|
fn body_type(&self) -> impl Iterator<Item = Type> {
|
||||||
|
match self {
|
||||||
|
Trait::FromRequest => Some(parse_quote!(B)).into_iter(),
|
||||||
|
Trait::FromRequestParts => None.into_iter(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn via_marker_type(&self) -> Option<Type> {
|
||||||
|
match self {
|
||||||
|
Trait::FromRequest => Some(parse_quote!(M)),
|
||||||
|
Trait::FromRequestParts => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl fmt::Display for Trait {
|
impl fmt::Display for Trait {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
@ -25,6 +43,55 @@ impl fmt::Display for Trait {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum State {
|
||||||
|
Custom(syn::Type),
|
||||||
|
Default(syn::Type),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
/// ```not_rust
|
||||||
|
/// impl<T> A for B {}
|
||||||
|
/// ^ this type
|
||||||
|
/// ```
|
||||||
|
fn impl_generics(&self) -> impl Iterator<Item = Type> {
|
||||||
|
match self {
|
||||||
|
State::Default(inner) => Some(inner.clone()),
|
||||||
|
State::Custom(_) => None,
|
||||||
|
}
|
||||||
|
.into_iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ```not_rust
|
||||||
|
/// impl<T> A<T> for B {}
|
||||||
|
/// ^ this type
|
||||||
|
/// ```
|
||||||
|
fn trait_generics(&self) -> impl Iterator<Item = Type> {
|
||||||
|
match self {
|
||||||
|
State::Default(inner) => iter::once(inner.clone()),
|
||||||
|
State::Custom(inner) => iter::once(inner.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bounds(&self) -> TokenStream {
|
||||||
|
match self {
|
||||||
|
State::Custom(_) => quote! {},
|
||||||
|
State::Default(inner) => quote! {
|
||||||
|
#inner: ::std::marker::Send + ::std::marker::Sync,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToTokens for State {
|
||||||
|
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||||
|
match self {
|
||||||
|
State::Custom(inner) => inner.to_tokens(tokens),
|
||||||
|
State::Default(inner) => inner.to_tokens(tokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
||||||
match item {
|
match item {
|
||||||
syn::Item::Struct(item) => {
|
syn::Item::Struct(item) => {
|
||||||
@ -40,7 +107,23 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
|||||||
|
|
||||||
let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
|
let generic_ident = parse_single_generic_type_on_struct(generics, &fields, tr)?;
|
||||||
|
|
||||||
let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?;
|
let FromRequestContainerAttrs {
|
||||||
|
via,
|
||||||
|
rejection,
|
||||||
|
state,
|
||||||
|
} = parse_attrs("from_request", &attrs)?;
|
||||||
|
|
||||||
|
let state = match state {
|
||||||
|
Some((_, state)) => State::Custom(state),
|
||||||
|
None => infer_state_type_from_field_types(&fields)
|
||||||
|
.map(State::Custom)
|
||||||
|
.or_else(|| infer_state_type_from_field_attributes(&fields).map(State::Custom))
|
||||||
|
.or_else(|| {
|
||||||
|
let via = via.as_ref().map(|(_, via)| via)?;
|
||||||
|
state_from_via(&ident, via).map(State::Custom)
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| State::Default(syn::parse_quote!(S))),
|
||||||
|
};
|
||||||
|
|
||||||
match (via.map(second), rejection.map(second)) {
|
match (via.map(second), rejection.map(second)) {
|
||||||
(Some(via), rejection) => impl_struct_by_extracting_all_at_once(
|
(Some(via), rejection) => impl_struct_by_extracting_all_at_once(
|
||||||
@ -49,11 +132,12 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
|||||||
via,
|
via,
|
||||||
rejection,
|
rejection,
|
||||||
generic_ident,
|
generic_ident,
|
||||||
|
state,
|
||||||
tr,
|
tr,
|
||||||
),
|
),
|
||||||
(None, rejection) => {
|
(None, rejection) => {
|
||||||
error_on_generic_ident(generic_ident, tr)?;
|
error_on_generic_ident(generic_ident, tr)?;
|
||||||
impl_struct_by_extracting_each_field(ident, fields, rejection, tr)
|
impl_struct_by_extracting_each_field(ident, fields, rejection, state, tr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,7 +162,20 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
|||||||
return Err(syn::Error::new_spanned(where_clause, generics_error));
|
return Err(syn::Error::new_spanned(where_clause, generics_error));
|
||||||
}
|
}
|
||||||
|
|
||||||
let FromRequestContainerAttrs { via, rejection } = parse_attrs("from_request", &attrs)?;
|
let FromRequestContainerAttrs {
|
||||||
|
via,
|
||||||
|
rejection,
|
||||||
|
state,
|
||||||
|
} = parse_attrs("from_request", &attrs)?;
|
||||||
|
|
||||||
|
let state = match state {
|
||||||
|
Some((_, state)) => State::Custom(state),
|
||||||
|
None => (|| {
|
||||||
|
let via = via.as_ref().map(|(_, via)| via)?;
|
||||||
|
state_from_via(&ident, via).map(State::Custom)
|
||||||
|
})()
|
||||||
|
.unwrap_or_else(|| State::Default(syn::parse_quote!(S))),
|
||||||
|
};
|
||||||
|
|
||||||
match (via.map(second), rejection) {
|
match (via.map(second), rejection) {
|
||||||
(Some(via), rejection) => impl_enum_by_extracting_all_at_once(
|
(Some(via), rejection) => impl_enum_by_extracting_all_at_once(
|
||||||
@ -86,6 +183,7 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result<TokenStream> {
|
|||||||
variants,
|
variants,
|
||||||
via,
|
via,
|
||||||
rejection.map(second),
|
rejection.map(second),
|
||||||
|
state,
|
||||||
tr,
|
tr,
|
||||||
),
|
),
|
||||||
(None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned(
|
(None, Some((rejection_kw, _))) => Err(syn::Error::new_spanned(
|
||||||
@ -210,6 +308,7 @@ fn impl_struct_by_extracting_each_field(
|
|||||||
ident: syn::Ident,
|
ident: syn::Ident,
|
||||||
fields: syn::Fields,
|
fields: syn::Fields,
|
||||||
rejection: Option<syn::Path>,
|
rejection: Option<syn::Path>,
|
||||||
|
state: State,
|
||||||
tr: Trait,
|
tr: Trait,
|
||||||
) -> syn::Result<TokenStream> {
|
) -> syn::Result<TokenStream> {
|
||||||
let extract_fields = extract_fields(&fields, &rejection, tr)?;
|
let extract_fields = extract_fields(&fields, &rejection, tr)?;
|
||||||
@ -222,22 +321,34 @@ fn impl_struct_by_extracting_each_field(
|
|||||||
quote!(::axum::response::Response)
|
quote!(::axum::response::Response)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let impl_generics = tr
|
||||||
|
.body_type()
|
||||||
|
.chain(state.impl_generics())
|
||||||
|
.collect::<Punctuated<Type, Token![,]>>();
|
||||||
|
|
||||||
|
let trait_generics = state
|
||||||
|
.trait_generics()
|
||||||
|
.chain(tr.body_type())
|
||||||
|
.collect::<Punctuated<Type, Token![,]>>();
|
||||||
|
|
||||||
|
let state_bounds = state.bounds();
|
||||||
|
|
||||||
Ok(match tr {
|
Ok(match tr {
|
||||||
Trait::FromRequest => quote! {
|
Trait::FromRequest => quote! {
|
||||||
#[::axum::async_trait]
|
#[::axum::async_trait]
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
|
impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident
|
||||||
where
|
where
|
||||||
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
|
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
|
||||||
B::Data: ::std::marker::Send,
|
B::Data: ::std::marker::Send,
|
||||||
B::Error: ::std::convert::Into<::axum::BoxError>,
|
B::Error: ::std::convert::Into<::axum::BoxError>,
|
||||||
S: ::std::marker::Send + ::std::marker::Sync,
|
#state_bounds
|
||||||
{
|
{
|
||||||
type Rejection = #rejection_ident;
|
type Rejection = #rejection_ident;
|
||||||
|
|
||||||
async fn from_request(
|
async fn from_request(
|
||||||
mut req: ::axum::http::Request<B>,
|
mut req: ::axum::http::Request<B>,
|
||||||
state: &S,
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::std::result::Result::Ok(Self {
|
::std::result::Result::Ok(Self {
|
||||||
#(#extract_fields)*
|
#(#extract_fields)*
|
||||||
@ -248,15 +359,15 @@ fn impl_struct_by_extracting_each_field(
|
|||||||
Trait::FromRequestParts => quote! {
|
Trait::FromRequestParts => quote! {
|
||||||
#[::axum::async_trait]
|
#[::axum::async_trait]
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<S> ::axum::extract::FromRequestParts<S> for #ident
|
impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident
|
||||||
where
|
where
|
||||||
S: ::std::marker::Send + ::std::marker::Sync,
|
#state_bounds
|
||||||
{
|
{
|
||||||
type Rejection = #rejection_ident;
|
type Rejection = #rejection_ident;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(
|
||||||
parts: &mut ::axum::http::request::Parts,
|
parts: &mut ::axum::http::request::Parts,
|
||||||
state: &S,
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::std::result::Result::Ok(Self {
|
::std::result::Result::Ok(Self {
|
||||||
#(#extract_fields)*
|
#(#extract_fields)*
|
||||||
@ -547,9 +658,10 @@ fn peel_result_ok(ty: &syn::Type) -> Option<&syn::Type> {
|
|||||||
fn impl_struct_by_extracting_all_at_once(
|
fn impl_struct_by_extracting_all_at_once(
|
||||||
ident: syn::Ident,
|
ident: syn::Ident,
|
||||||
fields: syn::Fields,
|
fields: syn::Fields,
|
||||||
path: syn::Path,
|
via_path: syn::Path,
|
||||||
rejection: Option<syn::Path>,
|
rejection: Option<syn::Path>,
|
||||||
generic_ident: Option<Ident>,
|
generic_ident: Option<Ident>,
|
||||||
|
state: State,
|
||||||
tr: Trait,
|
tr: Trait,
|
||||||
) -> syn::Result<TokenStream> {
|
) -> syn::Result<TokenStream> {
|
||||||
let fields = match fields {
|
let fields = match fields {
|
||||||
@ -570,7 +682,7 @@ fn impl_struct_by_extracting_all_at_once(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let path_span = path.span();
|
let path_span = via_path.span();
|
||||||
|
|
||||||
let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
|
let (associated_rejection_type, map_err) = if let Some(rejection) = &rejection {
|
||||||
let rejection = quote! { #rejection };
|
let rejection = quote! { #rejection };
|
||||||
@ -584,43 +696,68 @@ fn impl_struct_by_extracting_all_at_once(
|
|||||||
(rejection, map_err)
|
(rejection, map_err)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// for something like
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// #[derive(Clone, Default, FromRequest)]
|
||||||
|
// #[from_request(via(State))]
|
||||||
|
// struct AppState {}
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// we need to implement `impl<B, M> FromRequest<AppState, B, M>` but only for
|
||||||
|
// - `#[derive(FromRequest)]`, not `#[derive(FromRequestParts)]`
|
||||||
|
// - `State`, not other extractors
|
||||||
|
//
|
||||||
|
// honestly not sure why but the tests all pass
|
||||||
|
let via_marker_type = if path_ident_is_state(&via_path) {
|
||||||
|
tr.via_marker_type()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let impl_generics = tr
|
||||||
|
.body_type()
|
||||||
|
.chain(via_marker_type.clone())
|
||||||
|
.chain(state.impl_generics())
|
||||||
|
.chain(generic_ident.is_some().then(|| parse_quote!(T)))
|
||||||
|
.collect::<Punctuated<Type, Token![,]>>();
|
||||||
|
|
||||||
|
let trait_generics = state
|
||||||
|
.trait_generics()
|
||||||
|
.chain(tr.body_type())
|
||||||
|
.chain(via_marker_type)
|
||||||
|
.collect::<Punctuated<Type, Token![,]>>();
|
||||||
|
|
||||||
|
let ident_generics = generic_ident
|
||||||
|
.is_some()
|
||||||
|
.then(|| quote! { <T> })
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
let rejection_bound = rejection.as_ref().map(|rejection| {
|
let rejection_bound = rejection.as_ref().map(|rejection| {
|
||||||
match (tr, generic_ident.is_some()) {
|
match (tr, generic_ident.is_some()) {
|
||||||
(Trait::FromRequest, true) => {
|
(Trait::FromRequest, true) => {
|
||||||
quote! {
|
quote! {
|
||||||
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequest<S, B>>::Rejection>,
|
#rejection: ::std::convert::From<<#via_path<T> as ::axum::extract::FromRequest<#trait_generics>>::Rejection>,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
(Trait::FromRequest, false) => {
|
(Trait::FromRequest, false) => {
|
||||||
quote! {
|
quote! {
|
||||||
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequest<S, B>>::Rejection>,
|
#rejection: ::std::convert::From<<#via_path<Self> as ::axum::extract::FromRequest<#trait_generics>>::Rejection>,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
(Trait::FromRequestParts, true) => {
|
(Trait::FromRequestParts, true) => {
|
||||||
quote! {
|
quote! {
|
||||||
#rejection: ::std::convert::From<<#path<T> as ::axum::extract::FromRequestParts<S>>::Rejection>,
|
#rejection: ::std::convert::From<<#via_path<T> as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
(Trait::FromRequestParts, false) => {
|
(Trait::FromRequestParts, false) => {
|
||||||
quote! {
|
quote! {
|
||||||
#rejection: ::std::convert::From<<#path<Self> as ::axum::extract::FromRequestParts<S>>::Rejection>,
|
#rejection: ::std::convert::From<<#via_path<Self> as ::axum::extract::FromRequestParts<#trait_generics>>::Rejection>,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}).unwrap_or_default();
|
}).unwrap_or_default();
|
||||||
|
|
||||||
let impl_generics = match (tr, generic_ident.is_some()) {
|
|
||||||
(Trait::FromRequest, true) => quote! { S, B, T },
|
|
||||||
(Trait::FromRequest, false) => quote! { S, B },
|
|
||||||
(Trait::FromRequestParts, true) => quote! { S, T },
|
|
||||||
(Trait::FromRequestParts, false) => quote! { S },
|
|
||||||
};
|
|
||||||
|
|
||||||
let type_generics = generic_ident
|
|
||||||
.is_some()
|
|
||||||
.then(|| quote! { <T> })
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let via_type_generics = if generic_ident.is_some() {
|
let via_type_generics = if generic_ident.is_some() {
|
||||||
quote! { T }
|
quote! { T }
|
||||||
} else {
|
} else {
|
||||||
@ -635,27 +772,29 @@ fn impl_struct_by_extracting_all_at_once(
|
|||||||
quote! { value }
|
quote! { value }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let state_bounds = state.bounds();
|
||||||
|
|
||||||
let tokens = match tr {
|
let tokens = match tr {
|
||||||
Trait::FromRequest => {
|
Trait::FromRequest => {
|
||||||
quote_spanned! {path_span=>
|
quote_spanned! {path_span=>
|
||||||
#[::axum::async_trait]
|
#[::axum::async_trait]
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<#impl_generics> ::axum::extract::FromRequest<S, B> for #ident #type_generics
|
impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics
|
||||||
where
|
where
|
||||||
#path<#via_type_generics>: ::axum::extract::FromRequest<S, B>,
|
#via_path<#via_type_generics>: ::axum::extract::FromRequest<#trait_generics>,
|
||||||
#rejection_bound
|
#rejection_bound
|
||||||
B: ::std::marker::Send + 'static,
|
B: ::std::marker::Send + 'static,
|
||||||
S: ::std::marker::Send + ::std::marker::Sync,
|
#state_bounds
|
||||||
{
|
{
|
||||||
type Rejection = #associated_rejection_type;
|
type Rejection = #associated_rejection_type;
|
||||||
|
|
||||||
async fn from_request(
|
async fn from_request(
|
||||||
req: ::axum::http::Request<B>,
|
req: ::axum::http::Request<B>,
|
||||||
state: &S
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::axum::extract::FromRequest::from_request(req, state)
|
::axum::extract::FromRequest::from_request(req, state)
|
||||||
.await
|
.await
|
||||||
.map(|#path(value)| #value_to_self)
|
.map(|#via_path(value)| #value_to_self)
|
||||||
.map_err(#map_err)
|
.map_err(#map_err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -665,21 +804,21 @@ fn impl_struct_by_extracting_all_at_once(
|
|||||||
quote_spanned! {path_span=>
|
quote_spanned! {path_span=>
|
||||||
#[::axum::async_trait]
|
#[::axum::async_trait]
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<#impl_generics> ::axum::extract::FromRequestParts<S> for #ident #type_generics
|
impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics
|
||||||
where
|
where
|
||||||
#path<#via_type_generics>: ::axum::extract::FromRequestParts<S>,
|
#via_path<#via_type_generics>: ::axum::extract::FromRequestParts<#trait_generics>,
|
||||||
#rejection_bound
|
#rejection_bound
|
||||||
S: ::std::marker::Send + ::std::marker::Sync,
|
#state_bounds
|
||||||
{
|
{
|
||||||
type Rejection = #associated_rejection_type;
|
type Rejection = #associated_rejection_type;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(
|
||||||
parts: &mut ::axum::http::request::Parts,
|
parts: &mut ::axum::http::request::Parts,
|
||||||
state: &S
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::axum::extract::FromRequestParts::from_request_parts(parts, state)
|
::axum::extract::FromRequestParts::from_request_parts(parts, state)
|
||||||
.await
|
.await
|
||||||
.map(|#path(value)| #value_to_self)
|
.map(|#via_path(value)| #value_to_self)
|
||||||
.map_err(#map_err)
|
.map_err(#map_err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -695,6 +834,7 @@ fn impl_enum_by_extracting_all_at_once(
|
|||||||
variants: Punctuated<syn::Variant, Token![,]>,
|
variants: Punctuated<syn::Variant, Token![,]>,
|
||||||
path: syn::Path,
|
path: syn::Path,
|
||||||
rejection: Option<syn::Path>,
|
rejection: Option<syn::Path>,
|
||||||
|
state: State,
|
||||||
tr: Trait,
|
tr: Trait,
|
||||||
) -> syn::Result<TokenStream> {
|
) -> syn::Result<TokenStream> {
|
||||||
for variant in variants {
|
for variant in variants {
|
||||||
@ -738,23 +878,35 @@ fn impl_enum_by_extracting_all_at_once(
|
|||||||
|
|
||||||
let path_span = path.span();
|
let path_span = path.span();
|
||||||
|
|
||||||
|
let impl_generics = tr
|
||||||
|
.body_type()
|
||||||
|
.chain(state.impl_generics())
|
||||||
|
.collect::<Punctuated<Type, Token![,]>>();
|
||||||
|
|
||||||
|
let trait_generics = state
|
||||||
|
.trait_generics()
|
||||||
|
.chain(tr.body_type())
|
||||||
|
.collect::<Punctuated<Type, Token![,]>>();
|
||||||
|
|
||||||
|
let state_bounds = state.bounds();
|
||||||
|
|
||||||
let tokens = match tr {
|
let tokens = match tr {
|
||||||
Trait::FromRequest => {
|
Trait::FromRequest => {
|
||||||
quote_spanned! {path_span=>
|
quote_spanned! {path_span=>
|
||||||
#[::axum::async_trait]
|
#[::axum::async_trait]
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<S, B> ::axum::extract::FromRequest<S, B> for #ident
|
impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident
|
||||||
where
|
where
|
||||||
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
|
B: ::axum::body::HttpBody + ::std::marker::Send + 'static,
|
||||||
B::Data: ::std::marker::Send,
|
B::Data: ::std::marker::Send,
|
||||||
B::Error: ::std::convert::Into<::axum::BoxError>,
|
B::Error: ::std::convert::Into<::axum::BoxError>,
|
||||||
S: ::std::marker::Send + ::std::marker::Sync,
|
#state_bounds
|
||||||
{
|
{
|
||||||
type Rejection = #associated_rejection_type;
|
type Rejection = #associated_rejection_type;
|
||||||
|
|
||||||
async fn from_request(
|
async fn from_request(
|
||||||
req: ::axum::http::Request<B>,
|
req: ::axum::http::Request<B>,
|
||||||
state: &S
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::axum::extract::FromRequest::from_request(req, state)
|
::axum::extract::FromRequest::from_request(req, state)
|
||||||
.await
|
.await
|
||||||
@ -768,15 +920,15 @@ fn impl_enum_by_extracting_all_at_once(
|
|||||||
quote_spanned! {path_span=>
|
quote_spanned! {path_span=>
|
||||||
#[::axum::async_trait]
|
#[::axum::async_trait]
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<S> ::axum::extract::FromRequestParts<S> for #ident
|
impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident
|
||||||
where
|
where
|
||||||
S: ::std::marker::Send + ::std::marker::Sync,
|
#state_bounds
|
||||||
{
|
{
|
||||||
type Rejection = #associated_rejection_type;
|
type Rejection = #associated_rejection_type;
|
||||||
|
|
||||||
async fn from_request_parts(
|
async fn from_request_parts(
|
||||||
parts: &mut ::axum::http::request::Parts,
|
parts: &mut ::axum::http::request::Parts,
|
||||||
state: &S
|
state: &#state,
|
||||||
) -> ::std::result::Result<Self, Self::Rejection> {
|
) -> ::std::result::Result<Self, Self::Rejection> {
|
||||||
::axum::extract::FromRequestParts::from_request_parts(parts, state)
|
::axum::extract::FromRequestParts::from_request_parts(parts, state)
|
||||||
.await
|
.await
|
||||||
@ -791,6 +943,90 @@ fn impl_enum_by_extracting_all_at_once(
|
|||||||
Ok(tokens)
|
Ok(tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// For a struct like
|
||||||
|
///
|
||||||
|
/// ```skip
|
||||||
|
/// struct Extractor {
|
||||||
|
/// state: State<AppState>,
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// We can infer the state type to be `AppState` because it appears inside a `State`
|
||||||
|
fn infer_state_type_from_field_types(fields: &Fields) -> Option<Type> {
|
||||||
|
match fields {
|
||||||
|
Fields::Named(fields_named) => {
|
||||||
|
crate::infer_state_type(fields_named.named.iter().map(|field| &field.ty))
|
||||||
|
}
|
||||||
|
Fields::Unnamed(fields_unnamed) => {
|
||||||
|
crate::infer_state_type(fields_unnamed.unnamed.iter().map(|field| &field.ty))
|
||||||
|
}
|
||||||
|
Fields::Unit => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// For a struct like
|
||||||
|
///
|
||||||
|
/// ```skip
|
||||||
|
/// struct Extractor {
|
||||||
|
/// #[from_request(via(State))]
|
||||||
|
/// state: AppState,
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// We can infer the state type to be `AppState` because it has `via(State)` and thus can be
|
||||||
|
/// extracted with `State<AppState>`
|
||||||
|
fn infer_state_type_from_field_attributes(fields: &Fields) -> Option<Type> {
|
||||||
|
let state_inputs = match fields {
|
||||||
|
Fields::Named(fields_named) => {
|
||||||
|
fields_named
|
||||||
|
.named
|
||||||
|
.iter()
|
||||||
|
.filter_map(|field| {
|
||||||
|
// TODO(david): its a little wasteful to parse the attributes again here
|
||||||
|
// ideally we should parse things once and pass the data down
|
||||||
|
let FromRequestFieldAttrs { via } =
|
||||||
|
parse_attrs("from_request", &field.attrs).ok()?;
|
||||||
|
let (_, via_path) = via?;
|
||||||
|
path_ident_is_state(&via_path).then(|| &field.ty)
|
||||||
|
})
|
||||||
|
.collect::<HashSet<_>>()
|
||||||
|
}
|
||||||
|
Fields::Unnamed(fields_unnamed) => {
|
||||||
|
fields_unnamed
|
||||||
|
.unnamed
|
||||||
|
.iter()
|
||||||
|
.filter_map(|field| {
|
||||||
|
// TODO(david): its a little wasteful to parse the attributes again here
|
||||||
|
// ideally we should parse things once and pass the data down
|
||||||
|
let FromRequestFieldAttrs { via } =
|
||||||
|
parse_attrs("from_request", &field.attrs).ok()?;
|
||||||
|
let (_, via_path) = via?;
|
||||||
|
path_ident_is_state(&via_path).then(|| &field.ty)
|
||||||
|
})
|
||||||
|
.collect::<HashSet<_>>()
|
||||||
|
}
|
||||||
|
Fields::Unit => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
if state_inputs.len() == 1 {
|
||||||
|
state_inputs.iter().next().map(|&ty| ty.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn path_ident_is_state(path: &Path) -> bool {
|
||||||
|
if let Some(last_segment) = path.segments.last() {
|
||||||
|
last_segment.ident == "State"
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn state_from_via(ident: &Ident, via: &Path) -> Option<Type> {
|
||||||
|
path_ident_is_state(via).then(|| parse_quote!(#ident))
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ui() {
|
fn ui() {
|
||||||
crate::run_ui_tests("from_request");
|
crate::run_ui_tests("from_request");
|
||||||
|
@ -7,18 +7,21 @@ use syn::{
|
|||||||
pub(crate) mod kw {
|
pub(crate) mod kw {
|
||||||
syn::custom_keyword!(via);
|
syn::custom_keyword!(via);
|
||||||
syn::custom_keyword!(rejection);
|
syn::custom_keyword!(rejection);
|
||||||
|
syn::custom_keyword!(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub(super) struct FromRequestContainerAttrs {
|
pub(super) struct FromRequestContainerAttrs {
|
||||||
pub(super) via: Option<(kw::via, syn::Path)>,
|
pub(super) via: Option<(kw::via, syn::Path)>,
|
||||||
pub(super) rejection: Option<(kw::rejection, syn::Path)>,
|
pub(super) rejection: Option<(kw::rejection, syn::Path)>,
|
||||||
|
pub(super) state: Option<(kw::state, syn::Type)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Parse for FromRequestContainerAttrs {
|
impl Parse for FromRequestContainerAttrs {
|
||||||
fn parse(input: ParseStream) -> syn::Result<Self> {
|
fn parse(input: ParseStream) -> syn::Result<Self> {
|
||||||
let mut via = None;
|
let mut via = None;
|
||||||
let mut rejection = None;
|
let mut rejection = None;
|
||||||
|
let mut state = None;
|
||||||
|
|
||||||
while !input.is_empty() {
|
while !input.is_empty() {
|
||||||
let lh = input.lookahead1();
|
let lh = input.lookahead1();
|
||||||
@ -26,6 +29,8 @@ impl Parse for FromRequestContainerAttrs {
|
|||||||
parse_parenthesized_attribute(input, &mut via)?;
|
parse_parenthesized_attribute(input, &mut via)?;
|
||||||
} else if lh.peek(kw::rejection) {
|
} else if lh.peek(kw::rejection) {
|
||||||
parse_parenthesized_attribute(input, &mut rejection)?;
|
parse_parenthesized_attribute(input, &mut rejection)?;
|
||||||
|
} else if lh.peek(kw::state) {
|
||||||
|
parse_parenthesized_attribute(input, &mut state)?;
|
||||||
} else {
|
} else {
|
||||||
return Err(lh.error());
|
return Err(lh.error());
|
||||||
}
|
}
|
||||||
@ -33,15 +38,24 @@ impl Parse for FromRequestContainerAttrs {
|
|||||||
let _ = input.parse::<Token![,]>();
|
let _ = input.parse::<Token![,]>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self { via, rejection })
|
Ok(Self {
|
||||||
|
via,
|
||||||
|
rejection,
|
||||||
|
state,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Combine for FromRequestContainerAttrs {
|
impl Combine for FromRequestContainerAttrs {
|
||||||
fn combine(mut self, other: Self) -> syn::Result<Self> {
|
fn combine(mut self, other: Self) -> syn::Result<Self> {
|
||||||
let Self { via, rejection } = other;
|
let Self {
|
||||||
|
via,
|
||||||
|
rejection,
|
||||||
|
state,
|
||||||
|
} = other;
|
||||||
combine_attribute(&mut self.via, via)?;
|
combine_attribute(&mut self.via, via)?;
|
||||||
combine_attribute(&mut self.rejection, rejection)?;
|
combine_attribute(&mut self.rejection, rejection)?;
|
||||||
|
combine_attribute(&mut self.state, state)?;
|
||||||
Ok(self)
|
Ok(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,9 +43,11 @@
|
|||||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||||
#![cfg_attr(test, allow(clippy::float_cmp))]
|
#![cfg_attr(test, allow(clippy::float_cmp))]
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use quote::{quote, ToTokens};
|
use quote::{quote, ToTokens};
|
||||||
use syn::parse::Parse;
|
use syn::{parse::Parse, Type};
|
||||||
|
|
||||||
mod attr_parsing;
|
mod attr_parsing;
|
||||||
mod debug_handler;
|
mod debug_handler;
|
||||||
@ -613,6 +615,50 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_state_type<'a, I>(types: I) -> Option<Type>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = &'a Type>,
|
||||||
|
{
|
||||||
|
let state_inputs = types
|
||||||
|
.filter_map(|ty| {
|
||||||
|
if let Type::Path(path) = ty {
|
||||||
|
Some(&path.path)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter_map(|path| {
|
||||||
|
if let Some(last_segment) = path.segments.last() {
|
||||||
|
if last_segment.ident != "State" {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
match &last_segment.arguments {
|
||||||
|
syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
|
||||||
|
Some(args.args.first().unwrap())
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter_map(|generic_arg| {
|
||||||
|
if let syn::GenericArgument::Type(ty) = generic_arg {
|
||||||
|
Some(ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
|
if state_inputs.len() == 1 {
|
||||||
|
state_inputs.iter().next().map(|&ty| ty.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn run_ui_tests(directory: &str) {
|
fn run_ui_tests(directory: &str) {
|
||||||
#[rustversion::stable]
|
#[rustversion::stable]
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
use axum_macros::FromRequest;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
struct Extractor {
|
||||||
|
inner_state: State<AppState>,
|
||||||
|
other_state: State<OtherState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct OtherState {}
|
||||||
|
|
||||||
|
fn assert_from_request()
|
||||||
|
where
|
||||||
|
Extractor: axum::extract::FromRequest<AppState, axum::body::Body, Rejection = axum::response::Response>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
@ -0,0 +1,23 @@
|
|||||||
|
error[E0277]: the trait bound `AppState: FromRef<S>` is not satisfied
|
||||||
|
--> tests/from_request/fail/state_infer_multiple_different_types.rs:6:18
|
||||||
|
|
|
||||||
|
6 | inner_state: State<AppState>,
|
||||||
|
| ^^^^^ the trait `FromRef<S>` is not implemented for `AppState`
|
||||||
|
|
|
||||||
|
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<AppState>`
|
||||||
|
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
||||||
|
|
|
||||||
|
4 | #[derive(FromRequest, AppState: FromRef<S>)]
|
||||||
|
| ++++++++++++++++++++++
|
||||||
|
|
||||||
|
error[E0277]: the trait bound `OtherState: FromRef<S>` is not satisfied
|
||||||
|
--> tests/from_request/fail/state_infer_multiple_different_types.rs:7:18
|
||||||
|
|
|
||||||
|
7 | other_state: State<OtherState>,
|
||||||
|
| ^^^^^ the trait `FromRef<S>` is not implemented for `OtherState`
|
||||||
|
|
|
||||||
|
= note: required because of the requirements on the impl of `FromRequestParts<S>` for `State<OtherState>`
|
||||||
|
help: consider extending the `where` clause, but there might be an alternative better way to express this requirement
|
||||||
|
|
|
||||||
|
4 | #[derive(FromRequest, OtherState: FromRef<S>)]
|
||||||
|
| ++++++++++++++++++++++++
|
@ -1,4 +1,4 @@
|
|||||||
error: expected `via` or `rejection`
|
error: expected one of: `via`, `rejection`, `state`
|
||||||
--> tests/from_request/fail/unknown_attr_container.rs:4:16
|
--> tests/from_request/fail/unknown_attr_container.rs:4:16
|
||||||
|
|
|
|
||||||
4 | #[from_request(foo)]
|
4 | #[from_request(foo)]
|
||||||
|
27
axum-macros/tests/from_request/pass/state_cookie.rs
Normal file
27
axum-macros/tests/from_request/pass/state_cookie.rs
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
use axum_macros::FromRequest;
|
||||||
|
use axum::extract::FromRef;
|
||||||
|
use axum_extra::extract::cookie::{PrivateCookieJar, Key};
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
#[from_request(state(AppState))]
|
||||||
|
struct Extractor {
|
||||||
|
cookies: PrivateCookieJar,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AppState {
|
||||||
|
key: Key,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Key {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.key.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_from_request()
|
||||||
|
where
|
||||||
|
Extractor: axum::extract::FromRequest<AppState, axum::body::Body, Rejection = axum::response::Response>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
34
axum-macros/tests/from_request/pass/state_enum_via.rs
Normal file
34
axum-macros/tests/from_request/pass/state_enum_via.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{State, FromRef},
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequest;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/a", get(|_: AppState| async {}))
|
||||||
|
.route("/b", get(|_: InnerState| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, FromRequest)]
|
||||||
|
#[from_request(via(State))]
|
||||||
|
enum AppState {
|
||||||
|
One,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppState {
|
||||||
|
fn default() -> AppState {
|
||||||
|
Self::One
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
#[from_request(via(State), state(AppState))]
|
||||||
|
enum InnerState {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for InnerState {
|
||||||
|
fn from_ref(_: &AppState) -> Self {
|
||||||
|
todo!(":shrug:")
|
||||||
|
}
|
||||||
|
}
|
35
axum-macros/tests/from_request/pass/state_enum_via_parts.rs
Normal file
35
axum-macros/tests/from_request/pass/state_enum_via_parts.rs
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{State, FromRef},
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequestParts;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/a", get(|_: AppState| async {}))
|
||||||
|
.route("/b", get(|_: InnerState| async {}))
|
||||||
|
.route("/c", get(|_: AppState, _: InnerState| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, FromRequestParts)]
|
||||||
|
#[from_request(via(State))]
|
||||||
|
enum AppState {
|
||||||
|
One,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppState {
|
||||||
|
fn default() -> AppState {
|
||||||
|
Self::One
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromRequestParts)]
|
||||||
|
#[from_request(via(State), state(AppState))]
|
||||||
|
enum InnerState {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for InnerState {
|
||||||
|
fn from_ref(_: &AppState) -> Self {
|
||||||
|
todo!(":shrug:")
|
||||||
|
}
|
||||||
|
}
|
44
axum-macros/tests/from_request/pass/state_explicit.rs
Normal file
44
axum-macros/tests/from_request/pass/state_explicit.rs
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
use axum_macros::FromRequest;
|
||||||
|
use axum::{
|
||||||
|
extract::{FromRef, State},
|
||||||
|
Router,
|
||||||
|
routing::get,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/b", get(|_: Extractor| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
#[from_request(state(AppState))]
|
||||||
|
struct Extractor {
|
||||||
|
app_state: State<AppState>,
|
||||||
|
one: State<One>,
|
||||||
|
two: State<Two>,
|
||||||
|
other_extractor: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct AppState {
|
||||||
|
one: One,
|
||||||
|
two: Two,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct One {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for One {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.one.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct Two {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Two {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.two.clone()
|
||||||
|
}
|
||||||
|
}
|
33
axum-macros/tests/from_request/pass/state_explicit_parts.rs
Normal file
33
axum-macros/tests/from_request/pass/state_explicit_parts.rs
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
use axum_macros::FromRequestParts;
|
||||||
|
use axum::{
|
||||||
|
extract::{FromRef, State, Query},
|
||||||
|
Router,
|
||||||
|
routing::get,
|
||||||
|
};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/b", get(|_: Extractor| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromRequestParts)]
|
||||||
|
#[from_request(state(AppState))]
|
||||||
|
struct Extractor {
|
||||||
|
inner_state: State<InnerState>,
|
||||||
|
other: Query<HashMap<String, String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct AppState {
|
||||||
|
inner: InnerState,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct InnerState {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for InnerState {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.inner.clone()
|
||||||
|
}
|
||||||
|
}
|
34
axum-macros/tests/from_request/pass/state_field_explicit.rs
Normal file
34
axum-macros/tests/from_request/pass/state_field_explicit.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{State, FromRef},
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequest;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/", get(|_: Extractor| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
#[from_request(state(AppState))]
|
||||||
|
struct Extractor {
|
||||||
|
#[from_request(via(State))]
|
||||||
|
state: AppState,
|
||||||
|
#[from_request(via(State))]
|
||||||
|
inner: InnerState,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct AppState {
|
||||||
|
inner: InnerState,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct InnerState {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for InnerState {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.inner.clone()
|
||||||
|
}
|
||||||
|
}
|
20
axum-macros/tests/from_request/pass/state_field_infer.rs
Normal file
20
axum-macros/tests/from_request/pass/state_field_infer.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequest;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/", get(|_: Extractor| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
struct Extractor {
|
||||||
|
#[from_request(via(State))]
|
||||||
|
state: AppState,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct AppState {}
|
18
axum-macros/tests/from_request/pass/state_infer.rs
Normal file
18
axum-macros/tests/from_request/pass/state_infer.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
use axum_macros::FromRequest;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
struct Extractor {
|
||||||
|
inner_state: State<AppState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {}
|
||||||
|
|
||||||
|
fn assert_from_request()
|
||||||
|
where
|
||||||
|
Extractor: axum::extract::FromRequest<AppState, axum::body::Body, Rejection = axum::response::Response>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
19
axum-macros/tests/from_request/pass/state_infer_multiple.rs
Normal file
19
axum-macros/tests/from_request/pass/state_infer_multiple.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
use axum_macros::FromRequest;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
#[derive(FromRequest)]
|
||||||
|
struct Extractor {
|
||||||
|
inner_state: State<AppState>,
|
||||||
|
also_inner_state: State<AppState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {}
|
||||||
|
|
||||||
|
fn assert_from_request()
|
||||||
|
where
|
||||||
|
Extractor: axum::extract::FromRequest<AppState, axum::body::Body, Rejection = axum::response::Response>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
18
axum-macros/tests/from_request/pass/state_infer_parts.rs
Normal file
18
axum-macros/tests/from_request/pass/state_infer_parts.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
use axum_macros::FromRequestParts;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
#[derive(FromRequestParts)]
|
||||||
|
struct Extractor {
|
||||||
|
inner_state: State<AppState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {}
|
||||||
|
|
||||||
|
fn assert_from_request()
|
||||||
|
where
|
||||||
|
Extractor: axum::extract::FromRequestParts<AppState, Rejection = axum::response::Response>,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
28
axum-macros/tests/from_request/pass/state_via.rs
Normal file
28
axum-macros/tests/from_request/pass/state_via.rs
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{FromRef, State},
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequest;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/b", get(|_: (), _: AppState| async {}))
|
||||||
|
.route("/c", get(|_: (), _: InnerState| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default, FromRequest)]
|
||||||
|
#[from_request(via(State), state(AppState))]
|
||||||
|
struct AppState {
|
||||||
|
inner: InnerState,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default, FromRequest)]
|
||||||
|
#[from_request(via(State), state(AppState))]
|
||||||
|
struct InnerState {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for InnerState {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.inner.clone()
|
||||||
|
}
|
||||||
|
}
|
17
axum-macros/tests/from_request/pass/state_via_infer.rs
Normal file
17
axum-macros/tests/from_request/pass/state_via_infer.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequest;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/b", get(|_: AppState| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we're extract "via" `State<AppState>` and not specifying state
|
||||||
|
// assume `AppState` is the state
|
||||||
|
#[derive(Clone, Default, FromRequest)]
|
||||||
|
#[from_request(via(State))]
|
||||||
|
struct AppState {}
|
29
axum-macros/tests/from_request/pass/state_via_parts.rs
Normal file
29
axum-macros/tests/from_request/pass/state_via_parts.rs
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{FromRef, State},
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequestParts;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> = Router::with_state(AppState::default())
|
||||||
|
.route("/a", get(|_: AppState, _: InnerState, _: String| async {}))
|
||||||
|
.route("/b", get(|_: AppState, _: String| async {}))
|
||||||
|
.route("/c", get(|_: InnerState, _: String| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default, FromRequestParts)]
|
||||||
|
#[from_request(via(State))]
|
||||||
|
struct AppState {
|
||||||
|
inner: InnerState,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default, FromRequestParts)]
|
||||||
|
#[from_request(via(State), state(AppState))]
|
||||||
|
struct InnerState {}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for InnerState {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.inner.clone()
|
||||||
|
}
|
||||||
|
}
|
36
axum-macros/tests/from_request/pass/state_with_rejection.rs
Normal file
36
axum-macros/tests/from_request/pass/state_with_rejection.rs
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
use std::convert::Infallible;
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
routing::get,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use axum_macros::FromRequest;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let _: Router<AppState> =
|
||||||
|
Router::with_state(AppState::default()).route("/a", get(|_: Extractor| async {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default, FromRequest)]
|
||||||
|
#[from_request(rejection(MyRejection))]
|
||||||
|
struct Extractor {
|
||||||
|
state: State<AppState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
struct AppState {}
|
||||||
|
|
||||||
|
struct MyRejection {}
|
||||||
|
|
||||||
|
impl From<Infallible> for MyRejection {
|
||||||
|
fn from(err: Infallible) -> Self {
|
||||||
|
match err {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for MyRejection {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
().into_response()
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user