net: add try_read_buf and try_recv_buf (#3351)

This commit is contained in:
cssivision 2021-01-02 17:37:34 +08:00 committed by GitHub
parent 56272b2ec7
commit 3b6bee822d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 770 additions and 1 deletions

View File

@ -17,6 +17,10 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
cfg_io_util! {
use bytes::BufMut;
}
cfg_net! {
/// A TCP stream between a local and a remote socket.
///
@ -559,6 +563,85 @@ impl TcpStream {
.try_io(Interest::READABLE, || (&*self.io).read(buf))
}
cfg_io_util! {
/// Try to read data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// Receives any pending data from the socket but does not wait for new data
/// to arrive. On success, returns the number of bytes read. Because
/// `try_read_buf()` is non-blocking, the buffer does not have to be stored by
/// the async task and can exist entirely on the stack.
///
/// Usually, [`readable()`] or [`ready()`] is used with this function.
///
/// [`readable()`]: TcpStream::readable()
/// [`ready()`]: TcpStream::ready()
///
/// # Return
///
/// If data is successfully read, `Ok(n)` is returned, where `n` is the
/// number of bytes read. `Ok(0)` indicates the stream's read half is closed
/// and will no longer yield data. If the stream is not ready to read data
/// `Err(io::ErrorKind::WouldBlock)` is returned.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::TcpStream;
/// use std::error::Error;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let stream = TcpStream::connect("127.0.0.1:8080").await?;
///
/// loop {
/// // Wait for the socket to be readable
/// stream.readable().await?;
///
/// let mut buf = Vec::with_capacity(4096);
///
/// // Try to read data, this may still fail with `WouldBlock`
/// // if the readiness event is a false positive.
/// match stream.try_read_buf(&mut buf) {
/// Ok(0) => break,
/// Ok(n) => {
/// println!("read {} bytes", n);
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
/// }
/// Err(e) => {
/// return Err(e.into());
/// }
/// }
/// }
///
/// Ok(())
/// }
/// ```
pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().try_io(Interest::READABLE, || {
use std::io::Read;
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
// Safety: We trust `TcpStream::read` to have filled up `n` bytes in the
// buffer.
let n = (&*self.io).read(dst)?;
unsafe {
buf.advance_mut(n);
}
Ok(n)
})
}
}
/// Wait for the socket to become writable.
///
/// This function is equivalent to `ready(Interest::WRITABLE)` and is usually

View File

@ -7,6 +7,10 @@ use std::io;
use std::net::{self, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::task::{Context, Poll};
cfg_io_util! {
use bytes::BufMut;
}
cfg_net! {
/// A UDP socket
///
@ -683,6 +687,137 @@ impl UdpSocket {
.try_io(Interest::READABLE, || self.io.recv(buf))
}
cfg_io_util! {
/// Try to receive data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// The function must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UdpSocket;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
/// socket.connect("127.0.0.1:8081").await?;
///
/// loop {
/// // Wait for the socket to be readable
/// socket.readable().await?;
///
/// let mut buf = Vec::with_capacity(1024);
///
/// // Try to recv data, this may still fail with `WouldBlock`
/// // if the readiness event is a false positive.
/// match socket.try_recv_buf(&mut buf) {
/// Ok(n) => {
/// println!("GOT {:?}", &buf[..n]);
/// break;
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
/// }
/// Err(e) => {
/// return Err(e);
/// }
/// }
/// }
///
/// Ok(())
/// }
/// ```
pub fn try_recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().try_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
// Safety: We trust `UdpSocket::recv` to have filled up `n` bytes in the
// buffer.
let n = (&*self.io).recv(dst)?;
unsafe {
buf.advance_mut(n);
}
Ok(n)
})
}
/// Try to receive a single datagram message on the socket. On success,
/// returns the number of bytes read and the origin.
///
/// The function must be called with valid byte array buf of sufficient size
/// to hold the message bytes. If a message is too long to fit in the
/// supplied buffer, excess bytes may be discarded.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UdpSocket;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
///
/// loop {
/// // Wait for the socket to be readable
/// socket.readable().await?;
///
/// let mut buf = Vec::with_capacity(1024);
///
/// // Try to recv data, this may still fail with `WouldBlock`
/// // if the readiness event is a false positive.
/// match socket.try_recv_buf_from(&mut buf) {
/// Ok((n, _addr)) => {
/// println!("GOT {:?}", &buf[..n]);
/// break;
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
/// }
/// Err(e) => {
/// return Err(e);
/// }
/// }
/// }
///
/// Ok(())
/// }
/// ```
pub fn try_recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> {
self.io.registration().try_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
// Safety: We trust `UdpSocket::recv_from` to have filled up `n` bytes in the
// buffer.
let (n, addr) = (&*self.io).recv_from(dst)?;
unsafe {
buf.advance_mut(n);
}
Ok((n, addr))
})
}
}
/// Sends data on the socket to the given address. On success, returns the
/// number of bytes written.
///
@ -904,7 +1039,6 @@ impl UdpSocket {
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let socket = UdpSocket::bind("127.0.0.1:8080").await?;
/// socket.connect("127.0.0.1:8081").await?;
///
/// loop {
/// // Wait for the socket to be readable

View File

@ -10,6 +10,10 @@ use std::os::unix::net;
use std::path::Path;
use std::task::{Context, Poll};
cfg_io_util! {
use bytes::BufMut;
}
cfg_net_unix! {
/// An I/O object representing a Unix datagram socket.
///
@ -652,6 +656,130 @@ impl UnixDatagram {
.try_io(Interest::READABLE, || self.io.recv(buf))
}
cfg_io_util! {
/// Try to receive data from the socket without waiting.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UnixDatagram;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let dir = tempfile::tempdir().unwrap();
/// let client_path = dir.path().join("client.sock");
/// let server_path = dir.path().join("server.sock");
/// let socket = UnixDatagram::bind(&client_path)?;
///
/// loop {
/// // Wait for the socket to be readable
/// socket.readable().await?;
///
/// let mut buf = Vec::with_capacity(1024);
///
/// // Try to recv data, this may still fail with `WouldBlock`
/// // if the readiness event is a false positive.
/// match socket.try_recv_buf_from(&mut buf) {
/// Ok((n, _addr)) => {
/// println!("GOT {:?}", &buf[..n]);
/// break;
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
/// }
/// Err(e) => {
/// return Err(e);
/// }
/// }
/// }
///
/// Ok(())
/// }
/// ```
pub fn try_recv_buf_from<B: BufMut>(&self, buf: &mut B) -> io::Result<(usize, SocketAddr)> {
let (n, addr) = self.io.registration().try_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
// Safety: We trust `UnixDatagram::recv_from` to have filled up `n` bytes in the
// buffer.
let (n, addr) = (&*self.io).recv_from(dst)?;
unsafe {
buf.advance_mut(n);
}
Ok((n, addr))
})?;
Ok((n, SocketAddr(addr)))
}
/// Try to read data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UnixDatagram;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// // Connect to a peer
/// let dir = tempfile::tempdir().unwrap();
/// let client_path = dir.path().join("client.sock");
/// let server_path = dir.path().join("server.sock");
/// let socket = UnixDatagram::bind(&client_path)?;
/// socket.connect(&server_path)?;
///
/// loop {
/// // Wait for the socket to be readable
/// socket.readable().await?;
///
/// let mut buf = Vec::with_capacity(1024);
///
/// // Try to recv data, this may still fail with `WouldBlock`
/// // if the readiness event is a false positive.
/// match socket.try_recv_buf(&mut buf) {
/// Ok(n) => {
/// println!("GOT {:?}", &buf[..n]);
/// break;
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
/// }
/// Err(e) => {
/// return Err(e);
/// }
/// }
/// }
///
/// Ok(())
/// }
/// ```
pub fn try_recv_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().try_io(Interest::READABLE, || {
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
// Safety: We trust `UnixDatagram::recv` to have filled up `n` bytes in the
// buffer.
let n = (&*self.io).recv(dst)?;
unsafe {
buf.advance_mut(n);
}
Ok(n)
})
}
}
/// Sends data on the socket to the specified address.
///
/// # Examples

