diff --git a/tokio/src/time/mod.rs b/tokio/src/time/mod.rs index cb023c542..20e8dc807 100644 --- a/tokio/src/time/mod.rs +++ b/tokio/src/time/mod.rs @@ -96,6 +96,11 @@ mod timeout; #[doc(inline)] pub use timeout::{timeout, timeout_at, Timeout, Elapsed}; +cfg_stream! { + mod throttle; + pub use throttle::{throttle, Throttle}; +} + mod wheel; #[cfg(test)] diff --git a/tokio/src/time/throttle.rs b/tokio/src/time/throttle.rs index f81f0fcf6..2daa30fcb 100644 --- a/tokio/src/time/throttle.rs +++ b/tokio/src/time/throttle.rs @@ -7,34 +7,62 @@ use std::marker::Unpin; use std::pin::Pin; use std::task::{self, Poll}; +use futures_core::Stream; +use pin_project_lite::pin_project; + /// Slow down a stream by enforcing a delay between items. -#[derive(Debug)] -#[must_use = "streams do nothing unless polled"] -pub struct Throttle { - /// `None` when duration is zero. - delay: Option, +/// They will be produced not more often than the specified interval. +/// +/// # Example +/// +/// Create a throttled stream. +/// ```rust,norun +/// use futures::stream::StreamExt; +/// use std::time::Duration; +/// use tokio::time::throttle; +/// +/// # async fn dox() { +/// let mut item_stream = throttle(Duration::from_secs(2), futures::stream::repeat("one")); +/// +/// loop { +/// // The string will be produced at most every 2 seconds +/// println!("{:?}", item_stream.next().await); +/// } +/// # } +/// ``` +pub fn throttle(duration: Duration, stream: T) -> Throttle +where + T: Stream, +{ + let delay = if duration == Duration::from_millis(0) { + None + } else { + Some(Delay::new_timeout(Instant::now() + duration, duration)) + }; - /// Set to true when `delay` has returned ready, but `stream` hasn't. - has_delayed: bool, - - /// The stream to throttle - stream: T, + Throttle { + delay, + duration, + has_delayed: true, + stream, + } } -impl Throttle { - /// Slow down a stream by enforcing a delay between items. - pub fn new(stream: T, duration: Duration) -> Self { - let delay = if duration == Duration::from_millis(0) { - None - } else { - Some(Delay::new_timeout(Instant::now() + duration, duration)) - }; +pin_project! { + /// Stream for the [`throttle`](throttle) function. + #[derive(Debug)] + #[must_use = "streams do nothing unless polled"] + pub struct Throttle { + // `None` when duration is zero. + delay: Option, + duration: Duration, - Self { - delay, - has_delayed: true, - stream, - } + // Set to true when `delay` has returned ready, but `stream` hasn't. + has_delayed: bool, + + // The stream to throttle + #[pin] + stream: T, } } @@ -68,29 +96,27 @@ impl Stream for Throttle { type Item = T::Item; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - unsafe { - if !self.has_delayed && self.delay.is_some() { - ready!(self - .as_mut() - .map_unchecked_mut(|me| me.delay.as_mut().unwrap()) - .poll(cx)); - self.as_mut().get_unchecked_mut().has_delayed = true; - } - - let value = ready!(self - .as_mut() - .map_unchecked_mut(|me| &mut me.stream) - .poll_next(cx)); - - if value.is_some() { - if let Some(ref mut delay) = self.as_mut().get_unchecked_mut().delay { - delay.reset_timeout(); - } - - self.as_mut().get_unchecked_mut().has_delayed = false; - } - - Poll::Ready(value) + if !self.has_delayed && self.delay.is_some() { + ready!(Pin::new(self.as_mut() + .project().delay.as_mut().unwrap()) + .poll(cx)); + *self.as_mut().project().has_delayed = true; } + + let value = ready!(self + .as_mut() + .project().stream + .poll_next(cx)); + + if value.is_some() { + let dur = self.duration; + if let Some(ref mut delay) = self.as_mut().project().delay { + delay.reset(Instant::now() + dur); + } + + *self.as_mut().project().has_delayed = false; + } + + Poll::Ready(value) } } diff --git a/tokio/tests/time_throttle.rs b/tokio/tests/time_throttle.rs new file mode 100644 index 000000000..7102d1734 --- /dev/null +++ b/tokio/tests/time_throttle.rs @@ -0,0 +1,30 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::time::{self, throttle}; +use tokio_test::*; + +use std::time::Duration; + +#[tokio::test] +async fn usage() { + time::pause(); + + let mut stream = task::spawn(throttle( + Duration::from_millis(100), + futures::stream::repeat(()), + )); + + assert_ready!(stream.poll_next()); + assert_pending!(stream.poll_next()); + + time::advance(Duration::from_millis(90)).await; + + assert_pending!(stream.poll_next()); + + time::advance(Duration::from_millis(101)).await; + + assert!(stream.is_woken()); + + assert_ready!(stream.poll_next()); +}