Percent decode automatically in extract::Path (#272)

* Percent decode automatically in `extract::Path`

Fixes https://github.com/tokio-rs/axum/issues/261

* return an error if path param contains invalid utf-8

* Mention automatic decoding in the docs

* Update changelog: This is a breaking change

* cleanup

* fix tests
This commit is contained in:
David Pedersen 2021-10-02 16:04:29 +02:00 committed by GitHub
parent 2c2bcd7754
commit afabded385
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 156 additions and 58 deletions

View File

@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased # Unreleased
- Improve performance of `BoxRoute` ([#339]) - Improve performance of `BoxRoute` ([#339])
- **breaking:** Automatically do percent decoding in `extract::Path`
([#272])
- **breaking:** `Router::boxed` now the inner service to implement `Clone` and - **breaking:** `Router::boxed` now the inner service to implement `Clone` and
`Sync` in addition to the previous trait bounds ([#339]) `Sync` in addition to the previous trait bounds ([#339])
- **breaking:** Added feature flags for HTTP1 and JSON. This enables removing a - **breaking:** Added feature flags for HTTP1 and JSON. This enables removing a
@ -16,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#339]: https://github.com/tokio-rs/axum/pull/339 [#339]: https://github.com/tokio-rs/axum/pull/339
[#286]: https://github.com/tokio-rs/axum/pull/286 [#286]: https://github.com/tokio-rs/axum/pull/286
[#272]: https://github.com/tokio-rs/axum/pull/272
# 0.2.6 (02. October, 2021) # 0.2.6 (02. October, 2021)

View File

@ -27,30 +27,32 @@ ws = ["tokio-tungstenite", "sha-1", "base64"]
async-trait = "0.1" async-trait = "0.1"
bitflags = "1.0" bitflags = "1.0"
bytes = "1.0" bytes = "1.0"
dyn-clone = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] } futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2" http = "0.2"
http-body = "0.4.3" http-body = "0.4.3"
hyper = { version = "0.14", features = ["server", "tcp", "stream"] } hyper = { version = "0.14", features = ["server", "tcp", "stream"] }
percent-encoding = "2.1"
pin-project-lite = "0.2.7" pin-project-lite = "0.2.7"
regex = "1.5" regex = "1.5"
serde = "1.0" serde = "1.0"
serde_json = { version = "1.0", optional = true }
serde_urlencoded = "0.7" serde_urlencoded = "0.7"
sync_wrapper = "0.1.1"
tokio = { version = "1", features = ["time"] } tokio = { version = "1", features = ["time"] }
tokio-util = "0.6" tokio-util = "0.6"
tower = { version = "0.4", default-features = false, features = ["util", "buffer", "make"] } tower = { version = "0.4", default-features = false, features = ["util", "buffer", "make"] }
tower-service = "0.3"
tower-layer = "0.3"
tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] } tower-http = { version = "0.1", features = ["add-extension", "map-response-body"] }
sync_wrapper = "0.1.1" tower-layer = "0.3"
tower-service = "0.3"
# optional dependencies # optional dependencies
tokio-tungstenite = { optional = true, version = "0.15" }
sha-1 = { optional = true, version = "0.9.6" }
base64 = { optional = true, version = "0.13" } base64 = { optional = true, version = "0.13" }
headers = { optional = true, version = "0.3" } headers = { optional = true, version = "0.3" }
multer = { optional = true, version = "2.0.0" }
mime = { optional = true, version = "0.3" } mime = { optional = true, version = "0.3" }
multer = { optional = true, version = "2.0.0" }
serde_json = { version = "1.0", optional = true }
sha-1 = { optional = true, version = "0.9.6" }
tokio-tungstenite = { optional = true, version = "0.15" }
[dev-dependencies] [dev-dependencies]
futures = "0.3" futures = "0.3"
@ -82,4 +84,11 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"] rustdoc-args = ["--cfg", "docsrs"]
[package.metadata.playground] [package.metadata.playground]
features = ["ws", "multipart", "headers"] features = [
"http1",
"http2",
"json",
"multipart",
"tower",
"ws",
]

View File

@ -1,5 +1,4 @@
use crate::routing::UrlParams; use crate::util::{ByteStr, PercentDecodedByteStr};
use crate::util::ByteStr;
use serde::{ use serde::{
de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor},
forward_to_deserialize_any, Deserializer, forward_to_deserialize_any, Deserializer,
@ -53,20 +52,20 @@ macro_rules! parse_single_value {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.0.len() != 1 { if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom( return Err(PathDeserializerError::custom(
format!( format!(
"wrong number of parameters: {} expected 1", "wrong number of parameters: {} expected 1",
self.url_params.0.len() self.url_params.len()
) )
.as_str(), .as_str(),
)); ));
} }
let value = self.url_params.0[0].1.parse().map_err(|_| { let value = self.url_params[0].1.parse().map_err(|_| {
PathDeserializerError::custom(format!( PathDeserializerError::custom(format!(
"can not parse `{:?}` to a `{}`", "can not parse `{:?}` to a `{}`",
self.url_params.0[0].1.as_str(), self.url_params[0].1.as_str(),
$tp $tp
)) ))
})?; })?;
@ -76,12 +75,12 @@ macro_rules! parse_single_value {
} }
pub(crate) struct PathDeserializer<'de> { pub(crate) struct PathDeserializer<'de> {
url_params: &'de UrlParams, url_params: &'de [(ByteStr, PercentDecodedByteStr)],
} }
impl<'de> PathDeserializer<'de> { impl<'de> PathDeserializer<'de> {
#[inline] #[inline]
pub(crate) fn new(url_params: &'de UrlParams) -> Self { pub(crate) fn new(url_params: &'de [(ByteStr, PercentDecodedByteStr)]) -> Self {
PathDeserializer { url_params } PathDeserializer { url_params }
} }
} }
@ -114,13 +113,13 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.0.len() != 1 { if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(format!( return Err(PathDeserializerError::custom(format!(
"wrong number of parameters: {} expected 1", "wrong number of parameters: {} expected 1",
self.url_params.0.len() self.url_params.len()
))); )));
} }
visitor.visit_str(&self.url_params.0[0].1) visitor.visit_str(&self.url_params[0].1)
} }
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error> fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
@ -157,7 +156,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>, V: Visitor<'de>,
{ {
visitor.visit_seq(SeqDeserializer { visitor.visit_seq(SeqDeserializer {
params: &self.url_params.0, params: self.url_params,
}) })
} }
@ -165,18 +164,18 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.0.len() < len { if self.url_params.len() < len {
return Err(PathDeserializerError::custom( return Err(PathDeserializerError::custom(
format!( format!(
"wrong number of parameters: {} expected {}", "wrong number of parameters: {} expected {}",
self.url_params.0.len(), self.url_params.len(),
len len
) )
.as_str(), .as_str(),
)); ));
} }
visitor.visit_seq(SeqDeserializer { visitor.visit_seq(SeqDeserializer {
params: &self.url_params.0, params: self.url_params,
}) })
} }
@ -189,18 +188,18 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.0.len() < len { if self.url_params.len() < len {
return Err(PathDeserializerError::custom( return Err(PathDeserializerError::custom(
format!( format!(
"wrong number of parameters: {} expected {}", "wrong number of parameters: {} expected {}",
self.url_params.0.len(), self.url_params.len(),
len len
) )
.as_str(), .as_str(),
)); ));
} }
visitor.visit_seq(SeqDeserializer { visitor.visit_seq(SeqDeserializer {
params: &self.url_params.0, params: self.url_params,
}) })
} }
@ -209,7 +208,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
V: Visitor<'de>, V: Visitor<'de>,
{ {
visitor.visit_map(MapDeserializer { visitor.visit_map(MapDeserializer {
params: &self.url_params.0, params: self.url_params,
value: None, value: None,
}) })
} }
@ -235,21 +234,21 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> {
where where
V: Visitor<'de>, V: Visitor<'de>,
{ {
if self.url_params.0.len() != 1 { if self.url_params.len() != 1 {
return Err(PathDeserializerError::custom(format!( return Err(PathDeserializerError::custom(format!(
"wrong number of parameters: {} expected 1", "wrong number of parameters: {} expected 1",
self.url_params.0.len() self.url_params.len()
))); )));
} }
visitor.visit_enum(EnumDeserializer { visitor.visit_enum(EnumDeserializer {
value: &self.url_params.0[0].1, value: &self.url_params[0].1,
}) })
} }
} }
struct MapDeserializer<'de> { struct MapDeserializer<'de> {
params: &'de [(ByteStr, ByteStr)], params: &'de [(ByteStr, PercentDecodedByteStr)],
value: Option<&'de str>, value: Option<&'de str>,
} }
@ -519,7 +518,7 @@ impl<'de> VariantAccess<'de> for UnitVariant {
} }
struct SeqDeserializer<'de> { struct SeqDeserializer<'de> {
params: &'de [(ByteStr, ByteStr)], params: &'de [(ByteStr, PercentDecodedByteStr)],
} }
impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { impl<'de> SeqAccess<'de> for SeqDeserializer<'de> {
@ -561,18 +560,16 @@ mod tests {
a: i32, a: i32,
} }
fn create_url_params<I, K, V>(values: I) -> UrlParams fn create_url_params<I, K, V>(values: I) -> Vec<(ByteStr, PercentDecodedByteStr)>
where where
I: IntoIterator<Item = (K, V)>, I: IntoIterator<Item = (K, V)>,
K: AsRef<str>, K: AsRef<str>,
V: AsRef<str>, V: AsRef<str>,
{ {
UrlParams(
values values
.into_iter() .into_iter()
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v))) .map(|(k, v)| (ByteStr::new(k), PercentDecodedByteStr::new(v).unwrap()))
.collect(), .collect()
)
} }
macro_rules! check_single_value { macro_rules! check_single_value {
@ -601,6 +598,7 @@ mod tests {
check_single_value!(f32, "123", 123.0); check_single_value!(f32, "123", 123.0);
check_single_value!(f64, "123", 123.0); check_single_value!(f64, "123", 123.0);
check_single_value!(String, "abc", "abc"); check_single_value!(String, "abc", "abc");
check_single_value!(String, "one%20two", "one two");
check_single_value!(char, "a", 'a'); check_single_value!(char, "a", 'a');
let url_params = create_url_params(vec![("a", "B")]); let url_params = create_url_params(vec![("a", "B")]);

View File

@ -1,14 +1,24 @@
mod de; mod de;
use super::{rejection::*, FromRequest}; use super::{rejection::*, FromRequest};
use crate::{extract::RequestParts, routing::UrlParams}; use crate::{
extract::RequestParts,
routing::{InvalidUtf8InPathParam, UrlParams},
};
use async_trait::async_trait; use async_trait::async_trait;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::ops::{Deref, DerefMut}; use std::{
borrow::Cow,
ops::{Deref, DerefMut},
};
/// Extractor that will get captures from the URL and parse them using /// Extractor that will get captures from the URL and parse them using
/// [`serde`]. /// [`serde`].
/// ///
/// Any percent encoded parameters will be automatically decoded. The decoded
/// parameters must be valid UTF-8, otherwise `Path` will fail and return a `400
/// Bad Request` response.
///
/// # Example /// # Example
/// ///
/// ```rust,no_run /// ```rust,no_run
@ -140,20 +150,45 @@ where
{ {
type Rejection = PathParamsRejection; type Rejection = PathParamsRejection;
#[allow(warnings)]
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
const EMPTY_URL_PARAMS: &UrlParams = &UrlParams(Vec::new()); let params = match req
let url_params = if let Some(params) = req
.extensions_mut() .extensions_mut()
.and_then(|ext| ext.get::<Option<UrlParams>>()) .and_then(|ext| ext.get::<Option<UrlParams>>())
{ {
params.as_ref().unwrap_or(EMPTY_URL_PARAMS) Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
} else { Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
return Err(InvalidPathParam::new(key.as_str()).into())
}
Some(None) => Cow::Owned(Vec::new()),
None => {
return Err(MissingRouteParams.into()); return Err(MissingRouteParams.into());
}
}; };
T::deserialize(de::PathDeserializer::new(url_params)) T::deserialize(de::PathDeserializer::new(&*params))
.map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0))) .map_err(|err| PathParamsRejection::InvalidPathParam(InvalidPathParam::new(err.0)))
.map(Path) .map(Path)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use crate::{handler::get, Router};
#[tokio::test]
async fn percent_decoding() {
let app = Router::new().route(
"/:key",
get(|Path(param): Path<String>| async move { param }),
);
let client = TestClient::new(app);
let res = client.get("/one%20two").send().await;
assert_eq!(res.text().await, "one two");
}
}