View File

@ -15,6 +15,10 @@ use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
cfg_io_util! {
use bytes::BufMut;
}
cfg_net_unix! {
/// A structure representing a connected Unix socket.
///
@ -267,6 +271,87 @@ impl UnixStream {
.try_io(Interest::READABLE, || (&*self.io).read(buf))
}
cfg_io_util! {
/// Try to read data from the stream into the provided buffer, advancing the
/// buffer's internal cursor, returning how many bytes were read.
///
/// Receives any pending data from the socket but does not wait for new data
/// to arrive. On success, returns the number of bytes read. Because
/// `try_read_buf()` is non-blocking, the buffer does not have to be stored by
/// the async task and can exist entirely on the stack.
///
/// Usually, [`readable()`] or [`ready()`] is used with this function.
///
/// [`readable()`]: UnixStream::readable()
/// [`ready()`]: UnixStream::ready()
///
/// # Return
///
/// If data is successfully read, `Ok(n)` is returned, where `n` is the
/// number of bytes read. `Ok(0)` indicates the stream's read half is closed
/// and will no longer yield data. If the stream is not ready to read data
/// `Err(io::ErrorKind::WouldBlock)` is returned.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::UnixStream;
/// use std::error::Error;
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let dir = tempfile::tempdir().unwrap();
/// let bind_path = dir.path().join("bind_path");
/// let stream = UnixStream::connect(bind_path).await?;
///
/// loop {
/// // Wait for the socket to be readable
/// stream.readable().await?;
///
/// let mut buf = Vec::with_capacity(4096);
///
/// // Try to read data, this may still fail with `WouldBlock`
/// // if the readiness event is a false positive.
/// match stream.try_read_buf(&mut buf) {
/// Ok(0) => break,
/// Ok(n) => {
/// println!("read {} bytes", n);
/// }
/// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
/// continue;
/// }
/// Err(e) => {
/// return Err(e.into());
/// }
/// }
/// }
///
/// Ok(())
/// }
/// ```
pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
self.io.registration().try_io(Interest::READABLE, || {
use std::io::Read;
let dst = buf.chunk_mut();
let dst =
unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) };
// Safety: We trust `UnixStream::read` to have filled up `n` bytes in the
// buffer.
let n = (&*self.io).read(dst)?;
unsafe {
buf.advance_mut(n);
}
Ok(n)
})
}
}
/// Wait for the socket to become writable.
///
/// This function is equivalent to `ready(Interest::WRITABLE)` and is usually

