update to Tokio 1.0 (#489)

This branch updates Tower to depend on Tokio v1.0. In particular, the
following changes were necessary:

* `tokio::sync::Semaphore` now has a `close` operation, so permit
  acquisition is fallible. Our uses of the semaphore are updated to
  handle this. Also, this allows removing the janky homemade
  implementation of closing semaphores by adding a big pile of
  permits!

* `tokio::sync`'s channels are no longer `Stream`s. This necessitated a
  few changes:
  - Replacing a few explicit `poll_next` calls with `poll_recv`
  - Updating some tests that used `mpsc::Receiver` as a `Stream` to add
    a wrapper type that makes it a `Stream`
  - Updating `CallAll`'s examples (I changed it to just use a
    `futures::channel` MPSC)

* `tokio::time::Sleep` is no longer `Unpin`. Therefore, the rate-limit
  `Service` needs to `Box::pin` it. To avoid the overhead of
  allocating/deallocating `Box`es every time the rate limit is
  exhausted, I moved the `Sleep` out of the `State` enum and onto the
  `Service` struct, and changed the code to `reset` it every time the
  service is rate-limited. This way, we only allocate the box once when
  the service is created.

There should be no actual changes in functionality.

Signed-off-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
Eliza Weisman 2020-12-23 11:59:54 -08:00 committed by GitHub
parent 124816a40e
commit 45974d018d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 78 additions and 51 deletions

View File

@ -26,8 +26,8 @@ edition = "2018"
[features]
default = ["log"]
log = ["tracing/log"]
balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio/stream"]
buffer = ["tokio/sync", "tokio/rt", "tokio/stream"]
balance = ["discover", "load", "ready-cache", "make", "rand", "slab", "tokio-stream"]
buffer = ["tokio/sync", "tokio/rt", "tokio-stream"]
discover = []
filter = []
hedge = ["util", "filter", "futures-util", "hdrhistogram", "tokio/time"]
@ -55,14 +55,16 @@ hdrhistogram = { version = "6.0", optional = true }
indexmap = { version = "1.0.2", optional = true }
rand = { version = "0.7", features = ["small_rng"], optional = true }
slab = { version = "0.4", optional = true }
tokio = { version = "0.3.2", optional = true, features = ["sync"] }
tokio = { version = "1", optional = true, features = ["sync"] }
tokio-stream = { version = "0.1.0", optional = true }
[dev-dependencies]
futures-util = { version = "0.3", default-features = false, features = ["alloc", "async-await"] }
futures = "0.3"
hdrhistogram = "6.0"
quickcheck = { version = "0.9", default-features = false }
tokio = { version = "0.3.2", features = ["macros", "stream", "sync", "test-util", "rt-multi-thread"] }
tokio-test = "0.3"
tokio = { version = "1", features = ["macros", "sync", "test-util", "rt-multi-thread"] }
tokio-stream = "0.1"
tokio-test = "0.4"
tower-test = { version = "0.4", path = "../tower-test" }
tracing-subscriber = "0.2.14"
# env_logger = { version = "0.5.3", default-features = false }

View File

