diff --git a/CHANGELOG.md b/CHANGELOG.md index 7947ec8e..ade148fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `pin-project-lite` instead of `pin-project`. ([#95](https://github.com/tokio-rs/axum/pull/95)) - Re-export `http` crate and `hyper::Server`. ([#110](https://github.com/tokio-rs/axum/pull/110)) - Fix `Query` and `Form` extractors giving bad request error when query string is empty. ([#117](https://github.com/tokio-rs/axum/pull/117)) +- Add `Path` extractor. ([#124](https://github.com/tokio-rs/axum/pull/124)) ## Breaking changes diff --git a/examples/error_handling_and_dependency_injection.rs b/examples/error_handling_and_dependency_injection.rs index d9231374..e354f402 100644 --- a/examples/error_handling_and_dependency_injection.rs +++ b/examples/error_handling_and_dependency_injection.rs @@ -11,7 +11,7 @@ use axum::{ async_trait, - extract::{Extension, Json, UrlParams}, + extract::{Extension, Json, Path}, prelude::*, response::IntoResponse, AddExtensionLayer, @@ -56,7 +56,7 @@ async fn main() { /// are automatically converted into `AppError` which implements `IntoResponse` /// so it can be returned from handlers directly. async fn users_show( - UrlParams((user_id,)): UrlParams<(Uuid,)>, + Path(user_id): Path, Extension(user_repo): Extension, ) -> Result, AppError> { let user = user_repo.find(user_id).await?; diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 5d8c913c..81923f3f 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -8,7 +8,7 @@ use axum::{ async_trait, - extract::{extractor_middleware, ContentLengthLimit, Extension, RequestParts, UrlParams}, + extract::{extractor_middleware, ContentLengthLimit, Extension, Path, RequestParts}, prelude::*, response::IntoResponse, routing::BoxRoute, @@ -79,7 +79,7 @@ struct State { } async fn kv_get( - UrlParams((key,)): UrlParams<(String,)>, + Path(key): Path, Extension(state): Extension, ) -> Result { let db = &state.read().unwrap().db; @@ -92,7 +92,7 @@ async fn kv_get( } async fn kv_set( - UrlParams((key,)): UrlParams<(String,)>, + Path(key): Path, ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb Extension(state): Extension, ) { @@ -113,10 +113,7 @@ fn admin_routes() -> BoxRoute { state.write().unwrap().db.clear(); } - async fn remove_key( - UrlParams((key,)): UrlParams<(String,)>, - Extension(state): Extension, - ) { + async fn remove_key(Path(key): Path, Extension(state): Extension) { state.write().unwrap().db.remove(&key); } diff --git a/examples/templates.rs b/examples/templates.rs index f2c5f087..4f3bda27 100644 --- a/examples/templates.rs +++ b/examples/templates.rs @@ -29,14 +29,8 @@ async fn main() { .unwrap(); } -async fn greet(params: extract::UrlParamsMap) -> impl IntoResponse { - let name = params - .get("name") - .expect("`name` will be there if route was matched") - .to_string(); - +async fn greet(extract::Path(name): extract::Path) -> impl IntoResponse { let template = HelloTemplate { name }; - HtmlTemplate(template) } diff --git a/examples/todos.rs b/examples/todos.rs index 76bc85b5..188396ea 100644 --- a/examples/todos.rs +++ b/examples/todos.rs @@ -14,7 +14,7 @@ //! ``` use axum::{ - extract::{Extension, Json, Query, UrlParams}, + extract::{Extension, Json, Path, Query}, prelude::*, response::IntoResponse, service::ServiceExt, @@ -129,7 +129,7 @@ struct UpdateTodo { } async fn todos_update( - UrlParams((id,)): UrlParams<(Uuid,)>, + Path(id): Path, Json(input): Json, Extension(db): Extension, ) -> Result { @@ -153,10 +153,7 @@ async fn todos_update( Ok(response::Json(todo)) } -async fn todos_delete( - UrlParams((id,)): UrlParams<(Uuid,)>, - Extension(db): Extension, -) -> impl IntoResponse { +async fn todos_delete(Path(id): Path, Extension(db): Extension) -> impl IntoResponse { if db.write().unwrap().remove(&id).is_some() { StatusCode::NO_CONTENT } else { diff --git a/examples/versioning.rs b/examples/versioning.rs index 9df60371..f87c3955 100644 --- a/examples/versioning.rs +++ b/examples/versioning.rs @@ -12,6 +12,7 @@ use axum::{ }; use http::Response; use http::StatusCode; +use std::collections::HashMap; use std::net::SocketAddr; #[tokio::main] @@ -53,7 +54,7 @@ where type Rejection = Response; async fn from_request(req: &mut RequestParts) -> Result { - let params = extract::UrlParamsMap::from_request(req) + let params = extract::Path::>::from_request(req) .await .map_err(IntoResponse::into_response)?; @@ -61,7 +62,7 @@ where .get("version") .ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?; - match version { + match version.as_str() { "v1" => Ok(Version::V1), "v2" => Ok(Version::V2), "v3" => Ok(Version::V3), diff --git a/src/extract/mod.rs b/src/extract/mod.rs index d9cffd38..5edfe64f 100644 --- a/src/extract/mod.rs +++ b/src/extract/mod.rs @@ -258,6 +258,7 @@ mod content_length_limit; mod extension; mod form; mod json; +mod path; mod query; mod raw_query; mod request_parts; @@ -273,6 +274,7 @@ pub use self::{ extractor_middleware::extractor_middleware, form::Form, json::Json, + path::Path, query::Query, raw_query::RawQuery, request_parts::{Body, BodyStream}, diff --git a/src/extract/path/de.rs b/src/extract/path/de.rs new file mode 100644 index 00000000..acfb3196 --- /dev/null +++ b/src/extract/path/de.rs @@ -0,0 +1,671 @@ +use crate::routing::UrlParams; +use crate::util::ByteStr; +use serde::{ + de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, + forward_to_deserialize_any, Deserializer, +}; +use std::fmt::{self, Display}; + +/// This type represents errors that can occur when deserializing. +#[derive(Debug, Eq, PartialEq)] +pub(crate) struct PathDeserializerError(pub(crate) String); + +impl de::Error for PathDeserializerError { + #[inline] + fn custom(msg: T) -> Self { + PathDeserializerError(msg.to_string()) + } +} + +impl std::error::Error for PathDeserializerError { + #[inline] + fn description(&self) -> &str { + "path deserializer error" + } +} + +impl fmt::Display for PathDeserializerError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + PathDeserializerError(msg) => write!(f, "{}", msg), + } + } +} + +macro_rules! unsupported_type { + ($trait_fn:ident, $name:literal) => { + fn $trait_fn(self, _: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom(concat!( + "unsupported type: ", + $name + ))) + } + }; +} + +macro_rules! parse_single_value { + ($trait_fn:ident, $visit_fn:ident, $tp:literal) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.url_params.0.len() != 1 { + return Err(PathDeserializerError::custom( + format!( + "wrong number of parameters: {} expected 1", + self.url_params.0.len() + ) + .as_str(), + )); + } + + let value = self.url_params.0[0].1.parse().map_err(|_| { + PathDeserializerError::custom(format!( + "can not parse `{:?}` to a `{}`", + self.url_params.0[0].1.as_str(), + $tp + )) + })?; + visitor.$visit_fn(value) + } + }; +} + +pub(crate) struct PathDeserializer<'de> { + url_params: &'de UrlParams, +} + +impl<'de> PathDeserializer<'de> { + #[inline] + pub(crate) fn new(url_params: &'de UrlParams) -> Self { + PathDeserializer { url_params } + } +} + +impl<'de> Deserializer<'de> for PathDeserializer<'de> { + type Error = PathDeserializerError; + + unsupported_type!(deserialize_any, "'any'"); + unsupported_type!(deserialize_bytes, "bytes"); + unsupported_type!(deserialize_option, "Option"); + unsupported_type!(deserialize_identifier, "identifier"); + unsupported_type!(deserialize_ignored_any, "ignored_any"); + + parse_single_value!(deserialize_bool, visit_bool, "bool"); + parse_single_value!(deserialize_i8, visit_i8, "i8"); + parse_single_value!(deserialize_i16, visit_i16, "i16"); + parse_single_value!(deserialize_i32, visit_i32, "i32"); + parse_single_value!(deserialize_i64, visit_i64, "i64"); + parse_single_value!(deserialize_u8, visit_u8, "u8"); + parse_single_value!(deserialize_u16, visit_u16, "u16"); + parse_single_value!(deserialize_u32, visit_u32, "u32"); + parse_single_value!(deserialize_u64, visit_u64, "u64"); + parse_single_value!(deserialize_f32, visit_f32, "f32"); + parse_single_value!(deserialize_f64, visit_f64, "f64"); + parse_single_value!(deserialize_string, visit_string, "String"); + parse_single_value!(deserialize_byte_buf, visit_string, "String"); + parse_single_value!(deserialize_char, visit_char, "char"); + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.url_params.0.len() != 1 { + return Err(PathDeserializerError::custom(format!( + "wrong number of parameters: {} expected 1", + self.url_params.0.len() + ))); + } + visitor.visit_str(&self.url_params.0[0].1) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(SeqDeserializer { + params: &self.url_params.0, + }) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.url_params.0.len() < len { + return Err(PathDeserializerError::custom( + format!( + "wrong number of parameters: {} expected {}", + self.url_params.0.len(), + len + ) + .as_str(), + )); + } + visitor.visit_seq(SeqDeserializer { + params: &self.url_params.0, + }) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.url_params.0.len() < len { + return Err(PathDeserializerError::custom( + format!( + "wrong number of parameters: {} expected {}", + self.url_params.0.len(), + len + ) + .as_str(), + )); + } + visitor.visit_seq(SeqDeserializer { + params: &self.url_params.0, + }) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(MapDeserializer { + params: &self.url_params.0, + value: None, + }) + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.url_params.0.len() != 1 { + return Err(PathDeserializerError::custom(format!( + "wrong number of parameters: {} expected 1", + self.url_params.0.len() + ))); + } + + visitor.visit_enum(EnumDeserializer { + value: &self.url_params.0[0].1, + }) + } +} + +struct MapDeserializer<'de> { + params: &'de [(ByteStr, ByteStr)], + value: Option<&'de str>, +} + +impl<'de> MapAccess<'de> for MapDeserializer<'de> { + type Error = PathDeserializerError; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + match self.params.split_first() { + Some(((key, value), tail)) => { + self.value = Some(value); + self.params = tail; + seed.deserialize(KeyDeserializer { key }).map(Some) + } + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + match self.value.take() { + Some(value) => seed.deserialize(ValueDeserializer { value }), + None => Err(serde::de::Error::custom("value is missing")), + } + } +} + +struct KeyDeserializer<'de> { + key: &'de str, +} + +macro_rules! parse_key { + ($trait_fn:ident) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_str(self.key) + } + }; +} + +impl<'de> Deserializer<'de> for KeyDeserializer<'de> { + type Error = PathDeserializerError; + + parse_key!(deserialize_identifier); + parse_key!(deserialize_str); + parse_key!(deserialize_string); + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom("Unexpected")) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char bytes + byte_buf option unit unit_struct seq tuple + tuple_struct map newtype_struct struct enum ignored_any + } +} + +macro_rules! parse_value { + ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let v = self.value.parse().map_err(|_| { + PathDeserializerError::custom(format!( + "can not parse `{:?}` to a `{}`", + self.value, $ty + )) + })?; + visitor.$visit_fn(v) + } + }; +} + +struct ValueDeserializer<'de> { + value: &'de str, +} + +impl<'de> Deserializer<'de> for ValueDeserializer<'de> { + type Error = PathDeserializerError; + + unsupported_type!(deserialize_any, "any"); + unsupported_type!(deserialize_seq, "seq"); + unsupported_type!(deserialize_map, "map"); + unsupported_type!(deserialize_identifier, "identifier"); + + parse_value!(deserialize_bool, visit_bool, "bool"); + parse_value!(deserialize_i8, visit_i8, "i8"); + parse_value!(deserialize_i16, visit_i16, "i16"); + parse_value!(deserialize_i32, visit_i32, "i16"); + parse_value!(deserialize_i64, visit_i64, "i64"); + parse_value!(deserialize_u8, visit_u8, "u8"); + parse_value!(deserialize_u16, visit_u16, "u16"); + parse_value!(deserialize_u32, visit_u32, "u32"); + parse_value!(deserialize_u64, visit_u64, "u64"); + parse_value!(deserialize_f32, visit_f32, "f32"); + parse_value!(deserialize_f64, visit_f64, "f64"); + parse_value!(deserialize_string, visit_string, "String"); + parse_value!(deserialize_byte_buf, visit_string, "String"); + parse_value!(deserialize_char, visit_char, "char"); + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.value) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.value.as_bytes()) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom("unsupported type: tuple")) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom( + "unsupported type: tuple struct", + )) + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom("unsupported type: struct")) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(EnumDeserializer { value: self.value }) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } +} + +struct EnumDeserializer<'de> { + value: &'de str, +} + +impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { + type Error = PathDeserializerError; + type Variant = UnitVariant; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + Ok(( + seed.deserialize(KeyDeserializer { key: self.value })?, + UnitVariant, + )) + } +} + +struct UnitVariant; + +impl<'de> VariantAccess<'de> for UnitVariant { + type Error = PathDeserializerError; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + Err(PathDeserializerError::custom("not supported")) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom("not supported")) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializerError::custom("not supported")) + } +} + +struct SeqDeserializer<'de> { + params: &'de [(ByteStr, ByteStr)], +} + +impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { + type Error = PathDeserializerError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.params.split_first() { + Some(((_, value), tail)) => { + self.params = tail; + Ok(Some(seed.deserialize(ValueDeserializer { value })?)) + } + None => Ok(None), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::util::ByteStr; + use serde::Deserialize; + use std::collections::HashMap; + + #[derive(Debug, Deserialize, Eq, PartialEq)] + enum MyEnum { + A, + B, + #[serde(rename = "c")] + C, + } + + #[derive(Debug, Deserialize, Eq, PartialEq)] + struct Struct { + c: String, + b: bool, + a: i32, + } + + fn create_url_params(values: I) -> UrlParams + where + I: IntoIterator, + K: AsRef, + V: AsRef, + { + UrlParams( + values + .into_iter() + .map(|(k, v)| (ByteStr::new(k), ByteStr::new(v))) + .collect(), + ) + } + + macro_rules! check_single_value { + ($ty:ty, $value_str:literal, $value:expr) => { + #[allow(clippy::bool_assert_comparison)] + { + let url_params = create_url_params([("value", $value_str)]); + let deserializer = PathDeserializer::new(&url_params); + assert_eq!(<$ty>::deserialize(deserializer).unwrap(), $value); + } + }; + } + + #[test] + fn test_parse_single_value() { + check_single_value!(bool, "true", true); + check_single_value!(bool, "false", false); + check_single_value!(i8, "-123", -123); + check_single_value!(i16, "-123", -123); + check_single_value!(i32, "-123", -123); + check_single_value!(i64, "-123", -123); + check_single_value!(u8, "123", 123); + check_single_value!(u16, "123", 123); + check_single_value!(u32, "123", 123); + check_single_value!(u64, "123", 123); + check_single_value!(f32, "123", 123.0); + check_single_value!(f64, "123", 123.0); + check_single_value!(String, "abc", "abc"); + check_single_value!(char, "a", 'a'); + + let url_params = create_url_params([("a", "B")]); + assert_eq!( + MyEnum::deserialize(PathDeserializer::new(&url_params)).unwrap(), + MyEnum::B + ); + + let url_params = create_url_params([("a", "1"), ("b", "2")]); + assert_eq!( + i32::deserialize(PathDeserializer::new(&url_params)).unwrap_err(), + PathDeserializerError::custom("wrong number of parameters: 2 expected 1".to_string()) + ); + } + + #[test] + fn test_parse_seq() { + let url_params = create_url_params([("a", "1"), ("b", "true"), ("c", "abc")]); + assert_eq!( + <(i32, bool, String)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), + (1, true, "abc".to_string()) + ); + + #[derive(Debug, Deserialize, Eq, PartialEq)] + struct TupleStruct(i32, bool, String); + assert_eq!( + TupleStruct::deserialize(PathDeserializer::new(&url_params)).unwrap(), + TupleStruct(1, true, "abc".to_string()) + ); + + let url_params = create_url_params([("a", "1"), ("b", "2"), ("c", "3")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + vec![1, 2, 3] + ); + + let url_params = create_url_params([("a", "c"), ("a", "B")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + vec![MyEnum::C, MyEnum::B] + ); + } + + #[test] + fn test_parse_struct() { + let url_params = create_url_params([("a", "1"), ("b", "true"), ("c", "abc")]); + assert_eq!( + Struct::deserialize(PathDeserializer::new(&url_params)).unwrap(), + Struct { + c: "abc".to_string(), + b: true, + a: 1, + } + ); + } + + #[test] + fn test_parse_map() { + let url_params = create_url_params([("a", "1"), ("b", "true"), ("c", "abc")]); + assert_eq!( + >::deserialize(PathDeserializer::new(&url_params)).unwrap(), + [("a", "1"), ("b", "true"), ("c", "abc")] + .iter() + .map(|(key, value)| ((*key).to_string(), (*value).to_string())) + .collect() + ); + } +} diff --git a/src/extract/path/mod.rs b/src/extract/path/mod.rs new file mode 100644 index 00000000..767631bf --- /dev/null +++ b/src/extract/path/mod.rs @@ -0,0 +1,113 @@ +mod de; + +use super::{rejection::*, FromRequest}; +use crate::{extract::RequestParts, routing::UrlParams}; +use async_trait::async_trait; +use serde::de::DeserializeOwned; +use std::ops::{Deref, DerefMut}; + +/// Extractor that will get captures from the URL and parse them using [`serde`](https://crates.io/crates/serde). +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{extract::Path, prelude::*}; +/// use uuid::Uuid; +/// +/// async fn users_teams_show( +/// Path((user_id, team_id)): Path<(Uuid, Uuid)>, +/// ) { +/// // ... +/// } +/// +/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// If the path contains only one parameter, then you can omit the tuple. +/// +/// ```rust,no_run +/// use axum::{extract::Path, prelude::*}; +/// use uuid::Uuid; +/// +/// async fn user_info(Path(user_id): Path) { +/// // ... +/// } +/// +/// let app = route("/users/:user_id", get(user_info)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +/// +/// Path segments also can be deserialized into any type that implements [serde::Deserialize](https://docs.rs/serde/1.0.127/serde/trait.Deserialize.html). +/// Path segment labels will be matched with struct field names. +/// +/// ```rust,no_run +/// use axum::{extract::Path, prelude::*}; +/// use serde::Deserialize; +/// use uuid::Uuid; +/// +/// #[derive(Deserialize)] +/// struct Params { +/// user_id: Uuid, +/// team_id: Uuid, +/// } +/// +/// async fn users_teams_show( +/// Path(Params { user_id, team_id }): Path, +/// ) { +/// // ... +/// } +/// +/// let app = route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// # async { +/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +/// # }; +/// ``` +#[derive(Debug)] +pub struct Path(pub T); + +impl Deref for Path { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Path { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[async_trait] +impl FromRequest for Path +where + T: DeserializeOwned + Send, + B: Send, +{ + type Rejection = PathParamsRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + const EMPTY_URL_PARAMS: &UrlParams = &UrlParams(Vec::new()); + + let url_params = if let Some(params) = req + .extensions_mut() + .and_then(|ext| ext.get::>()) + { + params.as_ref().unwrap_or(EMPTY_URL_PARAMS) + } else { + return Err(MissingRouteParams.into()); + }; + + T::deserialize(de::PathDeserializer::new(url_params)) + .map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0))) + .map(Path) + } +} diff --git a/src/extract/rejection.rs b/src/extract/rejection.rs index ccd1f150..64cf80f9 100644 --- a/src/extract/rejection.rs +++ b/src/extract/rejection.rs @@ -159,6 +159,25 @@ impl IntoResponse for InvalidUrlParam { } } +/// Rejection type for [`Path`](super::Path) if the capture route +/// param didn't have the expected type. +#[derive(Debug)] +pub struct InvalidPathParam(String); + +impl InvalidPathParam { + pub(super) fn new(err: impl Into) -> Self { + InvalidPathParam(err.into()) + } +} + +impl IntoResponse for InvalidPathParam { + fn into_response(self) -> http::Response { + let mut res = http::Response::new(Body::from(format!("Invalid URL param. {}", self.0))); + *res.status_mut() = http::StatusCode::BAD_REQUEST; + res + } +} + /// Rejection type for extractors that deserialize query strings if the input /// couldn't be deserialized into the target type. #[derive(Debug)] @@ -254,6 +273,17 @@ composite_rejection! { } } +composite_rejection! { + /// Rejection used for [`Path`](super::Path). + /// + /// Contains one variant for each way the [`Path`](super::Path) extractor + /// can fail. + pub enum PathParamsRejection { + InvalidPathParam, + MissingRouteParams, + } +} + composite_rejection! { /// Rejection used for [`Bytes`](bytes::Bytes). /// diff --git a/src/lib.rs b/src/lib.rs index 62bc3525..75f87ccc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -239,8 +239,8 @@ //! # }; //! ``` //! -//! [`extract::UrlParams`] can be used to extract params from a dynamic URL. It -//! is compatible with any type that implements [`std::str::FromStr`], such as +//! [`extract::Path`] can be used to extract params from a dynamic URL. It +//! is compatible with any type that implements [`serde::Deserialize`], such as //! [`Uuid`]: //! //! ```rust,no_run @@ -249,9 +249,7 @@ //! //! let app = route("/users/:id", post(create_user)); //! -//! async fn create_user(params: extract::UrlParams<(Uuid,)>) { -//! let user_id: Uuid = (params.0).0; -//! +//! async fn create_user(extract::Path(user_id): extract::Path) { //! // ... //! } //! # async { @@ -259,9 +257,6 @@ //! # }; //! ``` //! -//! There is also [`UrlParamsMap`](extract::UrlParamsMap) which provide a map -//! like API for extracting URL params. -//! //! You can also apply multiple extractors: //! //! ```rust,no_run @@ -284,10 +279,9 @@ //! } //! //! async fn get_user_things( -//! params: extract::UrlParams<(Uuid,)>, +//! extract::Path(user_id): extract::Path, //! pagination: Option>, //! ) { -//! let user_id: Uuid = (params.0).0; //! let pagination: Pagination = pagination.unwrap_or_default().0; //! //! // ... diff --git a/src/tests.rs b/src/tests.rs index 6d8e5973..183de768 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -9,6 +9,7 @@ use hyper::{Body, Server}; use serde::Deserialize; use serde_json::json; use std::{ + collections::HashMap, convert::Infallible, net::{SocketAddr, TcpListener}, task::{Context, Poll}, @@ -244,20 +245,14 @@ async fn routing() { async fn extracting_url_params() { let app = route( "/users/:id", - get(|params: extract::UrlParams<(i32,)>| async move { - let (id,) = params.0; + get(|extract::Path(id): extract::Path| async move { assert_eq!(id, 42); }) - .post(|params_map: extract::UrlParamsMap| async move { - assert_eq!(params_map.get("id").unwrap(), "1337"); - assert_eq!( - params_map - .get_typed::("id") - .expect("missing") - .expect("failed to parse"), - 1337 - ); - }), + .post( + |extract::Path(params_map): extract::Path>| async move { + assert_eq!(params_map.get("id").unwrap(), &1337); + }, + ), ); let addr = run_in_background(app).await; @@ -283,12 +278,7 @@ async fn extracting_url_params() { async fn extracting_url_params_multiple_times() { let app = route( "/users/:id", - get( - |_: extract::UrlParams<(i32,)>, - _: extract::UrlParamsMap, - _: extract::UrlParams<(i32,)>, - _: extract::UrlParamsMap| async {}, - ), + get(|_: extract::Path, _: extract::Path| async {}), ); let addr = run_in_background(app).await; diff --git a/src/tests/nest.rs b/src/tests/nest.rs index d1b0ac37..f518eb4e 100644 --- a/src/tests/nest.rs +++ b/src/tests/nest.rs @@ -1,4 +1,5 @@ use super::*; +use std::collections::HashMap; #[tokio::test] async fn nesting_apps() { @@ -8,23 +9,27 @@ async fn nesting_apps() { ) .route( "/users/:id", - get(|params: extract::UrlParamsMap| async move { - format!( - "{}: users#show ({})", - params.get("version").unwrap(), - params.get("id").unwrap() - ) - }), + get( + |params: extract::Path>| async move { + format!( + "{}: users#show ({})", + params.get("version").unwrap(), + params.get("id").unwrap() + ) + }, + ), ) .route( "/games/:id", - get(|params: extract::UrlParamsMap| async move { - format!( - "{}: games#show ({})", - params.get("version").unwrap(), - params.get("id").unwrap() - ) - }), + get( + |params: extract::Path>| async move { + format!( + "{}: games#show ({})", + params.get("version").unwrap(), + params.get("id").unwrap() + ) + }, + ), ); let app = route("/", get(|| async { "hi" })).nest("/:version/api", api_routes); diff --git a/src/util.rs b/src/util.rs index cdf437a0..1c05533c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,9 +1,19 @@ use bytes::Bytes; +use std::ops::Deref; /// A string like type backed by `Bytes` making it cheap to clone. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct ByteStr(Bytes); +impl Deref for ByteStr { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + impl ByteStr { pub(crate) fn new(s: S) -> Self where