update to Tokio 1.0 (#489)

This branch updates Tower to depend on Tokio v1.0. In particular, the
following changes were necessary:

* `tokio::sync::Semaphore` now has a `close` operation, so permit
  acquisition is fallible. Our uses of the semaphore are updated to
  handle this. Also, this allows removing the janky homemade
  implementation of closing semaphores by adding a big pile of
  permits!

* `tokio::sync`'s channels are no longer `Stream`s. This necessitated a
  few changes:
  - Replacing a few explicit `poll_next` calls with `poll_recv`
  - Updating some tests that used `mpsc::Receiver` as a `Stream` to add
    a wrapper type that makes it a `Stream`
  - Updating `CallAll`'s examples (I changed it to just use a
    `futures::channel` MPSC)

* `tokio::time::Sleep` is no longer `Unpin`. Therefore, the rate-limit
  `Service` needs to `Box::pin` it. To avoid the overhead of
  allocating/deallocating `Box`es every time the rate limit is
  exhausted, I moved the `Sleep` out of the `State` enum and onto the
  `Service` struct, and changed the code to `reset` it every time the
  service is rate-limited. This way, we only allocate the box once when
  the service is created.

There should be no actual changes in functionality.

Signed-off-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
Eliza Weisman 2020-12-23 11:59:54 -08:00 committed by GitHub
parent 124816a40e
commit 45974d018d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 78 additions and 51 deletions

View File

@ -26,8 +26,8 @@ edition = "2018"
[features] [features]
default = ["log"] default = ["log"]
log = ["tracing/log"] log = ["tracing/log"]
balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio/stream"] balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio-stream"]
buffer = ["tokio/sync", "tokio/rt", "tokio/stream"] buffer = ["tokio/sync", "tokio/rt", "tokio-stream"]
discover = [] discover = []
filter = [] filter = []
hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time"] hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time"]
@ -55,14 +55,16 @@ hdrhistogram = { version = "6.0", optional = true }
indexmap = { version = "1.0.2", optional = true } indexmap = { version = "1.0.2", optional = true }
rand = { version = "0.7", features = ["small_rng"], optional = true } rand = { version = "0.7", features = ["small_rng"], optional = true }
slab = { version = "0.4", optional = true } slab = { version = "0.4", optional = true }
tokio = { version = "0.3.2", optional = true, features = ["sync"] } tokio = { version = "1", optional = true, features = ["sync"] }
tokio-stream = { version = "0.1.0", optional = true }
[dev-dependencies] [dev-dependencies]
futures-util = { version = "0.3", default-features = false, features = ["alloc", "async-await"] } futures = "0.3"
hdrhistogram = "6.0" hdrhistogram = "6.0"
quickcheck = { version = "0.9", default-features = false } quickcheck = { version = "0.9", default-features = false }
tokio = { version = "0.3.2", features = ["macros", "stream", "sync", "test-util", "rt-multi-thread"] } tokio = { version = "1", features = ["macros", "sync", "test-util", "rt-multi-thread"] }
tokio-test = "0.3" tokio-stream = "0.1"
tokio-test = "0.4"
tower-test = { version = "0.4", path = "../tower-test" } tower-test = { version = "0.4", path = "../tower-test" }
tracing-subscriber = "0.2.14" tracing-subscriber = "0.2.14"
# env_logger = { version = "0.5.3", default-features = false } # env_logger = { version = "0.5.3", default-features = false }

View File

