mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-28 12:10:37 +00:00
io: efficient implementation of vectored writes for BufWriter (#3163)
This commit is contained in:
parent
38204f5fba
commit
0531549b6e
@ -2,7 +2,7 @@ use crate::io::util::DEFAULT_BUF_SIZE;
|
||||
use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use std::io::{self, SeekFrom};
|
||||
use std::io::{self, IoSlice, SeekFrom};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{cmp, fmt, mem};
|
||||
@ -268,6 +268,18 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
|
||||
self.get_pin_mut().poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.get_pin_mut().poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.get_ref().is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.get_pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use crate::io::util::{BufReader, BufWriter};
|
||||
use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use std::io::{self, SeekFrom};
|
||||
use std::io::{self, IoSlice, SeekFrom};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
@ -127,6 +127,18 @@ impl<RW: AsyncRead + AsyncWrite> AsyncWrite for BufStream<RW> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
use std::fmt;
|
||||
use std::io::{self, SeekFrom, Write};
|
||||
use std::io::{self, IoSlice, SeekFrom, Write};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
@ -133,6 +133,72 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
mut bufs: &[IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if self.inner.is_write_vectored() {
|
||||
let total_len = bufs
|
||||
.iter()
|
||||
.fold(0usize, |acc, b| acc.saturating_add(b.len()));
|
||||
if total_len > self.buf.capacity() - self.buf.len() {
|
||||
ready!(self.as_mut().flush_buf(cx))?;
|
||||
}
|
||||
let me = self.as_mut().project();
|
||||
if total_len >= me.buf.capacity() {
|
||||
// It's more efficient to pass the slices directly to the
|
||||
// underlying writer than to buffer them.
|
||||
// The case when the total_len calculation saturates at
|
||||
// usize::MAX is also handled here.
|
||||
me.inner.poll_write_vectored(cx, bufs)
|
||||
} else {
|
||||
bufs.iter().for_each(|b| me.buf.extend_from_slice(b));
|
||||
Poll::Ready(Ok(total_len))
|
||||
}
|
||||
} else {
|
||||
// Remove empty buffers at the beginning of bufs.
|
||||
while bufs.first().map(|buf| buf.len()) == Some(0) {
|
||||
bufs = &bufs[1..];
|
||||
}
|
||||
if bufs.is_empty() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
// Flush if the first buffer doesn't fit.
|
||||
let first_len = bufs[0].len();
|
||||
if first_len > self.buf.capacity() - self.buf.len() {
|
||||
ready!(self.as_mut().flush_buf(cx))?;
|
||||
debug_assert!(self.buf.is_empty());
|
||||
}
|
||||
let me = self.as_mut().project();
|
||||
if first_len >= me.buf.capacity() {
|
||||
// The slice is at least as large as the buffering capacity,
|
||||
// so it's better to write it directly, bypassing the buffer.
|
||||
debug_assert!(me.buf.is_empty());
|
||||
return me.inner.poll_write(cx, &bufs[0]);
|
||||
} else {
|
||||
me.buf.extend_from_slice(&bufs[0]);
|
||||
bufs = &bufs[1..];
|
||||
}
|
||||
let mut total_written = first_len;
|
||||
debug_assert!(total_written != 0);
|
||||
// Append the buffers that fit in the internal buffer.
|
||||
for buf in bufs {
|
||||
if buf.len() > me.buf.capacity() - me.buf.len() {
|
||||
break;
|
||||
} else {
|
||||
me.buf.extend_from_slice(buf);
|
||||
total_written += buf.len();
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(total_written))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
ready!(self.as_mut().flush_buf(cx))?;
|
||||
self.get_pin_mut().poll_flush(cx)
|
||||
|
@ -8,6 +8,17 @@ use std::io::{self, Cursor};
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter, SeekFrom};
|
||||
|
||||
use futures::future;
|
||||
use tokio_test::assert_ok;
|
||||
|
||||
use std::cmp;
|
||||
use std::io::IoSlice;
|
||||
|
||||
mod support {
|
||||
pub(crate) mod io_vec;
|
||||
}
|
||||
use support::io_vec::IoBufs;
|
||||
|
||||
struct MaybePending {
|
||||
inner: Vec<u8>,
|
||||
ready: bool,
|
||||
@ -47,6 +58,14 @@ impl AsyncWrite for MaybePending {
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let mut writer = Pin::new(writer);
|
||||
future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn buf_writer() {
|
||||
let mut writer = BufWriter::with_capacity(2, Vec::new());
|
||||
@ -249,3 +268,270 @@ async fn maybe_pending_buf_writer_seek() {
|
||||
&[0, 1, 8, 9, 4, 5, 6, 7]
|
||||
);
|
||||
}
|
||||
|
||||
struct MockWriter {
|
||||
data: Vec<u8>,
|
||||
write_len: usize,
|
||||
vectored: bool,
|
||||
}
|
||||
|
||||
impl MockWriter {
|
||||
fn new(write_len: usize) -> Self {
|
||||
MockWriter {
|
||||
data: Vec::new(),
|
||||
write_len,
|
||||
vectored: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn vectored(write_len: usize) -> Self {
|
||||
MockWriter {
|
||||
data: Vec::new(),
|
||||
write_len,
|
||||
vectored: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize {
|
||||
let len = cmp::min(buf.len(), limit);
|
||||
self.data.extend_from_slice(&buf[..len]);
|
||||
len
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MockWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
let this = self.get_mut();
|
||||
let n = this.write_up_to(buf, this.write_len);
|
||||
Ok(n).into()
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
_: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
let this = self.get_mut();
|
||||
let mut total_written = 0;
|
||||
for buf in bufs {
|
||||
let n = this.write_up_to(buf, this.write_len - total_written);
|
||||
total_written += n;
|
||||
if total_written == this.write_len {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(total_written).into()
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.vectored
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Ok(()).into()
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Ok(()).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_empty_on_non_vectored() {
|
||||
let mut w = BufWriter::new(MockWriter::new(4));
|
||||
let n = assert_ok!(write_vectored(&mut w, &[]).await);
|
||||
assert_eq!(n, 0);
|
||||
|
||||
let io_vec = [IoSlice::new(&[]); 3];
|
||||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
|
||||
assert_eq!(n, 0);
|
||||
|
||||
assert_ok!(w.flush().await);
|
||||
assert!(w.get_ref().data.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_empty_on_vectored() {
|
||||
let mut w = BufWriter::new(MockWriter::vectored(4));
|
||||
let n = assert_ok!(write_vectored(&mut w, &[]).await);
|
||||
assert_eq!(n, 0);
|
||||
|
||||
let io_vec = [IoSlice::new(&[]); 3];
|
||||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
|
||||
assert_eq!(n, 0);
|
||||
|
||||
assert_ok!(w.flush().await);
|
||||
assert!(w.get_ref().data.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_basic_on_non_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let bufs = [
|
||||
IoSlice::new(&msg[0..4]),
|
||||
IoSlice::new(&msg[4..8]),
|
||||
IoSlice::new(&msg[8..]),
|
||||
];
|
||||
let mut w = BufWriter::new(MockWriter::new(4));
|
||||
let n = assert_ok!(write_vectored(&mut w, &bufs).await);
|
||||
assert_eq!(n, msg.len());
|
||||
assert!(w.buffer() == &msg[..]);
|
||||
assert_ok!(w.flush().await);
|
||||
assert_eq!(w.get_ref().data, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_basic_on_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let bufs = [
|
||||
IoSlice::new(&msg[0..4]),
|
||||
IoSlice::new(&msg[4..8]),
|
||||
IoSlice::new(&msg[8..]),
|
||||
];
|
||||
let mut w = BufWriter::new(MockWriter::vectored(4));
|
||||
let n = assert_ok!(write_vectored(&mut w, &bufs).await);
|
||||
assert_eq!(n, msg.len());
|
||||
assert!(w.buffer() == &msg[..]);
|
||||
assert_ok!(w.flush().await);
|
||||
assert_eq!(w.get_ref().data, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_large_total_on_non_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let mut bufs = [
|
||||
IoSlice::new(&msg[0..4]),
|
||||
IoSlice::new(&msg[4..8]),
|
||||
IoSlice::new(&msg[8..]),
|
||||
];
|
||||
let io_vec = IoBufs::new(&mut bufs);
|
||||
let mut w = BufWriter::with_capacity(8, MockWriter::new(4));
|
||||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
|
||||
assert_eq!(n, 8);
|
||||
assert!(w.buffer() == &msg[..8]);
|
||||
let io_vec = io_vec.advance(n);
|
||||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
|
||||
assert_eq!(n, 3);
|
||||
assert!(w.get_ref().data.as_slice() == &msg[..8]);
|
||||
assert!(w.buffer() == &msg[8..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_large_total_on_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let mut bufs = [
|
||||
IoSlice::new(&msg[0..4]),
|
||||
IoSlice::new(&msg[4..8]),
|
||||
IoSlice::new(&msg[8..]),
|
||||
];
|
||||
let io_vec = IoBufs::new(&mut bufs);
|
||||
let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10));
|
||||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
|
||||
assert_eq!(n, 10);
|
||||
assert!(w.buffer().is_empty());
|
||||
let io_vec = io_vec.advance(n);
|
||||
let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
|
||||
assert_eq!(n, 1);
|
||||
assert!(w.get_ref().data.as_slice() == &msg[..10]);
|
||||
assert!(w.buffer() == &msg[10..]);
|
||||
}
|
||||
|
||||
struct VectoredWriteHarness {
|
||||
writer: BufWriter<MockWriter>,
|
||||
buf_capacity: usize,
|
||||
}
|
||||
|
||||
impl VectoredWriteHarness {
|
||||
fn new(buf_capacity: usize) -> Self {
|
||||
VectoredWriteHarness {
|
||||
writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)),
|
||||
buf_capacity,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_vectored_backend(buf_capacity: usize) -> Self {
|
||||
VectoredWriteHarness {
|
||||
writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)),
|
||||
buf_capacity,
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize {
|
||||
let mut total_written = 0;
|
||||
while !io_vec.is_empty() {
|
||||
let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await);
|
||||
assert!(n != 0);
|
||||
assert!(self.writer.buffer().len() <= self.buf_capacity);
|
||||
total_written += n;
|
||||
io_vec = io_vec.advance(n);
|
||||
}
|
||||
total_written
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> &[u8] {
|
||||
assert_ok!(self.writer.flush().await);
|
||||
&self.writer.get_ref().data
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_odd_on_non_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let mut bufs = [
|
||||
IoSlice::new(&msg[0..4]),
|
||||
IoSlice::new(&[]),
|
||||
IoSlice::new(&msg[4..9]),
|
||||
IoSlice::new(&msg[9..]),
|
||||
];
|
||||
let mut h = VectoredWriteHarness::new(8);
|
||||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
|
||||
assert_eq!(bytes_written, msg.len());
|
||||
assert_eq!(h.flush().await, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_odd_on_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let mut bufs = [
|
||||
IoSlice::new(&msg[0..4]),
|
||||
IoSlice::new(&[]),
|
||||
IoSlice::new(&msg[4..9]),
|
||||
IoSlice::new(&msg[9..]),
|
||||
];
|
||||
let mut h = VectoredWriteHarness::with_vectored_backend(8);
|
||||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
|
||||
assert_eq!(bytes_written, msg.len());
|
||||
assert_eq!(h.flush().await, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_large_slice_on_non_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let mut bufs = [
|
||||
IoSlice::new(&[]),
|
||||
IoSlice::new(&msg[..9]),
|
||||
IoSlice::new(&msg[9..]),
|
||||
];
|
||||
let mut h = VectoredWriteHarness::new(8);
|
||||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
|
||||
assert_eq!(bytes_written, msg.len());
|
||||
assert_eq!(h.flush().await, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_vectored_large_slice_on_vectored() {
|
||||
let msg = b"foo bar baz";
|
||||
let mut bufs = [
|
||||
IoSlice::new(&[]),
|
||||
IoSlice::new(&msg[..9]),
|
||||
IoSlice::new(&msg[9..]),
|
||||
];
|
||||
let mut h = VectoredWriteHarness::with_vectored_backend(8);
|
||||
let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
|
||||
assert_eq!(bytes_written, msg.len());
|
||||
assert_eq!(h.flush().await, msg);
|
||||
}
|
||||
|
45
tokio/tests/support/io_vec.rs
Normal file
45
tokio/tests/support/io_vec.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use std::io::IoSlice;
|
||||
use std::ops::Deref;
|
||||
use std::slice;
|
||||
|
||||
pub struct IoBufs<'a, 'b>(&'b mut [IoSlice<'a>]);
|
||||
|
||||
impl<'a, 'b> IoBufs<'a, 'b> {
|
||||
pub fn new(slices: &'b mut [IoSlice<'a>]) -> Self {
|
||||
IoBufs(slices)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
pub fn advance(mut self, n: usize) -> IoBufs<'a, 'b> {
|
||||
let mut to_remove = 0;
|
||||
let mut remaining_len = n;
|
||||
for slice in self.0.iter() {
|
||||
if remaining_len < slice.len() {
|
||||
break;
|
||||
} else {
|
||||
remaining_len -= slice.len();
|
||||
to_remove += 1;
|
||||
}
|
||||
}
|
||||
self.0 = self.0.split_at_mut(to_remove).1;
|
||||
if let Some(slice) = self.0.first_mut() {
|
||||
let tail = &slice[remaining_len..];
|
||||
// Safety: recasts slice to the original lifetime
|
||||
let tail = unsafe { slice::from_raw_parts(tail.as_ptr(), tail.len()) };
|
||||
*slice = IoSlice::new(tail);
|
||||
} else if remaining_len != 0 {
|
||||
panic!("advance past the end of the slice vector");
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b> Deref for IoBufs<'a, 'b> {
|
||||
type Target = [IoSlice<'a>];
|
||||
fn deref(&self) -> &[IoSlice<'a>] {
|
||||
self.0
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user