io: make copy cooperative (#6265)

This commit is contained in:
Rustin 2024-01-07 00:22:26 +08:00 committed by GitHub
parent 9780bf491f
commit 3275cfb638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 2 deletions

View File

@ -82,6 +82,19 @@ impl CopyBuffer {
R: AsyncRead + ?Sized, R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized, W: AsyncWrite + ?Sized,
{ {
ready!(crate::trace::trace_leaf(cx));
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
// Keep track of task budget
let coop = ready!(crate::runtime::coop::poll_proceed(cx));
loop { loop {
// If our buffer is empty, then we need to read some data to // If our buffer is empty, then we need to read some data to
// continue. // continue.
@ -90,13 +103,49 @@ impl CopyBuffer {
self.cap = 0; self.cap = 0;
match self.poll_fill_buf(cx, reader.as_mut()) { match self.poll_fill_buf(cx, reader.as_mut()) {
Poll::Ready(Ok(())) => (), Poll::Ready(Ok(())) => {
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), #[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
}
Poll::Ready(Err(err)) => {
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
return Poll::Ready(Err(err));
}
Poll::Pending => { Poll::Pending => {
// Try flushing when the reader has no progress to avoid deadlock // Try flushing when the reader has no progress to avoid deadlock
// when the reader depends on buffered writer. // when the reader depends on buffered writer.
if self.need_flush { if self.need_flush {
ready!(writer.as_mut().poll_flush(cx))?; ready!(writer.as_mut().poll_flush(cx))?;
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
self.need_flush = false; self.need_flush = false;
} }
@ -108,6 +157,17 @@ impl CopyBuffer {
// If our buffer has some data, let's write it out! // If our buffer has some data, let's write it out!
while self.pos < self.cap { while self.pos < self.cap {
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
if i == 0 { if i == 0 {
return Poll::Ready(Err(io::Error::new( return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
@ -132,6 +192,17 @@ impl CopyBuffer {
// data and finish the transfer. // data and finish the transfer.
if self.pos == self.cap && self.read_done { if self.pos == self.cap && self.read_done {
ready!(writer.as_mut().poll_flush(cx))?; ready!(writer.as_mut().poll_flush(cx))?;
#[cfg(any(
feature = "fs",
feature = "io-std",
feature = "net",
feature = "process",
feature = "rt",
feature = "signal",
feature = "sync",
feature = "time",
))]
coop.made_progress();
return Poll::Ready(Ok(self.amt)); return Poll::Ready(Ok(self.amt));
} }
} }

View File

@ -85,3 +85,18 @@ async fn proxy() {
assert_eq!(n, 1024); assert_eq!(n, 1024);
} }
#[tokio::test]
async fn copy_is_cooperative() {
tokio::select! {
biased;
_ = async {
loop {
let mut reader: &[u8] = b"hello";
let mut writer: Vec<u8> = vec![];
let _ = io::copy(&mut reader, &mut writer).await;
}
} => {},
_ = tokio::task::yield_now() => {}
}
}

View File

@ -138,3 +138,28 @@ async fn immediate_exit_on_read_error() {
assert!(copy_bidirectional(&mut a, &mut b).await.is_err()); assert!(copy_bidirectional(&mut a, &mut b).await.is_err());
} }
#[tokio::test]
async fn copy_bidirectional_is_cooperative() {
tokio::select! {
biased;
_ = async {
loop {
let payload = b"here, take this";
let mut a = tokio_test::io::Builder::new()
.read(payload)
.write(payload)
.build();
let mut b = tokio_test::io::Builder::new()
.read(payload)
.write(payload)
.build();
let _ = copy_bidirectional(&mut a, &mut b).await;
}
} => {},
_ = tokio::task::yield_now() => {}
}
}