Routing with dynamic parts!

This commit is contained in:
David Pedersen 2021-05-30 15:44:26 +02:00
parent 7328127a3d
commit 763d4e8d21
5 changed files with 182 additions and 30 deletions

View File

@ -1,3 +1,5 @@
#![allow(warnings)]
use bytes::Bytes; use bytes::Bytes;
use http::{Request, StatusCode}; use http::{Request, StatusCode};
use hyper::Server; use hyper::Server;
@ -20,9 +22,8 @@ async fn main() {
// build our application with some routes // build our application with some routes
let app = tower_web::app() let app = tower_web::app()
.at("/get") .at("/:key")
.get(get) .get(get)
.at("/set")
.post(set) .post(set)
// convert it into a `Service` // convert it into a `Service`
.into_service(); .into_service();
@ -49,41 +50,36 @@ struct State {
db: HashMap<String, Bytes>, db: HashMap<String, Bytes>,
} }
#[derive(Deserialize)]
struct GetSetQueryString {
key: String,
}
async fn get( async fn get(
_req: Request<Body>, _req: Request<Body>,
query: extract::Query<GetSetQueryString>, params: extract::UrlParams,
state: extract::Extension<SharedState>, state: extract::Extension<SharedState>,
) -> Result<Bytes, Error> { ) -> Result<Bytes, Error> {
let state = state.into_inner(); let state = state.into_inner();
let db = &state.lock().unwrap().db; let db = &state.lock().unwrap().db;
let key = query.into_inner().key; let key = params.get("key")?;
if let Some(value) = db.get(&key) { if let Some(value) = db.get(key) {
Ok(value.clone()) Ok(value.clone())
} else { } else {
Err(Error::WithStatus(StatusCode::NOT_FOUND)) Err(Error::Status(StatusCode::NOT_FOUND))
} }
} }
async fn set( async fn set(
_req: Request<Body>, _req: Request<Body>,
query: extract::Query<GetSetQueryString>, params: extract::UrlParams,
value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb
state: extract::Extension<SharedState>, state: extract::Extension<SharedState>,
) -> Result<response::Empty, Error> { ) -> Result<response::Empty, Error> {
let state = state.into_inner(); let state = state.into_inner();
let db = &mut state.lock().unwrap().db; let db = &mut state.lock().unwrap().db;
let key = query.into_inner().key; let key = params.get("key")?;
let value = value.into_inner(); let value = value.into_inner();
db.insert(key, value); db.insert(key.to_string(), value);
Ok(response::Empty) Ok(response::Empty)
} }

View File

@ -37,7 +37,10 @@ pub enum Error {
PayloadTooLarge, PayloadTooLarge,
#[error("response failed with status {0}")] #[error("response failed with status {0}")]
WithStatus(StatusCode), Status(StatusCode),
#[error("unknown URL param `{0}`")]
UnknownUrlParam(String),
} }
impl From<Infallible> for Error { impl From<Infallible> for Error {
@ -64,14 +67,14 @@ where
| Error::QueryStringMissing | Error::QueryStringMissing
| Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST), | Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST),
Error::WithStatus(status) => make_response(status), Error::Status(status) => make_response(status),
Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED), Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED),
Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE), Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE),
Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => { Error::MissingExtension { .. }
make_response(StatusCode::INTERNAL_SERVER_ERROR) | Error::SerializeResponseBody(_)
} | Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR),
Error::Service(err) => match err.downcast::<Error>() { Error::Service(err) => match err.downcast::<Error>() {
Ok(err) => Err(*err), Ok(err) => Err(*err),

View File

@ -6,8 +6,10 @@ use http_body::Body as _;
use pin_project::pin_project; use pin_project::pin_project;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use std::{ use std::{
collections::HashMap,
future::Future, future::Future,
pin::Pin, pin::Pin,
str::FromStr,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -181,3 +183,31 @@ impl<const N: u64> FromRequest for BytesMaxLength<N> {
}) })
} }
} }
pub struct UrlParams(HashMap<String, String>);
impl UrlParams {
pub fn get(&self, key: &str) -> Result<&str, Error> {
if let Some(value) = self.0.get(key) {
Ok(value)
} else {
Err(Error::UnknownUrlParam(key.to_string()))
}
}
}
impl FromRequest for UrlParams {
type Future = future::Ready<Result<Self, Error>>;
fn from_request(req: &mut Request<Body>) -> Self::Future {
if let Some(params) = req
.extensions_mut()
.get_mut::<Option<crate::routing::UrlParams>>()
{
let params = params.take().expect("params already taken").0;
future::ok(Self(params.into_iter().collect()))
} else {
panic!("no url params found for matched route. This is a bug in tower-web")
}
}
}

View File

@ -6,8 +6,6 @@ Improvements to make:
Support extracting headers, perhaps via `headers::Header`? Support extracting headers, perhaps via `headers::Header`?
Actual routing
Improve compile times with lots of routes, can we box and combine routers? Improve compile times with lots of routes, can we box and combine routers?
Tests Tests

View File

