diff --git a/tokio/src/stream/filter.rs b/tokio/src/stream/filter.rs index 88da15b7f..799630b23 100644 --- a/tokio/src/stream/filter.rs +++ b/tokio/src/stream/filter.rs @@ -26,11 +26,7 @@ where } } -impl Filter -where - St: Stream, - F: FnMut(&St::Item) -> bool, -{ +impl Filter { pub(super) fn new(stream: St, f: F) -> Self { Self { stream, f } } diff --git a/tokio/src/stream/filter_map.rs b/tokio/src/stream/filter_map.rs index 3aeb036f4..8dc05a546 100644 --- a/tokio/src/stream/filter_map.rs +++ b/tokio/src/stream/filter_map.rs @@ -26,11 +26,7 @@ where } } -impl FilterMap -where - St: Stream, - F: FnMut(St::Item) -> Option, -{ +impl FilterMap { pub(super) fn new(stream: St, f: F) -> Self { Self { stream, f } } diff --git a/tokio/src/stream/map.rs b/tokio/src/stream/map.rs index cc0c84e0b..dfac5a2c9 100644 --- a/tokio/src/stream/map.rs +++ b/tokio/src/stream/map.rs @@ -24,12 +24,8 @@ where } } -impl Map -where - St: Stream, - F: FnMut(St::Item) -> T, -{ - pub(super) fn new(stream: St, f: F) -> Map { +impl Map { + pub(super) fn new(stream: St, f: F) -> Self { Map { stream, f } } } diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index f29791b49..5d367f02f 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -25,6 +25,9 @@ use try_next::TryNext; mod take; use take::Take; +mod take_while; +use take_while::TakeWhile; + pub use futures_core::Stream; /// An extension trait for `Stream`s that provides a variety of convenient @@ -232,6 +235,36 @@ pub trait StreamExt: Stream { { Take::new(self, n) } + + /// Take elements from this stream while the provided predicate + /// resolves to `true`. + /// + /// This function, like `Iterator::take_while`, will take elements from the + /// stream until the predicate `f` resolves to `false`. Once one element + /// returns false it will always return that the stream is done. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio::stream::{self, StreamExt}; + /// + /// let mut stream = stream::iter(1..=10).take_while(|x| *x <= 3); + /// + /// assert_eq!(Some(1), stream.next().await); + /// assert_eq!(Some(2), stream.next().await); + /// assert_eq!(Some(3), stream.next().await); + /// assert_eq!(None, stream.next().await); + /// # } + /// ``` + fn take_while(self, f: F) -> TakeWhile + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + TakeWhile::new(self, f) + } } -impl StreamExt for T where T: Stream {} +impl StreamExt for St where St: Stream {} diff --git a/tokio/src/stream/next.rs b/tokio/src/stream/next.rs index 5139fbea1..3909c0c23 100644 --- a/tokio/src/stream/next.rs +++ b/tokio/src/stream/next.rs @@ -13,7 +13,7 @@ pub struct Next<'a, St: ?Sized> { impl Unpin for Next<'_, St> {} -impl<'a, St: ?Sized + Stream + Unpin> Next<'a, St> { +impl<'a, St: ?Sized> Next<'a, St> { pub(super) fn new(stream: &'a mut St) -> Self { Next { stream } } diff --git a/tokio/src/stream/take.rs b/tokio/src/stream/take.rs index f0dbbb005..a92430b77 100644 --- a/tokio/src/stream/take.rs +++ b/tokio/src/stream/take.rs @@ -7,7 +7,7 @@ use core::task::{Context, Poll}; use pin_project_lite::pin_project; pin_project! { - /// Stream for the [`map`](super::StreamExt::map) method. + /// Stream for the [`take`](super::StreamExt::take) method. #[must_use = "streams do nothing unless polled"] pub struct Take { #[pin] @@ -27,7 +27,7 @@ where } } -impl Take { +impl Take { pub(super) fn new(stream: St, remaining: usize) -> Self { Self { stream, remaining } } diff --git a/tokio/src/stream/take_while.rs b/tokio/src/stream/take_while.rs new file mode 100644 index 000000000..cf1e16061 --- /dev/null +++ b/tokio/src/stream/take_while.rs @@ -0,0 +1,79 @@ +use crate::stream::Stream; + +use core::fmt; +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream for the [`take_while`](super::StreamExt::take_while) method. + #[must_use = "streams do nothing unless polled"] + pub struct TakeWhile { + #[pin] + stream: St, + predicate: F, + done: bool, + } +} + +impl fmt::Debug for TakeWhile +where + St: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TakeWhile") + .field("stream", &self.stream) + .field("done", &self.done) + .finish() + } +} + +impl TakeWhile { + pub(super) fn new(stream: St, predicate: F) -> Self { + Self { + stream, + predicate, + done: false, + } + } +} + +impl Stream for TakeWhile +where + St: Stream, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !*self.as_mut().project().done { + self.as_mut().project().stream.poll_next(cx).map(|ready| { + let ready = ready.and_then(|item| { + if !(self.as_mut().project().predicate)(&item) { + None + } else { + Some(item) + } + }); + + if ready.is_none() { + *self.as_mut().project().done = true; + } + + ready + }) + } else { + Poll::Ready(None) + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.done { + return (0, Some(0)); + } + + let (_, upper) = self.stream.size_hint(); + + (0, upper) + } +} diff --git a/tokio/src/stream/try_next.rs b/tokio/src/stream/try_next.rs index ade5ecf09..59e0eb1a4 100644 --- a/tokio/src/stream/try_next.rs +++ b/tokio/src/stream/try_next.rs @@ -13,7 +13,7 @@ pub struct TryNext<'a, St: ?Sized> { impl Unpin for TryNext<'_, St> {} -impl<'a, St: ?Sized + Stream + Unpin> TryNext<'a, St> { +impl<'a, St: ?Sized> TryNext<'a, St> { pub(super) fn new(stream: &'a mut St) -> Self { Self { inner: Next::new(stream), diff --git a/tokio/src/time/tests/mock_clock.rs b/tokio/src/time/tests/mock_clock.rs index 7f8c7eca6..ac509e3fb 100644 --- a/tokio/src/time/tests/mock_clock.rs +++ b/tokio/src/time/tests/mock_clock.rs @@ -85,7 +85,7 @@ impl MockClock { let ctx = context::ThreadContext::clone_current(); let _e = ctx .with_clock(self.clock.clone()) - .with_time_handle(Some(handle.clone())) + .with_time_handle(Some(handle)) .enter(); let time = self.time.clone();