View File

@ -234,3 +234,81 @@ fn write_until_pending(stream: &mut TcpStream) {
}
}
}
#[tokio::test]
async fn try_read_buf() {
const DATA: &[u8] = b"this is some data to write to the socket";
// Create listener
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
// Create socket pair
let client = TcpStream::connect(listener.local_addr().unwrap())
.await
.unwrap();
let (server, _) = listener.accept().await.unwrap();
let mut written = DATA.to_vec();
// Track the server receiving data
let mut readable = task::spawn(server.readable());
assert_pending!(readable.poll());
// Write data.
client.writable().await.unwrap();
assert_eq!(DATA.len(), client.try_write(DATA).unwrap());
// The task should be notified
while !readable.is_woken() {
tokio::task::yield_now().await;
}
// Fill the write buffer
loop {
// Still ready
let mut writable = task::spawn(client.writable());
assert_ready_ok!(writable.poll());
match client.try_write(DATA) {
Ok(n) => written.extend(&DATA[..n]),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(e) => panic!("error = {:?}", e),
}
}
{
// Write buffer full
let mut writable = task::spawn(client.writable());
assert_pending!(writable.poll());
// Drain the socket from the server end
let mut read = Vec::with_capacity(written.len());
let mut i = 0;
while i < read.capacity() {
server.readable().await.unwrap();
match server.try_read_buf(&mut read) {
Ok(n) => i += n,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("error = {:?}", e),
}
}
assert_eq!(read, written);
}
// Now, we listen for shutdown
drop(client);
loop {
let ready = server.ready(Interest::READABLE).await.unwrap();
if ready.is_read_closed() {
return;
} else {
tokio::task::yield_now().await;
}
}
}

View File

