sync: add broadcast::Sender::closed (#6685)

This commit is contained in:
Evan Rittenhouse 2025-01-09 09:37:49 -06:00 committed by GitHub
parent 5f3296df77
commit 5c8cd33820
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 115 additions and 2 deletions

View File

@ -1,7 +1,7 @@
use std::sync::{self, MutexGuard, TryLockError};
/// Adapter for `std::Mutex` that removes the poisoning aspects
/// from its api.
/// from its API.
#[derive(Debug)]
pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>);

View File

@ -301,6 +301,8 @@ pub mod error {
use self::error::{RecvError, SendError, TryRecvError};
use super::Notify;
/// Data shared between senders and receivers.
struct Shared<T> {
/// slots in the channel.
@ -314,6 +316,9 @@ struct Shared<T> {
/// Number of outstanding Sender handles.
num_tx: AtomicUsize,
/// Notify when the last subscribed [`Receiver`] drops.
notify_last_rx_drop: Notify,
}
/// Next position to write a value.
@ -528,6 +533,7 @@ impl<T> Sender<T> {
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
notify_last_rx_drop: Notify::new(),
});
Sender { shared }
@ -805,6 +811,50 @@ impl<T> Sender<T> {
Arc::ptr_eq(&self.shared, &other.shared)
}
/// A future which completes when the number of [Receiver]s subscribed to this `Sender` reaches
/// zero.
///
/// # Examples
///
/// ```
/// use futures::FutureExt;
/// use tokio::sync::broadcast;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx1) = broadcast::channel::<u32>(16);
/// let mut rx2 = tx.subscribe();
///
/// tokio::spawn(async move {
/// assert_eq!(rx1.recv().await.unwrap(), 10);
/// });
///
/// let _ = tx.send(10);
/// assert!(tx.closed().now_or_never().is_none());
///
/// let _ = tokio::spawn(async move {
/// assert_eq!(rx2.recv().await.unwrap(), 10);
/// }).await;
///
/// assert!(tx.closed().now_or_never().is_some());
/// }
/// ```
pub async fn closed(&self) {
loop {
let notified = self.shared.notify_last_rx_drop.notified();
{
// Ensure the lock drops if the channel isn't closed
let tail = self.shared.tail.lock();
if tail.closed {
return;
}
}
notified.await;
}
}
fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
tail.closed = true;
@ -819,8 +869,14 @@ fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {
assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers");
tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
if tail.rx_cnt == 0 {
// Potentially need to re-open the channel, if a new receiver has been added between calls
// to poll(). Note that we use rx_cnt == 0 instead of is_closed since is_closed also
// applies if the sender has been dropped
tail.closed = false;
}
tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
let next = tail.pos;
drop(tail);
@ -1346,6 +1402,12 @@ impl<T> Drop for Receiver<T> {
tail.rx_cnt -= 1;
let until = tail.pos;
let remaining_rx = tail.rx_cnt;
if remaining_rx == 0 {
self.shared.notify_last_rx_drop.notify_waiters();
tail.closed = true;
}
drop(tail);

View File

@ -656,3 +656,54 @@ async fn receiver_recv_is_cooperative() {
_ = tokio::task::yield_now() => {},
}
}
#[test]
fn broadcast_sender_closed() {
let (tx, rx) = broadcast::channel::<()>(1);
let rx2 = tx.subscribe();
let mut task = task::spawn(tx.closed());
assert_pending!(task.poll());
drop(rx);
assert!(!task.is_woken());
assert_pending!(task.poll());
drop(rx2);
assert!(task.is_woken());
assert_ready!(task.poll());
}
#[test]
fn broadcast_sender_closed_with_extra_subscribe() {
let (tx, rx) = broadcast::channel::<()>(1);
let rx2 = tx.subscribe();
let mut task = task::spawn(tx.closed());
assert_pending!(task.poll());
drop(rx);
assert!(!task.is_woken());
assert_pending!(task.poll());
drop(rx2);
assert!(task.is_woken());
let rx3 = tx.subscribe();
assert_pending!(task.poll());
drop(rx3);
assert!(task.is_woken());
assert_ready!(task.poll());
let mut task2 = task::spawn(tx.closed());
assert_ready!(task2.poll());
let rx4 = tx.subscribe();
let mut task3 = task::spawn(tx.closed());
assert_pending!(task3.poll());
drop(rx4);
assert!(task3.is_woken());
assert_ready!(task3.poll());
}