diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 263ec62c9..359b14f5e 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -428,7 +428,7 @@ cfg_sync! { pub mod mpsc; mod mutex; - pub use mutex::{Mutex, MutexGuard, TryLockError}; + pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard}; mod notify; pub use notify::Notify; diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index 69eec678c..e0618a5d6 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -5,6 +5,7 @@ use std::cell::UnsafeCell; use std::error::Error; use std::fmt; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; /// An asynchronous `Mutex`-like type. /// @@ -18,8 +19,8 @@ use std::ops::{Deref, DerefMut}; /// 1. The lock does not need to be held across await points. /// 2. The duration of any single lock is near-instant. /// -/// On the other hand, the Tokio mutex is for the situation where the lock needs -/// to be held for longer periods of time, or across await points. +/// On the other hand, the Tokio mutex is for the situation where the lock +/// needs to be held for longer periods of time, or across await points. /// /// # Examples: /// @@ -71,18 +72,20 @@ use std::ops::{Deref, DerefMut}; /// } /// ``` /// There are a few things of note here to pay attention to in this example. -/// 1. The mutex is wrapped in an [`Arc`] to allow it to be shared across threads. +/// 1. The mutex is wrapped in an [`Arc`] to allow it to be shared across +/// threads. /// 2. Each spawned task obtains a lock and releases it on every iteration. -/// 3. Mutation of the data protected by the Mutex is done by de-referencing the obtained lock -/// as seen on lines 12 and 19. +/// 3. Mutation of the data protected by the Mutex is done by de-referencing +/// the obtained lock as seen on lines 12 and 19. /// -/// Tokio's Mutex works in a simple FIFO (first in, first out) style where all calls -/// to [`lock`] complete in the order they were performed. In that way -/// the Mutex is "fair" and predictable in how it distributes the locks to inner data. This is why -/// the output of the program above is an in-order count to 50. Locks are released and reacquired -/// after every iteration, so basically, each thread goes to the back of the line after it increments -/// the value once. Finally, since there is only a single valid lock at any given time, there is no -/// possibility of a race condition when mutating the inner value. +/// Tokio's Mutex works in a simple FIFO (first in, first out) style where all +/// calls to [`lock`] complete in the order they were performed. In that way the +/// Mutex is "fair" and predictable in how it distributes the locks to inner +/// data. This is why the output of the program above is an in-order count to +/// 50. Locks are released and reacquired after every iteration, so basically, +/// each thread goes to the back of the line after it increments the value once. +/// Finally, since there is only a single valid lock at any given time, there is +/// no possibility of a race condition when mutating the inner value. /// /// Note that in contrast to [`std::sync::Mutex`], this implementation does not /// poison the mutex when a thread holding the [`MutexGuard`] panics. In such a @@ -104,22 +107,42 @@ pub struct Mutex { /// A handle to a held `Mutex`. /// -/// As long as you have this guard, you have exclusive access to the underlying `T`. The guard -/// internally keeps a reference-couned pointer to the original `Mutex`, so even if the lock goes -/// away, the guard remains valid. +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally borrows the `Mutex`, so the mutex will not be +/// dropped while a guard exists. /// -/// The lock is automatically released whenever the guard is dropped, at which point `lock` -/// will succeed yet again. +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. pub struct MutexGuard<'a, T> { lock: &'a Mutex, } +/// An owned handle to a held `Mutex`. +/// +/// This guard is only available from a `Mutex` that is wrapped in an [`Arc`]. It +/// is identical to `MutexGuard`, except that rather than borrowing the `Mutex`, +/// it clones the `Arc`, incrementing the reference count. This means that +/// unlike `MutexGuard`, it will have the `'static` lifetime. +/// +/// As long as you have this guard, you have exclusive access to the underlying +/// `T`. The guard internally keeps a reference-couned pointer to the original +/// `Mutex`, so even if the lock goes away, the guard remains valid. +/// +/// The lock is automatically released whenever the guard is dropped, at which +/// point `lock` will succeed yet again. +/// +/// [`Arc`]: std::sync::Arc +pub struct OwnedMutexGuard { + lock: Arc>, +} + // As long as T: Send, it's fine to send and share Mutex between threads. -// If T was not Send, sending and sharing a Mutex would be bad, since you can access T through -// Mutex. +// If T was not Send, sending and sharing a Mutex would be bad, since you can +// access T through Mutex. unsafe impl Send for Mutex where T: Send {} unsafe impl Sync for Mutex where T: Send {} unsafe impl<'a, T> Sync for MutexGuard<'a, T> where T: Send + Sync {} +unsafe impl Sync for OwnedMutexGuard where T: Send + Sync {} /// Error returned from the [`Mutex::try_lock`] function. /// @@ -145,12 +168,20 @@ fn bounds() { // This has to take a value, since the async fn's return type is unnameable. fn check_send_sync_val(_t: T) {} fn check_send_sync() {} + fn check_static() {} + fn check_static_val(_t: T) {} + check_send::>(); + check_send::>(); check_unpin::>(); check_send_sync::>(); + check_static::>(); let mutex = Mutex::new(1); check_send_sync_val(mutex.lock()); + let arc_mutex = Arc::new(Mutex::new(1)); + check_send_sync_val(arc_mutex.clone().lock_owned()); + check_static_val(arc_mutex.lock_owned()); } impl Mutex { @@ -188,12 +219,47 @@ impl Mutex { /// } /// ``` pub async fn lock(&self) -> MutexGuard<'_, T> { + self.acquire().await; + MutexGuard { lock: self } + } + + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, this returns an + /// [`OwnedMutexGuard`]. + /// + /// This method is identical to [`Mutex::lock`], except that the returned + /// guard references the `Mutex` with an [`Arc`] rather than by borrowing + /// it. Therefore, the `Mutex` must be wrapped in an `Arc` to call this + /// method, and the guard will live for the `'static` lifetime, as it keeps + /// the `Mutex` alive by holding an `Arc`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let mut n = mutex.clone().lock_owned().await; + /// *n = 2; + /// } + /// ``` + /// + /// [`Arc`]: std::sync::Arc + pub async fn lock_owned(self: Arc) -> OwnedMutexGuard { + self.acquire().await; + OwnedMutexGuard { lock: self } + } + + async fn acquire(&self) { self.s.acquire(1).cooperate().await.unwrap_or_else(|_| { - // The semaphore was closed. but, we never explicitly close it, and we have a - // handle to it through the Arc, which means that this can never happen. + // The semaphore was closed. but, we never explicitly close it, and + // we own it exclusively, which means that this can never happen. unreachable!() }); - MutexGuard { lock: self } } /// Attempts to acquire the lock, and returns [`TryLockError`] if the @@ -220,6 +286,37 @@ impl Mutex { } } + /// Attempts to acquire the lock, and returns [`TryLockError`] if the lock + /// is currently held somewhere else. + /// + /// This method is identical to [`Mutex::try_lock`], except that the + /// returned guard references the `Mutex` with an [`Arc`] rather than by + /// borrowing it. Therefore, the `Mutex` must be wrapped in an `Arc` to call + /// this method, and the guard will live for the `'static` lifetime, as it + /// keeps the `Mutex` alive by holding an `Arc`. + /// + /// [`TryLockError`]: TryLockError + /// [`Arc`]: std::sync::Arc + /// # Examples + /// + /// ``` + /// use tokio::sync::Mutex; + /// use std::sync::Arc; + /// # async fn dox() -> Result<(), tokio::sync::TryLockError> { + /// + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let n = mutex.clone().try_lock_owned()?; + /// assert_eq!(*n, 1); + /// # Ok(()) + /// # } + pub fn try_lock_owned(self: Arc) -> Result, TryLockError> { + match self.s.try_acquire(1) { + Ok(_) => Ok(OwnedMutexGuard { lock: self }), + Err(_) => Err(TryLockError(())), + } + } + /// Consumes the mutex, returning the underlying data. /// # Examples /// @@ -239,12 +336,6 @@ impl Mutex { } } -impl<'a, T> Drop for MutexGuard<'a, T> { - fn drop(&mut self) { - self.lock.s.release(1) - } -} - impl From for Mutex { fn from(s: T) -> Self { Self::new(s) @@ -260,6 +351,14 @@ where } } +// === impl MutexGuard === + +impl<'a, T> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + self.lock.s.release(1) + } +} + impl<'a, T> Deref for MutexGuard<'a, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -284,3 +383,36 @@ impl<'a, T: fmt::Display> fmt::Display for MutexGuard<'a, T> { fmt::Display::fmt(&**self, f) } } + +// === impl OwnedMutexGuard === + +impl Drop for OwnedMutexGuard { + fn drop(&mut self) { + self.lock.s.release(1) + } +} + +impl Deref for OwnedMutexGuard { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.c.get() } + } +} + +impl DerefMut for OwnedMutexGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.c.get() } + } +} + +impl fmt::Debug for OwnedMutexGuard { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl fmt::Display for OwnedMutexGuard { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index 1fea19c2a..45d11bd44 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -203,6 +203,9 @@ async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex::lock(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex>::lock(_): Send & Sync); async_assert_fn!(tokio::sync::Mutex>::lock(_): !Send & !Sync); +async_assert_fn!(tokio::sync::Mutex::lock_owned(_): Send & Sync); +async_assert_fn!(tokio::sync::Mutex>::lock_owned(_): Send & Sync); +async_assert_fn!(tokio::sync::Mutex>::lock_owned(_): !Send & !Sync); async_assert_fn!(tokio::sync::Notify::notified(_): Send & !Sync); async_assert_fn!(tokio::sync::RwLock::read(_): Send & Sync); async_assert_fn!(tokio::sync::RwLock::write(_): Send & Sync); diff --git a/tokio/tests/sync_mutex_owned.rs b/tokio/tests/sync_mutex_owned.rs new file mode 100644 index 000000000..eef966fd4 --- /dev/null +++ b/tokio/tests/sync_mutex_owned.rs @@ -0,0 +1,121 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::sync::Mutex; +use tokio::time::{interval, timeout}; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready}; + +use std::sync::Arc; +use std::time::Duration; + +#[test] +fn straight_execution() { + let l = Arc::new(Mutex::new(100)); + + { + let mut t = spawn(l.clone().lock_owned()); + let mut g = assert_ready!(t.poll()); + assert_eq!(&*g, &100); + *g = 99; + } + { + let mut t = spawn(l.clone().lock_owned()); + let mut g = assert_ready!(t.poll()); + assert_eq!(&*g, &99); + *g = 98; + } + { + let mut t = spawn(l.lock_owned()); + let g = assert_ready!(t.poll()); + assert_eq!(&*g, &98); + } +} + +#[test] +fn readiness() { + let l = Arc::new(Mutex::new(100)); + let mut t1 = spawn(l.clone().lock_owned()); + let mut t2 = spawn(l.clone().lock_owned()); + + let g = assert_ready!(t1.poll()); + + // We can't now acquire the lease since it's already held in g + assert_pending!(t2.poll()); + + // But once g unlocks, we can acquire it + drop(g); + assert!(t2.is_woken()); + assert_ready!(t2.poll()); +} + +#[tokio::test] +/// Ensure a mutex is unlocked if a future holding the lock +/// is aborted prematurely. +async fn aborted_future_1() { + let m1: Arc> = Arc::new(Mutex::new(0)); + { + let m2 = m1.clone(); + // Try to lock mutex in a future that is aborted prematurely + timeout(Duration::from_millis(1u64), async move { + let mut iv = interval(Duration::from_millis(1000)); + m2.lock_owned().await; + iv.tick().await; + iv.tick().await; + }) + .await + .unwrap_err(); + } + // This should succeed as there is no lock left for the mutex. + timeout(Duration::from_millis(1u64), async move { + m1.lock_owned().await; + }) + .await + .expect("Mutex is locked"); +} + +#[tokio::test] +/// This test is similar to `aborted_future_1` but this time the +/// aborted future is waiting for the lock. +async fn aborted_future_2() { + let m1: Arc> = Arc::new(Mutex::new(0)); + { + // Lock mutex + let _lock = m1.clone().lock_owned().await; + { + let m2 = m1.clone(); + // Try to lock mutex in a future that is aborted prematurely + timeout(Duration::from_millis(1u64), async move { + m2.lock_owned().await; + }) + .await + .unwrap_err(); + } + } + // This should succeed as there is no lock left for the mutex. + timeout(Duration::from_millis(1u64), async move { + m1.lock_owned().await; + }) + .await + .expect("Mutex is locked"); +} + +#[test] +fn try_lock_owned() { + let m: Arc> = Arc::new(Mutex::new(0)); + { + let g1 = m.clone().try_lock_owned(); + assert_eq!(g1.is_ok(), true); + let g2 = m.clone().try_lock_owned(); + assert_eq!(g2.is_ok(), false); + } + let g3 = m.try_lock_owned(); + assert_eq!(g3.is_ok(), true); +} + +#[tokio::test] +async fn debug_format() { + let s = "debug"; + let m = Arc::new(Mutex::new(s.to_string())); + assert_eq!(format!("{:?}", s), format!("{:?}", m.lock_owned().await)); +}