net: add ReadHalf::{poll,poll_peak} (#2151)

The `&mut self` requirements for `TcpStream` methods ensure that there are at
most two tasks using the stream--one for reading and one for writing.

`TcpStream::split` allows two separate tasks to hold a reference to a single
`TcpStream`. `TcpStream::{peek,poll_peek}` only poll for read readiness, and
therefore are safe to use with a `ReadHalf`.

Instead of duplicating `TcpStream::poll_peek`, a private method is now used by
both `poll_peek` methods that uses the fact that only a `&TcpStream` is
required.

Closes #2136
This commit is contained in:
Kevin Leimkuhler 2020-01-22 13:22:10 -08:00 committed by Carl Lerche
parent 5fe2df0fba
commit 7f580071f3
3 changed files with 124 additions and 1 deletions

View File

@ -8,6 +8,7 @@
//! split has no associated overhead and enforces all invariants at the type
//! level.
use crate::future::poll_fn;
use crate::io::{AsyncRead, AsyncWrite};
use crate::net::TcpStream;
@ -33,6 +34,79 @@ pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
(ReadHalf(&*stream), WriteHalf(&*stream))
}
impl ReadHalf<'_> {
/// Attempt to receive data on the socket, without removing that data from
/// the queue, registering the current task for wakeup if data is not yet
/// available.
///
/// See the [`TcpStream::poll_peek`] level documenation for more details.
///
/// # Examples
///
/// ```no_run
/// use tokio::io;
/// use tokio::net::TcpStream;
///
/// use futures::future::poll_fn;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let mut stream = TcpStream::connect("127.0.0.1:8000").await?;
/// let (mut read_half, _) = stream.split();
/// let mut buf = [0; 10];
///
/// poll_fn(|cx| {
/// read_half.poll_peek(cx, &mut buf)
/// }).await?;
///
/// Ok(())
/// }
/// ```
///
/// [`TcpStream::poll_peek`]: TcpStream::poll_peek
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.0.poll_peek2(cx, buf)
}
/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked.
///
/// See the [`TcpStream::peek`] level documenation for more details.
///
/// # Examples
///
/// ```no_run
/// use tokio::net::TcpStream;
/// use tokio::prelude::*;
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let mut stream = TcpStream::connect("127.0.0.1:8080").await?;
/// let (mut read_half, _) = stream.split();
///
/// let mut b1 = [0; 10];
/// let mut b2 = [0; 10];
///
/// // Peek at the data
/// let n = read_half.peek(&mut b1).await?;
///
/// // Read the data
/// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
/// assert_eq!(&b1[..n], &b2[..n]);
///
/// Ok(())
/// }
/// ```
///
/// [`TcpStream::peek`]: TcpStream::peek
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| self.poll_peek(cx, buf)).await
}
}
impl AsyncRead for ReadHalf<'_> {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
false

View File

@ -258,6 +258,14 @@ impl TcpStream {
/// }
/// ```
pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.poll_peek2(cx, buf)
}
pub(super) fn poll_peek2(
&self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
match self.io.get_ref().peek(buf) {

View File

@ -1 +1,42 @@
// TODO: write tests using TcpStream::split()
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use std::io::Result;
use std::io::{Read, Write};
use std::{net, thread};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[tokio::test]
async fn split() -> Result<()> {
const MSG: &[u8] = b"split";
let listener = net::TcpListener::bind("127.0.0.1:0")?;
let addr = listener.local_addr()?;
let handle = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
stream.write(MSG).unwrap();
let mut read_buf = [0u8; 32];
let read_len = stream.read(&mut read_buf).unwrap();
assert_eq!(&read_buf[..read_len], MSG);
});
let mut stream = TcpStream::connect(&addr).await?;
let (mut read_half, mut write_half) = stream.split();
let mut read_buf = [0u8; 32];
let peek_len1 = read_half.peek(&mut read_buf[..]).await?;
let peek_len2 = read_half.peek(&mut read_buf[..]).await?;
assert_eq!(peek_len1, peek_len2);
let read_len = read_half.read(&mut read_buf[..]).await?;
assert_eq!(peek_len1, read_len);
assert_eq!(&read_buf[..read_len], MSG);
write_half.write(MSG).await?;
handle.join().unwrap();
Ok(())
}