diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index e1d1a83a3..d89a9ddce 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -461,6 +461,9 @@ cfg_sync! { mod task; pub(crate) use task::AtomicWaker; + mod once_cell; + pub use self::once_cell::{OnceCell, SetError}; + pub mod watch; } diff --git a/tokio/src/sync/once_cell.rs b/tokio/src/sync/once_cell.rs new file mode 100644 index 000000000..fa9b1f19f --- /dev/null +++ b/tokio/src/sync/once_cell.rs @@ -0,0 +1,400 @@ +use super::Semaphore; +use crate::loom::cell::UnsafeCell; +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::mem::MaybeUninit; +use std::ops::Drop; +use std::ptr; +use std::sync::atomic::{AtomicBool, Ordering}; + +/// A thread-safe cell which can be written to only once. +/// +/// Provides the functionality to either set the value, in case `OnceCell` +/// is uninitialized, or get the already initialized value by using an async +/// function via [`OnceCell::get_or_init`]. +/// +/// [`OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init +/// +/// # Examples +/// ``` +/// use tokio::sync::OnceCell; +/// +/// async fn some_computation() -> u32 { +/// 1 + 1 +/// } +/// +/// static ONCE: OnceCell = OnceCell::const_new(); +/// +/// #[tokio::main] +/// async fn main() { +/// let result1 = ONCE.get_or_init(some_computation).await; +/// assert_eq!(*result1, 2); +/// } +/// ``` +pub struct OnceCell { + value_set: AtomicBool, + value: UnsafeCell>, + semaphore: Semaphore, +} + +impl Default for OnceCell { + fn default() -> OnceCell { + OnceCell::new() + } +} + +impl fmt::Debug for OnceCell { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OnceCell") + .field("value", &self.get()) + .finish() + } +} + +impl Clone for OnceCell { + fn clone(&self) -> OnceCell { + OnceCell::new_with(self.get().cloned()) + } +} + +impl PartialEq for OnceCell { + fn eq(&self, other: &OnceCell) -> bool { + self.get() == other.get() + } +} + +impl Eq for OnceCell {} + +impl Drop for OnceCell { + fn drop(&mut self) { + if self.initialized() { + unsafe { + self.value + .with_mut(|ptr| ptr::drop_in_place((&mut *ptr).as_mut_ptr())); + }; + } + } +} + +impl OnceCell { + /// Creates a new uninitialized OnceCell instance. + pub fn new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::new(1), + } + } + + /// Creates a new initialized OnceCell instance if `value` is `Some`, otherwise + /// has the same functionality as [`OnceCell::new`]. + /// + /// [`OnceCell::new`]: crate::sync::OnceCell::new + pub fn new_with(value: Option) -> Self { + if let Some(v) = value { + let semaphore = Semaphore::new(0); + semaphore.close(); + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(v)), + semaphore, + } + } else { + OnceCell::new() + } + } + + /// Creates a new uninitialized OnceCell instance. + #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] + pub const fn const_new() -> Self { + OnceCell { + value_set: AtomicBool::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + semaphore: Semaphore::const_new(1), + } + } + + /// Whether the value of the OnceCell is set or not. + pub fn initialized(&self) -> bool { + self.value_set.load(Ordering::Acquire) + } + + // SAFETY: safe to call only once self.initialized() is true + unsafe fn get_unchecked(&self) -> &T { + &*self.value.with(|ptr| (*ptr).as_ptr()) + } + + // SAFETY: safe to call only once self.initialized() is true. Safe because + // because of the mutable reference. + unsafe fn get_unchecked_mut(&mut self) -> &mut T { + &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) + } + + // SAFETY: safe to call only once a permit on the semaphore has been + // acquired + unsafe fn set_value(&self, value: T) { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + self.value_set.store(true, Ordering::Release); + self.semaphore.close(); + } + + /// Tries to get a reference to the value of the OnceCell. + /// + /// Returns None if the value of the OnceCell hasn't previously been initialized. + pub fn get(&self) -> Option<&T> { + if self.initialized() { + Some(unsafe { self.get_unchecked() }) + } else { + None + } + } + + /// Tries to return a mutable reference to the value of the cell. + /// + /// Returns None if the cell hasn't previously been initialized. + pub fn get_mut(&mut self) -> Option<&mut T> { + if self.initialized() { + Some(unsafe { self.get_unchecked_mut() }) + } else { + None + } + } + + /// Sets the value of the OnceCell to the argument value. + /// + /// If the value of the OnceCell was already set prior to this call + /// then [`SetError::AlreadyInitializedError`] is returned. If another thread + /// is initializing the cell while this method is called, + /// [`SetError::InitializingError`] is returned. In order to wait + /// for an ongoing initialization to finish, call + /// [`OnceCell::get_or_init`] instead. + /// + /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError + /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError + /// ['OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init + pub fn set(&self, value: T) -> Result<(), SetError> { + if !self.initialized() { + // Another thread might be initializing the cell, in which case `try_acquire` will + // return an error + match self.semaphore.try_acquire() { + Ok(_permit) => { + if !self.initialized() { + // SAFETY: There is only one permit on the semaphore, hence only one + // mutable reference is created + unsafe { self.set_value(value) }; + + return Ok(()); + } else { + unreachable!( + "acquired the permit after OnceCell value was already initialized." + ); + } + } + _ => { + // Couldn't acquire the permit, look if initializing process is already completed + if !self.initialized() { + return Err(SetError::InitializingError(value)); + } + } + } + } + + Err(SetError::AlreadyInitializedError(value)) + } + + /// Tries to initialize the value of the OnceCell using the async function `f`. + /// If the value of the OnceCell was already initialized prior to this call, + /// a reference to that initialized value is returned. If some other thread + /// initiated the initialization prior to this call and the initialization + /// hasn't completed, this call waits until the initialization is finished. + /// + /// This will deadlock if `f` tries to initialize the cell itself. + pub async fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> Fut, + Fut: Future, + { + if self.initialized() { + // SAFETY: once the value is initialized, no mutable references are given out, so + // we can give out arbitrarily many immutable references + unsafe { self.get_unchecked() } + } else { + // After acquire().await we have either acquired a permit while self.value + // is still uninitialized, or the current thread is awoken after another thread + // has intialized the value and closed the semaphore, in which case self.initialized + // is true and we don't set the value here + match self.semaphore.acquire().await { + Ok(_permit) => { + if !self.initialized() { + // If `f()` panics or `select!` is called, this `get_or_init` call + // is aborted and the semaphore permit is dropped. + let value = f().await; + + // SAFETY: There is only one permit on the semaphore, hence only one + // mutable reference is created + unsafe { self.set_value(value) }; + + // SAFETY: once the value is initialized, no mutable references are given out, so + // we can give out arbitrarily many immutable references + unsafe { self.get_unchecked() } + } else { + unreachable!("acquired semaphore after value was already initialized."); + } + } + Err(_) => { + if self.initialized() { + // SAFETY: once the value is initialized, no mutable references are given out, so + // we can give out arbitrarily many immutable references + unsafe { self.get_unchecked() } + } else { + unreachable!( + "Semaphore closed, but the OnceCell has not been initialized." + ); + } + } + } + } + } + + /// Tries to initialize the value of the OnceCell using the async function `f`. + /// If the value of the OnceCell was already initialized prior to this call, + /// a reference to that initialized value is returned. If some other thread + /// initiated the initialization prior to this call and the initialization + /// hasn't completed, this call waits until the initialization is finished. + /// If the function argument `f` returns an error, `get_or_try_init` + /// returns that error, otherwise the result of `f` will be stored in the cell. + /// + /// This will deadlock if `f` tries to initialize the cell itself. + pub async fn get_or_try_init(&self, f: F) -> Result<&T, E> + where + F: FnOnce() -> Fut, + Fut: Future>, + { + if self.initialized() { + // SAFETY: once the value is initialized, no mutable references are given out, so + // we can give out arbitrarily many immutable references + unsafe { Ok(self.get_unchecked()) } + } else { + // After acquire().await we have either acquired a permit while self.value + // is still uninitialized, or the current thread is awoken after another thread + // has intialized the value and closed the semaphore, in which case self.initialized + // is true and we don't set the value here + match self.semaphore.acquire().await { + Ok(_permit) => { + if !self.initialized() { + // If `f()` panics or `select!` is called, this `get_or_try_init` call + // is aborted and the semaphore permit is dropped. + let value = f().await; + + match value { + Ok(value) => { + // SAFETY: There is only one permit on the semaphore, hence only one + // mutable reference is created + unsafe { self.set_value(value) }; + + // SAFETY: once the value is initialized, no mutable references are given out, so + // we can give out arbitrarily many immutable references + unsafe { Ok(self.get_unchecked()) } + } + Err(e) => Err(e), + } + } else { + unreachable!("acquired semaphore after value was already initialized."); + } + } + Err(_) => { + if self.initialized() { + // SAFETY: once the value is initialized, no mutable references are given out, so + // we can give out arbitrarily many immutable references + unsafe { Ok(self.get_unchecked()) } + } else { + unreachable!( + "Semaphore closed, but the OnceCell has not been initialized." + ); + } + } + } + } + } + + /// Moves the value out of the cell, destroying the cell in the process. + /// + /// Returns `None` if the cell is uninitialized. + pub fn into_inner(mut self) -> Option { + if self.initialized() { + // Set to uninitialized for the destructor of `OnceCell` to work properly + *self.value_set.get_mut() = false; + Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) + } else { + None + } + } + + /// Takes ownership of the current value, leaving the cell uninitialized. + /// + /// Returns `None` if the cell is uninitialized. + pub fn take(&mut self) -> Option { + std::mem::take(self).into_inner() + } +} + +// Since `get` gives us access to immutable references of the +// OnceCell, OnceCell can only be Sync if T is Sync, otherwise +// OnceCell would allow sharing references of !Sync values across +// threads. We need T to be Send in order for OnceCell to by Sync +// because we can use `set` on `&OnceCell` to send +// values (of type T) across threads. +unsafe impl Sync for OnceCell {} + +// Access to OnceCell's value is guarded by the semaphore permit +// and atomic operations on `value_set`, so as long as T itself is Send +// it's safe to send it to another thread +unsafe impl Send for OnceCell {} + +/// Errors that can be returned from [`OnceCell::set`] +/// +/// [`OnceCell::set`]: crate::sync::OnceCell::set +#[derive(Debug, PartialEq)] +pub enum SetError { + /// Error resulting from [`OnceCell::set`] calls if the cell was previously initialized. + /// + /// [`OnceCell::set`]: crate::sync::OnceCell::set + AlreadyInitializedError(T), + + /// Error resulting from [`OnceCell::set`] calls when the cell is currently being + /// inintialized during the calls to that method. + /// + /// [`OnceCell::set`]: crate::sync::OnceCell::set + InitializingError(T), +} + +impl fmt::Display for SetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SetError::AlreadyInitializedError(_) => write!(f, "AlreadyInitializedError"), + SetError::InitializingError(_) => write!(f, "InitializingError"), + } + } +} + +impl Error for SetError {} + +impl SetError { + /// Whether `SetError` is `SetError::AlreadyInitializedError`. + pub fn is_already_init_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => true, + SetError::InitializingError(_) => false, + } + } + + /// Whether `SetError` is `SetError::InitializingError` + pub fn is_initializing_err(&self) -> bool { + match self { + SetError::AlreadyInitializedError(_) => false, + SetError::InitializingError(_) => true, + } + } +} diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index 671fa4a70..211c572cf 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -1,9 +1,12 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +#![allow(clippy::type_complexity)] use std::cell::Cell; +use std::future::Future; use std::io::{Cursor, SeekFrom}; use std::net::SocketAddr; +use std::pin::Pin; use std::rc::Rc; use tokio::net::TcpStream; use tokio::time::{Duration, Instant}; @@ -265,6 +268,28 @@ async_assert_fn!(tokio::sync::watch::Sender::closed(_): Send & Sync); async_assert_fn!(tokio::sync::watch::Sender>::closed(_): !Send & !Sync); async_assert_fn!(tokio::sync::watch::Sender>::closed(_): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell::get_or_init( + _, fn() -> Pin + Send + Sync>>): Send & Sync); +async_assert_fn!(tokio::sync::OnceCell::get_or_init( + _, fn() -> Pin + Send>>): Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell::get_or_init( + _, fn() -> Pin>>): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell>::get_or_init( + _, fn() -> Pin> + Send + Sync>>): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell>::get_or_init( + _, fn() -> Pin> + Send>>): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell>::get_or_init( + _, fn() -> Pin>>>): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell>::get_or_init( + _, fn() -> Pin> + Send + Sync>>): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell>::get_or_init( + _, fn() -> Pin> + Send>>): !Send & !Sync); +async_assert_fn!(tokio::sync::OnceCell>::get_or_init( + _, fn() -> Pin>>>): !Send & !Sync); +assert_value!(tokio::sync::OnceCell: Send & Sync); +assert_value!(tokio::sync::OnceCell>: Send & !Sync); +assert_value!(tokio::sync::OnceCell>: !Send & !Sync); + async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFutureSync<()>): Send & Sync); async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFutureSend<()>): Send & !Sync); async_assert_fn!(tokio::task::LocalKey::scope(_, u32, BoxFuture<()>): !Send & !Sync); diff --git a/tokio/tests/sync_once_cell.rs b/tokio/tests/sync_once_cell.rs new file mode 100644 index 000000000..60f50d214 --- /dev/null +++ b/tokio/tests/sync_once_cell.rs @@ -0,0 +1,268 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::mem; +use std::ops::Drop; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::time::Duration; +use tokio::runtime; +use tokio::sync::{OnceCell, SetError}; +use tokio::time; + +async fn func1() -> u32 { + 5 +} + +async fn func2() -> u32 { + time::sleep(Duration::from_millis(1)).await; + 10 +} + +async fn func_err() -> Result { + Err(()) +} + +async fn func_ok() -> Result { + Ok(10) +} + +async fn func_panic() -> u32 { + time::sleep(Duration::from_millis(1)).await; + panic!(); +} + +async fn sleep_and_set() -> u32 { + // Simulate sleep by pausing time and waiting for another thread to + // resume clock when calling `set`, then finding the cell being initialized + // by this call + time::sleep(Duration::from_millis(2)).await; + 5 +} + +async fn advance_time_and_set(cell: &'static OnceCell, v: u32) -> Result<(), SetError> { + time::advance(Duration::from_millis(1)).await; + cell.set(v) +} + +#[test] +fn get_or_init() { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .start_paused(true) + .build() + .unwrap(); + + static ONCE: OnceCell = OnceCell::const_new(); + + rt.block_on(async { + let handle1 = rt.spawn(async { ONCE.get_or_init(func1).await }); + let handle2 = rt.spawn(async { ONCE.get_or_init(func2).await }); + + time::advance(Duration::from_millis(1)).await; + time::resume(); + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + assert_eq!(*result1, 5); + assert_eq!(*result2, 5); + }); +} + +#[test] +fn get_or_init_panic() { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + + static ONCE: OnceCell = OnceCell::const_new(); + + rt.block_on(async { + time::pause(); + + let handle1 = rt.spawn(async { ONCE.get_or_init(func1).await }); + let handle2 = rt.spawn(async { ONCE.get_or_init(func_panic).await }); + + time::advance(Duration::from_millis(1)).await; + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + assert_eq!(*result1, 5); + assert_eq!(*result2, 5); + }); +} + +#[test] +fn set_and_get() { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + + static ONCE: OnceCell = OnceCell::const_new(); + + rt.block_on(async { + let _ = rt.spawn(async { ONCE.set(5) }).await; + let value = ONCE.get().unwrap(); + assert_eq!(*value, 5); + }); +} + +#[test] +fn get_uninit() { + static ONCE: OnceCell = OnceCell::const_new(); + let uninit = ONCE.get(); + assert!(uninit.is_none()); +} + +#[test] +fn set_twice() { + static ONCE: OnceCell = OnceCell::const_new(); + + let first = ONCE.set(5); + assert_eq!(first, Ok(())); + let second = ONCE.set(6); + assert!(second.err().unwrap().is_already_init_err()); +} + +#[test] +fn set_while_initializing() { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + + static ONCE: OnceCell = OnceCell::const_new(); + + rt.block_on(async { + time::pause(); + + let handle1 = rt.spawn(async { ONCE.get_or_init(sleep_and_set).await }); + let handle2 = rt.spawn(async { advance_time_and_set(&ONCE, 10).await }); + + time::advance(Duration::from_millis(2)).await; + + let result1 = handle1.await.unwrap(); + let result2 = handle2.await.unwrap(); + + assert_eq!(*result1, 5); + assert!(result2.err().unwrap().is_initializing_err()); + }); +} + +#[test] +fn get_or_try_init() { + let rt = runtime::Builder::new_current_thread() + .enable_time() + .start_paused(true) + .build() + .unwrap(); + + static ONCE: OnceCell = OnceCell::const_new(); + + rt.block_on(async { + let handle1 = rt.spawn(async { ONCE.get_or_try_init(func_err).await }); + let handle2 = rt.spawn(async { ONCE.get_or_try_init(func_ok).await }); + + time::advance(Duration::from_millis(1)).await; + time::resume(); + + let result1 = handle1.await.unwrap(); + assert!(result1.is_err()); + + let result2 = handle2.await.unwrap(); + assert_eq!(*result2.unwrap(), 10); + }); +} + +#[test] +fn drop_cell() { + static NUM_DROPS: AtomicU32 = AtomicU32::new(0); + + struct Foo {} + + let fooer = Foo {}; + + impl Drop for Foo { + fn drop(&mut self) { + NUM_DROPS.fetch_add(1, Ordering::Release); + } + } + + { + let once_cell = OnceCell::new(); + let prev = once_cell.set(fooer); + assert!(prev.is_ok()) + } + assert!(NUM_DROPS.load(Ordering::Acquire) == 1); +} + +#[test] +fn drop_cell_new_with() { + static NUM_DROPS: AtomicU32 = AtomicU32::new(0); + + struct Foo {} + + let fooer = Foo {}; + + impl Drop for Foo { + fn drop(&mut self) { + NUM_DROPS.fetch_add(1, Ordering::Release); + } + } + + { + let once_cell = OnceCell::new_with(Some(fooer)); + assert!(once_cell.initialized()); + } + assert!(NUM_DROPS.load(Ordering::Acquire) == 1); +} + +#[test] +fn drop_into_inner() { + static NUM_DROPS: AtomicU32 = AtomicU32::new(0); + + struct Foo {} + + let fooer = Foo {}; + + impl Drop for Foo { + fn drop(&mut self) { + NUM_DROPS.fetch_add(1, Ordering::Release); + } + } + + let once_cell = OnceCell::new(); + assert!(once_cell.set(fooer).is_ok()); + let fooer = once_cell.into_inner(); + let count = NUM_DROPS.load(Ordering::Acquire); + assert!(count == 0); + drop(fooer); + let count = NUM_DROPS.load(Ordering::Acquire); + assert!(count == 1); +} + +#[test] +fn drop_into_inner_new_with() { + static NUM_DROPS: AtomicU32 = AtomicU32::new(0); + + struct Foo {} + + let fooer = Foo {}; + + impl Drop for Foo { + fn drop(&mut self) { + NUM_DROPS.fetch_add(1, Ordering::Release); + } + } + + let once_cell = OnceCell::new_with(Some(fooer)); + let fooer = once_cell.into_inner(); + let count = NUM_DROPS.load(Ordering::Acquire); + assert!(count == 0); + mem::drop(fooer); + let count = NUM_DROPS.load(Ordering::Acquire); + assert!(count == 1); +}