mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-25 12:00:35 +00:00
195 lines
5.4 KiB
Rust
195 lines
5.4 KiB
Rust
use std::{
|
|
future::poll_fn,
|
|
io::IoSlice,
|
|
pin::Pin,
|
|
task::{Context, Poll},
|
|
};
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
|
use tokio_util::io::{InspectReader, InspectWriter};
|
|
|
|
/// An AsyncRead implementation that works byte-by-byte, to catch out callers
|
|
/// who don't allow for `buf` being part-filled before the call
|
|
struct SmallReader {
|
|
contents: Vec<u8>,
|
|
}
|
|
|
|
impl Unpin for SmallReader {}
|
|
|
|
impl AsyncRead for SmallReader {
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
_cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
if let Some(byte) = self.contents.pop() {
|
|
buf.put_slice(&[byte])
|
|
}
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn read_tee() {
|
|
let contents = b"This could be really long, you know".to_vec();
|
|
let reader = SmallReader {
|
|
contents: contents.clone(),
|
|
};
|
|
let mut altout: Vec<u8> = Vec::new();
|
|
let mut teeout = Vec::new();
|
|
{
|
|
let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes));
|
|
tee.read_to_end(&mut teeout).await.unwrap();
|
|
}
|
|
assert_eq!(teeout, altout);
|
|
assert_eq!(altout.len(), contents.len());
|
|
}
|
|
|
|
/// An AsyncWrite implementation that works byte-by-byte for poll_write, and
|
|
/// that reads the whole of the first buffer plus one byte from the second in
|
|
/// poll_write_vectored.
|
|
///
|
|
/// This is designed to catch bugs in handling partially written buffers
|
|
#[derive(Debug)]
|
|
struct SmallWriter {
|
|
contents: Vec<u8>,
|
|
}
|
|
|
|
impl Unpin for SmallWriter {}
|
|
|
|
impl AsyncWrite for SmallWriter {
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
_cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<Result<usize, std::io::Error>> {
|
|
// Just write one byte at a time
|
|
if buf.is_empty() {
|
|
return Poll::Ready(Ok(0));
|
|
}
|
|
self.contents.push(buf[0]);
|
|
Poll::Ready(Ok(1))
|
|
}
|
|
|
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn poll_shutdown(
|
|
self: Pin<&mut Self>,
|
|
_cx: &mut Context<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn poll_write_vectored(
|
|
mut self: Pin<&mut Self>,
|
|
_cx: &mut Context<'_>,
|
|
bufs: &[IoSlice<'_>],
|
|
) -> Poll<Result<usize, std::io::Error>> {
|
|
// Write all of the first buffer, then one byte from the second buffer
|
|
// This should trip up anything that doesn't correctly handle multiple
|
|
// buffers.
|
|
if bufs.is_empty() {
|
|
return Poll::Ready(Ok(0));
|
|
}
|
|
let mut written_len = bufs[0].len();
|
|
self.contents.extend_from_slice(&bufs[0]);
|
|
|
|
if bufs.len() > 1 {
|
|
let buf = bufs[1];
|
|
if !buf.is_empty() {
|
|
written_len += 1;
|
|
self.contents.push(buf[0]);
|
|
}
|
|
}
|
|
Poll::Ready(Ok(written_len))
|
|
}
|
|
|
|
fn is_write_vectored(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn write_tee() {
|
|
let mut altout: Vec<u8> = Vec::new();
|
|
let mut writeout = SmallWriter {
|
|
contents: Vec::new(),
|
|
};
|
|
{
|
|
let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes));
|
|
tee.write_all(b"A testing string, very testing")
|
|
.await
|
|
.unwrap();
|
|
}
|
|
assert_eq!(altout, writeout.contents);
|
|
}
|
|
|
|
// This is inefficient, but works well enough for test use.
|
|
// If you want something similar for real code, you'll want to avoid all the
|
|
// fun of manipulating `bufs` - ideally, by the time you read this,
|
|
// IoSlice::advance_slices will be stable, and you can use that.
|
|
async fn write_all_vectored<W: AsyncWrite + Unpin>(
|
|
mut writer: W,
|
|
mut bufs: Vec<Vec<u8>>,
|
|
) -> Result<usize, std::io::Error> {
|
|
let mut res = 0;
|
|
while !bufs.is_empty() {
|
|
let mut written = poll_fn(|cx| {
|
|
let bufs: Vec<IoSlice> = bufs.iter().map(|v| IoSlice::new(v)).collect();
|
|
Pin::new(&mut writer).poll_write_vectored(cx, &bufs)
|
|
})
|
|
.await?;
|
|
res += written;
|
|
while written > 0 {
|
|
let buf_len = bufs[0].len();
|
|
if buf_len <= written {
|
|
bufs.remove(0);
|
|
written -= buf_len;
|
|
} else {
|
|
let buf = &mut bufs[0];
|
|
let drain_len = written.min(buf.len());
|
|
buf.drain(..drain_len);
|
|
written -= drain_len;
|
|
}
|
|
}
|
|
}
|
|
Ok(res)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn write_tee_vectored() {
|
|
let mut altout: Vec<u8> = Vec::new();
|
|
let mut writeout = SmallWriter {
|
|
contents: Vec::new(),
|
|
};
|
|
let original = b"A very long string split up";
|
|
let bufs: Vec<Vec<u8>> = original
|
|
.split(|b| b.is_ascii_whitespace())
|
|
.map(Vec::from)
|
|
.collect();
|
|
assert!(bufs.len() > 1);
|
|
let expected: Vec<u8> = {
|
|
let mut out = Vec::new();
|
|
for item in &bufs {
|
|
out.extend_from_slice(item)
|
|
}
|
|
out
|
|
};
|
|
{
|
|
let mut bufcount = 0;
|
|
let tee = InspectWriter::new(&mut writeout, |bytes| {
|
|
bufcount += 1;
|
|
altout.extend(bytes)
|
|
});
|
|
|
|
assert!(tee.is_write_vectored());
|
|
|
|
write_all_vectored(tee, bufs.clone()).await.unwrap();
|
|
|
|
assert!(bufcount >= bufs.len());
|
|
}
|
|
assert_eq!(altout, writeout.contents);
|
|
assert_eq!(writeout.contents, expected);
|
|
}
|