@ -90,7 +90,7 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project(); let mut this = self.project();
while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_next(cx) { while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_recv(cx) {
this.services.remove(sid); this.services.remove(sid);
tracing::trace!( tracing::trace!(
pool.services = this.services.len(), pool.services = this.services.len(),

View File

@ -5,7 +5,6 @@ use super::{
}; };
use crate::semaphore::Semaphore; use crate::semaphore::Semaphore;
use futures_core::ready;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tower_service::Service; use tower_service::Service;
@ -116,9 +115,9 @@ where
// Then, poll to acquire a semaphore permit. If we acquire a permit, // Then, poll to acquire a semaphore permit. If we acquire a permit,
// then there's enough buffer capacity to send a new request. Otherwise, // then there's enough buffer capacity to send a new request. Otherwise,
// we need to wait for capacity. // we need to wait for capacity.
ready!(self.semaphore.poll_acquire(cx)); self.semaphore
.poll_acquire(cx)
Poll::Ready(Ok(())) .map_err(|_| self.get_worker_error())
} }
fn call(&mut self, request: Request) -> Self::Future { fn call(&mut self, request: Request) -> Self::Future {

View File

@ -10,7 +10,7 @@ use std::{
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::{stream::Stream, sync::mpsc}; use tokio::sync::mpsc;
use tower_service::Service; use tower_service::Service;
/// Task that handles processing the buffer. This type should not be used /// Task that handles processing the buffer. This type should not be used
@ -96,7 +96,7 @@ where
} }
// Get the next request // Get the next request
while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_next(cx)) { while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
if !msg.tx.is_closed() { if !msg.tx.is_closed() {
tracing::trace!("processing new request"); tracing::trace!("processing new request");
return Poll::Ready(Some((msg, true))); return Poll::Ready(Some((msg, true)));

View File

@ -48,7 +48,10 @@ where
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// First, poll the semaphore... // First, poll the semaphore...
ready!(self.semaphore.poll_acquire(cx)); ready!(self.semaphore.poll_acquire(cx)).expect(
"ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
should never fail",
);
// ...and if it's ready, poll the inner service. // ...and if it's ready, poll the inner service.
self.inner.poll_ready(cx) self.inner.poll_ready(cx)
} }

View File

@ -15,20 +15,22 @@ pub struct RateLimit<T> {
inner: T, inner: T,
rate: Rate, rate: Rate,
state: State, state: State,
sleep: Pin<Box<Sleep>>,
} }
#[derive(Debug)] #[derive(Debug)]
enum State { enum State {
// The service has hit its limit // The service has hit its limit
Limited(Sleep), Limited,
Ready { until: Instant, rem: u64 }, Ready { until: Instant, rem: u64 },
} }
impl<T> RateLimit<T> { impl<T> RateLimit<T> {
/// Create a new rate limiter /// Create a new rate limiter
pub fn new(inner: T, rate: Rate) -> Self { pub fn new(inner: T, rate: Rate) -> Self {
let until = Instant::now();
let state = State::Ready { let state = State::Ready {
until: Instant::now(), until,
rem: rate.num(), rem: rate.num(),
}; };
@ -36,6 +38,10 @@ impl<T> RateLimit<T> {
inner, inner,
rate, rate,
state: state, state: state,
// The sleep won't actually be used with this duration, but
// we create it eagerly so that we can reset it in place rather than
// `Box::pin`ning a new `Sleep` every time we need one.
sleep: Box::pin(tokio::time::sleep_until(until)),
} }
} }
@ -66,8 +72,8 @@ where
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.state { match self.state {
State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))), State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
State::Limited(ref mut sleep) => { State::Limited => {
if let Poll::Pending = Pin::new(sleep).poll(cx) { if let Poll::Pending = Pin::new(&mut self.sleep).poll(cx) {
tracing::trace!("rate limit exceeded; sleeping."); tracing::trace!("rate limit exceeded; sleeping.");
return Poll::Pending; return Poll::Pending;
} }
@ -98,14 +104,16 @@ where
self.state = State::Ready { until, rem }; self.state = State::Ready { until, rem };
} else { } else {
// The service is disabled until further notice // The service is disabled until further notice
let sleep = tokio::time::sleep_until(until); // Reset the sleep future in place, so that we don't have to
self.state = State::Limited(sleep); // deallocate the existing box and allocate a new one.
self.sleep.as_mut().reset(until);
self.state = State::Limited;
} }
// Call the inner future // Call the inner future
self.inner.call(request) self.inner.call(request)
} }
State::Limited(..) => panic!("service not ready; poll_ready must be called first"), State::Limited => panic!("service not ready; poll_ready must be called first"),
} }
} }
} }

