diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 05a58070e..abc4974a3 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -272,6 +272,9 @@ struct Tail { /// Number of active receivers rx_cnt: usize, + + /// True if the channel is closed + closed: bool, } /// Slot in the buffer @@ -319,7 +322,10 @@ struct RecvGuard<'a, T> { } /// Max number of receivers. Reserve space to lock. -const MAX_RECEIVERS: usize = usize::MAX >> 1; +const MAX_RECEIVERS: usize = usize::MAX >> 2; +const CLOSED: usize = 1; +const WRITER: usize = 2; +const READER: usize = 4; /// Create a bounded, multi-producer, multi-consumer channel where each sent /// value is broadcasted to all active receivers. @@ -389,7 +395,11 @@ pub fn channel(mut capacity: usize) -> (Sender, Receiver) { let shared = Arc::new(Shared { buffer: buffer.into_boxed_slice(), mask: capacity - 1, - tail: Mutex::new(Tail { pos: 0, rx_cnt: 1 }), + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: 1, + closed: false, + }), condvar: Condvar::new(), wait_stack: AtomicPtr::new(ptr::null_mut()), num_tx: AtomicUsize::new(1), @@ -580,15 +590,15 @@ impl Sender { let slot = &self.shared.buffer[idx]; // Acquire the write lock - let mut prev = slot.lock.fetch_or(1, SeqCst); + let mut prev = slot.lock.fetch_or(WRITER, SeqCst); - while prev & !1 != 0 { + while prev & !WRITER != 0 { // Concurrent readers, we must go to sleep tail = self.shared.condvar.wait(tail).unwrap(); prev = slot.lock.load(SeqCst); - if prev & 1 == 0 { + if prev & WRITER == 0 { // The writer lock bit was cleared while this thread was // sleeping. This can only happen if a newer write happened on // this slot by another thread. Bail early as an optimization, @@ -604,13 +614,18 @@ impl Sender { // Slot lock acquired slot.write.pos.with_mut(|ptr| unsafe { *ptr = pos }); - slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); // Set remaining receivers slot.rem.store(rem, SeqCst); - // Release the slot lock - slot.lock.store(0, SeqCst); + // Set the closed bit if the value is `None`; otherwise write the value + if value.is_none() { + tail.closed = true; + slot.lock.store(CLOSED, SeqCst); + } else { + slot.write.val.with_mut(|ptr| unsafe { *ptr = value }); + slot.lock.store(0, SeqCst); + } // Release the mutex. This must happen after the slot lock is released, // otherwise the writer lock bit could be cleared while another thread @@ -688,28 +703,52 @@ impl Receiver { if guard.pos() != self.next { let pos = guard.pos(); - guard.drop_no_rem_dec(); - + // The receiver has read all current values in the channel if pos.wrapping_add(self.shared.buffer.len() as u64) == self.next { + guard.drop_no_rem_dec(); return Err(TryRecvError::Empty); - } else { - let tail = self.shared.tail.lock().unwrap(); - - // `tail.pos` points to the slot the **next** send writes to. - // Because a receiver is lagging, this slot also holds the - // oldest value. To make the positions match, we subtract the - // capacity. - let next = tail.pos.wrapping_sub(self.shared.buffer.len() as u64); - let missed = next.wrapping_sub(self.next); - - self.next = next; - - return Err(TryRecvError::Lagged(missed)); } + + let tail = self.shared.tail.lock().unwrap(); + + // `tail.pos` points to the slot that the **next** send writes to. If + // the channel is closed, the previous slot is the oldest value. + let mut adjust = 0; + if tail.closed { + adjust = 1 + } + let next = tail + .pos + .wrapping_sub(self.shared.buffer.len() as u64 + adjust); + + let missed = next.wrapping_sub(self.next); + + drop(tail); + + // The receiver is slow but no values have been missed + if missed == 0 { + self.next = self.next.wrapping_add(1); + return Ok(guard); + } + + guard.drop_no_rem_dec(); + self.next = next; + + return Err(TryRecvError::Lagged(missed)); } self.next = self.next.wrapping_add(1); + // If the `CLOSED` bit it set on the slot, the channel is closed + // + // `try_rx_lock` could check for this and bail early. If it's return + // value was changed to represent the state of the lock, it could + // match on being closed, empty, or available for reading. + if slot.lock.load(SeqCst) & CLOSED == CLOSED { + guard.drop_no_rem_dec(); + return Err(TryRecvError::Closed); + } + Ok(guard) } } @@ -909,7 +948,6 @@ impl Drop for Receiver { while self.next != until { match self.recv_ref(true) { - // Ignore the value Ok(_) => {} // The channel is closed Err(TryRecvError::Closed) => break, @@ -954,13 +992,15 @@ impl Slot { let mut curr = self.lock.load(SeqCst); loop { - if curr & 1 == 1 { + if curr & WRITER == WRITER { // Locked by sender return false; } - // Only increment (by 2) if the LSB "lock" bit is not set. - let res = self.lock.compare_exchange(curr, curr + 2, SeqCst, SeqCst); + // Only increment (by `READER`) if the `WRITER` bit is not set. + let res = self + .lock + .compare_exchange(curr, curr + READER, SeqCst, SeqCst); match res { Ok(_) => return true, @@ -978,7 +1018,7 @@ impl Slot { } } - if 1 == self.lock.fetch_sub(2, SeqCst) - 2 { + if WRITER == self.lock.fetch_sub(READER, SeqCst) - READER { // First acquire the lock to make sure our sender is waiting on the // condition variable, otherwise the notification could be lost. mem::drop(tail.lock().unwrap()); diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index e9e7b3661..4fb7c0aa7 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -40,6 +40,15 @@ macro_rules! assert_lagged { }; } +macro_rules! assert_closed { + ($e:expr) => { + match assert_err!($e) { + broadcast::TryRecvError::Closed => {} + _ => panic!("did not lag"), + } + }; +} + trait AssertSend: Send {} impl AssertSend for broadcast::Sender {} impl AssertSend for broadcast::Receiver {} @@ -229,7 +238,8 @@ fn lagging_rx() { assert_ok!(tx.send("three")); // Lagged too far - assert_lagged!(rx2.try_recv(), 1); + let x = dbg!(rx2.try_recv()); + assert_lagged!(x, 1); // Calling again gets the next value assert_eq!("two", assert_recv!(rx2)); @@ -349,6 +359,98 @@ fn unconsumed_messages_are_dropped() { assert_eq!(1, Arc::strong_count(&msg)); } +#[test] +fn single_capacity_recvs() { + let (tx, mut rx) = broadcast::channel(1); + + assert_ok!(tx.send(1)); + + assert_eq!(assert_recv!(rx), 1); + assert_empty!(rx); +} + +#[test] +fn single_capacity_recvs_after_drop_1() { + let (tx, mut rx) = broadcast::channel(1); + + assert_ok!(tx.send(1)); + drop(tx); + + assert_eq!(assert_recv!(rx), 1); + assert_closed!(rx.try_recv()); +} + +#[test] +fn single_capacity_recvs_after_drop_2() { + let (tx, mut rx) = broadcast::channel(1); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + drop(tx); + + assert_lagged!(rx.try_recv(), 1); + assert_eq!(assert_recv!(rx), 2); + assert_closed!(rx.try_recv()); +} + +#[test] +fn dropping_sender_does_not_overwrite() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + drop(tx); + + assert_eq!(assert_recv!(rx), 1); + assert_eq!(assert_recv!(rx), 2); + assert_closed!(rx.try_recv()); +} + +#[test] +fn lagging_receiver_recovers_after_wrap_closed_1() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + assert_ok!(tx.send(3)); + drop(tx); + + assert_lagged!(rx.try_recv(), 1); + assert_eq!(assert_recv!(rx), 2); + assert_eq!(assert_recv!(rx), 3); + assert_closed!(rx.try_recv()); +} + +#[test] +fn lagging_receiver_recovers_after_wrap_closed_2() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + assert_ok!(tx.send(3)); + assert_ok!(tx.send(4)); + drop(tx); + + assert_lagged!(rx.try_recv(), 2); + assert_eq!(assert_recv!(rx), 3); + assert_eq!(assert_recv!(rx), 4); + assert_closed!(rx.try_recv()); +} + +#[test] +fn lagging_receiver_recovers_after_wrap_open() { + let (tx, mut rx) = broadcast::channel(2); + + assert_ok!(tx.send(1)); + assert_ok!(tx.send(2)); + assert_ok!(tx.send(3)); + + assert_lagged!(rx.try_recv(), 1); + assert_eq!(assert_recv!(rx), 2); + assert_eq!(assert_recv!(rx), 3); + assert_empty!(rx); +} + fn is_closed(err: broadcast::RecvError) -> bool { match err { broadcast::RecvError::Closed => true,