mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-25 12:00:35 +00:00
io: wrappers for inspecting data on IO resources (#5033)
This commit is contained in:
parent
9b87daad7e
commit
96fab053ab
134
tokio-util/src/io/inspect.rs
Normal file
134
tokio-util/src/io/inspect.rs
Normal file
@ -0,0 +1,134 @@
|
||||
use futures_core::ready;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::io::{IoSlice, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
/// An adapter that lets you inspect the data that's being read.
|
||||
///
|
||||
/// This is useful for things like hashing data as it's read in.
|
||||
pub struct InspectReader<R, F> {
|
||||
#[pin]
|
||||
reader: R,
|
||||
f: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, F> InspectReader<R, F> {
|
||||
/// Create a new InspectReader, wrapping `reader` and calling `f` for the
|
||||
/// new data supplied by each read call.
|
||||
///
|
||||
/// The closure will only be called with an empty slice if the inner reader
|
||||
/// returns without reading data into the buffer. This happens at EOF, or if
|
||||
/// `poll_read` is called with a zero-size buffer.
|
||||
pub fn new(reader: R, f: F) -> InspectReader<R, F>
|
||||
where
|
||||
R: AsyncRead,
|
||||
F: FnMut(&[u8]),
|
||||
{
|
||||
InspectReader { reader, f }
|
||||
}
|
||||
|
||||
/// Consumes the `InspectReader`, returning the wrapped reader
|
||||
pub fn into_inner(self) -> R {
|
||||
self.reader
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead, F: FnMut(&[u8])> AsyncRead for InspectReader<R, F> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
let me = self.project();
|
||||
let filled_length = buf.filled().len();
|
||||
ready!(me.reader.poll_read(cx, buf))?;
|
||||
(me.f)(&buf.filled()[filled_length..]);
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// An adapter that lets you inspect the data that's being written.
|
||||
///
|
||||
/// This is useful for things like hashing data as it's written out.
|
||||
pub struct InspectWriter<W, F> {
|
||||
#[pin]
|
||||
writer: W,
|
||||
f: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<W, F> InspectWriter<W, F> {
|
||||
/// Create a new InspectWriter, wrapping `write` and calling `f` for the
|
||||
/// data successfully written by each write call.
|
||||
///
|
||||
/// The closure `f` will never be called with an empty slice. A vectored
|
||||
/// write can result in multiple calls to `f` - at most one call to `f` per
|
||||
/// buffer supplied to `poll_write_vectored`.
|
||||
pub fn new(writer: W, f: F) -> InspectWriter<W, F>
|
||||
where
|
||||
W: AsyncWrite,
|
||||
F: FnMut(&[u8]),
|
||||
{
|
||||
InspectWriter { writer, f }
|
||||
}
|
||||
|
||||
/// Consumes the `InspectWriter`, returning the wrapped writer
|
||||
pub fn into_inner(self) -> W {
|
||||
self.writer
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite, F: FnMut(&[u8])> AsyncWrite for InspectWriter<W, F> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
let me = self.project();
|
||||
let res = me.writer.poll_write(cx, buf);
|
||||
if let Poll::Ready(Ok(count)) = res {
|
||||
if count != 0 {
|
||||
(me.f)(&buf[..count]);
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let me = self.project();
|
||||
me.writer.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let me = self.project();
|
||||
me.writer.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize>> {
|
||||
let me = self.project();
|
||||
let res = me.writer.poll_write_vectored(cx, bufs);
|
||||
if let Poll::Ready(Ok(mut count)) = res {
|
||||
for buf in bufs {
|
||||
if count == 0 {
|
||||
break;
|
||||
}
|
||||
let size = count.min(buf.len());
|
||||
if size != 0 {
|
||||
(me.f)(&buf[..size]);
|
||||
count -= size;
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.writer.is_write_vectored()
|
||||
}
|
||||
}
|
@ -10,14 +10,17 @@
|
||||
//! [`Body`]: https://docs.rs/hyper/0.13/hyper/struct.Body.html
|
||||
//! [`AsyncRead`]: tokio::io::AsyncRead
|
||||
|
||||
mod inspect;
|
||||
mod read_buf;
|
||||
mod reader_stream;
|
||||
mod stream_reader;
|
||||
|
||||
cfg_io_util! {
|
||||
mod sync_bridge;
|
||||
pub use self::sync_bridge::SyncIoBridge;
|
||||
}
|
||||
|
||||
pub use self::inspect::{InspectReader, InspectWriter};
|
||||
pub use self::read_buf::read_buf;
|
||||
pub use self::reader_stream::ReaderStream;
|
||||
pub use self::stream_reader::StreamReader;
|
||||
|
194
tokio-util/tests/io_inspect.rs
Normal file
194
tokio-util/tests/io_inspect.rs
Normal file
@ -0,0 +1,194 @@
|
||||
use futures::future::poll_fn;
|
||||
use std::{
|
||||
io::IoSlice,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_util::io::{InspectReader, InspectWriter};
|
||||
|
||||
/// An AsyncRead implementation that works byte-by-byte, to catch out callers
|
||||
/// who don't allow for `buf` being part-filled before the call
|
||||
struct SmallReader {
|
||||
contents: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Unpin for SmallReader {}
|
||||
|
||||
impl AsyncRead for SmallReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if let Some(byte) = self.contents.pop() {
|
||||
buf.put_slice(&[byte])
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_tee() {
|
||||
let contents = b"This could be really long, you know".to_vec();
|
||||
let reader = SmallReader {
|
||||
contents: contents.clone(),
|
||||
};
|
||||
let mut altout: Vec<u8> = Vec::new();
|
||||
let mut teeout = Vec::new();
|
||||
{
|
||||
let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
|
||||
tee.read_to_end(&mut teeout).await.unwrap();
|
||||
}
|
||||
assert_eq!(teeout, altout);
|
||||
assert_eq!(altout.len(), contents.len());
|
||||
}
|
||||
|
||||
/// An AsyncWrite implementation that works byte-by-byte for poll_write, and
|
||||
/// that reads the whole of the first buffer plus one byte from the second in
|
||||
/// poll_write_vectored.
|
||||
///
|
||||
/// This is designed to catch bugs in handling partially written buffers
|
||||
#[derive(Debug)]
|
||||
struct SmallWriter {
|
||||
contents: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Unpin for SmallWriter {}
|
||||
|
||||
impl AsyncWrite for SmallWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
// Just write one byte at a time
|
||||
if buf.is_empty() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
self.contents.push(buf[0]);
|
||||
Poll::Ready(Ok(1))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
// Write all of the first buffer, then one byte from the second buffer
|
||||
// This should trip up anything that doesn't correctly handle multiple
|
||||
// buffers.
|
||||
if bufs.is_empty() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
let mut written_len = bufs[0].len();
|
||||
self.contents.extend_from_slice(&bufs[0]);
|
||||
|
||||
if bufs.len() > 1 {
|
||||
let buf = bufs[1];
|
||||
if !buf.is_empty() {
|
||||
written_len += 1;
|
||||
self.contents.push(buf[0]);
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(written_len))
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_tee() {
|
||||
let mut altout: Vec<u8> = Vec::new();
|
||||
let mut writeout = SmallWriter {
|
||||
contents: Vec::new(),
|
||||
};
|
||||
{
|
||||
let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
|
||||
tee.write_all(b"A testing string, very testing")
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
assert_eq!(altout, writeout.contents);
|
||||
}
|
||||
|
||||
// This is inefficient, but works well enough for test use.
|
||||
// If you want something similar for real code, you'll want to avoid all the
|
||||
// fun of manipulating `bufs` - ideally, by the time you read this,
|
||||
// IoSlice::advance_slices will be stable, and you can use that.
|
||||
async fn write_all_vectored<W: AsyncWrite + Unpin>(
|
||||
mut writer: W,
|
||||
mut bufs: Vec<Vec<u8>>,
|
||||
) -> Result<usize, std::io::Error> {
|
||||
let mut res = 0;
|
||||
while !bufs.is_empty() {
|
||||
let mut written = poll_fn(|cx| {
|
||||
let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
|
||||
Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
|
||||
})
|
||||
.await?;
|
||||
res += written;
|
||||
while written > 0 {
|
||||
let buf_len = bufs[0].len();
|
||||
if buf_len <= written {
|
||||
bufs.remove(0);
|
||||
written -= buf_len;
|
||||
} else {
|
||||
let buf = &mut bufs[0];
|
||||
let drain_len = written.min(buf.len());
|
||||
buf.drain(..drain_len);
|
||||
written -= drain_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_tee_vectored() {
|
||||
let mut altout: Vec<u8> = Vec::new();
|
||||
let mut writeout = SmallWriter {
|
||||
contents: Vec::new(),
|
||||
};
|
||||
let original = b"A very long string split up";
|
||||
let bufs: Vec<Vec<u8>> = original
|
||||
.split(|b| b.is_ascii_whitespace())
|
||||
.map(Vec::from)
|
||||
.collect();
|
||||
assert!(bufs.len() > 1);
|
||||
let expected: Vec<u8> = {
|
||||
let mut out = Vec::new();
|
||||
for item in &bufs {
|
||||
out.extend_from_slice(item)
|
||||
}
|
||||
out
|
||||
};
|
||||
{
|
||||
let mut bufcount = 0;
|
||||
let tee = InspectWriter::new(&mut writeout, |bytes| {
|
||||
bufcount += 1;
|
||||
altout.extend(bytes)
|
||||
});
|
||||
|
||||
assert!(tee.is_write_vectored());
|
||||
|
||||
write_all_vectored(tee, bufs.clone()).await.unwrap();
|
||||
|
||||
assert!(bufcount >= bufs.len());
|
||||
}
|
||||
assert_eq!(altout, writeout.contents);
|
||||
assert_eq!(writeout.contents, expected);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user