mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-28 12:10:37 +00:00
io: add AsyncReadExt::{chain, take} (#1484)
This commit is contained in:
parent
a791f4a758
commit
24fb33e012
@ -1,12 +1,27 @@
|
||||
use crate::io::chain::{chain, Chain};
|
||||
use crate::io::copy::{copy, Copy};
|
||||
use crate::io::read::{read, Read};
|
||||
use crate::io::read_exact::{read_exact, ReadExact};
|
||||
use crate::io::read_to_end::{read_to_end, ReadToEnd};
|
||||
use crate::io::read_to_string::{read_to_string, ReadToString};
|
||||
use crate::io::take::{take, Take};
|
||||
use crate::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// An extension trait which adds utility methods to `AsyncRead` types.
|
||||
pub trait AsyncReadExt: AsyncRead {
|
||||
/// Creates an adaptor which will chain this stream with another.
|
||||
///
|
||||
/// The returned `AsyncRead` instance will first read all bytes from this object
|
||||
/// until EOF is encountered. Afterwards the output is equivalent to the
|
||||
/// output of `next`.
|
||||
fn chain<R>(self, next: R) -> Chain<Self, R>
|
||||
where
|
||||
Self: Sized,
|
||||
R: AsyncRead,
|
||||
{
|
||||
chain(self, next)
|
||||
}
|
||||
|
||||
/// Copy all data from `self` into the provided `AsyncWrite`.
|
||||
///
|
||||
/// The returned future will copy all the bytes read from `reader` into the
|
||||
@ -63,6 +78,15 @@ pub trait AsyncReadExt: AsyncRead {
|
||||
{
|
||||
read_to_string(self, dst)
|
||||
}
|
||||
|
||||
/// Creates an AsyncRead adapter which will read at most `limit` bytes
|
||||
/// from the underlying reader.
|
||||
fn take(self, limit: u64) -> Take<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
take(self, limit)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + ?Sized> AsyncReadExt for R {}
|
||||
|
142
tokio-io/src/io/chain.rs
Normal file
142
tokio-io/src/io/chain.rs
Normal file
@ -0,0 +1,142 @@
|
||||
use crate::{AsyncBufRead, AsyncRead};
|
||||
use futures_core::ready;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
|
||||
#[must_use = "streams do nothing unless polled"]
|
||||
pub struct Chain<T, U> {
|
||||
first: T,
|
||||
second: U,
|
||||
done_first: bool,
|
||||
}
|
||||
|
||||
impl<T, U> Unpin for Chain<T, U>
|
||||
where
|
||||
T: Unpin,
|
||||
U: Unpin,
|
||||
{
|
||||
}
|
||||
|
||||
pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
Chain {
|
||||
first,
|
||||
second,
|
||||
done_first: false,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Chain<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
unsafe_pinned!(first: T);
|
||||
unsafe_pinned!(second: U);
|
||||
unsafe_unpinned!(done_first: bool);
|
||||
|
||||
/// Gets references to the underlying readers in this `Chain`.
|
||||
pub fn get_ref(&self) -> (&T, &U) {
|
||||
(&self.first, &self.second)
|
||||
}
|
||||
|
||||
/// Gets mutable references to the underlying readers in this `Chain`.
|
||||
///
|
||||
/// Care should be taken to avoid modifying the internal I/O state of the
|
||||
/// underlying readers as doing so may corrupt the internal state of this
|
||||
/// `Chain`.
|
||||
pub fn get_mut(&mut self) -> (&mut T, &mut U) {
|
||||
(&mut self.first, &mut self.second)
|
||||
}
|
||||
|
||||
/// Gets pinned mutable references to the underlying readers in this `Chain`.
|
||||
///
|
||||
/// Care should be taken to avoid modifying the internal I/O state of the
|
||||
/// underlying readers as doing so may corrupt the internal state of this
|
||||
/// `Chain`.
|
||||
pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
|
||||
unsafe {
|
||||
let Self { first, second, .. } = self.get_unchecked_mut();
|
||||
(Pin::new_unchecked(first), Pin::new_unchecked(second))
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the `Chain`, returning the wrapped readers.
|
||||
pub fn into_inner(self) -> (T, U) {
|
||||
(self.first, self.second)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> fmt::Debug for Chain<T, U>
|
||||
where
|
||||
T: fmt::Debug,
|
||||
U: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Chain")
|
||||
.field("t", &self.first)
|
||||
.field("u", &self.second)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncRead for Chain<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if !self.done_first {
|
||||
match ready!(self.as_mut().first().poll_read(cx, buf)?) {
|
||||
0 if !buf.is_empty() => *self.as_mut().done_first() = true,
|
||||
n => return Poll::Ready(Ok(n)),
|
||||
}
|
||||
}
|
||||
self.second().poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncBufRead for Chain<T, U>
|
||||
where
|
||||
T: AsyncBufRead,
|
||||
U: AsyncBufRead,
|
||||
{
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
let Self {
|
||||
first,
|
||||
second,
|
||||
done_first,
|
||||
} = unsafe { self.get_unchecked_mut() };
|
||||
let first = unsafe { Pin::new_unchecked(first) };
|
||||
let second = unsafe { Pin::new_unchecked(second) };
|
||||
|
||||
if !*done_first {
|
||||
match ready!(first.poll_fill_buf(cx)?) {
|
||||
buf if buf.is_empty() => {
|
||||
*done_first = true;
|
||||
}
|
||||
buf => return Poll::Ready(Ok(buf)),
|
||||
}
|
||||
}
|
||||
second.poll_fill_buf(cx)
|
||||
}
|
||||
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
if !self.done_first {
|
||||
self.first().consume(amt)
|
||||
} else {
|
||||
self.second().consume(amt)
|
||||
}
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ mod async_read_ext;
|
||||
mod async_write_ext;
|
||||
mod buf_reader;
|
||||
mod buf_writer;
|
||||
mod chain;
|
||||
mod copy;
|
||||
mod flush;
|
||||
mod lines;
|
||||
@ -13,6 +14,7 @@ mod read_to_end;
|
||||
mod read_to_string;
|
||||
mod read_until;
|
||||
mod shutdown;
|
||||
mod take;
|
||||
mod write;
|
||||
mod write_all;
|
||||
|
||||
|
120
tokio-io/src/io/take.rs
Normal file
120
tokio-io/src/io/take.rs
Normal file
@ -0,0 +1,120 @@
|
||||
use crate::{AsyncBufRead, AsyncRead};
|
||||
use futures_core::ready;
|
||||
use pin_utils::{unsafe_pinned, unsafe_unpinned};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{cmp, io};
|
||||
|
||||
/// Stream for the [`take`](super::AsyncReadExt::take) method.
|
||||
#[derive(Debug)]
|
||||
#[must_use = "streams do nothing unless you `.await` or poll them"]
|
||||
pub struct Take<R> {
|
||||
inner: R,
|
||||
// Add '_' to avoid conflicts with `limit` method.
|
||||
limit_: u64,
|
||||
}
|
||||
|
||||
impl<R: Unpin> Unpin for Take<R> {}
|
||||
|
||||
pub(super) fn take<R: AsyncRead>(inner: R, limit: u64) -> Take<R> {
|
||||
Take {
|
||||
inner,
|
||||
limit_: limit,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> Take<R> {
|
||||
unsafe_pinned!(inner: R);
|
||||
unsafe_unpinned!(limit_: u64);
|
||||
|
||||
/// Returns the remaining number of bytes that can be
|
||||
/// read before this instance will return EOF.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This instance may reach `EOF` after reading fewer bytes than indicated by
|
||||
/// this method if the underlying [`AsyncRead`] instance reaches EOF.
|
||||
pub fn limit(&self) -> u64 {
|
||||
self.limit_
|
||||
}
|
||||
|
||||
/// Sets the number of bytes that can be read before this instance will
|
||||
/// return EOF. This is the same as constructing a new `Take` instance, so
|
||||
/// the amount of bytes read and the previous limit value don't matter when
|
||||
/// calling this method.
|
||||
pub fn set_limit(&mut self, limit: u64) {
|
||||
self.limit_ = limit
|
||||
}
|
||||
|
||||
/// Gets a reference to the underlying reader.
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
/// Gets a mutable reference to the underlying reader.
|
||||
///
|
||||
/// Care should be taken to avoid modifying the internal I/O state of the
|
||||
/// underlying reader as doing so may corrupt the internal limit of this
|
||||
/// `Take`.
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.inner
|
||||
}
|
||||
|
||||
/// Gets a pinned mutable reference to the underlying reader.
|
||||
///
|
||||
/// Care should be taken to avoid modifying the internal I/O state of the
|
||||
/// underlying reader as doing so may corrupt the internal limit of this
|
||||
/// `Take`.
|
||||
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
|
||||
self.inner()
|
||||
}
|
||||
|
||||
/// Consumes the `Take`, returning the wrapped reader.
|
||||
pub fn into_inner(self) -> R {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead> AsyncRead for Take<R> {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
self.inner.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
if self.limit_ == 0 {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
|
||||
let max = std::cmp::min(buf.len() as u64, self.limit_) as usize;
|
||||
let n = ready!(self.as_mut().inner().poll_read(cx, &mut buf[..max]))?;
|
||||
*self.as_mut().limit_() -= n as u64;
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
let Self { inner, limit_ } = unsafe { self.get_unchecked_mut() };
|
||||
let inner = unsafe { Pin::new_unchecked(inner) };
|
||||
|
||||
// Don't call into inner reader at all at EOF because it may still block
|
||||
if *limit_ == 0 {
|
||||
return Poll::Ready(Ok(&[]));
|
||||
}
|
||||
|
||||
let buf = ready!(inner.poll_fill_buf(cx)?);
|
||||
let cap = cmp::min(buf.len() as u64, *limit_) as usize;
|
||||
Poll::Ready(Ok(&buf[..cap]))
|
||||
}
|
||||
|
||||
fn consume(mut self: Pin<&mut Self>, amt: usize) {
|
||||
// Don't let callers reset the limit by passing an overlarge value
|
||||
let amt = cmp::min(amt as u64, self.limit_) as usize;
|
||||
*self.as_mut().limit_() -= amt as u64;
|
||||
self.inner().consume(amt);
|
||||
}
|
||||
}
|
16
tokio-io/tests/chain.rs
Normal file
16
tokio-io/tests/chain.rs
Normal file
@ -0,0 +1,16 @@
|
||||
#![warn(rust_2018_idioms)]
|
||||
#![feature(async_await)]
|
||||
|
||||
use tokio_io::AsyncReadExt;
|
||||
use tokio_test::assert_ok;
|
||||
|
||||
#[tokio::test]
|
||||
async fn chain() {
|
||||
let mut buf = Vec::new();
|
||||
let rd1: &[u8] = b"hello ";
|
||||
let rd2: &[u8] = b"world";
|
||||
|
||||
let mut rd = rd1.chain(rd2);
|
||||
assert_ok!(rd.read_to_end(&mut buf).await);
|
||||
assert_eq!(buf, b"hello world");
|
||||
}
|
16
tokio-io/tests/take.rs
Normal file
16
tokio-io/tests/take.rs
Normal file
@ -0,0 +1,16 @@
|
||||
#![warn(rust_2018_idioms)]
|
||||
#![feature(async_await)]
|
||||
|
||||
use tokio_io::AsyncReadExt;
|
||||
use tokio_test::assert_ok;
|
||||
|
||||
#[tokio::test]
|
||||
async fn take() {
|
||||
let mut buf = [0; 6];
|
||||
let rd: &[u8] = b"hello world";
|
||||
|
||||
let mut rd = rd.take(4);
|
||||
let n = assert_ok!(rd.read(&mut buf).await);
|
||||
assert_eq!(n, 4);
|
||||
assert_eq!(&buf, &b"hell\0\0"[..]);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user