From 7c3f1cb4a3d6076cb5e1aedf2311f62c8a7a2fd7 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Sat, 11 Jan 2020 16:33:52 -0800 Subject: [PATCH] stream: add `StreamExt::chain` (#2093) Asynchronous equivalent to `Iterator::chain`. --- tokio/src/stream/chain.rs | 57 +++++++++++++++++++++++++++++ tokio/src/stream/mod.rs | 38 ++++++++++++++++++++ tokio/tests/stream_chain.rs | 71 +++++++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+) create mode 100644 tokio/src/stream/chain.rs create mode 100644 tokio/tests/stream_chain.rs diff --git a/tokio/src/stream/chain.rs b/tokio/src/stream/chain.rs new file mode 100644 index 000000000..5f0324a4b --- /dev/null +++ b/tokio/src/stream/chain.rs @@ -0,0 +1,57 @@ +use crate::stream::{Fuse, Stream}; + +use core::pin::Pin; +use core::task::{Context, Poll}; +use pin_project_lite::pin_project; + +pin_project! { + /// Stream returned by the [`chain`](super::StreamExt::chain) method. + pub struct Chain { + #[pin] + a: Fuse, + #[pin] + b: U, + } +} + +impl Chain { + pub(super) fn new(a: T, b: U) -> Chain + where + T: Stream, + U: Stream, + { + Chain { a: Fuse::new(a), b } + } +} + +impl Stream for Chain +where + T: Stream, + U: Stream, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use Poll::Ready; + + let me = self.project(); + + if let Some(v) = ready!(me.a.poll_next(cx)) { + return Ready(Some(v)); + } + + me.b.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + let (a_lower, a_upper) = self.a.size_hint(); + let (b_lower, b_upper) = self.b.size_hint(); + + let upper = match (a_upper, b_upper) { + (Some(a_upper), Some(b_upper)) => Some(a_upper + b_upper), + _ => None, + }; + + (a_lower + b_lower, upper) + } +} diff --git a/tokio/src/stream/mod.rs b/tokio/src/stream/mod.rs index fada4442b..9be1b102c 100644 --- a/tokio/src/stream/mod.rs +++ b/tokio/src/stream/mod.rs @@ -10,6 +10,9 @@ use all::AllFuture; mod any; use any::AnyFuture; +mod chain; +use chain::Chain; + mod empty; pub use empty::{empty, Empty}; @@ -539,6 +542,41 @@ pub trait StreamExt: Stream { { AnyFuture::new(self, f) } + + /// Combine two streams into one by first returning all values from the + /// first stream then all values from the second stream. + /// + /// As long as `self` still has values to emit, no values from `other` are + /// emitted, even if some are ready. + /// + /// # Examples + /// + /// ``` + /// use tokio::stream::{self, StreamExt}; + /// + /// #[tokio::main] + /// async fn main() { + /// let one = stream::iter(vec![1, 2, 3]); + /// let two = stream::iter(vec![4, 5, 6]); + /// + /// let mut stream = one.chain(two); + /// + /// assert_eq!(stream.next().await, Some(1)); + /// assert_eq!(stream.next().await, Some(2)); + /// assert_eq!(stream.next().await, Some(3)); + /// assert_eq!(stream.next().await, Some(4)); + /// assert_eq!(stream.next().await, Some(5)); + /// assert_eq!(stream.next().await, Some(6)); + /// assert_eq!(stream.next().await, None); + /// } + /// ``` + fn chain(self, other: U) -> Chain + where + U: Stream, + Self: Sized, + { + Chain::new(self, other) + } } impl StreamExt for St where St: Stream {} diff --git a/tokio/tests/stream_chain.rs b/tokio/tests/stream_chain.rs new file mode 100644 index 000000000..0e14618b4 --- /dev/null +++ b/tokio/tests/stream_chain.rs @@ -0,0 +1,71 @@ +use tokio::stream::{self, Stream, StreamExt}; +use tokio::sync::mpsc; +use tokio_test::{assert_pending, assert_ready, task}; + +#[tokio::test] +async fn basic_usage() { + let one = stream::iter(vec![1, 2, 3]); + let two = stream::iter(vec![4, 5, 6]); + + let mut stream = one.chain(two); + + assert_eq!(stream.size_hint(), (6, Some(6))); + assert_eq!(stream.next().await, Some(1)); + + assert_eq!(stream.size_hint(), (5, Some(5))); + assert_eq!(stream.next().await, Some(2)); + + assert_eq!(stream.size_hint(), (4, Some(4))); + assert_eq!(stream.next().await, Some(3)); + + assert_eq!(stream.size_hint(), (3, Some(3))); + assert_eq!(stream.next().await, Some(4)); + + assert_eq!(stream.size_hint(), (2, Some(2))); + assert_eq!(stream.next().await, Some(5)); + + assert_eq!(stream.size_hint(), (1, Some(1))); + assert_eq!(stream.next().await, Some(6)); + + assert_eq!(stream.size_hint(), (0, Some(0))); + assert_eq!(stream.next().await, None); + + assert_eq!(stream.size_hint(), (0, Some(0))); + assert_eq!(stream.next().await, None); +} + +#[tokio::test] +async fn pending_first() { + let (tx1, rx1) = mpsc::unbounded_channel(); + let (tx2, rx2) = mpsc::unbounded_channel(); + + let mut stream = task::spawn(rx1.chain(rx2)); + assert_eq!(stream.size_hint(), (0, None)); + + assert_pending!(stream.poll_next()); + + tx2.send(2).unwrap(); + assert!(!stream.is_woken()); + + assert_pending!(stream.poll_next()); + + tx1.send(1).unwrap(); + assert!(stream.is_woken()); + assert_eq!(Some(1), assert_ready!(stream.poll_next())); + + assert_pending!(stream.poll_next()); + + drop(tx1); + + assert_eq!(stream.size_hint(), (0, None)); + + assert!(stream.is_woken()); + assert_eq!(Some(2), assert_ready!(stream.poll_next())); + + assert_eq!(stream.size_hint(), (0, None)); + + drop(tx2); + + assert_eq!(stream.size_hint(), (0, None)); + assert_eq!(None, assert_ready!(stream.poll_next())); +}