diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index ade7bf03..9c867321 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -8,7 +8,7 @@ use self::rejection::*; use crate::response::IntoResponse; use async_trait::async_trait; use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; -use std::convert::Infallible; +use std::{convert::Infallible, sync::Arc}; pub mod rejection; @@ -49,7 +49,7 @@ pub use self::from_ref::FromRef; /// where /// // these bounds are required by `async_trait` /// B: Send, -/// S: Send, +/// S: Send + Sync, /// { /// type Rejection = http::StatusCode; /// @@ -79,7 +79,7 @@ pub trait FromRequest: Sized { /// Has several convenience methods for getting owned parts of the request. #[derive(Debug)] pub struct RequestParts { - state: S, + pub(crate) state: Arc, method: Method, uri: Uri, version: Version, @@ -110,6 +110,17 @@ impl RequestParts { /// /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html pub fn with_state(state: S, req: Request) -> Self { + Self::with_state_arc(Arc::new(state), req) + } + + /// Create a new `RequestParts` with the given [`Arc`]'ed state. + /// + /// You generally shouldn't need to construct this type yourself, unless + /// using extractors outside of axum for example to implement a + /// [`tower::Service`]. + /// + /// [`tower::Service`]: https://docs.rs/tower/lastest/tower/trait.Service.html + pub fn with_state_arc(state: Arc, req: Request) -> Self { let ( http::request::Parts { method, @@ -153,7 +164,7 @@ impl RequestParts { /// impl FromRequest for MyExtractor /// where /// B: Send, - /// S: Send, + /// S: Send + Sync, /// { /// type Rejection = Infallible; /// @@ -285,7 +296,7 @@ impl FromRequest for Option where T: FromRequest, B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -299,7 +310,7 @@ impl FromRequest for Result where T: FromRequest, B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 4faaf2d3..ed3ff202 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -3,13 +3,13 @@ use crate::BoxError; use async_trait::async_trait; use bytes::Bytes; use http::{Extensions, HeaderMap, Method, Request, Uri, Version}; -use std::convert::Infallible; +use std::{convert::Infallible, sync::Arc}; #[async_trait] impl FromRequest for Request where B: Send, - S: Clone + Send, + S: Send + Sync, { type Rejection = BodyAlreadyExtracted; @@ -17,7 +17,7 @@ where let req = std::mem::replace( req, RequestParts { - state: req.state().clone(), + state: Arc::clone(&req.state), method: req.method.clone(), version: req.version, uri: req.uri.clone(), @@ -35,7 +35,7 @@ where impl FromRequest for Method where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -48,7 +48,7 @@ where impl FromRequest for Uri where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -61,7 +61,7 @@ where impl FromRequest for Version where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -79,7 +79,7 @@ where impl FromRequest for HeaderMap where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -94,7 +94,7 @@ where B: http_body::Body + Send, B::Data: Send, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = BytesRejection; @@ -115,7 +115,7 @@ where B: http_body::Body + Send, B::Data: Send, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = StringRejection; @@ -137,7 +137,7 @@ where impl FromRequest for http::request::Parts where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 05e38bf0..3a5938e2 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -7,7 +7,7 @@ use std::convert::Infallible; impl FromRequest for () where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -26,7 +26,7 @@ macro_rules! impl_from_request { where $( $ty: FromRequest + Send, )* B: Send, - S: Send, + S: Send + Sync, { type Rejection = Response; diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 84b2a91f..d6d1c8ec 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -195,7 +195,7 @@ macro_rules! impl_traits_for_either { $($ident: FromRequest),*, $last: FromRequest, B: Send, - S: Send, + S: Send + Sync, { type Rejection = $last::Rejection; diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 9545fe16..64519bf4 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -33,7 +33,7 @@ use std::ops::{Deref, DerefMut}; /// impl FromRequest for Session /// where /// B: Send, -/// S: Send, +/// S: Send + Sync, /// { /// type Rejection = (StatusCode, String); /// @@ -49,7 +49,7 @@ use std::ops::{Deref, DerefMut}; /// impl FromRequest for CurrentUser /// where /// B: Send, -/// S: Send, +/// S: Send + Sync, /// { /// type Rejection = Response; /// @@ -93,7 +93,7 @@ struct CachedEntry(T); impl FromRequest for Cached where B: Send, - S: Send, + S: Send + Sync, T: FromRequest + Clone + Send + Sync + 'static, { type Rejection = T::Rejection; @@ -145,7 +145,7 @@ mod tests { impl FromRequest for Extractor where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index 44842d03..3edbe68f 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -91,7 +91,7 @@ pub struct CookieJar { impl FromRequest for CookieJar where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index d3705fb2..7a88380c 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -90,7 +90,7 @@ impl fmt::Debug for PrivateCookieJar { impl FromRequest for PrivateCookieJar where B: Send, - S: Send, + S: Send + Sync, K: FromRef + Into, { type Rejection = Infallible; diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index 74da2a11..05ffcc92 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -108,7 +108,7 @@ impl fmt::Debug for SignedCookieJar { impl FromRequest for SignedCookieJar where B: Send, - S: Send, + S: Send + Sync, K: FromRef + Into, { type Rejection = Infallible; diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index 593bfba6..08c36755 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -61,7 +61,7 @@ where B: HttpBody + Send, B::Data: Send, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = FormRejection; diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index debc6957..feae007e 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -62,7 +62,7 @@ impl FromRequest for Query where T: DeserializeOwned, B: Send, - S: Send, + S: Send + Sync, { type Rejection = QueryRejection; diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index e0d2135c..e9abc408 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -110,7 +110,7 @@ impl DerefMut for WithRejection { impl FromRequest for WithRejection where B: Send, - S: Send, + S: Send + Sync, E: FromRequest, R: From + IntoResponse, { @@ -138,7 +138,7 @@ mod tests { impl FromRequest for TestExtractor where B: Send, - S: Send, + S: Send + Sync, { type Rejection = (); diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 6842327a..ef12f896 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -6,7 +6,7 @@ use axum::{ response::{IntoResponse, Response}, }; use futures_util::future::{BoxFuture, FutureExt, Map}; -use std::{future::Future, marker::PhantomData}; +use std::{future::Future, marker::PhantomData, sync::Arc}; mod or; @@ -24,7 +24,11 @@ pub trait HandlerCallWithExtractors: Sized { type Future: Future + Send + 'static; /// Call the handler with the extracted inputs. - fn call(self, state: S, extractors: T) -> >::Future; + fn call( + self, + state: Arc, + extractors: T, + ) -> >::Future; /// Conver this `HandlerCallWithExtractors` into [`Handler`]. fn into_handler(self) -> IntoHandler { @@ -70,7 +74,7 @@ pub trait HandlerCallWithExtractors: Sized { /// impl FromRequest for AdminPermissions /// where /// B: Send, - /// S: Send, + /// S: Send + Sync, /// { /// // check for admin permissions... /// # type Rejection = (); @@ -85,7 +89,7 @@ pub trait HandlerCallWithExtractors: Sized { /// impl FromRequest for User /// where /// B: Send, - /// S: Send, + /// S: Send + Sync, /// { /// // check for a logged in user... /// # type Rejection = (); @@ -130,7 +134,7 @@ macro_rules! impl_handler_call_with { fn call( self, - _state: S, + _state: Arc, ($($ty,)*): ($($ty,)*), ) -> >::Future { self($($ty,)*).map(IntoResponse::into_response) @@ -172,13 +176,13 @@ where T: FromRequest + Send + 'static, T::Rejection: Send, B: Send + 'static, - S: Clone + Send + 'static, + S: Send + Sync + 'static, { type Future = BoxFuture<'static, Response>; - fn call(self, state: S, req: http::Request) -> Self::Future { + fn call(self, state: Arc, req: http::Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state(state.clone(), req); + let mut req = RequestParts::with_state_arc(Arc::clone(&state), req); match req.extract::().await { Ok(t) => self.handler.call(state, t).await, Err(rejection) => rejection.into_response(), diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index 6478b35d..fb307ccf 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -8,7 +8,7 @@ use axum::{ }; use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map}; use http::StatusCode; -use std::{future::Future, marker::PhantomData}; +use std::{future::Future, marker::PhantomData, sync::Arc}; /// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another /// [`Handler`]. @@ -37,7 +37,7 @@ where fn call( self, - state: S, + state: Arc, extractors: Either, ) -> , S, B>>::Future { match extractors { @@ -64,14 +64,14 @@ where Lt::Rejection: Send, Rt::Rejection: Send, B: Send + 'static, - S: Clone + Send + 'static, + S: Send + Sync + 'static, { // this puts `futures_util` in our public API but thats fine in axum-extra type Future = BoxFuture<'static, Response>; - fn call(self, state: S, req: Request) -> Self::Future { + fn call(self, state: Arc, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state(state.clone(), req); + let mut req = RequestParts::with_state_arc(Arc::clone(&state), req); if let Ok(lt) = req.extract::().await { return self.lhs.call(state, lt).await; diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index 242b43e7..46ddc35e 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -104,7 +104,7 @@ where B::Data: Into, B::Error: Into, T: DeserializeOwned, - S: Send, + S: Send + Sync, { type Rejection = BodyAlreadyExtracted; diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index a30421a0..ddf45671 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -103,7 +103,7 @@ where B: HttpBody + Send, B::Data: Send, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = ProtoBufRejection; diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 0d968b94..d64e1ff9 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -178,7 +178,7 @@ pub trait RouterExt: sealed::Sealed { impl RouterExt for Router where B: axum::body::HttpBody + Send + 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { #[cfg(feature = "typed-routing")] fn typed_get(self, handler: H) -> Self diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 239d6e18..c15f94d8 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -53,7 +53,7 @@ where impl Resource where B: axum::body::HttpBody + Send + 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { /// Create a `Resource` with the given name and state. /// diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 9e84bd24..e1e3d457 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -223,7 +223,7 @@ fn impl_struct_by_extracting_each_field( B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, - S: Send, + S: ::std::marker::Send + ::std::marker::Sync, { type Rejection = #rejection_ident; @@ -659,7 +659,7 @@ fn impl_struct_by_extracting_all_at_once( #path<#via_type_generics>: ::axum::extract::FromRequest, #rejection_bound B: ::std::marker::Send, - S: ::std::marker::Send, + S: ::std::marker::Send + ::std::marker::Sync, { type Rejection = #associated_rejection_type; @@ -725,7 +725,7 @@ fn impl_enum_by_extracting_all_at_once( B: ::axum::body::HttpBody + ::std::marker::Send + 'static, B::Data: ::std::marker::Send, B::Error: ::std::convert::Into<::axum::BoxError>, - S: ::std::marker::Send, + S: ::std::marker::Send + ::std::marker::Sync, { type Rejection = #associated_rejection_type; diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 7b27a4bd..2d5983ad 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -226,7 +226,7 @@ mod typed_path; /// impl FromRequest for OtherExtractor /// where /// B: Send, -/// S: Send, +/// S: Send + Sync, /// { /// // this rejection doesn't implement `Display` and `Error` /// type Rejection = (StatusCode, String); diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 6a8f03c1..9df47028 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -130,7 +130,7 @@ fn expand_named_fields( impl ::axum::extract::FromRequest for #ident where B: Send, - S: Send, + S: Send + Sync, { type Rejection = #rejection_assoc_type; @@ -233,7 +233,7 @@ fn expand_unnamed_fields( impl ::axum::extract::FromRequest for #ident where B: Send, - S: Send, + S: Send + Sync, { type Rejection = #rejection_assoc_type; @@ -315,7 +315,7 @@ fn expand_unit_fields( impl ::axum::extract::FromRequest for #ident where B: Send, - S: Send, + S: Send + Sync, { type Rejection = #rejection_assoc_type; diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index d38d5e0c..168a1c81 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -10,7 +10,7 @@ struct A; impl FromRequest for A where B: Send, - S: Send, + S: Send + Sync, { type Rejection = (); diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index 06b87f0a..4090265c 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -10,7 +10,7 @@ struct A; impl FromRequest for A where B: Send, - S: Send, + S: Send + Sync, { type Rejection = (); diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 762809b6..4941f596 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -123,7 +123,7 @@ impl A { impl FromRequest for A where B: Send, - S: Send, + S: Send + Sync, { type Rejection = (); diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index a88382cf..a926eb7f 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -10,7 +10,7 @@ struct A; impl FromRequest for A where B: Send, - S: Send, + S: Send + Sync, { type Rejection = (); diff --git a/axum-macros/tests/from_request/pass/derive_opt_out.rs b/axum-macros/tests/from_request/pass/derive_opt_out.rs index e73d5a95..c5ef9deb 100644 --- a/axum-macros/tests/from_request/pass/derive_opt_out.rs +++ b/axum-macros/tests/from_request/pass/derive_opt_out.rs @@ -17,7 +17,7 @@ struct OtherExtractor; impl FromRequest for OtherExtractor where B: Send, - S: Send, + S: Send + Sync, { type Rejection = OtherExtractorRejection; diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index c308d615..40a25d58 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -31,7 +31,7 @@ struct OtherExtractor; impl FromRequest for OtherExtractor where B: Send + 'static, - S: Send, + S: Send + Sync, { // this rejection doesn't implement `Display` and `Error` type Rejection = (StatusCode, String); diff --git a/axum/benches/benches.rs b/axum/benches/benches.rs index 81d4ec8c..9cf47e29 100644 --- a/axum/benches/benches.rs +++ b/axum/benches/benches.rs @@ -1,6 +1,7 @@ use axum::{ + extract::State, routing::{get, post}, - Json, Router, Server, + Extension, Json, Router, Server, }; use hyper::server::conn::AddrIncoming; use serde::{Deserialize, Serialize}; @@ -50,6 +51,30 @@ fn main() { }), ) }); + + let state = AppState { + _string: "aaaaaaaaaaaaaaaaaa".to_owned(), + _vec: Vec::from([ + "aaaaaaaaaaaaaaaaaa".to_owned(), + "bbbbbbbbbbbbbbbbbb".to_owned(), + "cccccccccccccccccc".to_owned(), + ]), + }; + + benchmark("extension").run(|| { + Router::new() + .route("/", get(|_: Extension| async {})) + .layer(Extension(state.clone())) + }); + + benchmark("state") + .run(|| Router::with_state(state.clone()).route("/", get(|_: State| async {}))); +} + +#[derive(Clone)] +struct AppState { + _string: String, + _vec: Vec, } #[derive(Deserialize, Serialize)] @@ -92,9 +117,10 @@ impl BenchmarkBuilder { config_method!(headers, &'static [(&'static str, &'static str)]); config_method!(body, &'static str); - fn run(self, f: F) + fn run(self, f: F) where - F: FnOnce() -> Router, + F: FnOnce() -> Router, + S: Clone + Send + Sync + 'static, { // support only running some benchmarks with // ``` diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 9c31d0e4..d87eb978 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -424,7 +424,7 @@ struct ExtractUserAgent(HeaderValue); impl FromRequest for ExtractUserAgent where B: Send, - S: Send, + S: Send + Sync, { type Rejection = (StatusCode, &'static str); @@ -476,7 +476,7 @@ struct AuthenticatedUser { impl FromRequest for AuthenticatedUser where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Response; diff --git a/axum/src/extension.rs b/axum/src/extension.rs index d040b9e4..4c93ce1b 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -77,7 +77,7 @@ impl FromRequest for Extension where T: Clone + Send + Sync + 'static, B: Send, - S: Send, + S: Send + Sync, { type Rejection = ExtensionRejection; diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 3aa7684c..ba8c301c 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -131,7 +131,7 @@ pub struct ConnectInfo(pub T); impl FromRequest for ConnectInfo where B: Send, - S: Send, + S: Send + Sync, T: Clone + Send + Sync + 'static, { type Rejection = as FromRequest>::Rejection; diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs index f4c47543..ae27a6c2 100644 --- a/axum/src/extract/content_length_limit.rs +++ b/axum/src/extract/content_length_limit.rs @@ -41,7 +41,7 @@ where T: FromRequest, T::Rejection: IntoResponse, B: Send, - S: Send, + S: Send + Sync, { type Rejection = ContentLengthLimitRejection; diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index 79ae13fc..6137c64e 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -24,7 +24,7 @@ pub struct Host(pub String); impl FromRequest for Host where B: Send, - S: Send, + S: Send + Sync, { type Rejection = HostRejection; diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index 6413cf77..35a076b2 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -67,7 +67,7 @@ impl MatchedPath { impl FromRequest for MatchedPath where B: Send, - S: Send, + S: Send + Sync, { type Rejection = MatchedPathRejection; diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 076f4db1..3063a45c 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -54,7 +54,7 @@ impl FromRequest for Multipart where B: HttpBody + Default + Unpin + Send + 'static, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = MultipartRejection; diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index ca9e9fb6..16ed753b 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -167,7 +167,7 @@ impl FromRequest for Path where T: DeserializeOwned + Send, B: Send, - S: Send, + S: Send + Sync, { type Rejection = PathRejection; diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index ce1f747c..6eedd8c1 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -53,7 +53,7 @@ impl FromRequest for Query where T: DeserializeOwned, B: Send, - S: Send, + S: Send + Sync, { type Rejection = QueryRejection; diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index faf8df6e..b0090957 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -30,7 +30,7 @@ pub struct RawQuery(pub Option); impl FromRequest for RawQuery where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index c04ff534..5a0da460 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -89,7 +89,7 @@ pub struct OriginalUri(pub Uri); impl FromRequest for OriginalUri where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Infallible; @@ -146,7 +146,7 @@ where B: HttpBody + Send + 'static, B::Data: Into, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = BodyAlreadyExtracted; @@ -201,7 +201,7 @@ pub struct RawBody(pub B); impl FromRequest for RawBody where B: Send, - S: Send, + S: Send + Sync, { type Rejection = BodyAlreadyExtracted; diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index 94ccf5b1..1385dfec 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -153,7 +153,7 @@ use std::{ /// // keep `S` generic but require that it can produce a `MyLibraryState` /// // this means users will have to implement `FromRef for MyLibraryState` /// MyLibraryState: FromRef, -/// S: Send, +/// S: Send + Sync, /// { /// type Rejection = Infallible; /// @@ -182,7 +182,7 @@ impl FromRequest for State where B: Send, InnerState: FromRef, - OuterState: Send, + OuterState: Send + Sync, { type Rejection = Infallible; diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 952ea136..976d12a7 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -278,7 +278,7 @@ impl WebSocketUpgrade { impl FromRequest for WebSocketUpgrade where B: Send, - S: Send, + S: Send + Sync, { type Rejection = WebSocketUpgradeRejection; diff --git a/axum/src/form.rs b/axum/src/form.rs index 8267b8ef..9477eff2 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -62,7 +62,7 @@ where B: HttpBody + Send, B::Data: Send, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = FormRejection; diff --git a/axum/src/handler/into_service.rs b/axum/src/handler/into_service.rs index c8e51a24..b902eb90 100644 --- a/axum/src/handler/into_service.rs +++ b/axum/src/handler/into_service.rs @@ -5,6 +5,7 @@ use std::{ convert::Infallible, fmt, marker::PhantomData, + sync::Arc, task::{Context, Poll}, }; use tower_service::Service; @@ -16,7 +17,7 @@ use tower_service::Service; /// [`HandlerWithoutStateExt::into_service`]: super::HandlerWithoutStateExt::into_service pub struct IntoService { handler: H, - state: S, + state: Arc, _marker: PhantomData (T, B)>, } @@ -35,7 +36,7 @@ fn traits() { } impl IntoService { - pub(super) fn new(handler: H, state: S) -> Self { + pub(super) fn new(handler: H, state: Arc) -> Self { Self { handler, state, @@ -55,12 +56,11 @@ impl fmt::Debug for IntoService { impl Clone for IntoService where H: Clone, - S: Clone, { fn clone(&self) -> Self { Self { handler: self.handler.clone(), - state: self.state.clone(), + state: Arc::clone(&self.state), _marker: PhantomData, } } @@ -70,7 +70,7 @@ impl Service> for IntoService where H: Handler + Clone + Send + 'static, B: Send + 'static, - S: Clone, + S: Send + Sync, { type Response = Response; type Error = Infallible; @@ -88,7 +88,7 @@ where use futures_util::future::FutureExt; let handler = self.handler.clone(); - let future = Handler::call(handler, self.state.clone(), req); + let future = Handler::call(handler, Arc::clone(&self.state), req); let future = future.map(Ok as _); super::future::IntoServiceFuture::new(future) diff --git a/axum/src/handler/into_service_state_in_extension.rs b/axum/src/handler/into_service_state_in_extension.rs index 011161d9..3c59d043 100644 --- a/axum/src/handler/into_service_state_in_extension.rs +++ b/axum/src/handler/into_service_state_in_extension.rs @@ -5,6 +5,7 @@ use std::{ convert::Infallible, fmt, marker::PhantomData, + sync::Arc, task::{Context, Poll}, }; use tower_service::Service; @@ -54,7 +55,7 @@ impl Service> for IntoServiceStateInExtension where H: Handler + Clone + Send + 'static, B: Send + 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { type Response = Response; type Error = Infallible; @@ -73,7 +74,7 @@ where let state = req .extensions_mut() - .remove::() + .remove::>() .expect("state extension missing. This is a bug in axum, please file an issue"); let handler = self.handler.clone(); diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index d5142346..80514cf0 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -42,7 +42,7 @@ use crate::{ routing::IntoMakeService, }; use http::Request; -use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin}; +use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, sync::Arc}; use tower::ServiceExt; use tower_layer::Layer; use tower_service::Service; @@ -100,7 +100,7 @@ pub trait Handler: Clone + Send + Sized + 'static { type Future: Future + Send + 'static; /// Call the handler with the given request. - fn call(self, state: S, req: Request) -> Self::Future; + fn call(self, state: Arc, req: Request) -> Self::Future; /// Apply a [`tower::Layer`] to the handler. /// @@ -151,6 +151,11 @@ pub trait Handler: Clone + Send + Sized + 'static { /// Convert the handler into a [`Service`] by providing the state fn with_state(self, state: S) -> WithState { + self.with_state_arc(Arc::new(state)) + } + + /// Convert the handler into a [`Service`] by providing the state + fn with_state_arc(self, state: Arc) -> WithState { WithState { service: IntoService::new(self, state), } @@ -166,7 +171,7 @@ where { type Future = Pin + Send>>; - fn call(self, _state: S, _req: Request) -> Self::Future { + fn call(self, _state: Arc, _req: Request) -> Self::Future { Box::pin(async move { self().await.into_response() }) } } @@ -179,15 +184,15 @@ macro_rules! impl_handler { F: FnOnce($($ty,)*) -> Fut + Clone + Send + 'static, Fut: Future + Send, B: Send + 'static, - S: Send + 'static, + S: Send + Sync + 'static, Res: IntoResponse, $( $ty: FromRequest + Send,)* { type Future = Pin + Send>>; - fn call(self, state: S, req: Request) -> Self::Future { + fn call(self, state: Arc, req: Request) -> Self::Future { Box::pin(async move { - let mut req = RequestParts::with_state(state, req); + let mut req = RequestParts::with_state_arc(state, req); $( let $ty = match $ty::from_request(&mut req).await { @@ -254,10 +259,10 @@ where { type Future = future::LayeredFuture; - fn call(self, state: S, req: Request) -> Self::Future { + fn call(self, state: Arc, req: Request) -> Self::Future { use futures_util::future::{FutureExt, Map}; - let svc = self.handler.with_state(state); + let svc = self.handler.with_state_arc(state); let svc = self.layer.layer(svc); let future: Map< diff --git a/axum/src/handler/with_state.rs b/axum/src/handler/with_state.rs index 4afc9b10..71ca82bf 100644 --- a/axum/src/handler/with_state.rs +++ b/axum/src/handler/with_state.rs @@ -106,7 +106,7 @@ impl Service> for WithState where H: Handler + Clone + Send + 'static, B: Send + 'static, - S: Clone, + S: Send + Sync, { type Response = as Service>>::Response; type Error = as Service>>::Error; @@ -134,7 +134,6 @@ impl std::fmt::Debug for WithState { impl Clone for WithState where H: Clone, - S: Clone, { fn clone(&self) -> Self { Self { diff --git a/axum/src/json.rs b/axum/src/json.rs index e35a1623..e60edf80 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -100,7 +100,7 @@ where B: HttpBody + Send, B::Data: Send, B::Error: Into, - S: Send, + S: Send + Sync, { type Rejection = JsonRejection; diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index dfa3dfec..f9391731 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -48,7 +48,7 @@ use tower_service::Service; /// impl FromRequest for RequireAuth /// where /// B: Send, -/// S: Send, +/// S: Send + Sync, /// { /// type Rejection = StatusCode; /// @@ -283,7 +283,7 @@ mod tests { impl FromRequest for RequireAuth where B: Send, - S: Send, + S: Send + Sync, { type Rejection = StatusCode; diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index a71dc2fd..52a8e61f 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -16,6 +16,7 @@ use std::{ convert::Infallible, fmt, marker::PhantomData, + sync::Arc, task::{Context, Poll}, }; use tower::{service_fn, util::MapResponseLayer}; @@ -143,7 +144,7 @@ macro_rules! top_level_handler_fn { H: Handler, B: Send + 'static, T: 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { on(MethodFilter::$method, handler) } @@ -279,7 +280,7 @@ macro_rules! chained_handler_fn { where H: Handler, T: 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { self.on(MethodFilter::$method, handler) } @@ -428,7 +429,7 @@ where H: Handler, B: Send + 'static, T: 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { MethodRouter::new().on(filter, handler) } @@ -475,7 +476,7 @@ where H: Handler, B: Send + 'static, T: 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { MethodRouter::new() .fallback_boxed_response_body(IntoServiceStateInExtension::new(handler)) @@ -599,7 +600,7 @@ where where H: Handler, T: 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { self.on_service_boxed_response_body(filter, IntoServiceStateInExtension::new(handler)) } @@ -618,7 +619,7 @@ where where H: Handler, T: 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { self.fallback_service(IntoServiceStateInExtension::new(handler)) } @@ -727,6 +728,13 @@ where /// /// See [`State`](crate::extract::State) for more details about accessing state. pub fn with_state(self, state: S) -> WithState { + self.with_state_arc(Arc::new(state)) + } + + /// Provide the [`Arc`]'ed state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. + pub fn with_state_arc(self, state: Arc) -> WithState { WithState { method_router: self, state, @@ -1127,7 +1135,7 @@ where /// Created with [`MethodRouter::with_state`] pub struct WithState { method_router: MethodRouter, - state: S, + state: Arc, } impl WithState { @@ -1156,14 +1164,11 @@ impl WithState { } } -impl Clone for WithState -where - S: Clone, -{ +impl Clone for WithState { fn clone(&self) -> Self { Self { method_router: self.method_router.clone(), - state: self.state.clone(), + state: Arc::clone(&self.state), } } } @@ -1183,7 +1188,7 @@ where impl Service> for WithState where B: HttpBody, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { type Response = Response; type Error = E; @@ -1232,7 +1237,7 @@ where }, } = self; - req.extensions_mut().insert(state.clone()); + req.extensions_mut().insert(Arc::clone(state)); call!(req, method, HEAD, head); call!(req, method, HEAD, get); diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 62d9e507..2768d283 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -62,19 +62,16 @@ impl RouteId { /// The router type for composing handlers and services. pub struct Router { - state: S, + state: Arc, routes: HashMap>, node: Arc, fallback: Fallback, } -impl Clone for Router -where - S: Clone, -{ +impl Clone for Router { fn clone(&self) -> Self { Self { - state: self.state.clone(), + state: Arc::clone(&self.state), routes: self.routes.clone(), node: Arc::clone(&self.node), fallback: self.fallback.clone(), @@ -85,7 +82,7 @@ where impl Default for Router where B: HttpBody + Send + 'static, - S: Default + Clone + Send + Sync + 'static, + S: Default + Send + Sync + 'static, { fn default() -> Self { Self::with_state(S::default()) @@ -125,7 +122,7 @@ where impl Router where B: HttpBody + Send + 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { /// Create a new `Router` with the given state. /// @@ -134,6 +131,16 @@ where /// Unless you add additional routes this will respond with `404 Not Found` to /// all requests. pub fn with_state(state: S) -> Self { + Self::with_state_arc(Arc::new(state)) + } + + /// Create a new `Router` with the given [`Arc`]'ed state. + /// + /// See [`State`](crate::extract::State) for more details about accessing state. + /// + /// Unless you add additional routes this will respond with `404 Not Found` to + /// all requests. + pub fn with_state_arc(state: Arc) -> Self { Self { state, routes: Default::default(), @@ -262,7 +269,7 @@ where pub fn merge(mut self, other: R) -> Self where R: Into>, - S2: Clone + Send + Sync + 'static, + S2: Send + Sync + 'static, { let Router { state, @@ -282,7 +289,7 @@ where method_router // this will set the state for each route // such we don't override the inner state later in `MethodRouterWithState` - .layer(Extension(state.clone())) + .layer(Extension(Arc::clone(&state))) .downcast_state(), ), Endpoint::Route(route) => self.route_service(path, route), @@ -383,8 +390,8 @@ where H: Handler, T: 'static, { - let state = self.state.clone(); - self.fallback_service(handler.with_state(state)) + let state = Arc::clone(&self.state); + self.fallback_service(handler.with_state_arc(state)) } /// Add a fallback [`Service`] to the router. @@ -484,7 +491,10 @@ where .clone(); match &mut route { - Endpoint::MethodRouter(inner) => inner.clone().with_state(self.state.clone()).call(req), + Endpoint::MethodRouter(inner) => inner + .clone() + .with_state_arc(Arc::clone(&self.state)) + .call(req), Endpoint::Route(inner) => inner.call(req), } } @@ -498,7 +508,7 @@ where impl Service> for Router where B: HttpBody + Send + 'static, - S: Clone + Send + Sync + 'static, + S: Send + Sync + 'static, { type Response = Response; type Error = Infallible; @@ -618,10 +628,7 @@ enum Endpoint { Route(Route), } -impl Clone for Endpoint -where - S: Clone, -{ +impl Clone for Endpoint { fn clone(&self) -> Self { match self { Endpoint::MethodRouter(inner) => Endpoint::MethodRouter(inner.clone()), diff --git a/axum/src/typed_header.rs b/axum/src/typed_header.rs index c28a24a8..60ab2041 100644 --- a/axum/src/typed_header.rs +++ b/axum/src/typed_header.rs @@ -56,7 +56,7 @@ impl FromRequest for TypedHeader where T: headers::Header, B: Send, - S: Send, + S: Send + Sync, { type Rejection = TypedHeaderRejection; diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index be948375..f9179bb4 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -82,7 +82,7 @@ struct PrintRequestBody; #[async_trait] impl FromRequest for PrintRequestBody where - S: Send + Clone, + S: Clone + Send + Sync, { type Rejection = Response; diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index 20e3b4d4..3e41928e 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -58,7 +58,7 @@ struct Json(T); #[async_trait] impl FromRequest for Json where - S: Send, + S: Send + Sync, // these trait bounds are copied from `impl FromRequest for axum::Json` T: DeserializeOwned, B: axum::body::HttpBody + Send, diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 8330b95a..c4923069 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -57,7 +57,7 @@ where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` T: DeserializeOwned + Send, B: Send, - S: Send, + S: Send + Sync, { type Rejection = (StatusCode, axum::Json); diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 8725581d..84d51f44 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -124,7 +124,7 @@ impl AuthBody { #[async_trait] impl FromRequest for Claims where - S: Send, + S: Send + Sync, B: Send, { type Rejection = AuthError; diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index 8682eb85..e5988de2 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -63,7 +63,7 @@ pub struct ValidatedForm(pub T); impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, - S: Send, + S: Send + Sync, B: http_body::Body + Send, B::Data: Send, B::Error: Into, diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index cf8e15f2..6b53f77e 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -51,7 +51,7 @@ enum Version { impl FromRequest for Version where B: Send, - S: Send, + S: Send + Sync, { type Rejection = Response;