From 45974d018d6d229270aa9a81170093e046b2a86c Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Wed, 23 Dec 2020 11:59:54 -0800 Subject: [PATCH] update to Tokio 1.0 (#489) This branch updates Tower to depend on Tokio v1.0. In particular, the following changes were necessary: * `tokio::sync::Semaphore` now has a `close` operation, so permit acquisition is fallible. Our uses of the semaphore are updated to handle this. Also, this allows removing the janky homemade implementation of closing semaphores by adding a big pile of permits! * `tokio::sync`'s channels are no longer `Stream`s. This necessitated a few changes: - Replacing a few explicit `poll_next` calls with `poll_recv` - Updating some tests that used `mpsc::Receiver` as a `Stream` to add a wrapper type that makes it a `Stream` - Updating `CallAll`'s examples (I changed it to just use a `futures::channel` MPSC) * `tokio::time::Sleep` is no longer `Unpin`. Therefore, the rate-limit `Service` needs to `Box::pin` it. To avoid the overhead of allocating/deallocating `Box`es every time the rate limit is exhausted, I moved the `Sleep` out of the `State` enum and onto the `Service` struct, and changed the code to `reset` it every time the service is rate-limited. This way, we only allocate the box once when the service is created. There should be no actual changes in functionality. Signed-off-by: Eliza Weisman --- tower/Cargo.toml | 14 ++++++++------ tower/src/balance/pool/mod.rs | 2 +- tower/src/buffer/service.rs | 7 +++---- tower/src/buffer/worker.rs | 4 ++-- tower/src/limit/concurrency/service.rs | 5 ++++- tower/src/limit/rate/service.rs | 22 +++++++++++++++------- tower/src/semaphore.rs | 25 ++++++------------------- tower/src/util/call_all/ordered.rs | 16 +++++++++------- tower/tests/balance/main.rs | 2 +- tower/tests/support.rs | 25 +++++++++++++++++++++++++ tower/tests/util/call_all.rs | 7 ++++--- 11 files changed, 78 insertions(+), 51 deletions(-) diff --git a/tower/Cargo.toml b/tower/Cargo.toml index 52b3a661..c11b1f62 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -26,8 +26,8 @@ edition = "2018" [features] default = ["log"] log = ["tracing/log"] -balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio/stream"] -buffer = ["tokio/sync", "tokio/rt", "tokio/stream"] +balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio-stream"] +buffer = ["tokio/sync", "tokio/rt", "tokio-stream"] discover = [] filter = [] hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time"] @@ -55,14 +55,16 @@ hdrhistogram = { version = "6.0", optional = true } indexmap = { version = "1.0.2", optional = true } rand = { version = "0.7", features = ["small_rng"], optional = true } slab = { version = "0.4", optional = true } -tokio = { version = "0.3.2", optional = true, features = ["sync"] } +tokio = { version = "1", optional = true, features = ["sync"] } +tokio-stream = { version = "0.1.0", optional = true } [dev-dependencies] -futures-util = { version = "0.3", default-features = false, features = ["alloc", "async-await"] } +futures = "0.3" hdrhistogram = "6.0" quickcheck = { version = "0.9", default-features = false } -tokio = { version = "0.3.2", features = ["macros", "stream", "sync", "test-util", "rt-multi-thread"] } -tokio-test = "0.3" +tokio = { version = "1", features = ["macros", "sync", "test-util", "rt-multi-thread"] } +tokio-stream = "0.1" +tokio-test = "0.4" tower-test = { version = "0.4", path = "../tower-test" } tracing-subscriber = "0.2.14" # env_logger = { version = "0.5.3", default-features = false } diff --git a/tower/src/balance/pool/mod.rs b/tower/src/balance/pool/mod.rs index 705ae916..799aaf84 100644 --- a/tower/src/balance/pool/mod.rs +++ b/tower/src/balance/pool/mod.rs @@ -90,7 +90,7 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.project(); - while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_next(cx) { + while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_recv(cx) { this.services.remove(sid); tracing::trace!( pool.services = this.services.len(), diff --git a/tower/src/buffer/service.rs b/tower/src/buffer/service.rs index e3a77fb3..ff414bfb 100644 --- a/tower/src/buffer/service.rs +++ b/tower/src/buffer/service.rs @@ -5,7 +5,6 @@ use super::{ }; use crate::semaphore::Semaphore; -use futures_core::ready; use std::task::{Context, Poll}; use tokio::sync::{mpsc, oneshot}; use tower_service::Service; @@ -116,9 +115,9 @@ where // Then, poll to acquire a semaphore permit. If we acquire a permit, // then there's enough buffer capacity to send a new request. Otherwise, // we need to wait for capacity. - ready!(self.semaphore.poll_acquire(cx)); - - Poll::Ready(Ok(())) + self.semaphore + .poll_acquire(cx) + .map_err(|_| self.get_worker_error()) } fn call(&mut self, request: Request) -> Self::Future { diff --git a/tower/src/buffer/worker.rs b/tower/src/buffer/worker.rs index 1f70c8e2..c8d8d6a9 100644 --- a/tower/src/buffer/worker.rs +++ b/tower/src/buffer/worker.rs @@ -10,7 +10,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tokio::{stream::Stream, sync::mpsc}; +use tokio::sync::mpsc; use tower_service::Service; /// Task that handles processing the buffer. This type should not be used @@ -96,7 +96,7 @@ where } // Get the next request - while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_next(cx)) { + while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) { if !msg.tx.is_closed() { tracing::trace!("processing new request"); return Poll::Ready(Some((msg, true))); diff --git a/tower/src/limit/concurrency/service.rs b/tower/src/limit/concurrency/service.rs index 790b46a7..fc3377e7 100644 --- a/tower/src/limit/concurrency/service.rs +++ b/tower/src/limit/concurrency/service.rs @@ -48,7 +48,10 @@ where fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { // First, poll the semaphore... - ready!(self.semaphore.poll_acquire(cx)); + ready!(self.semaphore.poll_acquire(cx)).expect( + "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \ + should never fail", + ); // ...and if it's ready, poll the inner service. self.inner.poll_ready(cx) } diff --git a/tower/src/limit/rate/service.rs b/tower/src/limit/rate/service.rs index e7332ea8..d641c341 100644 --- a/tower/src/limit/rate/service.rs +++ b/tower/src/limit/rate/service.rs @@ -15,20 +15,22 @@ pub struct RateLimit { inner: T, rate: Rate, state: State, + sleep: Pin>, } #[derive(Debug)] enum State { // The service has hit its limit - Limited(Sleep), + Limited, Ready { until: Instant, rem: u64 }, } impl RateLimit { /// Create a new rate limiter pub fn new(inner: T, rate: Rate) -> Self { + let until = Instant::now(); let state = State::Ready { - until: Instant::now(), + until, rem: rate.num(), }; @@ -36,6 +38,10 @@ impl RateLimit { inner, rate, state: state, + // The sleep won't actually be used with this duration, but + // we create it eagerly so that we can reset it in place rather than + // `Box::pin`ning a new `Sleep` every time we need one. + sleep: Box::pin(tokio::time::sleep_until(until)), } } @@ -66,8 +72,8 @@ where fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match self.state { State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))), - State::Limited(ref mut sleep) => { - if let Poll::Pending = Pin::new(sleep).poll(cx) { + State::Limited => { + if let Poll::Pending = Pin::new(&mut self.sleep).poll(cx) { tracing::trace!("rate limit exceeded; sleeping."); return Poll::Pending; } @@ -98,14 +104,16 @@ where self.state = State::Ready { until, rem }; } else { // The service is disabled until further notice - let sleep = tokio::time::sleep_until(until); - self.state = State::Limited(sleep); + // Reset the sleep future in place, so that we don't have to + // deallocate the existing box and allocate a new one. + self.sleep.as_mut().reset(until); + self.state = State::Limited; } // Call the inner future self.inner.call(request) } - State::Limited(..) => panic!("service not ready; poll_ready must be called first"), + State::Limited => panic!("service not ready; poll_ready must be called first"), } } } diff --git a/tower/src/semaphore.rs b/tower/src/semaphore.rs index ea1c005c..fecc0aba 100644 --- a/tower/src/semaphore.rs +++ b/tower/src/semaphore.rs @@ -1,4 +1,4 @@ -pub(crate) use self::sync::OwnedSemaphorePermit as Permit; +pub(crate) use self::sync::{AcquireError, OwnedSemaphorePermit as Permit}; use futures_core::ready; use std::{ fmt, @@ -19,11 +19,10 @@ pub(crate) struct Semaphore { #[derive(Debug)] pub(crate) struct Close { semaphore: Weak, - permits: usize, } enum State { - Waiting(Pin + Send + 'static>>), + Waiting(Pin> + Send + 'static>>), Ready(Permit), Empty, } @@ -33,7 +32,6 @@ impl Semaphore { let semaphore = Arc::new(sync::Semaphore::new(permits)); let close = Close { semaphore: Arc::downgrade(&semaphore), - permits, }; let semaphore = Self { semaphore, @@ -49,12 +47,12 @@ impl Semaphore { } } - pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> { + pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { loop { self.state = match self.state { - State::Ready(_) => return Poll::Ready(()), + State::Ready(_) => return Poll::Ready(Ok(())), State::Waiting(ref mut fut) => { - let permit = ready!(Pin::new(fut).poll(cx)); + let permit = ready!(Pin::new(fut).poll(cx))?; State::Ready(permit) } State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), @@ -95,19 +93,8 @@ impl fmt::Debug for State { impl Close { /// Close the semaphore, waking any remaining tasks currently awaiting a permit. pub(crate) fn close(self) { - // The maximum number of permits that a `tokio::sync::Semaphore` - // can hold is usize::MAX >> 3. If we attempt to add more than that - // number of permits, the semaphore will panic. - // XXX(eliza): another shift is kinda janky but if we add (usize::MAX - // > 3 - initial permits) the semaphore impl panics (I think due to a - // bug in tokio?). - // TODO(eliza): Tokio should _really_ just expose `Semaphore::close` - // publicly so we don't have to do this nonsense... - const MAX: usize = std::usize::MAX >> 4; if let Some(semaphore) = self.semaphore.upgrade() { - // If we added `MAX - available_permits`, any tasks that are - // currently holding permits could drop them, overflowing the max. - semaphore.add_permits(MAX - self.permits); + semaphore.close() } } } diff --git a/tower/src/util/call_all/ordered.rs b/tower/src/util/call_all/ordered.rs index c3ae7e93..235cb90b 100644 --- a/tower/src/util/call_all/ordered.rs +++ b/tower/src/util/call_all/ordered.rs @@ -20,11 +20,11 @@ use tower_service::Service; /// # use std::error::Error; /// # use std::rc::Rc; /// # -/// use futures_util::future::{ready, Ready}; -/// use futures_util::StreamExt; +/// use futures::future::{ready, Ready}; +/// use futures::StreamExt; +/// use futures::channel::mpsc; /// use tower_service::Service; /// use tower::util::ServiceExt; -/// use tokio::prelude::*; /// /// // First, we need to have a Service to process our requests. /// #[derive(Debug, Eq, PartialEq)] @@ -46,15 +46,17 @@ use tower_service::Service; /// #[tokio::main] /// async fn main() { /// // Next, we need a Stream of requests. -/// let (mut reqs, rx) = tokio::sync::mpsc::unbounded_channel(); +// TODO(eliza): when `tokio-util` has a nice way to convert MPSCs to streams, +// tokio::sync::mpsc again? +/// let (mut reqs, rx) = mpsc::unbounded(); /// // Note that we have to help Rust out here by telling it what error type to use. /// // Specifically, it has to be From + From. /// let mut rsps = FirstLetter.call_all(rx); /// /// // Now, let's send a few requests and then check that we get the corresponding responses. -/// reqs.send("one"); -/// reqs.send("two"); -/// reqs.send("three"); +/// reqs.unbounded_send("one").unwrap(); +/// reqs.unbounded_send("two").unwrap(); +/// reqs.unbounded_send("three").unwrap(); /// drop(reqs); /// /// // We then loop over the response Strem that we get back from call_all. diff --git a/tower/tests/balance/main.rs b/tower/tests/balance/main.rs index bc323da5..c526418d 100644 --- a/tower/tests/balance/main.rs +++ b/tower/tests/balance/main.rs @@ -37,7 +37,7 @@ fn stress() { let _t = support::trace_init(); let mut task = task::spawn(()); let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); - let mut cache = Balance::<_, Req>::new(rx); + let mut cache = Balance::<_, Req>::new(support::IntoStream(rx)); let mut nready = 0; let mut services = slab::Slab::<(mock::Handle, bool)>::new(); diff --git a/tower/tests/support.rs b/tower/tests/support.rs index ba67c0cb..6a036ec9 100644 --- a/tower/tests/support.rs +++ b/tower/tests/support.rs @@ -1,5 +1,10 @@ #![allow(dead_code)] +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; +use tokio_stream::Stream; + pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard { let subscriber = tracing_subscriber::fmt() .with_test_writer() @@ -8,3 +13,23 @@ pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard { .finish(); tracing::subscriber::set_default(subscriber) } + +#[pin_project::pin_project] +#[derive(Clone, Debug)] +pub struct IntoStream(#[pin] pub S); + +impl Stream for IntoStream> { + type Item = I; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_recv(cx) + } +} + +impl Stream for IntoStream> { + type Item = I; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().0.poll_recv(cx) + } +} diff --git a/tower/tests/util/call_all.rs b/tower/tests/util/call_all.rs index a7633814..373b3aea 100644 --- a/tower/tests/util/call_all.rs +++ b/tower/tests/util/call_all.rs @@ -1,3 +1,4 @@ +use super::support; use futures_core::Stream; use futures_util::{ future::{ready, Ready}, @@ -39,7 +40,7 @@ impl Service<&'static str> for Srv { #[test] fn ordered() { - let _t = super::support::trace_init(); + let _t = support::trace_init(); let mut mock = task::spawn(()); @@ -50,7 +51,7 @@ fn ordered() { admit: admit.clone(), }; let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let ca = srv.call_all(rx); + let ca = srv.call_all(support::IntoStream(rx)); pin_mut!(ca); assert_pending!(mock.enter(|cx, _| ca.as_mut().poll_next(cx))); @@ -112,7 +113,7 @@ fn ordered() { #[tokio::test(flavor = "current_thread")] async fn unordered() { - let _t = super::support::trace_init(); + let _t = support::trace_init(); let (mock, handle) = mock::pair::<_, &'static str>(); pin_mut!(handle);