From 93cdfe8c5f25b815afbdb502922a8ded87c5c584 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Tue, 17 Aug 2021 19:00:24 +0200 Subject: [PATCH] Work around for http2 hang with `Or` (#199) This is a nasty hack that works around https://github.com/hyperium/hyper/issues/2621. Fixes https://github.com/tokio-rs/axum/issues/191 --- examples/hello_world.rs | 2 +- src/routing/mod.rs | 39 +++++++++++++- src/routing/or.rs | 24 +++++++-- src/tests/or.rs | 115 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 175 insertions(+), 5 deletions(-) diff --git a/examples/hello_world.rs b/examples/hello_world.rs index bdb0a877..56db7946 100644 --- a/examples/hello_world.rs +++ b/examples/hello_world.rs @@ -16,7 +16,7 @@ async fn main() { tracing_subscriber::fmt::init(); // build our application with a route - let app = route("/", get(handler)); + let app = route("/foo", get(handler)).or(route("/bar", get(handler))); // run it let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 3d6d4e90..dfc1f54e 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -547,7 +547,11 @@ where fn call(&mut self, request: Request) -> Self::Future { let mut res = Response::new(crate::body::empty()); - res.extensions_mut().insert(FromEmptyRouter { request }); + + if request.extensions().get::().is_some() { + res.extensions_mut().insert(FromEmptyRouter { request }); + } + *res.status_mut() = self.status; EmptyRouterFuture { future: futures_util::future::ok(res), @@ -565,6 +569,39 @@ struct FromEmptyRouter { request: Request, } +/// We need to track whether we're inside an `Or` or not, and only if we then +/// should we save the request into the response extensions. +/// +/// This is to work around https://github.com/hyperium/hyper/issues/2621. +/// +/// Since ours can be nested we have to track the depth to know when we're +/// leaving the top most `Or`. +/// +/// Hopefully when https://github.com/hyperium/hyper/issues/2621 is resolved we +/// can remove this nasty hack. +#[derive(Debug)] +struct OrDepth(usize); + +impl OrDepth { + fn new() -> Self { + Self(1) + } + + fn increment(&mut self) { + self.0 += 1; + } + + fn decrement(&mut self) { + self.0 -= 1; + } +} + +impl PartialEq for &mut OrDepth { + fn eq(&self, other: &usize) -> bool { + self.0 == *other + } +} + #[derive(Debug, Clone)] pub(crate) struct PathPattern(Arc); diff --git a/src/routing/or.rs b/src/routing/or.rs index 0c884e5a..4d5ee97e 100644 --- a/src/routing/or.rs +++ b/src/routing/or.rs @@ -1,6 +1,6 @@ //! [`Or`] used to combine two services into one. -use super::{FromEmptyRouter, RoutingDsl}; +use super::{FromEmptyRouter, OrDepth, RoutingDsl}; use crate::body::BoxBody; use futures_util::ready; use http::{Request, Response}; @@ -46,7 +46,13 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { + if let Some(count) = req.extensions_mut().get_mut::() { + count.increment(); + } else { + req.extensions_mut().insert(OrDepth::new()); + } + ResponseFuture { state: State::FirstFuture { f: self.first.clone().oneshot(req), @@ -100,7 +106,7 @@ where StateProj::FirstFuture { f } => { let mut response = ready!(f.poll(cx)?); - let req = if let Some(ext) = response + let mut req = if let Some(ext) = response .extensions_mut() .remove::>() { @@ -109,6 +115,18 @@ where return Poll::Ready(Ok(response)); }; + let mut leaving_outermost_or = false; + if let Some(depth) = req.extensions_mut().get_mut::() { + if depth == 1 { + leaving_outermost_or = true; + } else { + depth.decrement(); + } + } + if leaving_outermost_or { + req.extensions_mut().remove::(); + } + let second = this.second.take().expect("future polled after completion"); State::SecondFuture { diff --git a/src/tests/or.rs b/src/tests/or.rs index 3c432751..b8e3766f 100644 --- a/src/tests/or.rs +++ b/src/tests/or.rs @@ -41,6 +41,121 @@ async fn basic() { assert_eq!(res.status(), StatusCode::NOT_FOUND); } +#[tokio::test] +async fn multiple_ors_balanced_differently() { + let one = route("/one", get(|| async { "one" })); + let two = route("/two", get(|| async { "two" })); + let three = route("/three", get(|| async { "three" })); + let four = route("/four", get(|| async { "four" })); + + test( + "one", + one.clone() + .or(two.clone()) + .or(three.clone()) + .or(four.clone()), + ) + .await; + + test( + "two", + one.clone() + .or(two.clone()) + .or(three.clone().or(four.clone())), + ) + .await; + + test( + "three", + one.clone() + .or(two.clone().or(three.clone()).or(four.clone())), + ) + .await; + + test("four", one.or(two.or(three.or(four)))).await; + + async fn test(name: &str, app: S) + where + S: Service, Response = Response> + Clone + Send + 'static, + ResBody: http_body::Body + Send + 'static, + ResBody::Data: Send, + ResBody::Error: Into, + S::Future: Send, + S::Error: Into, + { + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + for n in ["one", "two", "three", "four"].iter() { + println!("running: {} / {}", name, n); + let res = client + .get(format!("http://{}/{}", addr, n)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await.unwrap(), *n); + } + } +} + +#[tokio::test] +async fn or_nested_inside_other_thing() { + let inner = route("/bar", get(|| async {})).or(route("/baz", get(|| async {}))); + let app = nest("/foo", inner); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/foo/bar", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/foo/baz", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn or_with_route_following() { + let one = route("/one", get(|| async { "one" })); + let two = route("/two", get(|| async { "two" })); + let app = one.or(two).route("/three", get(|| async { "three" })); + + let addr = run_in_background(app).await; + + let client = reqwest::Client::new(); + + let res = client + .get(format!("http://{}/one", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/two", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); + + let res = client + .get(format!("http://{}/three", addr)) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::OK); +} + #[tokio::test] async fn layer() { let one = route("/foo", get(|| async {}));