net: restore TcpStream::{poll_read_ready, poll_write_ready} (#2743)

This commit is contained in:
masnagam 2020-11-17 02:51:06 +09:00 committed by GitHub
parent 97c2c4203c
commit 4e39c9b818
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 132 additions and 2 deletions

View File

@ -356,6 +356,17 @@ impl TcpStream {
Ok(())
}
/// Polls for read readiness.
///
/// This function is intended for cases where creating and pinning a future
/// via [`readable`] is not feasible. Where possible, using [`readable`] is
/// preferred, as this supports polling from multiple tasks at once.
///
/// [`readable`]: method@Self::readable
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io.registration().poll_read_ready(cx).map_ok(|_| ())
}
/// Try to read data from the stream into the provided buffer, returning how
/// many bytes were read.
///
@ -467,6 +478,17 @@ impl TcpStream {
Ok(())
}
/// Polls for write readiness.
///
/// This function is intended for cases where creating and pinning a future
/// via [`writable`] is not feasible. Where possible, using [`writable`] is
/// preferred, as this supports polling from multiple tasks at once.
///
/// [`writable`]: method@Self::writable
pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.io.registration().poll_write_ready(cx).map_ok(|_| ())
}
/// Try to write a buffer to the stream, returning how many bytes were
/// written.
///

View File

@ -1,12 +1,16 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use tokio::io::Interest;
use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
use tokio::net::{TcpListener, TcpStream};
use tokio::try_join;
use tokio_test::task;
use tokio_test::{assert_pending, assert_ready_ok};
use tokio_test::{assert_ok, assert_pending, assert_ready_ok};
use std::io;
use std::task::Poll;
use futures::future::poll_fn;
#[tokio::test]
async fn try_read_write() {
@ -110,3 +114,107 @@ fn buffer_not_included_in_future() {
let n = mem::size_of_val(&fut);
assert!(n < 1000);
}
macro_rules! assert_readable_by_polling {
($stream:expr) => {
assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
};
}
macro_rules! assert_not_readable_by_polling {
($stream:expr) => {
poll_fn(|cx| {
assert_pending!($stream.poll_read_ready(cx));
Poll::Ready(())
})
.await;
};
}
macro_rules! assert_writable_by_polling {
($stream:expr) => {
assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
};
}
macro_rules! assert_not_writable_by_polling {
($stream:expr) => {
poll_fn(|cx| {
assert_pending!($stream.poll_write_ready(cx));
Poll::Ready(())
})
.await;
};
}
#[tokio::test]
async fn poll_read_ready() {
let (mut client, mut server) = create_pair().await;
// Initial state - not readable.
assert_not_readable_by_polling!(server);
// There is data in the buffer - readable.
assert_ok!(client.write_all(b"ping").await);
assert_readable_by_polling!(server);
// Readable until calls to `poll_read` return `Poll::Pending`.
let mut buf = [0u8; 4];
assert_ok!(server.read_exact(&mut buf).await);
assert_readable_by_polling!(server);
read_until_pending(&mut server);
assert_not_readable_by_polling!(server);
// Detect the client disconnect.
drop(client);
assert_readable_by_polling!(server);
}
#[tokio::test]
async fn poll_write_ready() {
let (mut client, server) = create_pair().await;
// Initial state - writable.
assert_writable_by_polling!(client);
// No space to write - not writable.
write_until_pending(&mut client);
assert_not_writable_by_polling!(client);
// Detect the server disconnect.
drop(server);
assert_writable_by_polling!(client);
}
async fn create_pair() -> (TcpStream, TcpStream) {
let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = assert_ok!(listener.local_addr());
let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept()));
(client, server)
}
fn read_until_pending(stream: &mut TcpStream) {
let mut buf = vec![0u8; 1024 * 1024];
loop {
match stream.try_read(&mut buf) {
Ok(_) => (),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
break;
}
}
}
}
fn write_until_pending(stream: &mut TcpStream) {
let buf = vec![0u8; 1024 * 1024];
loop {
match stream.try_write(&buf) {
Ok(_) => (),
Err(err) => {
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
break;
}
}
}
}