From 067ddff0635fe3b7edc9831167f7804de9891cf8 Mon Sep 17 00:00:00 2001 From: Nylonicious <50183564+nylonicious@users.noreply.github.com> Date: Tue, 22 Feb 2022 21:19:21 +0100 Subject: [PATCH] sync: add `watch::Sender::send_modify` method (#4310) --- tokio/src/sync/watch.rs | 81 +++++++++++++++++++++++++++++---------- tokio/tests/sync_watch.rs | 28 ++++++++++++++ 2 files changed, 88 insertions(+), 21 deletions(-) diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 5673e0fca..ab0ada7d0 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -60,6 +60,7 @@ use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::mem; use std::ops; +use std::panic; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -530,6 +531,61 @@ impl Sender { Ok(()) } + /// Modifies watched value, notifying all receivers. + /// + /// This can useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// # Panics + /// + /// This function panics if calling `func` results in a panic. + /// No receivers are notified if panic occurred, but if the closure has modified + /// the value, that change is still visible to future calls to `borrow`. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, state_rx) = watch::channel(State { counter: 0 }); + /// state_tx.send_modify(|state| state.counter += 1); + /// assert_eq!(state_rx.borrow().counter, 1); + /// ``` + pub fn send_modify(&self, func: F) + where + F: FnOnce(&mut T), + { + { + // Acquire the write lock and update the value. + let mut lock = self.shared.value.write().unwrap(); + // Update the value and catch possible panic inside func. + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + func(&mut lock); + })); + // If the func panicked return the panic to the caller. + if let Err(error) = result { + // Drop the lock to avoid poisoning it. + drop(lock); + panic::resume_unwind(error); + } + + self.shared.state.increment_version(); + + // Release the write lock. + // + // Incrementing the version counter while holding the lock ensures + // that receivers are able to figure out the version number of the + // value they are currently looking at. + drop(lock); + } + + self.shared.notify_rx.notify_waiters(); + } + /// Sends a new value via the channel, notifying all receivers and returning /// the previous value in the channel. /// @@ -546,28 +602,11 @@ impl Sender { /// assert_eq!(tx.send_replace(2), 1); /// assert_eq!(tx.send_replace(3), 2); /// ``` - pub fn send_replace(&self, value: T) -> T { - let old = { - // Acquire the write lock and update the value. - let mut lock = self.shared.value.write().unwrap(); - let old = mem::replace(&mut *lock, value); + pub fn send_replace(&self, mut value: T) -> T { + // swap old watched value with the new one + self.send_modify(|old| mem::swap(old, &mut value)); - self.shared.state.increment_version(); - - // Release the write lock. - // - // Incrementing the version counter while holding the lock ensures - // that receivers are able to figure out the version number of the - // value they are currently looking at. - drop(lock); - - old - }; - - // Notify all watchers - self.shared.notify_rx.notify_waiters(); - - old + value } /// Returns a reference to the most recently sent value diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 8b9ea81bb..2097b8bdf 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -211,3 +211,31 @@ fn reopened_after_subscribe() { drop(rx); assert!(tx.is_closed()); } + +#[test] +fn send_modify_panic() { + let (tx, mut rx) = watch::channel("one"); + + tx.send_modify(|old| *old = "two"); + assert_eq!(*rx.borrow_and_update(), "two"); + + let mut rx2 = rx.clone(); + assert_eq!(*rx2.borrow_and_update(), "two"); + + let mut task = spawn(rx2.changed()); + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + tx.send_modify(|old| { + *old = "panicked"; + panic!(); + }) + })); + assert!(result.is_err()); + + assert_pending!(task.poll()); + assert_eq!(*rx.borrow(), "panicked"); + + tx.send_modify(|old| *old = "three"); + assert_ready_ok!(task.poll()); + assert_eq!(*rx.borrow_and_update(), "three"); +}