mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-10-01 12:20:39 +00:00
sync: handle possibly dangling reference safely (#5812)
This commit is contained in:
parent
ce23db6bc7
commit
1bfe778acb
@ -57,151 +57,4 @@ pub mod either;
|
|||||||
|
|
||||||
pub use bytes;
|
pub use bytes;
|
||||||
|
|
||||||
#[cfg(any(feature = "io", feature = "codec"))]
|
mod util;
|
||||||
mod util {
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
||||||
|
|
||||||
use bytes::{Buf, BufMut};
|
|
||||||
use futures_core::ready;
|
|
||||||
use std::io::{self, IoSlice};
|
|
||||||
use std::mem::MaybeUninit;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
|
|
||||||
/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
|
|
||||||
///
|
|
||||||
/// [`BufMut`]: bytes::Buf
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use bytes::{Bytes, BytesMut};
|
|
||||||
/// use tokio_stream as stream;
|
|
||||||
/// use tokio::io::Result;
|
|
||||||
/// use tokio_util::io::{StreamReader, poll_read_buf};
|
|
||||||
/// use futures::future::poll_fn;
|
|
||||||
/// use std::pin::Pin;
|
|
||||||
/// # #[tokio::main]
|
|
||||||
/// # async fn main() -> std::io::Result<()> {
|
|
||||||
///
|
|
||||||
/// // Create a reader from an iterator. This particular reader will always be
|
|
||||||
/// // ready.
|
|
||||||
/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
|
|
||||||
///
|
|
||||||
/// let mut buf = BytesMut::new();
|
|
||||||
/// let mut reads = 0;
|
|
||||||
///
|
|
||||||
/// loop {
|
|
||||||
/// reads += 1;
|
|
||||||
/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
|
|
||||||
///
|
|
||||||
/// if n == 0 {
|
|
||||||
/// break;
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// // one or more reads might be necessary.
|
|
||||||
/// assert!(reads >= 1);
|
|
||||||
/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
|
|
||||||
/// # Ok(())
|
|
||||||
/// # }
|
|
||||||
/// ```
|
|
||||||
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
|
|
||||||
pub fn poll_read_buf<T: AsyncRead, B: BufMut>(
|
|
||||||
io: Pin<&mut T>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut B,
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
if !buf.has_remaining_mut() {
|
|
||||||
return Poll::Ready(Ok(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
let n = {
|
|
||||||
let dst = buf.chunk_mut();
|
|
||||||
|
|
||||||
// Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
|
|
||||||
// transparent wrapper around `[MaybeUninit<u8>]`.
|
|
||||||
let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
|
|
||||||
let mut buf = ReadBuf::uninit(dst);
|
|
||||||
let ptr = buf.filled().as_ptr();
|
|
||||||
ready!(io.poll_read(cx, &mut buf)?);
|
|
||||||
|
|
||||||
// Ensure the pointer does not change from under us
|
|
||||||
assert_eq!(ptr, buf.filled().as_ptr());
|
|
||||||
buf.filled().len()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Safety: This is guaranteed to be the number of initialized (and read)
|
|
||||||
// bytes due to the invariants provided by `ReadBuf::filled`.
|
|
||||||
unsafe {
|
|
||||||
buf.advance_mut(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
Poll::Ready(Ok(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Try to write data from an implementer of the [`Buf`] trait to an
|
|
||||||
/// [`AsyncWrite`], advancing the buffer's internal cursor.
|
|
||||||
///
|
|
||||||
/// This function will use [vectored writes] when the [`AsyncWrite`] supports
|
|
||||||
/// vectored writes.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
|
|
||||||
/// [`Buf`]:
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use tokio_util::io::poll_write_buf;
|
|
||||||
/// use tokio::io;
|
|
||||||
/// use tokio::fs::File;
|
|
||||||
///
|
|
||||||
/// use bytes::Buf;
|
|
||||||
/// use std::io::Cursor;
|
|
||||||
/// use std::pin::Pin;
|
|
||||||
/// use futures::future::poll_fn;
|
|
||||||
///
|
|
||||||
/// #[tokio::main]
|
|
||||||
/// async fn main() -> io::Result<()> {
|
|
||||||
/// let mut file = File::create("foo.txt").await?;
|
|
||||||
/// let mut buf = Cursor::new(b"data to write");
|
|
||||||
///
|
|
||||||
/// // Loop until the entire contents of the buffer are written to
|
|
||||||
/// // the file.
|
|
||||||
/// while buf.has_remaining() {
|
|
||||||
/// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// Ok(())
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// [`Buf`]: bytes::Buf
|
|
||||||
/// [`AsyncWrite`]: tokio::io::AsyncWrite
|
|
||||||
/// [`File`]: tokio::fs::File
|
|
||||||
/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
|
|
||||||
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
|
|
||||||
pub fn poll_write_buf<T: AsyncWrite, B: Buf>(
|
|
||||||
io: Pin<&mut T>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut B,
|
|
||||||
) -> Poll<io::Result<usize>> {
|
|
||||||
const MAX_BUFS: usize = 64;
|
|
||||||
|
|
||||||
if !buf.has_remaining() {
|
|
||||||
return Poll::Ready(Ok(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
let n = if io.is_write_vectored() {
|
|
||||||
let mut slices = [IoSlice::new(&[]); MAX_BUFS];
|
|
||||||
let cnt = buf.chunks_vectored(&mut slices);
|
|
||||||
ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
|
|
||||||
} else {
|
|
||||||
ready!(io.poll_write(cx, buf.chunk()))?
|
|
||||||
};
|
|
||||||
|
|
||||||
buf.advance(n);
|
|
||||||
|
|
||||||
Poll::Ready(Ok(n))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -4,6 +4,7 @@ pub(crate) mod guard;
|
|||||||
mod tree_node;
|
mod tree_node;
|
||||||
|
|
||||||
use crate::loom::sync::Arc;
|
use crate::loom::sync::Arc;
|
||||||
|
use crate::util::MaybeDangling;
|
||||||
use core::future::Future;
|
use core::future::Future;
|
||||||
use core::pin::Pin;
|
use core::pin::Pin;
|
||||||
use core::task::{Context, Poll};
|
use core::task::{Context, Poll};
|
||||||
@ -77,11 +78,23 @@ pin_project! {
|
|||||||
/// [`CancellationToken`] by value instead of using a reference.
|
/// [`CancellationToken`] by value instead of using a reference.
|
||||||
#[must_use = "futures do nothing unless polled"]
|
#[must_use = "futures do nothing unless polled"]
|
||||||
pub struct WaitForCancellationFutureOwned {
|
pub struct WaitForCancellationFutureOwned {
|
||||||
// Since `future` is the first field, it is dropped before the
|
// This field internally has a reference to the cancellation token, but camouflages
|
||||||
// cancellation_token field. This ensures that the reference inside the
|
// the relationship with `'static`. To avoid Undefined Behavior, we must ensure
|
||||||
// `Notified` remains valid.
|
// that the reference is only used while the cancellation token is still alive. To
|
||||||
|
// do that, we ensure that the future is the first field, so that it is dropped
|
||||||
|
// before the cancellation token.
|
||||||
|
//
|
||||||
|
// We use `MaybeDanglingFuture` here because without it, the compiler could assert
|
||||||
|
// the reference inside `future` to be valid even after the destructor of that
|
||||||
|
// field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed
|
||||||
|
// as an argument to a function, the reference can be asserted to be valid for the
|
||||||
|
// rest of that function.) To avoid that, we use `MaybeDangling` which tells the
|
||||||
|
// compiler that the reference stored inside it might not be valid.
|
||||||
|
//
|
||||||
|
// See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
|
||||||
|
// for more info.
|
||||||
#[pin]
|
#[pin]
|
||||||
future: tokio::sync::futures::Notified<'static>,
|
future: MaybeDangling<tokio::sync::futures::Notified<'static>>,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -279,7 +292,7 @@ impl WaitForCancellationFutureOwned {
|
|||||||
// # Safety
|
// # Safety
|
||||||
//
|
//
|
||||||
// cancellation_token is dropped after future due to the field ordering.
|
// cancellation_token is dropped after future due to the field ordering.
|
||||||
future: unsafe { Self::new_future(&cancellation_token) },
|
future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }),
|
||||||
cancellation_token,
|
cancellation_token,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -320,8 +333,9 @@ impl Future for WaitForCancellationFutureOwned {
|
|||||||
// # Safety
|
// # Safety
|
||||||
//
|
//
|
||||||
// cancellation_token is dropped after future due to the field ordering.
|
// cancellation_token is dropped after future due to the field ordering.
|
||||||
this.future
|
this.future.set(MaybeDangling::new(unsafe {
|
||||||
.set(unsafe { Self::new_future(this.cancellation_token) });
|
Self::new_future(this.cancellation_token)
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
67
tokio-util/src/util/maybe_dangling.rs
Normal file
67
tokio-util/src/util/maybe_dangling.rs
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
use core::future::Future;
|
||||||
|
use core::mem::MaybeUninit;
|
||||||
|
use core::pin::Pin;
|
||||||
|
use core::task::{Context, Poll};
|
||||||
|
|
||||||
|
/// A wrapper type that tells the compiler that the contents might not be valid.
|
||||||
|
///
|
||||||
|
/// This is necessary mainly when `T` contains a reference. In that case, the
|
||||||
|
/// compiler will sometimes assume that the reference is always valid; in some
|
||||||
|
/// cases it will assume this even after the destructor of `T` runs. For
|
||||||
|
/// example, when a reference is used as a function argument, then the compiler
|
||||||
|
/// will assume that the reference is valid until the function returns, even if
|
||||||
|
/// the reference is destroyed during the function. When the reference is used
|
||||||
|
/// as part of a self-referential struct, that assumption can be false. Wrapping
|
||||||
|
/// the reference in this type prevents the compiler from making that
|
||||||
|
/// assumption.
|
||||||
|
///
|
||||||
|
/// # Invariants
|
||||||
|
///
|
||||||
|
/// The `MaybeUninit` will always contain a valid value until the destructor runs.
|
||||||
|
//
|
||||||
|
// Reference
|
||||||
|
// See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
|
||||||
|
//
|
||||||
|
// TODO: replace this with an official solution once RFC #3336 or similar is available.
|
||||||
|
// <https://github.com/rust-lang/rfcs/pull/3336>
|
||||||
|
#[repr(transparent)]
|
||||||
|
pub(crate) struct MaybeDangling<T>(MaybeUninit<T>);
|
||||||
|
|
||||||
|
impl<T> Drop for MaybeDangling<T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Safety: `0` is always initialized.
|
||||||
|
unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> MaybeDangling<T> {
|
||||||
|
pub(crate) fn new(inner: T) -> Self {
|
||||||
|
Self(MaybeUninit::new(inner))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future> Future for MaybeDangling<F> {
|
||||||
|
type Output = F::Output;
|
||||||
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
// Safety: `0` is always initialized.
|
||||||
|
let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) };
|
||||||
|
fut.poll(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn maybedangling_runs_drop() {
|
||||||
|
struct SetOnDrop<'a>(&'a mut bool);
|
||||||
|
|
||||||
|
impl Drop for SetOnDrop<'_> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
*self.0 = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut success = false;
|
||||||
|
|
||||||
|
drop(MaybeDangling::new(SetOnDrop(&mut success)));
|
||||||
|
assert!(success);
|
||||||
|
}
|
8
tokio-util/src/util/mod.rs
Normal file
8
tokio-util/src/util/mod.rs
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
mod maybe_dangling;
|
||||||
|
#[cfg(any(feature = "io", feature = "codec"))]
|
||||||
|
mod poll_buf;
|
||||||
|
|
||||||
|
pub(crate) use maybe_dangling::MaybeDangling;
|
||||||
|
#[cfg(any(feature = "io", feature = "codec"))]
|
||||||
|
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
|
||||||
|
pub use poll_buf::{poll_read_buf, poll_write_buf};
|
145
tokio-util/src/util/poll_buf.rs
Normal file
145
tokio-util/src/util/poll_buf.rs
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
|
||||||
|
use bytes::{Buf, BufMut};
|
||||||
|
use futures_core::ready;
|
||||||
|
use std::io::{self, IoSlice};
|
||||||
|
use std::mem::MaybeUninit;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait.
|
||||||
|
///
|
||||||
|
/// [`BufMut`]: bytes::Buf
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use bytes::{Bytes, BytesMut};
|
||||||
|
/// use tokio_stream as stream;
|
||||||
|
/// use tokio::io::Result;
|
||||||
|
/// use tokio_util::io::{StreamReader, poll_read_buf};
|
||||||
|
/// use futures::future::poll_fn;
|
||||||
|
/// use std::pin::Pin;
|
||||||
|
/// # #[tokio::main]
|
||||||
|
/// # async fn main() -> std::io::Result<()> {
|
||||||
|
///
|
||||||
|
/// // Create a reader from an iterator. This particular reader will always be
|
||||||
|
/// // ready.
|
||||||
|
/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))]));
|
||||||
|
///
|
||||||
|
/// let mut buf = BytesMut::new();
|
||||||
|
/// let mut reads = 0;
|
||||||
|
///
|
||||||
|
/// loop {
|
||||||
|
/// reads += 1;
|
||||||
|
/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?;
|
||||||
|
///
|
||||||
|
/// if n == 0 {
|
||||||
|
/// break;
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // one or more reads might be necessary.
|
||||||
|
/// assert!(reads >= 1);
|
||||||
|
/// assert_eq!(&buf[..], &[0, 1, 2, 3]);
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
|
/// ```
|
||||||
|
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
|
||||||
|
pub fn poll_read_buf<T: AsyncRead, B: BufMut>(
|
||||||
|
io: Pin<&mut T>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
if !buf.has_remaining_mut() {
|
||||||
|
return Poll::Ready(Ok(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
let n = {
|
||||||
|
let dst = buf.chunk_mut();
|
||||||
|
|
||||||
|
// Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
|
||||||
|
// transparent wrapper around `[MaybeUninit<u8>]`.
|
||||||
|
let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
|
||||||
|
let mut buf = ReadBuf::uninit(dst);
|
||||||
|
let ptr = buf.filled().as_ptr();
|
||||||
|
ready!(io.poll_read(cx, &mut buf)?);
|
||||||
|
|
||||||
|
// Ensure the pointer does not change from under us
|
||||||
|
assert_eq!(ptr, buf.filled().as_ptr());
|
||||||
|
buf.filled().len()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Safety: This is guaranteed to be the number of initialized (and read)
|
||||||
|
// bytes due to the invariants provided by `ReadBuf::filled`.
|
||||||
|
unsafe {
|
||||||
|
buf.advance_mut(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to write data from an implementer of the [`Buf`] trait to an
|
||||||
|
/// [`AsyncWrite`], advancing the buffer's internal cursor.
|
||||||
|
///
|
||||||
|
/// This function will use [vectored writes] when the [`AsyncWrite`] supports
|
||||||
|
/// vectored writes.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements
|
||||||
|
/// [`Buf`]:
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use tokio_util::io::poll_write_buf;
|
||||||
|
/// use tokio::io;
|
||||||
|
/// use tokio::fs::File;
|
||||||
|
///
|
||||||
|
/// use bytes::Buf;
|
||||||
|
/// use std::io::Cursor;
|
||||||
|
/// use std::pin::Pin;
|
||||||
|
/// use futures::future::poll_fn;
|
||||||
|
///
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() -> io::Result<()> {
|
||||||
|
/// let mut file = File::create("foo.txt").await?;
|
||||||
|
/// let mut buf = Cursor::new(b"data to write");
|
||||||
|
///
|
||||||
|
/// // Loop until the entire contents of the buffer are written to
|
||||||
|
/// // the file.
|
||||||
|
/// while buf.has_remaining() {
|
||||||
|
/// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?;
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// Ok(())
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// [`Buf`]: bytes::Buf
|
||||||
|
/// [`AsyncWrite`]: tokio::io::AsyncWrite
|
||||||
|
/// [`File`]: tokio::fs::File
|
||||||
|
/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored
|
||||||
|
#[cfg_attr(not(feature = "io"), allow(unreachable_pub))]
|
||||||
|
pub fn poll_write_buf<T: AsyncWrite, B: Buf>(
|
||||||
|
io: Pin<&mut T>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
const MAX_BUFS: usize = 64;
|
||||||
|
|
||||||
|
if !buf.has_remaining() {
|
||||||
|
return Poll::Ready(Ok(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
let n = if io.is_write_vectored() {
|
||||||
|
let mut slices = [IoSlice::new(&[]); MAX_BUFS];
|
||||||
|
let cnt = buf.chunks_vectored(&mut slices);
|
||||||
|
ready!(io.poll_write_vectored(cx, &slices[..cnt]))?
|
||||||
|
} else {
|
||||||
|
ready!(io.poll_write(cx, buf.chunk()))?
|
||||||
|
};
|
||||||
|
|
||||||
|
buf.advance(n);
|
||||||
|
|
||||||
|
Poll::Ready(Ok(n))
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user