@ -90,7 +90,7 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_next(cx) {
while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_recv(cx) {
this.services.remove(sid);
tracing::trace!(
pool.services = this.services.len(),

View File

@ -5,7 +5,6 @@ use super::{
};
use crate::semaphore::Semaphore;
use futures_core::ready;
use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot};
use tower_service::Service;
@ -116,9 +115,9 @@ where
// Then, poll to acquire a semaphore permit. If we acquire a permit,
// then there's enough buffer capacity to send a new request. Otherwise,
// we need to wait for capacity.
ready!(self.semaphore.poll_acquire(cx));
Poll::Ready(Ok(()))
self.semaphore
.poll_acquire(cx)
.map_err(|_| self.get_worker_error())
}
fn call(&mut self, request: Request) -> Self::Future {

View File

@ -10,7 +10,7 @@ use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::{stream::Stream, sync::mpsc};
use tokio::sync::mpsc;
use tower_service::Service;
/// Task that handles processing the buffer. This type should not be used
@ -96,7 +96,7 @@ where
}
// Get the next request
while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_next(cx)) {
while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
if !msg.tx.is_closed() {
tracing::trace!("processing new request");
return Poll::Ready(Some((msg, true)));

View File

@ -48,7 +48,10 @@ where
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// First, poll the semaphore...
ready!(self.semaphore.poll_acquire(cx));
ready!(self.semaphore.poll_acquire(cx)).expect(
"ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
should never fail",
);
// ...and if it's ready, poll the inner service.
self.inner.poll_ready(cx)
}

View File

@ -15,20 +15,22 @@ pub struct RateLimit<T> {
inner: T,
rate: Rate,
state: State,
sleep: Pin<Box<Sleep>>,
}
#[derive(Debug)]
enum State {
// The service has hit its limit
Limited(Sleep),
Limited,
Ready { until: Instant, rem: u64 },
}
impl<T> RateLimit<T> {
/// Create a new rate limiter
pub fn new(inner: T, rate: Rate) -> Self {
let until = Instant::now();
let state = State::Ready {
until: Instant::now(),
until,
rem: rate.num(),
};
@ -36,6 +38,10 @@ impl<T> RateLimit<T> {
inner,
rate,
state: state,
// The sleep won't actually be used with this duration, but
// we create it eagerly so that we can reset it in place rather than
// `Box::pin`ning a new `Sleep` every time we need one.
sleep: Box::pin(tokio::time::sleep_until(until)),
}
}
@ -66,8 +72,8 @@ where
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.state {
State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
State::Limited(ref mut sleep) => {
if let Poll::Pending = Pin::new(sleep).poll(cx) {
State::Limited => {
if let Poll::Pending = Pin::new(&mut self.sleep).poll(cx) {
tracing::trace!("rate limit exceeded; sleeping.");
return Poll::Pending;
}
@ -98,14 +104,16 @@ where
self.state = State::Ready { until, rem };
} else {
// The service is disabled until further notice
let sleep = tokio::time::sleep_until(until);
self.state = State::Limited(sleep);
// Reset the sleep future in place, so that we don't have to
// deallocate the existing box and allocate a new one.
self.sleep.as_mut().reset(until);
self.state = State::Limited;
}
// Call the inner future
self.inner.call(request)
}
State::Limited(..) => panic!("service not ready; poll_ready must be called first"),
State::Limited => panic!("service not ready; poll_ready must be called first"),
}
}
}

View File

@ -1,4 +1,4 @@
pub(crate) use self::sync::OwnedSemaphorePermit as Permit;
pub(crate) use self::sync::{AcquireError, OwnedSemaphorePermit as Permit};
use futures_core::ready;
use std::{
fmt,
@ -19,11 +19,10 @@ pub(crate) struct Semaphore {
#[derive(Debug)]
pub(crate) struct Close {
semaphore: Weak<sync::Semaphore>,
permits: usize,
}
enum State {
Waiting(Pin<Box<dyn Future<Output = Permit> + Send + 'static>>),
Waiting(Pin<Box<dyn Future<Output = Result<Permit, AcquireError>> + Send + 'static>>),
Ready(Permit),
Empty,
}
@ -33,7 +32,6 @@ impl Semaphore {
let semaphore = Arc::new(sync::Semaphore::new(permits));
let close = Close {
semaphore: Arc::downgrade(&semaphore),
permits,
};
let semaphore = Self {
semaphore,
@ -49,12 +47,12 @@ impl Semaphore {
}
}
pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> {
pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), AcquireError>> {
loop {
self.state = match self.state {
State::Ready(_) => return Poll::Ready(()),
State::Ready(_) => return Poll::Ready(Ok(())),
State::Waiting(ref mut fut) => {
let permit = ready!(Pin::new(fut).poll(cx));
let permit = ready!(Pin::new(fut).poll(cx))?;
State::Ready(permit)
}
State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())),
@ -95,19 +93,8 @@ impl fmt::Debug for State {
impl Close {
/// Close the semaphore, waking any remaining tasks currently awaiting a permit.
pub(crate) fn close(self) {
// The maximum number of permits that a `tokio::sync::Semaphore`
// can hold is usize::MAX >> 3. If we attempt to add more than that
// number of permits, the semaphore will panic.
// XXX(eliza): another shift is kinda janky but if we add (usize::MAX
// > 3 - initial permits) the semaphore impl panics (I think due to a
// bug in tokio?).
// TODO(eliza): Tokio should _really_ just expose `Semaphore::close`
// publicly so we don't have to do this nonsense...
const MAX: usize = std::usize::MAX >> 4;
if let Some(semaphore) = self.semaphore.upgrade() {
// If we added `MAX - available_permits`, any tasks that are
// currently holding permits could drop them, overflowing the max.
semaphore.add_permits(MAX - self.permits);
semaphore.close()
}
}
}

View File

@ -20,11 +20,11 @@ use tower_service::Service;
/// # use std::error::Error;
/// # use std::rc::Rc;
/// #
/// use futures_util::future::{ready, Ready};
/// use futures_util::StreamExt;
/// use futures::future::{ready, Ready};
/// use futures::StreamExt;
/// use futures::channel::mpsc;
/// use tower_service::Service;
/// use tower::util::ServiceExt;
/// use tokio::prelude::*;
///
/// // First, we need to have a Service to process our requests.
/// #[derive(Debug, Eq, PartialEq)]
@ -46,15 +46,17 @@ use tower_service::Service;
/// #[tokio::main]
/// async fn main() {
/// // Next, we need a Stream of requests.
/// let (mut reqs, rx) = tokio::sync::mpsc::unbounded_channel();
// TODO(eliza): when `tokio-util` has a nice way to convert MPSCs to streams,
// tokio::sync::mpsc again?
/// let (mut reqs, rx) = mpsc::unbounded();
/// // Note that we have to help Rust out here by telling it what error type to use.
/// // Specifically, it has to be From<Service::Error> + From<Stream::Error>.
/// let mut rsps = FirstLetter.call_all(rx);
///
/// // Now, let's send a few requests and then check that we get the corresponding responses.
/// reqs.send("one");
/// reqs.send("two");
/// reqs.send("three");
/// reqs.unbounded_send("one").unwrap();
/// reqs.unbounded_send("two").unwrap();
/// reqs.unbounded_send("three").unwrap();
/// drop(reqs);
///
/// // We then loop over the response Strem that we get back from call_all.

View File

@ -37,7 +37,7 @@ fn stress() {
let _t = support::trace_init();
let mut task = task::spawn(());
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<_, &'static str>>();
let mut cache = Balance::<_, Req>::new(rx);
let mut cache = Balance::<_, Req>::new(support::IntoStream(rx));
let mut nready = 0;
let mut services = slab::Slab::<(mock::Handle<Req, Req>, bool)>::new();

View File

@ -1,5 +1,10 @@
#![allow(dead_code)]
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio_stream::Stream;
pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard {
let subscriber = tracing_subscriber::fmt()
.with_test_writer()
@ -8,3 +13,23 @@ pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard {
.finish();
tracing::subscriber::set_default(subscriber)
}
#[pin_project::pin_project]
#[derive(Clone, Debug)]
pub struct IntoStream<S>(#[pin] pub S);
impl<I> Stream for IntoStream<mpsc::Receiver<I>> {
type Item = I;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().0.poll_recv(cx)
}
}
impl<I> Stream for IntoStream<mpsc::UnboundedReceiver<I>> {
type Item = I;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().0.poll_recv(cx)
}
}

View File

@ -1,3 +1,4 @@
use super::support;
use futures_core::Stream;
use futures_util::{
future::{ready, Ready},
@ -39,7 +40,7 @@ impl Service<&'static str> for Srv {
#[test]
fn ordered() {
let _t = super::support::trace_init();
let _t = support::trace_init();
let mut mock = task::spawn(());
@ -50,7 +51,7 @@ fn ordered() {
admit: admit.clone(),
};
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let ca = srv.call_all(rx);
let ca = srv.call_all(support::IntoStream(rx));
pin_mut!(ca);
assert_pending!(mock.enter(|cx, _| ca.as_mut().poll_next(cx)));
@ -112,7 +113,7 @@ fn ordered() {
#[tokio::test(flavor = "current_thread")]
async fn unordered() {
let _t = super::support::trace_init();
let _t = support::trace_init();
let (mock, handle) = mock::pair::<_, &'static str>();
pin_mut!(handle);