diff --git a/tokio/src/io/util/copy.rs b/tokio/src/io/util/copy.rs index 8bd0bff7f..56310c86f 100644 --- a/tokio/src/io/util/copy.rs +++ b/tokio/src/io/util/copy.rs @@ -82,6 +82,19 @@ impl CopyBuffer { R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized, { + ready!(crate::trace::trace_leaf(cx)); + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + // Keep track of task budget + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); loop { // If our buffer is empty, then we need to read some data to // continue. @@ -90,13 +103,49 @@ impl CopyBuffer { self.cap = 0; match self.poll_fill_buf(cx, reader.as_mut()) { - Poll::Ready(Ok(())) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Ready(Ok(())) => { + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); + } + Poll::Ready(Err(err)) => { + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); + return Poll::Ready(Err(err)); + } Poll::Pending => { // Try flushing when the reader has no progress to avoid deadlock // when the reader depends on buffered writer. if self.need_flush { ready!(writer.as_mut().poll_flush(cx))?; + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); self.need_flush = false; } @@ -108,6 +157,17 @@ impl CopyBuffer { // If our buffer has some data, let's write it out! while self.pos < self.cap { let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); if i == 0 { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, @@ -132,6 +192,17 @@ impl CopyBuffer { // data and finish the transfer. if self.pos == self.cap && self.read_done { ready!(writer.as_mut().poll_flush(cx))?; + #[cfg(any( + feature = "fs", + feature = "io-std", + feature = "net", + feature = "process", + feature = "rt", + feature = "signal", + feature = "sync", + feature = "time", + ))] + coop.made_progress(); return Poll::Ready(Ok(self.amt)); } } diff --git a/tokio/tests/io_copy.rs b/tokio/tests/io_copy.rs index 005e17011..82d92a968 100644 --- a/tokio/tests/io_copy.rs +++ b/tokio/tests/io_copy.rs @@ -85,3 +85,18 @@ async fn proxy() { assert_eq!(n, 1024); } + +#[tokio::test] +async fn copy_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let mut reader: &[u8] = b"hello"; + let mut writer: Vec = vec![]; + let _ = io::copy(&mut reader, &mut writer).await; + } + } => {}, + _ = tokio::task::yield_now() => {} + } +} diff --git a/tokio/tests/io_copy_bidirectional.rs b/tokio/tests/io_copy_bidirectional.rs index 10eba3166..3cdce32d0 100644 --- a/tokio/tests/io_copy_bidirectional.rs +++ b/tokio/tests/io_copy_bidirectional.rs @@ -138,3 +138,28 @@ async fn immediate_exit_on_read_error() { assert!(copy_bidirectional(&mut a, &mut b).await.is_err()); } + +#[tokio::test] +async fn copy_bidirectional_is_cooperative() { + tokio::select! { + biased; + _ = async { + loop { + let payload = b"here, take this"; + + let mut a = tokio_test::io::Builder::new() + .read(payload) + .write(payload) + .build(); + + let mut b = tokio_test::io::Builder::new() + .read(payload) + .write(payload) + .build(); + + let _ = copy_bidirectional(&mut a, &mut b).await; + } + } => {}, + _ = tokio::task::yield_now() => {} + } +}