@ -84,13 +84,15 @@ impl<R> RouteAt<R> {
} }
fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Route<S, R>> { fn add_route_service<S>(self, service: S, method: Method) -> RouteBuilder<Route<S, R>> {
assert!(
self.route_spec.starts_with(b"/"),
"route spec must start with a slash (`/`)"
);
let new_app = App { let new_app = App {
router: Route { router: Route {
service, service,
route_spec: RouteSpec { route_spec: RouteSpec::new(method, self.route_spec.clone()),
method,
spec: self.route_spec.clone(),
},
fallback: self.app.router, fallback: self.app.router,
handler_ready: false, handler_ready: false,
fallback_ready: false, fallback_ready: false,
@ -196,9 +198,47 @@ struct RouteSpec {
} }
impl RouteSpec { impl RouteSpec {
fn matches<B>(&self, req: &Request<B>) -> bool { fn new(method: Method, spec: impl Into<Bytes>) -> Self {
// TODO(david): support dynamic placeholders like `/users/:id` Self {
req.method() == self.method && req.uri().path().as_bytes() == self.spec method,
spec: spec.into(),
}
}
}
impl RouteSpec {
fn matches<B>(&self, req: &Request<B>) -> Option<Vec<(String, String)>> {
if req.method() != self.method {
return None;
}
let path = req.uri().path().as_bytes();
let path_parts = path.split(|b| *b == b'/');
let spec_parts = self.spec.split(|b| *b == b'/');
if spec_parts.clone().count() != path_parts.clone().count() {
return None;
}
let mut params = Vec::new();
spec_parts
.zip(path_parts)
.all(|(spec, path)| {
if let Some(key) = spec.strip_prefix(b":") {
let key = std::str::from_utf8(key).unwrap().to_string();
if let Ok(value) = std::str::from_utf8(path) {
params.push((key, value.to_string()));
true
} else {
false
}
} else {
spec == path
}
})
.then(|| params)
} }
} }
@ -236,8 +276,8 @@ where
} }
} }
fn call(&mut self, req: Request<Body>) -> Self::Future { fn call(&mut self, mut req: Request<Body>) -> Self::Future {
if self.route_spec.matches(&req) { if let Some(params) = self.route_spec.matches(&req) {
assert!( assert!(
self.handler_ready, self.handler_ready,
"handler not ready. Did you forget to call `poll_ready`?" "handler not ready. Did you forget to call `poll_ready`?"
@ -245,6 +285,8 @@ where
self.handler_ready = false; self.handler_ready = false;
req.extensions_mut().insert(Some(UrlParams(params)));
future::Either::Left(BoxResponseBody(self.service.call(req))) future::Either::Left(BoxResponseBody(self.service.call(req)))
} else { } else {
assert!( assert!(
@ -260,6 +302,8 @@ where
} }
} }
pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>);
#[pin_project] #[pin_project]
pub struct BoxResponseBody<F>(#[pin] F); pub struct BoxResponseBody<F>(#[pin] F);
@ -282,3 +326,84 @@ where
Poll::Ready(Ok(response)) Poll::Ready(Ok(response))
} }
} }
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn test_routing() {
assert_match((Method::GET, "/"), (Method::GET, "/"));
refute_match((Method::GET, "/"), (Method::POST, "/"));
refute_match((Method::POST, "/"), (Method::GET, "/"));
assert_match((Method::GET, "/foo"), (Method::GET, "/foo"));
assert_match((Method::GET, "/foo/"), (Method::GET, "/foo/"));
refute_match((Method::GET, "/foo"), (Method::GET, "/foo/"));
refute_match((Method::GET, "/foo/"), (Method::GET, "/foo"));
assert_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar"));
refute_match((Method::GET, "/foo/bar/"), (Method::GET, "/foo/bar"));
refute_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar/"));
assert_match((Method::GET, "/:value"), (Method::GET, "/foo"));
assert_match((Method::GET, "/users/:id"), (Method::GET, "/users/1"));
assert_match(
(Method::GET, "/users/:id/action"),
(Method::GET, "/users/42/action"),
);
refute_match(
(Method::GET, "/users/:id/action"),
(Method::GET, "/users/42"),
);
refute_match(
(Method::GET, "/users/:id"),
(Method::GET, "/users/42/action"),
);
}
fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
let req = Request::builder()
.method(req_spec.0.clone())
.uri(req_spec.1)
.body(())
.unwrap();
assert!(
route.matches(&req).is_some(),
"`{} {}` doesn't match `{} {}`",
req.method(),
req.uri().path(),
route.method,
std::str::from_utf8(&route.spec).unwrap(),
);
}
fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) {
let route = RouteSpec::new(route_spec.0.clone(), route_spec.1);
let req = Request::builder()
.method(req_spec.0.clone())
.uri(req_spec.1)
.body(())
.unwrap();
assert!(
route.matches(&req).is_none(),
"`{} {}` shouldn't match `{} {}`",
req.method(),
req.uri().path(),
route.method,
std::str::from_utf8(&route.spec).unwrap(),
);
}
fn route(method: Method, uri: &'static str) -> RouteSpec {
RouteSpec::new(method, uri)
}
fn req(method: Method, uri: &str) -> Request<()> {
Request::builder().uri(uri).method(method).body(()).unwrap()
}
}