diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index dd48acdaf..6d6edf558 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -200,10 +200,14 @@ cfg_io_util! { pub(crate) mod util; pub use util::{ - copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, BufReader, - BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take, + copy, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, + BufReader, BufStream, BufWriter, Copy, Empty, Lines, Repeat, Sink, Split, Take, }; + cfg_stream! { + pub use util::{stream_reader, StreamReader}; + } + // Re-export io::Error so that users don't have to deal with conflicts when // `use`ing `tokio::io` and `std::io`. pub use std::io::{Error, ErrorKind, Result}; diff --git a/tokio/src/io/util/mod.rs b/tokio/src/io/util/mod.rs index d239e1a24..c4754abf0 100644 --- a/tokio/src/io/util/mod.rs +++ b/tokio/src/io/util/mod.rs @@ -60,6 +60,11 @@ cfg_io_util! { mod split; pub use split::Split; + cfg_stream! { + mod stream_reader; + pub use stream_reader::{stream_reader, StreamReader}; + } + mod take; pub use take::Take; diff --git a/tokio/src/io/util/stream_reader.rs b/tokio/src/io/util/stream_reader.rs new file mode 100644 index 000000000..b98f8bdfc --- /dev/null +++ b/tokio/src/io/util/stream_reader.rs @@ -0,0 +1,184 @@ +use crate::io::{AsyncBufRead, AsyncRead}; +use crate::stream::Stream; +use bytes::{Buf, BufMut}; +use pin_project_lite::pin_project; +use std::io; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Convert a stream of byte chunks into an [`AsyncRead`]. + /// + /// This type is usually created using the [`stream_reader`] function. + /// + /// [`AsyncRead`]: crate::io::AsyncRead + /// [`stream_reader`]: crate::io::stream_reader + #[derive(Debug)] + #[cfg_attr(docsrs, doc(cfg(feature = "stream")))] + #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] + pub struct StreamReader { + #[pin] + inner: S, + chunk: Option, + } +} + +/// Convert a stream of byte chunks into an [`AsyncRead`](crate::io::AsyncRead). +/// +/// # Example +/// +/// ``` +/// use bytes::Bytes; +/// use tokio::io::{stream_reader, AsyncReadExt}; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a stream from an iterator. +/// let stream = tokio::stream::iter(vec![ +/// Ok(Bytes::from_static(&[0, 1, 2, 3])), +/// Ok(Bytes::from_static(&[4, 5, 6, 7])), +/// Ok(Bytes::from_static(&[8, 9, 10, 11])), +/// ]); +/// +/// // Convert it to an AsyncRead. +/// let mut read = stream_reader(stream); +/// +/// // Read five bytes from the stream. +/// let mut buf = [0; 5]; +/// read.read_exact(&mut buf).await?; +/// assert_eq!(buf, [0, 1, 2, 3, 4]); +/// +/// // Read the rest of the current chunk. +/// assert_eq!(read.read(&mut buf).await?, 3); +/// assert_eq!(&buf[..3], [5, 6, 7]); +/// +/// // Read the next chunk. +/// assert_eq!(read.read(&mut buf).await?, 4); +/// assert_eq!(&buf[..4], [8, 9, 10, 11]); +/// +/// // We have now reached the end. +/// assert_eq!(read.read(&mut buf).await?, 0); +/// +/// # Ok(()) +/// # } +/// ``` +#[cfg_attr(docsrs, doc(cfg(feature = "stream")))] +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] +pub fn stream_reader(stream: S) -> StreamReader +where + S: Stream>, + B: Buf, +{ + StreamReader::new(stream) +} + +impl StreamReader +where + S: Stream>, + B: Buf, +{ + /// Convert the provided stream into an `AsyncRead`. + fn new(stream: S) -> Self { + Self { + inner: stream, + chunk: None, + } + } + /// Do we have a chunk and is it non-empty? + fn has_chunk(self: Pin<&mut Self>) -> bool { + if let Some(chunk) = self.project().chunk { + chunk.remaining() > 0 + } else { + false + } + } +} + +impl AsyncRead for StreamReader +where + S: Stream>, + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.len()); + (&mut buf[..len]).copy_from_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(len)) + } + fn poll_read_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut BM, + ) -> Poll> + where + Self: Sized, + { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let inner_buf = match self.as_mut().poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => buf, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + }; + let len = std::cmp::min(inner_buf.len(), buf.remaining_mut()); + buf.put_slice(&inner_buf[..len]); + + self.consume(len); + Poll::Ready(Ok(len)) + } + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [MaybeUninit]) -> bool { + false + } +} + +impl AsyncBufRead for StreamReader +where + S: Stream>, + B: Buf, +{ + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.as_mut().has_chunk() { + // This unwrap is very sad, but it can't be avoided. + let buf = self.project().chunk.as_ref().unwrap().bytes(); + return Poll::Ready(Ok(buf)); + } else { + match self.as_mut().project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + // Go around the loop in case the chunk is empty. + *self.as_mut().project().chunk = Some(chunk); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err)), + Poll::Ready(None) => return Poll::Ready(Ok(&[])), + Poll::Pending => return Poll::Pending, + } + } + } + } + fn consume(self: Pin<&mut Self>, amt: usize) { + if amt > 0 { + self.project() + .chunk + .as_mut() + .expect("No chunk present") + .advance(amt); + } + } +} diff --git a/tokio/tests/stream_reader.rs b/tokio/tests/stream_reader.rs new file mode 100644 index 000000000..8370df4da --- /dev/null +++ b/tokio/tests/stream_reader.rs @@ -0,0 +1,35 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use bytes::Bytes; +use tokio::io::{stream_reader, AsyncReadExt}; +use tokio::stream::iter; + +#[tokio::test] +async fn test_stream_reader() -> std::io::Result<()> { + let stream = iter(vec![ + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[0, 1, 2, 3])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[4, 5, 6, 7])), + Ok(Bytes::from_static(&[])), + Ok(Bytes::from_static(&[8, 9, 10, 11])), + Ok(Bytes::from_static(&[])), + ]); + + let mut read = stream_reader(stream); + + let mut buf = [0; 5]; + read.read_exact(&mut buf).await?; + assert_eq!(buf, [0, 1, 2, 3, 4]); + + assert_eq!(read.read(&mut buf).await?, 3); + assert_eq!(&buf[..3], [5, 6, 7]); + + assert_eq!(read.read(&mut buf).await?, 4); + assert_eq!(&buf[..4], [8, 9, 10, 11]); + + assert_eq!(read.read(&mut buf).await?, 0); + + Ok(()) +}