View File

@ -1,4 +1,4 @@
pub(crate) use self::sync::OwnedSemaphorePermit as Permit; pub(crate) use self::sync::{AcquireError, OwnedSemaphorePermit as Permit};
use futures_core::ready; use futures_core::ready;
use std::{ use std::{
fmt, fmt,
@ -19,11 +19,10 @@ pub(crate) struct Semaphore {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Close { pub(crate) struct Close {
semaphore: Weak<sync::Semaphore>, semaphore: Weak<sync::Semaphore>,
permits: usize,
} }
enum State { enum State {
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>), Waiting(Pin<Box<dyn Future<Output = Result<Permit, AcquireError>> + Send + 'static>>),
Ready(Permit), Ready(Permit),
Empty, Empty,
} }
@ -33,7 +32,6 @@ impl Semaphore {
let semaphore = Arc::new(sync::Semaphore::new(permits)); let semaphore = Arc::new(sync::Semaphore::new(permits));
let close = Close { let close = Close {
semaphore: Arc::downgrade(&semaphore), semaphore: Arc::downgrade(&semaphore),
permits,
}; };
let semaphore = Self { let semaphore = Self {
semaphore, semaphore,
@ -49,12 +47,12 @@ impl Semaphore {
} }
} }
pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> { pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), AcquireError>> {
loop { loop {
self.state = match self.state { self.state = match self.state {
State::Ready(_) => return Poll::Ready(()), State::Ready(_) => return Poll::Ready(Ok(())),
State::Waiting(ref mut fut) => { State::Waiting(ref mut fut) => {
let permit = ready!(Pin::new(fut).poll(cx)); let permit = ready!(Pin::new(fut).poll(cx))?;
State::Ready(permit) State::Ready(permit)
} }
State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())),
@ -95,19 +93,8 @@ impl fmt::Debug for State {
impl Close { impl Close {
/// Close the semaphore, waking any remaining tasks currently awaiting a permit. /// Close the semaphore, waking any remaining tasks currently awaiting a permit.
pub(crate) fn close(self) { pub(crate) fn close(self) {
// The maximum number of permits that a `tokio::sync::Semaphore`
// can hold is usize::MAX >> 3. If we attempt to add more than that
// number of permits, the semaphore will panic.
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
// > 3 - initial permits) the semaphore impl panics (I think due to a
// bug in tokio?).
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
// publicly so we don't have to do this nonsense...
const MAX: usize = std::usize::MAX >> 4;
if let Some(semaphore) = self.semaphore.upgrade() { if let Some(semaphore) = self.semaphore.upgrade() {
// If we added `MAX - available_permits`, any tasks that are semaphore.close()
// currently holding permits could drop them, overflowing the max.
semaphore.add_permits(MAX - self.permits);
} }
} }
} }

View File

@ -20,11 +20,11 @@ use tower_service::Service;
/// # use std::error::Error; /// # use std::error::Error;
/// # use std::rc::Rc; /// # use std::rc::Rc;
/// # /// #
/// use futures_util::future::{ready, Ready}; /// use futures::future::{ready, Ready};
/// use futures_util::StreamExt; /// use futures::StreamExt;
/// use futures::channel::mpsc;
/// use tower_service::Service; /// use tower_service::Service;
/// use tower::util::ServiceExt; /// use tower::util::ServiceExt;
/// use tokio::prelude::*;
/// ///
/// // First, we need to have a Service to process our requests. /// // First, we need to have a Service to process our requests.
/// #[derive(Debug, Eq, PartialEq)] /// #[derive(Debug, Eq, PartialEq)]
@ -46,15 +46,17 @@ use tower_service::Service;
/// #[tokio::main] /// #[tokio::main]
/// async fn main() { /// async fn main() {
/// // Next, we need a Stream of requests. /// // Next, we need a Stream of requests.
/// let (mut reqs, rx) = tokio::sync::mpsc::unbounded_channel(); // TODO(eliza): when `tokio-util` has a nice way to convert MPSCs to streams,
// tokio::sync::mpsc again?
/// let (mut reqs, rx) = mpsc::unbounded();
/// // Note that we have to help Rust out here by telling it what error type to use. /// // Note that we have to help Rust out here by telling it what error type to use.
/// // Specifically, it has to be From<Service::Error> + From<Stream::Error>. /// // Specifically, it has to be From<Service::Error> + From<Stream::Error>.
/// let mut rsps = FirstLetter.call_all(rx); /// let mut rsps = FirstLetter.call_all(rx);
/// ///
/// // Now, let's send a few requests and then check that we get the corresponding responses. /// // Now, let's send a few requests and then check that we get the corresponding responses.
/// reqs.send("one"); /// reqs.unbounded_send("one").unwrap();
/// reqs.send("two"); /// reqs.unbounded_send("two").unwrap();
/// reqs.send("three"); /// reqs.unbounded_send("three").unwrap();
/// drop(reqs); /// drop(reqs);
/// ///
/// // We then loop over the response Strem that we get back from call_all. /// // We then loop over the response Strem that we get back from call_all.

View File

@ -37,7 +37,7 @@ fn stress() {
let _t = support::trace_init(); let _t = support::trace_init();
let mut task = task::spawn(()); let mut task = task::spawn(());
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<_, &'static str>>(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<_, &'static str>>();
let mut cache = Balance::<_, Req>::new(rx); let mut cache = Balance::<_, Req>::new(support::IntoStream(rx));
let mut nready = 0; let mut nready = 0;
let mut services = slab::Slab::<(mock::Handle<Req, Req>, bool)>::new(); let mut services = slab::Slab::<(mock::Handle<Req, Req>, bool)>::new();

View File

@ -1,5 +1,10 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio_stream::Stream;
pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard { pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard {
let subscriber = tracing_subscriber::fmt() let subscriber = tracing_subscriber::fmt()
.with_test_writer() .with_test_writer()
@ -8,3 +13,23 @@ pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard {
.finish(); .finish();
tracing::subscriber::set_default(subscriber) tracing::subscriber::set_default(subscriber)
} }
#[pin_project::pin_project]
#[derive(Clone, Debug)]
pub struct IntoStream<S>(#[pin] pub S);
impl<I> Stream for IntoStream<mpsc::Receiver<I>> {
type Item = I;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().0.poll_recv(cx)
}
}
impl<I> Stream for IntoStream<mpsc::UnboundedReceiver<I>> {
type Item = I;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().0.poll_recv(cx)
}
}

View File

@ -1,3 +1,4 @@
use super::support;
use futures_core::Stream; use futures_core::Stream;
use futures_util::{ use futures_util::{
future::{ready, Ready}, future::{ready, Ready},
@ -39,7 +40,7 @@ impl Service<&'static str> for Srv {
#[test] #[test]
fn ordered() { fn ordered() {
let _t = super::support::trace_init(); let _t = support::trace_init();
let mut mock = task::spawn(()); let mut mock = task::spawn(());
@ -50,7 +51,7 @@ fn ordered() {
admit: admit.clone(), admit: admit.clone(),
}; };
let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let ca = srv.call_all(rx); let ca = srv.call_all(support::IntoStream(rx));
pin_mut!(ca); pin_mut!(ca);
assert_pending!(mock.enter(|cx, _| ca.as_mut().poll_next(cx))); assert_pending!(mock.enter(|cx, _| ca.as_mut().poll_next(cx)));
@ -112,7 +113,7 @@ fn ordered() {
#[tokio::test(flavor = "current_thread")] #[tokio::test(flavor = "current_thread")]
async fn unordered() { async fn unordered() {
let _t = super::support::trace_init(); let _t = support::trace_init();
let (mock, handle) = mock::pair::<_, &'static str>(); let (mock, handle) = mock::pair::<_, &'static str>();
pin_mut!(handle); pin_mut!(handle);