View File

@ -107,7 +107,7 @@ define_rejection! {
/// Rejection type for [`Path`](super::Path) if the capture route /// Rejection type for [`Path`](super::Path) if the capture route
/// param didn't have the expected type. /// param didn't have the expected type.
#[derive(Debug)] #[derive(Debug)]
pub struct InvalidPathParam(String); pub struct InvalidPathParam(pub(crate) String);
impl InvalidPathParam { impl InvalidPathParam {
pub(super) fn new(err: impl Into<String>) -> Self { pub(super) fn new(err: impl Into<String>) -> Self {

View File

@ -9,7 +9,7 @@ use crate::{
OriginalUri, OriginalUri,
}, },
service::HandleError, service::HandleError,
util::ByteStr, util::{ByteStr, PercentDecodedByteStr},
BoxError, BoxError,
}; };
use bytes::Bytes; use bytes::Bytes;
@ -627,22 +627,47 @@ where
} }
} }
#[derive(Debug)] // we store the potential error here such that users can handle invalid path
pub(crate) struct UrlParams(pub(crate) Vec<(ByteStr, ByteStr)>); // params using `Result<Path<T>, _>`. That wouldn't be possible if we
// returned an error immediately when decoding the param
pub(crate) struct UrlParams(
pub(crate) Result<Vec<(ByteStr, PercentDecodedByteStr)>, InvalidUtf8InPathParam>,
);
fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) { fn insert_url_params<B>(req: &mut Request<B>, params: Vec<(String, String)>) {
let params = params let params = params
.into_iter() .into_iter()
.map(|(k, v)| (ByteStr::new(k), ByteStr::new(v))); .map(|(k, v)| {
if let Some(decoded) = PercentDecodedByteStr::new(v) {
Ok((ByteStr::new(k), decoded))
} else {
Err(InvalidUtf8InPathParam {
key: ByteStr::new(k),
})
}
})
.collect::<Result<Vec<_>, _>>();
if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() { if let Some(current) = req.extensions_mut().get_mut::<Option<UrlParams>>() {
match params {
Ok(params) => {
let mut current = current.take().unwrap(); let mut current = current.take().unwrap();
current.0.extend(params); if let Ok(current) = &mut current.0 {
req.extensions_mut().insert(Some(current)); current.extend(params);
} else {
req.extensions_mut()
.insert(Some(UrlParams(params.collect())));
} }
req.extensions_mut().insert(Some(current));
}
Err(err) => {
req.extensions_mut().insert(Some(UrlParams(Err(err))));
}
}
} else {
req.extensions_mut().insert(Some(UrlParams(params)));
}
}
pub(crate) struct InvalidUtf8InPathParam {
pub(crate) key: ByteStr,
} }
/// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed` /// A [`Service`] that responds with `404 Not Found` or `405 Method not allowed`

View File

@ -30,6 +30,34 @@ impl ByteStr {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedByteStr(ByteStr);
impl PercentDecodedByteStr {
pub(crate) fn new<S>(s: S) -> Option<Self>
where
S: AsRef<str>,
{
percent_encoding::percent_decode(s.as_ref().as_bytes())
.decode_utf8()
.ok()
.map(|decoded| Self(ByteStr::new(decoded)))
}
pub(crate) fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Deref for PercentDecodedByteStr {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
pin_project! { pin_project! {
#[project = EitherProj] #[project = EitherProj]
pub(crate) enum Either<A, B> { pub(crate) enum Either<A, B> {