diff --git a/tokio-util/src/lib.rs b/tokio-util/src/lib.rs index b074aad32..22ad92b8c 100644 --- a/tokio-util/src/lib.rs +++ b/tokio-util/src/lib.rs @@ -57,151 +57,4 @@ pub mod either; pub use bytes; -#[cfg(any(feature = "io", feature = "codec"))] -mod util { - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - - use bytes::{Buf, BufMut}; - use futures_core::ready; - use std::io::{self, IoSlice}; - use std::mem::MaybeUninit; - use std::pin::Pin; - use std::task::{Context, Poll}; - - /// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. - /// - /// [`BufMut`]: bytes::Buf - /// - /// # Example - /// - /// ``` - /// use bytes::{Bytes, BytesMut}; - /// use tokio_stream as stream; - /// use tokio::io::Result; - /// use tokio_util::io::{StreamReader, poll_read_buf}; - /// use futures::future::poll_fn; - /// use std::pin::Pin; - /// # #[tokio::main] - /// # async fn main() -> std::io::Result<()> { - /// - /// // Create a reader from an iterator. This particular reader will always be - /// // ready. - /// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); - /// - /// let mut buf = BytesMut::new(); - /// let mut reads = 0; - /// - /// loop { - /// reads += 1; - /// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; - /// - /// if n == 0 { - /// break; - /// } - /// } - /// - /// // one or more reads might be necessary. - /// assert!(reads >= 1); - /// assert_eq!(&buf[..], &[0, 1, 2, 3]); - /// # Ok(()) - /// # } - /// ``` - #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] - pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let n = { - let dst = buf.chunk_mut(); - - // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a - // transparent wrapper around `[MaybeUninit]`. - let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - ready!(io.poll_read(cx, &mut buf)?); - - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) - } - - /// Try to write data from an implementer of the [`Buf`] trait to an - /// [`AsyncWrite`], advancing the buffer's internal cursor. - /// - /// This function will use [vectored writes] when the [`AsyncWrite`] supports - /// vectored writes. - /// - /// # Examples - /// - /// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements - /// [`Buf`]: - /// - /// ```no_run - /// use tokio_util::io::poll_write_buf; - /// use tokio::io; - /// use tokio::fs::File; - /// - /// use bytes::Buf; - /// use std::io::Cursor; - /// use std::pin::Pin; - /// use futures::future::poll_fn; - /// - /// #[tokio::main] - /// async fn main() -> io::Result<()> { - /// let mut file = File::create("foo.txt").await?; - /// let mut buf = Cursor::new(b"data to write"); - /// - /// // Loop until the entire contents of the buffer are written to - /// // the file. - /// while buf.has_remaining() { - /// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; - /// } - /// - /// Ok(()) - /// } - /// ``` - /// - /// [`Buf`]: bytes::Buf - /// [`AsyncWrite`]: tokio::io::AsyncWrite - /// [`File`]: tokio::fs::File - /// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored - #[cfg_attr(not(feature = "io"), allow(unreachable_pub))] - pub fn poll_write_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut B, - ) -> Poll> { - const MAX_BUFS: usize = 64; - - if !buf.has_remaining() { - return Poll::Ready(Ok(0)); - } - - let n = if io.is_write_vectored() { - let mut slices = [IoSlice::new(&[]); MAX_BUFS]; - let cnt = buf.chunks_vectored(&mut slices); - ready!(io.poll_write_vectored(cx, &slices[..cnt]))? - } else { - ready!(io.poll_write(cx, buf.chunk()))? - }; - - buf.advance(n); - - Poll::Ready(Ok(n)) - } -} +mod util; diff --git a/tokio-util/src/sync/cancellation_token.rs b/tokio-util/src/sync/cancellation_token.rs index 2c4e0250d..2251736a3 100644 --- a/tokio-util/src/sync/cancellation_token.rs +++ b/tokio-util/src/sync/cancellation_token.rs @@ -4,6 +4,7 @@ pub(crate) mod guard; mod tree_node; use crate::loom::sync::Arc; +use crate::util::MaybeDangling; use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; @@ -77,11 +78,23 @@ pin_project! { /// [`CancellationToken`] by value instead of using a reference. #[must_use = "futures do nothing unless polled"] pub struct WaitForCancellationFutureOwned { - // Since `future` is the first field, it is dropped before the - // cancellation_token field. This ensures that the reference inside the - // `Notified` remains valid. + // This field internally has a reference to the cancellation token, but camouflages + // the relationship with `'static`. To avoid Undefined Behavior, we must ensure + // that the reference is only used while the cancellation token is still alive. To + // do that, we ensure that the future is the first field, so that it is dropped + // before the cancellation token. + // + // We use `MaybeDanglingFuture` here because without it, the compiler could assert + // the reference inside `future` to be valid even after the destructor of that + // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed + // as an argument to a function, the reference can be asserted to be valid for the + // rest of that function.) To avoid that, we use `MaybeDangling` which tells the + // compiler that the reference stored inside it might not be valid. + // + // See + // for more info. #[pin] - future: tokio::sync::futures::Notified<'static>, + future: MaybeDangling>, cancellation_token: CancellationToken, } } @@ -279,7 +292,7 @@ impl WaitForCancellationFutureOwned { // # Safety // // cancellation_token is dropped after future due to the field ordering. - future: unsafe { Self::new_future(&cancellation_token) }, + future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }), cancellation_token, } } @@ -320,8 +333,9 @@ impl Future for WaitForCancellationFutureOwned { // # Safety // // cancellation_token is dropped after future due to the field ordering. - this.future - .set(unsafe { Self::new_future(this.cancellation_token) }); + this.future.set(MaybeDangling::new(unsafe { + Self::new_future(this.cancellation_token) + })); } } } diff --git a/tokio-util/src/util/maybe_dangling.rs b/tokio-util/src/util/maybe_dangling.rs new file mode 100644 index 000000000..c29a0894c --- /dev/null +++ b/tokio-util/src/util/maybe_dangling.rs @@ -0,0 +1,67 @@ +use core::future::Future; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::{Context, Poll}; + +/// A wrapper type that tells the compiler that the contents might not be valid. +/// +/// This is necessary mainly when `T` contains a reference. In that case, the +/// compiler will sometimes assume that the reference is always valid; in some +/// cases it will assume this even after the destructor of `T` runs. For +/// example, when a reference is used as a function argument, then the compiler +/// will assume that the reference is valid until the function returns, even if +/// the reference is destroyed during the function. When the reference is used +/// as part of a self-referential struct, that assumption can be false. Wrapping +/// the reference in this type prevents the compiler from making that +/// assumption. +/// +/// # Invariants +/// +/// The `MaybeUninit` will always contain a valid value until the destructor runs. +// +// Reference +// See +// +// TODO: replace this with an official solution once RFC #3336 or similar is available. +// +#[repr(transparent)] +pub(crate) struct MaybeDangling(MaybeUninit); + +impl Drop for MaybeDangling { + fn drop(&mut self) { + // Safety: `0` is always initialized. + unsafe { core::ptr::drop_in_place(self.0.as_mut_ptr()) }; + } +} + +impl MaybeDangling { + pub(crate) fn new(inner: T) -> Self { + Self(MaybeUninit::new(inner)) + } +} + +impl Future for MaybeDangling { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Safety: `0` is always initialized. + let fut = unsafe { self.map_unchecked_mut(|this| this.0.assume_init_mut()) }; + fut.poll(cx) + } +} + +#[test] +fn maybedangling_runs_drop() { + struct SetOnDrop<'a>(&'a mut bool); + + impl Drop for SetOnDrop<'_> { + fn drop(&mut self) { + *self.0 = true; + } + } + + let mut success = false; + + drop(MaybeDangling::new(SetOnDrop(&mut success))); + assert!(success); +} diff --git a/tokio-util/src/util/mod.rs b/tokio-util/src/util/mod.rs new file mode 100644 index 000000000..a17f25a6b --- /dev/null +++ b/tokio-util/src/util/mod.rs @@ -0,0 +1,8 @@ +mod maybe_dangling; +#[cfg(any(feature = "io", feature = "codec"))] +mod poll_buf; + +pub(crate) use maybe_dangling::MaybeDangling; +#[cfg(any(feature = "io", feature = "codec"))] +#[cfg_attr(not(feature = "io"), allow(unreachable_pub))] +pub use poll_buf::{poll_read_buf, poll_write_buf}; diff --git a/tokio-util/src/util/poll_buf.rs b/tokio-util/src/util/poll_buf.rs new file mode 100644 index 000000000..82af1bbfc --- /dev/null +++ b/tokio-util/src/util/poll_buf.rs @@ -0,0 +1,145 @@ +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use bytes::{Buf, BufMut}; +use futures_core::ready; +use std::io::{self, IoSlice}; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Try to read data from an `AsyncRead` into an implementer of the [`BufMut`] trait. +/// +/// [`BufMut`]: bytes::Buf +/// +/// # Example +/// +/// ``` +/// use bytes::{Bytes, BytesMut}; +/// use tokio_stream as stream; +/// use tokio::io::Result; +/// use tokio_util::io::{StreamReader, poll_read_buf}; +/// use futures::future::poll_fn; +/// use std::pin::Pin; +/// # #[tokio::main] +/// # async fn main() -> std::io::Result<()> { +/// +/// // Create a reader from an iterator. This particular reader will always be +/// // ready. +/// let mut read = StreamReader::new(stream::iter(vec![Result::Ok(Bytes::from_static(&[0, 1, 2, 3]))])); +/// +/// let mut buf = BytesMut::new(); +/// let mut reads = 0; +/// +/// loop { +/// reads += 1; +/// let n = poll_fn(|cx| poll_read_buf(Pin::new(&mut read), cx, &mut buf)).await?; +/// +/// if n == 0 { +/// break; +/// } +/// } +/// +/// // one or more reads might be necessary. +/// assert!(reads >= 1); +/// assert_eq!(&buf[..], &[0, 1, 2, 3]); +/// # Ok(()) +/// # } +/// ``` +#[cfg_attr(not(feature = "io"), allow(unreachable_pub))] +pub fn poll_read_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll> { + if !buf.has_remaining_mut() { + return Poll::Ready(Ok(0)); + } + + let n = { + let dst = buf.chunk_mut(); + + // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a + // transparent wrapper around `[MaybeUninit]`. + let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; + let mut buf = ReadBuf::uninit(dst); + let ptr = buf.filled().as_ptr(); + ready!(io.poll_read(cx, &mut buf)?); + + // Ensure the pointer does not change from under us + assert_eq!(ptr, buf.filled().as_ptr()); + buf.filled().len() + }; + + // Safety: This is guaranteed to be the number of initialized (and read) + // bytes due to the invariants provided by `ReadBuf::filled`. + unsafe { + buf.advance_mut(n); + } + + Poll::Ready(Ok(n)) +} + +/// Try to write data from an implementer of the [`Buf`] trait to an +/// [`AsyncWrite`], advancing the buffer's internal cursor. +/// +/// This function will use [vectored writes] when the [`AsyncWrite`] supports +/// vectored writes. +/// +/// # Examples +/// +/// [`File`] implements [`AsyncWrite`] and [`Cursor<&[u8]>`] implements +/// [`Buf`]: +/// +/// ```no_run +/// use tokio_util::io::poll_write_buf; +/// use tokio::io; +/// use tokio::fs::File; +/// +/// use bytes::Buf; +/// use std::io::Cursor; +/// use std::pin::Pin; +/// use futures::future::poll_fn; +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let mut file = File::create("foo.txt").await?; +/// let mut buf = Cursor::new(b"data to write"); +/// +/// // Loop until the entire contents of the buffer are written to +/// // the file. +/// while buf.has_remaining() { +/// poll_fn(|cx| poll_write_buf(Pin::new(&mut file), cx, &mut buf)).await?; +/// } +/// +/// Ok(()) +/// } +/// ``` +/// +/// [`Buf`]: bytes::Buf +/// [`AsyncWrite`]: tokio::io::AsyncWrite +/// [`File`]: tokio::fs::File +/// [vectored writes]: tokio::io::AsyncWrite::poll_write_vectored +#[cfg_attr(not(feature = "io"), allow(unreachable_pub))] +pub fn poll_write_buf( + io: Pin<&mut T>, + cx: &mut Context<'_>, + buf: &mut B, +) -> Poll> { + const MAX_BUFS: usize = 64; + + if !buf.has_remaining() { + return Poll::Ready(Ok(0)); + } + + let n = if io.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_BUFS]; + let cnt = buf.chunks_vectored(&mut slices); + ready!(io.poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(io.poll_write(cx, buf.chunk()))? + }; + + buf.advance(n); + + Poll::Ready(Ok(n)) +}