diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index bbd4cef03..351c77e37 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -96,5 +96,8 @@ pub use pending::{pending, Pending}; mod stream_map; pub use stream_map::StreamMap; +mod stream_close; +pub use stream_close::StreamNotifyClose; + #[doc(no_inline)] pub use futures_core::Stream; diff --git a/tokio-stream/src/stream_close.rs b/tokio-stream/src/stream_close.rs new file mode 100644 index 000000000..735acf091 --- /dev/null +++ b/tokio-stream/src/stream_close.rs @@ -0,0 +1,93 @@ +use crate::Stream; +use pin_project_lite::pin_project; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A `Stream` that wraps the values in an `Option`. + /// + /// Whenever the wrapped stream yields an item, this stream yields that item + /// wrapped in `Some`. When the inner stream ends, then this stream first + /// yields a `None` item, and then this stream will also end. + /// + /// # Example + /// + /// Using `StreamNotifyClose` to handle closed streams with `StreamMap`. + /// + /// ``` + /// use tokio_stream::{StreamExt, StreamMap, StreamNotifyClose}; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut map = StreamMap::new(); + /// let stream = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1])); + /// let stream2 = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1])); + /// map.insert(0, stream); + /// map.insert(1, stream2); + /// while let Some((key, val)) = map.next().await { + /// match val { + /// Some(val) => println!("got {val:?} from stream {key:?}"), + /// None => println!("stream {key:?} closed"), + /// } + /// } + /// } + /// ``` + #[must_use = "streams do nothing unless polled"] + pub struct StreamNotifyClose { + #[pin] + inner: Option, + } +} + +impl StreamNotifyClose { + /// Create a new `StreamNotifyClose`. + pub fn new(stream: S) -> Self { + Self { + inner: Some(stream), + } + } + + /// Get back the inner `Stream`. + /// + /// Returns `None` if the stream has reached its end. + pub fn into_inner(self) -> Option { + self.inner + } +} + +impl Stream for StreamNotifyClose +where + S: Stream, +{ + type Item = Option; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We can't invoke poll_next after it ended, so we unset the inner stream as a marker. + match self + .as_mut() + .project() + .inner + .as_pin_mut() + .map(|stream| S::poll_next(stream, cx)) + { + Some(Poll::Ready(Some(item))) => Poll::Ready(Some(Some(item))), + Some(Poll::Ready(None)) => { + self.project().inner.set(None); + Poll::Ready(Some(None)) + } + Some(Poll::Pending) => Poll::Pending, + None => Poll::Ready(None), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + if let Some(inner) = &self.inner { + // We always return +1 because when there's stream there's atleast one more item. + let (l, u) = inner.size_hint(); + (l.saturating_add(1), u.and_then(|u| u.checked_add(1))) + } else { + (0, Some(0)) + } + } +} diff --git a/tokio-stream/src/stream_map.rs b/tokio-stream/src/stream_map.rs index ddd14d586..0c11bf1d5 100644 --- a/tokio-stream/src/stream_map.rs +++ b/tokio-stream/src/stream_map.rs @@ -42,10 +42,18 @@ use std::task::{Context, Poll}; /// to be merged, it may be advisable to use tasks sending values on a shared /// [`mpsc`] channel. /// +/// # Notes +/// +/// `StreamMap` removes finished streams automatically, without alerting the user. +/// In some scenarios, the caller would want to know on closed streams. +/// To do this, use [`StreamNotifyClose`] as a wrapper to your stream. +/// It will return None when the stream is closed. +/// /// [`StreamExt::merge`]: crate::StreamExt::merge /// [`mpsc`]: https://docs.rs/tokio/1.0/tokio/sync/mpsc/index.html /// [`pin!`]: https://docs.rs/tokio/1.0/tokio/macro.pin.html /// [`Box::pin`]: std::boxed::Box::pin +/// [`StreamNotifyClose`]: crate::StreamNotifyClose /// /// # Examples /// @@ -170,6 +178,28 @@ use std::task::{Context, Poll}; /// } /// } /// ``` +/// +/// Using `StreamNotifyClose` to handle closed streams with `StreamMap`. +/// +/// ``` +/// use tokio_stream::{StreamExt, StreamMap, StreamNotifyClose}; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = StreamMap::new(); +/// let stream = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1])); +/// let stream2 = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1])); +/// map.insert(0, stream); +/// map.insert(1, stream2); +/// while let Some((key, val)) = map.next().await { +/// match val { +/// Some(val) => println!("got {val:?} from stream {key:?}"), +/// None => println!("stream {key:?} closed"), +/// } +/// } +/// } +/// ``` + #[derive(Debug)] pub struct StreamMap { /// Streams stored in the map diff --git a/tokio-stream/tests/stream_close.rs b/tokio-stream/tests/stream_close.rs new file mode 100644 index 000000000..9ddb5650e --- /dev/null +++ b/tokio-stream/tests/stream_close.rs @@ -0,0 +1,11 @@ +use tokio_stream::{StreamExt, StreamNotifyClose}; + +#[tokio::test] +async fn basic_usage() { + let mut stream = StreamNotifyClose::new(tokio_stream::iter(vec![0, 1])); + + assert_eq!(stream.next().await, Some(Some(0))); + assert_eq!(stream.next().await, Some(Some(1))); + assert_eq!(stream.next().await, Some(None)); + assert_eq!(stream.next().await, None); +}