diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 3b0330ac..db16461b 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -1,7 +1,7 @@ use axum::extract::FromRequestParts; use axum_core::__composite_rejection as composite_rejection; use axum_core::__define_rejection as define_rejection; -use http::request::Parts; +use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; /// Extractor that deserializes query strings into some type. @@ -95,6 +95,37 @@ where } } +impl Query +where + T: DeserializeOwned, +{ + /// Attempts to construct a [`Query`] from a reference to a [`Uri`]. + /// + /// # Example + /// ``` + /// use axum_extra::extract::Query; + /// use http::Uri; + /// use serde::Deserialize; + /// + /// #[derive(Deserialize)] + /// struct ExampleParams { + /// foo: String, + /// bar: u32, + /// } + /// + /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); + /// let result: Query = Query::try_from_uri(&uri).unwrap(); + /// assert_eq!(result.foo, String::from("hello")); + /// assert_eq!(result.bar, 42); + /// ``` + pub fn try_from_uri(value: &Uri) -> Result { + let query = value.query().unwrap_or_default(); + let params = + serde_html_form::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?; + Ok(Self(params)) + } +} + axum_core::__impl_deref!(Query); define_rejection! { @@ -338,4 +369,34 @@ mod tests { assert_eq!(res.status(), StatusCode::BAD_REQUEST); } + + #[test] + fn test_try_from_uri() { + #[derive(Deserialize)] + struct TestQueryParams { + foo: Vec, + bar: u32, + } + let uri: Uri = "http://example.com/path?foo=hello&bar=42&foo=goodbye" + .parse() + .unwrap(); + let result: Query = Query::try_from_uri(&uri).unwrap(); + assert_eq!(result.foo, [String::from("hello"), String::from("goodbye")]); + assert_eq!(result.bar, 42); + } + + #[test] + fn test_try_from_uri_with_invalid_query() { + #[derive(Deserialize)] + struct TestQueryParams { + _foo: String, + _bar: u32, + } + let uri: Uri = "http://example.com/path?foo=hello&bar=invalid" + .parse() + .unwrap(); + let result: Result, _> = Query::try_from_uri(&uri); + + assert!(result.is_err()); + } }