@ -353,3 +353,90 @@ async fn try_send_to_recv_from() {
}
}
}
#[tokio::test]
async fn try_recv_buf() {
// Create listener
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
// Create socket pair
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
// Connect the two
client.connect(server.local_addr().unwrap()).await.unwrap();
server.connect(client.local_addr().unwrap()).await.unwrap();
for _ in 0..5 {
loop {
client.writable().await.unwrap();
match client.try_send(b"hello world") {
Ok(n) => {
assert_eq!(n, 11);
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("{:?}", e),
}
}
loop {
server.readable().await.unwrap();
let mut buf = Vec::with_capacity(512);
match server.try_recv_buf(&mut buf) {
Ok(n) => {
assert_eq!(n, 11);
assert_eq!(&buf[0..11], &b"hello world"[..]);
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("{:?}", e),
}
}
}
}
#[tokio::test]
async fn try_recv_buf_from() {
// Create listener
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let saddr = server.local_addr().unwrap();
// Create socket pair
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let caddr = client.local_addr().unwrap();
for _ in 0..5 {
loop {
client.writable().await.unwrap();
match client.try_send_to(b"hello world", saddr) {
Ok(n) => {
assert_eq!(n, 11);
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("{:?}", e),
}
}
loop {
server.readable().await.unwrap();
let mut buf = Vec::with_capacity(512);
match server.try_recv_buf_from(&mut buf) {
Ok((n, addr)) => {
assert_eq!(n, 11);
assert_eq!(addr, caddr);
assert_eq!(&buf[0..11], &b"hello world"[..]);
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("{:?}", e),
}
}
}
}

View File

@ -230,3 +230,95 @@ async fn try_send_to_recv_from() -> std::io::Result<()> {
Ok(())
}
#[tokio::test]
async fn try_recv_buf_from() -> std::io::Result<()> {
let dir = tempfile::tempdir().unwrap();
let server_path = dir.path().join("server.sock");
let client_path = dir.path().join("client.sock");
// Create listener
let server = UnixDatagram::bind(&server_path)?;
// Create socket pair
let client = UnixDatagram::bind(&client_path)?;
for _ in 0..5 {
loop {
client.writable().await?;
match client.try_send_to(b"hello world", &server_path) {
Ok(n) => {
assert_eq!(n, 11);
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("{:?}", e),
}
}
loop {
server.readable().await?;
let mut buf = Vec::with_capacity(512);
match server.try_recv_buf_from(&mut buf) {
Ok((n, addr)) => {
assert_eq!(n, 11);
assert_eq!(addr.as_pathname(), Some(client_path.as_ref()));
assert_eq!(&buf[0..11], &b"hello world"[..]);
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("{:?}", e),
}
}
}
Ok(())
}
// Even though we use sync non-blocking io we still need a reactor.
#[tokio::test]
async fn try_recv_buf_never_block() -> io::Result<()> {
let payload = b"PAYLOAD";
let mut count = 0;
let (dgram1, dgram2) = UnixDatagram::pair()?;
// Send until we hit the OS `net.unix.max_dgram_qlen`.
loop {
dgram1.writable().await.unwrap();
match dgram1.try_send(payload) {
Err(err) => match err.kind() {
io::ErrorKind::WouldBlock | io::ErrorKind::Other => break,
_ => unreachable!("unexpected error {:?}", err),
},
Ok(len) => {
assert_eq!(len, payload.len());
}
}
count += 1;
}
// Read every dgram we sent.
while count > 0 {
let mut recv_buf = Vec::with_capacity(16);
dgram2.readable().await.unwrap();
let len = dgram2.try_recv_buf(&mut recv_buf)?;
assert_eq!(len, payload.len());
assert_eq!(payload, &recv_buf[..len]);
count -= 1;
}
let mut recv_buf = vec![0; 16];
let err = dgram2.try_recv_from(&mut recv_buf).unwrap_err();
match err.kind() {
io::ErrorKind::WouldBlock => (),
_ => unreachable!("unexpected error {:?}", err),
}
Ok(())
}

View File

@ -252,3 +252,85 @@ fn write_until_pending(stream: &mut UnixStream) {
}
}
}
#[tokio::test]
async fn try_read_buf() -> std::io::Result<()> {
let msg = b"hello world";
let dir = tempfile::tempdir()?;
let bind_path = dir.path().join("bind.sock");
// Create listener
let listener = UnixListener::bind(&bind_path)?;
// Create socket pair
let client = UnixStream::connect(&bind_path).await?;
let (server, _) = listener.accept().await?;
let mut written = msg.to_vec();
// Track the server receiving data
let mut readable = task::spawn(server.readable());
assert_pending!(readable.poll());
// Write data.
client.writable().await?;
assert_eq!(msg.len(), client.try_write(msg)?);
// The task should be notified
while !readable.is_woken() {
tokio::task::yield_now().await;
}
// Fill the write buffer
loop {
// Still ready
let mut writable = task::spawn(client.writable());
assert_ready_ok!(writable.poll());
match client.try_write(msg) {
Ok(n) => written.extend(&msg[..n]),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(e) => panic!("error = {:?}", e),
}
}
{
// Write buffer full
let mut writable = task::spawn(client.writable());
assert_pending!(writable.poll());
// Drain the socket from the server end
let mut read = Vec::with_capacity(written.len());
let mut i = 0;
while i < read.capacity() {
server.readable().await?;
match server.try_read_buf(&mut read) {
Ok(n) => i += n,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
Err(e) => panic!("error = {:?}", e),
}
}
assert_eq!(read, written);
}
// Now, we listen for shutdown
drop(client);
loop {
let ready = server.ready(Interest::READABLE).await?;
if ready.is_read_closed() {
break;
} else {
tokio::task::yield_now().await;
}
}
Ok(())
}