From 763d4e8d215612b7b203a1b17cbd58f149a2f026 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 30 May 2021 15:44:26 +0200 Subject: [PATCH] Routing with dynamic parts! --- examples/key_value_store.rs | 24 +++--- src/error.rs | 13 ++-- src/extract.rs | 30 ++++++++ src/lib.rs | 2 - src/routing.rs | 143 +++++++++++++++++++++++++++++++++--- 5 files changed, 182 insertions(+), 30 deletions(-) diff --git a/examples/key_value_store.rs b/examples/key_value_store.rs index 686dfef7..8447cee5 100644 --- a/examples/key_value_store.rs +++ b/examples/key_value_store.rs @@ -1,3 +1,5 @@ +#![allow(warnings)] + use bytes::Bytes; use http::{Request, StatusCode}; use hyper::Server; @@ -20,9 +22,8 @@ async fn main() { // build our application with some routes let app = tower_web::app() - .at("/get") + .at("/:key") .get(get) - .at("/set") .post(set) // convert it into a `Service` .into_service(); @@ -49,41 +50,36 @@ struct State { db: HashMap, } -#[derive(Deserialize)] -struct GetSetQueryString { - key: String, -} - async fn get( _req: Request, - query: extract::Query, + params: extract::UrlParams, state: extract::Extension, ) -> Result { let state = state.into_inner(); let db = &state.lock().unwrap().db; - let key = query.into_inner().key; + let key = params.get("key")?; - if let Some(value) = db.get(&key) { + if let Some(value) = db.get(key) { Ok(value.clone()) } else { - Err(Error::WithStatus(StatusCode::NOT_FOUND)) + Err(Error::Status(StatusCode::NOT_FOUND)) } } async fn set( _req: Request, - query: extract::Query, + params: extract::UrlParams, value: extract::BytesMaxLength<{ 1024 * 5_000 }>, // ~5mb state: extract::Extension, ) -> Result { let state = state.into_inner(); let db = &mut state.lock().unwrap().db; - let key = query.into_inner().key; + let key = params.get("key")?; let value = value.into_inner(); - db.insert(key, value); + db.insert(key.to_string(), value); Ok(response::Empty) } diff --git a/src/error.rs b/src/error.rs index 8edbde29..fcfae63c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -37,7 +37,10 @@ pub enum Error { PayloadTooLarge, #[error("response failed with status {0}")] - WithStatus(StatusCode), + Status(StatusCode), + + #[error("unknown URL param `{0}`")] + UnknownUrlParam(String), } impl From for Error { @@ -64,14 +67,14 @@ where | Error::QueryStringMissing | Error::DeserializeQueryString(_) => make_response(StatusCode::BAD_REQUEST), - Error::WithStatus(status) => make_response(status), + Error::Status(status) => make_response(status), Error::LengthRequired => make_response(StatusCode::LENGTH_REQUIRED), Error::PayloadTooLarge => make_response(StatusCode::PAYLOAD_TOO_LARGE), - Error::MissingExtension { .. } | Error::SerializeResponseBody(_) => { - make_response(StatusCode::INTERNAL_SERVER_ERROR) - } + Error::MissingExtension { .. } + | Error::SerializeResponseBody(_) + | Error::UnknownUrlParam(_) => make_response(StatusCode::INTERNAL_SERVER_ERROR), Error::Service(err) => match err.downcast::() { Ok(err) => Err(*err), diff --git a/src/extract.rs b/src/extract.rs index dc42ccb3..998bc232 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -6,8 +6,10 @@ use http_body::Body as _; use pin_project::pin_project; use serde::de::DeserializeOwned; use std::{ + collections::HashMap, future::Future, pin::Pin, + str::FromStr, task::{Context, Poll}, }; @@ -181,3 +183,31 @@ impl FromRequest for BytesMaxLength { }) } } + +pub struct UrlParams(HashMap); + +impl UrlParams { + pub fn get(&self, key: &str) -> Result<&str, Error> { + if let Some(value) = self.0.get(key) { + Ok(value) + } else { + Err(Error::UnknownUrlParam(key.to_string())) + } + } +} + +impl FromRequest for UrlParams { + type Future = future::Ready>; + + fn from_request(req: &mut Request) -> Self::Future { + if let Some(params) = req + .extensions_mut() + .get_mut::>() + { + let params = params.take().expect("params already taken").0; + future::ok(Self(params.into_iter().collect())) + } else { + panic!("no url params found for matched route. This is a bug in tower-web") + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 8ec5006a..50a813f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,8 +6,6 @@ Improvements to make: Support extracting headers, perhaps via `headers::Header`? -Actual routing - Improve compile times with lots of routes, can we box and combine routers? Tests diff --git a/src/routing.rs b/src/routing.rs index 5fbaf3ff..47c0bdf8 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -84,13 +84,15 @@ impl RouteAt { } fn add_route_service(self, service: S, method: Method) -> RouteBuilder> { + assert!( + self.route_spec.starts_with(b"/"), + "route spec must start with a slash (`/`)" + ); + let new_app = App { router: Route { service, - route_spec: RouteSpec { - method, - spec: self.route_spec.clone(), - }, + route_spec: RouteSpec::new(method, self.route_spec.clone()), fallback: self.app.router, handler_ready: false, fallback_ready: false, @@ -196,9 +198,47 @@ struct RouteSpec { } impl RouteSpec { - fn matches(&self, req: &Request) -> bool { - // TODO(david): support dynamic placeholders like `/users/:id` - req.method() == self.method && req.uri().path().as_bytes() == self.spec + fn new(method: Method, spec: impl Into) -> Self { + Self { + method, + spec: spec.into(), + } + } +} + +impl RouteSpec { + fn matches(&self, req: &Request) -> Option> { + if req.method() != self.method { + return None; + } + + let path = req.uri().path().as_bytes(); + let path_parts = path.split(|b| *b == b'/'); + + let spec_parts = self.spec.split(|b| *b == b'/'); + + if spec_parts.clone().count() != path_parts.clone().count() { + return None; + } + + let mut params = Vec::new(); + + spec_parts + .zip(path_parts) + .all(|(spec, path)| { + if let Some(key) = spec.strip_prefix(b":") { + let key = std::str::from_utf8(key).unwrap().to_string(); + if let Ok(value) = std::str::from_utf8(path) { + params.push((key, value.to_string())); + true + } else { + false + } + } else { + spec == path + } + }) + .then(|| params) } } @@ -236,8 +276,8 @@ where } } - fn call(&mut self, req: Request) -> Self::Future { - if self.route_spec.matches(&req) { + fn call(&mut self, mut req: Request) -> Self::Future { + if let Some(params) = self.route_spec.matches(&req) { assert!( self.handler_ready, "handler not ready. Did you forget to call `poll_ready`?" @@ -245,6 +285,8 @@ where self.handler_ready = false; + req.extensions_mut().insert(Some(UrlParams(params))); + future::Either::Left(BoxResponseBody(self.service.call(req))) } else { assert!( @@ -260,6 +302,8 @@ where } } +pub(crate) struct UrlParams(pub(crate) Vec<(String, String)>); + #[pin_project] pub struct BoxResponseBody(#[pin] F); @@ -282,3 +326,84 @@ where Poll::Ready(Ok(response)) } } + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + + #[test] + fn test_routing() { + assert_match((Method::GET, "/"), (Method::GET, "/")); + refute_match((Method::GET, "/"), (Method::POST, "/")); + refute_match((Method::POST, "/"), (Method::GET, "/")); + + assert_match((Method::GET, "/foo"), (Method::GET, "/foo")); + assert_match((Method::GET, "/foo/"), (Method::GET, "/foo/")); + refute_match((Method::GET, "/foo"), (Method::GET, "/foo/")); + refute_match((Method::GET, "/foo/"), (Method::GET, "/foo")); + + assert_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar")); + refute_match((Method::GET, "/foo/bar/"), (Method::GET, "/foo/bar")); + refute_match((Method::GET, "/foo/bar"), (Method::GET, "/foo/bar/")); + + assert_match((Method::GET, "/:value"), (Method::GET, "/foo")); + assert_match((Method::GET, "/users/:id"), (Method::GET, "/users/1")); + assert_match( + (Method::GET, "/users/:id/action"), + (Method::GET, "/users/42/action"), + ); + refute_match( + (Method::GET, "/users/:id/action"), + (Method::GET, "/users/42"), + ); + refute_match( + (Method::GET, "/users/:id"), + (Method::GET, "/users/42/action"), + ); + } + + fn assert_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { + let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); + let req = Request::builder() + .method(req_spec.0.clone()) + .uri(req_spec.1) + .body(()) + .unwrap(); + + assert!( + route.matches(&req).is_some(), + "`{} {}` doesn't match `{} {}`", + req.method(), + req.uri().path(), + route.method, + std::str::from_utf8(&route.spec).unwrap(), + ); + } + + fn refute_match(route_spec: (Method, &'static str), req_spec: (Method, &'static str)) { + let route = RouteSpec::new(route_spec.0.clone(), route_spec.1); + let req = Request::builder() + .method(req_spec.0.clone()) + .uri(req_spec.1) + .body(()) + .unwrap(); + + assert!( + route.matches(&req).is_none(), + "`{} {}` shouldn't match `{} {}`", + req.method(), + req.uri().path(), + route.method, + std::str::from_utf8(&route.spec).unwrap(), + ); + } + + fn route(method: Method, uri: &'static str) -> RouteSpec { + RouteSpec::new(method, uri) + } + + fn req(method: Method, uri: &str) -> Request<()> { + Request::builder().uri(uri).method(method).body(()).unwrap() + } +}