mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-10-01 12:20:39 +00:00
sync: use intrusive list strategy for broadcast (#2509)
Previously, in the broadcast channel, receiver wakers were passed to the sender via an atomic stack with allocated nodes. When a message was sent, the stack was drained. This caused a problem when many receivers pushed a waiter node then dropped. The waiter node remained indefinitely in cases where no values were sent. This patch switches broadcast to use the intrusive linked-list waiter strategy used by `Notify` and `Semaphore.
This commit is contained in:
parent
a32f918671
commit
fb7dfcf432
@ -11,10 +11,6 @@ impl<T> AtomicPtr<T> {
|
|||||||
let inner = std::sync::atomic::AtomicPtr::new(ptr);
|
let inner = std::sync::atomic::AtomicPtr::new(ptr);
|
||||||
AtomicPtr { inner }
|
AtomicPtr { inner }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn with_mut<R>(&mut self, f: impl FnOnce(&mut *mut T) -> R) -> R {
|
|
||||||
f(self.inner.get_mut())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Deref for AtomicPtr<T> {
|
impl<T> Deref for AtomicPtr<T> {
|
||||||
|
@ -109,12 +109,15 @@
|
|||||||
//! }
|
//! }
|
||||||
|
|
||||||
use crate::loom::cell::UnsafeCell;
|
use crate::loom::cell::UnsafeCell;
|
||||||
use crate::loom::future::AtomicWaker;
|
use crate::loom::sync::atomic::AtomicUsize;
|
||||||
use crate::loom::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize};
|
|
||||||
use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
|
use crate::loom::sync::{Arc, Mutex, RwLock, RwLockReadGuard};
|
||||||
|
use crate::util::linked_list::{self, LinkedList};
|
||||||
|
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::ptr;
|
use std::future::Future;
|
||||||
|
use std::marker::PhantomPinned;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::ptr::NonNull;
|
||||||
use std::sync::atomic::Ordering::SeqCst;
|
use std::sync::atomic::Ordering::SeqCst;
|
||||||
use std::task::{Context, Poll, Waker};
|
use std::task::{Context, Poll, Waker};
|
||||||
use std::usize;
|
use std::usize;
|
||||||
@ -192,8 +195,8 @@ pub struct Receiver<T> {
|
|||||||
/// Next position to read from
|
/// Next position to read from
|
||||||
next: u64,
|
next: u64,
|
||||||
|
|
||||||
/// Waiter state
|
/// Used to support the deprecated `poll_recv` fn
|
||||||
wait: Arc<WaitNode>,
|
waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error returned by [`Sender::send`][Sender::send].
|
/// Error returned by [`Sender::send`][Sender::send].
|
||||||
@ -251,12 +254,9 @@ struct Shared<T> {
|
|||||||
/// Mask a position -> index
|
/// Mask a position -> index
|
||||||
mask: usize,
|
mask: usize,
|
||||||
|
|
||||||
/// Tail of the queue
|
/// Tail of the queue. Includes the rx wait list.
|
||||||
tail: Mutex<Tail>,
|
tail: Mutex<Tail>,
|
||||||
|
|
||||||
/// Stack of pending waiters
|
|
||||||
wait_stack: AtomicPtr<WaitNode>,
|
|
||||||
|
|
||||||
/// Number of outstanding Sender handles
|
/// Number of outstanding Sender handles
|
||||||
num_tx: AtomicUsize,
|
num_tx: AtomicUsize,
|
||||||
}
|
}
|
||||||
@ -271,6 +271,9 @@ struct Tail {
|
|||||||
|
|
||||||
/// True if the channel is closed
|
/// True if the channel is closed
|
||||||
closed: bool,
|
closed: bool,
|
||||||
|
|
||||||
|
/// Receivers waiting for a value
|
||||||
|
waiters: LinkedList<Waiter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Slot in the buffer
|
/// Slot in the buffer
|
||||||
@ -296,23 +299,59 @@ struct Slot<T> {
|
|||||||
val: UnsafeCell<Option<T>>,
|
val: UnsafeCell<Option<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tracks a waiting receiver
|
/// An entry in the wait queue
|
||||||
#[derive(Debug)]
|
struct Waiter {
|
||||||
struct WaitNode {
|
/// True if queued
|
||||||
/// `true` if queued
|
queued: bool,
|
||||||
queued: AtomicBool,
|
|
||||||
|
|
||||||
/// Task to wake when a permit is made available.
|
/// Task waiting on the broadcast channel.
|
||||||
waker: AtomicWaker,
|
waker: Option<Waker>,
|
||||||
|
|
||||||
/// Next pointer in the stack of waiting senders.
|
/// Intrusive linked-list pointers.
|
||||||
next: UnsafeCell<*const WaitNode>,
|
pointers: linked_list::Pointers<Waiter>,
|
||||||
|
|
||||||
|
/// Should not be `Unpin`.
|
||||||
|
_p: PhantomPinned,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RecvGuard<'a, T> {
|
struct RecvGuard<'a, T> {
|
||||||
slot: RwLockReadGuard<'a, Slot<T>>,
|
slot: RwLockReadGuard<'a, Slot<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Receive a value future
|
||||||
|
struct Recv<R, T>
|
||||||
|
where
|
||||||
|
R: AsMut<Receiver<T>>,
|
||||||
|
{
|
||||||
|
/// Receiver being waited on
|
||||||
|
receiver: R,
|
||||||
|
|
||||||
|
/// Entry in the waiter `LinkedList`
|
||||||
|
waiter: UnsafeCell<Waiter>,
|
||||||
|
|
||||||
|
_p: std::marker::PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `AsMut<T>` is not implemented for `T` (coherence). Explicitly implementing
|
||||||
|
/// `AsMut` for `Receiver` would be included in the public API of the receiver
|
||||||
|
/// type. Instead, `Borrow` is used internally to bridge the gap.
|
||||||
|
struct Borrow<T>(T);
|
||||||
|
|
||||||
|
impl<T> AsMut<Receiver<T>> for Borrow<Receiver<T>> {
|
||||||
|
fn as_mut(&mut self) -> &mut Receiver<T> {
|
||||||
|
&mut self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T> AsMut<Receiver<T>> for Borrow<&'a mut Receiver<T>> {
|
||||||
|
fn as_mut(&mut self) -> &mut Receiver<T> {
|
||||||
|
&mut *self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<R: AsMut<Receiver<T>> + Send, T: Send> Send for Recv<R, T> {}
|
||||||
|
unsafe impl<R: AsMut<Receiver<T>> + Sync, T: Send> Sync for Recv<R, T> {}
|
||||||
|
|
||||||
/// Max number of receivers. Reserve space to lock.
|
/// Max number of receivers. Reserve space to lock.
|
||||||
const MAX_RECEIVERS: usize = usize::MAX >> 2;
|
const MAX_RECEIVERS: usize = usize::MAX >> 2;
|
||||||
|
|
||||||
@ -386,19 +425,15 @@ pub fn channel<T>(mut capacity: usize) -> (Sender<T>, Receiver<T>) {
|
|||||||
pos: 0,
|
pos: 0,
|
||||||
rx_cnt: 1,
|
rx_cnt: 1,
|
||||||
closed: false,
|
closed: false,
|
||||||
|
waiters: LinkedList::new(),
|
||||||
}),
|
}),
|
||||||
wait_stack: AtomicPtr::new(ptr::null_mut()),
|
|
||||||
num_tx: AtomicUsize::new(1),
|
num_tx: AtomicUsize::new(1),
|
||||||
});
|
});
|
||||||
|
|
||||||
let rx = Receiver {
|
let rx = Receiver {
|
||||||
shared: shared.clone(),
|
shared: shared.clone(),
|
||||||
next: 0,
|
next: 0,
|
||||||
wait: Arc::new(WaitNode {
|
waiter: None,
|
||||||
queued: AtomicBool::new(false),
|
|
||||||
waker: AtomicWaker::new(),
|
|
||||||
next: UnsafeCell::new(ptr::null()),
|
|
||||||
}),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let tx = Sender { shared };
|
let tx = Sender { shared };
|
||||||
@ -508,11 +543,7 @@ impl<T> Sender<T> {
|
|||||||
Receiver {
|
Receiver {
|
||||||
shared,
|
shared,
|
||||||
next,
|
next,
|
||||||
wait: Arc::new(WaitNode {
|
waiter: None,
|
||||||
queued: AtomicBool::new(false),
|
|
||||||
waker: AtomicWaker::new(),
|
|
||||||
next: UnsafeCell::new(ptr::null()),
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -589,34 +620,31 @@ impl<T> Sender<T> {
|
|||||||
slot.val.with_mut(|ptr| unsafe { *ptr = value });
|
slot.val.with_mut(|ptr| unsafe { *ptr = value });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release the slot lock before the tail lock
|
// Release the slot lock before notifying the receivers.
|
||||||
drop(slot);
|
drop(slot);
|
||||||
|
|
||||||
|
tail.notify_rx();
|
||||||
|
|
||||||
// Release the mutex. This must happen after the slot lock is released,
|
// Release the mutex. This must happen after the slot lock is released,
|
||||||
// otherwise the writer lock bit could be cleared while another thread
|
// otherwise the writer lock bit could be cleared while another thread
|
||||||
// is in the critical section.
|
// is in the critical section.
|
||||||
drop(tail);
|
drop(tail);
|
||||||
|
|
||||||
// Notify waiting receivers
|
|
||||||
self.notify_rx();
|
|
||||||
|
|
||||||
Ok(rem)
|
Ok(rem)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn notify_rx(&self) {
|
impl Tail {
|
||||||
let mut curr = self.shared.wait_stack.swap(ptr::null_mut(), SeqCst) as *const WaitNode;
|
fn notify_rx(&mut self) {
|
||||||
|
while let Some(mut waiter) = self.waiters.pop_back() {
|
||||||
|
// Safety: `waiters` lock is still held.
|
||||||
|
let waiter = unsafe { waiter.as_mut() };
|
||||||
|
|
||||||
while !curr.is_null() {
|
assert!(waiter.queued);
|
||||||
let waiter = unsafe { Arc::from_raw(curr) };
|
waiter.queued = false;
|
||||||
|
|
||||||
// Update `curr` before toggling `queued` and waking
|
let waker = waiter.waker.take().unwrap();
|
||||||
curr = waiter.next.with(|ptr| unsafe { *ptr });
|
waker.wake();
|
||||||
|
|
||||||
// Unset queued
|
|
||||||
waiter.queued.store(false, SeqCst);
|
|
||||||
|
|
||||||
// Wake
|
|
||||||
waiter.waker.wake();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -640,15 +668,21 @@ impl<T> Drop for Sender<T> {
|
|||||||
|
|
||||||
impl<T> Receiver<T> {
|
impl<T> Receiver<T> {
|
||||||
/// Locks the next value if there is one.
|
/// Locks the next value if there is one.
|
||||||
fn recv_ref(&mut self) -> Result<RecvGuard<'_, T>, TryRecvError> {
|
fn recv_ref(
|
||||||
|
&mut self,
|
||||||
|
waiter: Option<(&UnsafeCell<Waiter>, &Waker)>,
|
||||||
|
) -> Result<RecvGuard<'_, T>, TryRecvError> {
|
||||||
let idx = (self.next & self.shared.mask as u64) as usize;
|
let idx = (self.next & self.shared.mask as u64) as usize;
|
||||||
|
|
||||||
// The slot holding the next value to read
|
// The slot holding the next value to read
|
||||||
let mut slot = self.shared.buffer[idx].read().unwrap();
|
let mut slot = self.shared.buffer[idx].read().unwrap();
|
||||||
|
|
||||||
if slot.pos != self.next {
|
if slot.pos != self.next {
|
||||||
// The receiver has read all current values in the channel
|
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);
|
||||||
if slot.pos.wrapping_add(self.shared.buffer.len() as u64) == self.next {
|
|
||||||
|
// The receiver has read all current values in the channel and there
|
||||||
|
// is no waiter to register
|
||||||
|
if waiter.is_none() && next_pos == self.next {
|
||||||
return Err(TryRecvError::Empty);
|
return Err(TryRecvError::Empty);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -661,13 +695,60 @@ impl<T> Receiver<T> {
|
|||||||
// the slot lock.
|
// the slot lock.
|
||||||
drop(slot);
|
drop(slot);
|
||||||
|
|
||||||
let tail = self.shared.tail.lock().unwrap();
|
let mut tail = self.shared.tail.lock().unwrap();
|
||||||
|
|
||||||
// Acquire slot lock again
|
// Acquire slot lock again
|
||||||
slot = self.shared.buffer[idx].read().unwrap();
|
slot = self.shared.buffer[idx].read().unwrap();
|
||||||
|
|
||||||
// `tail.pos` points to the slot that the **next** send writes to. If
|
// Make sure the position did not change. This could happen in the
|
||||||
// the channel is closed, the previous slot is the oldest value.
|
// unlikely event that the buffer is wrapped between dropping the
|
||||||
|
// read lock and acquiring the tail lock.
|
||||||
|
if slot.pos != self.next {
|
||||||
|
let next_pos = slot.pos.wrapping_add(self.shared.buffer.len() as u64);
|
||||||
|
|
||||||
|
if next_pos == self.next {
|
||||||
|
// Store the waker
|
||||||
|
if let Some((waiter, waker)) = waiter {
|
||||||
|
// Safety: called while locked.
|
||||||
|
unsafe {
|
||||||
|
// Only queue if not already queued
|
||||||
|
waiter.with_mut(|ptr| {
|
||||||
|
// If there is no waker **or** if the currently
|
||||||
|
// stored waker references a **different** task,
|
||||||
|
// track the tasks' waker to be notified on
|
||||||
|
// receipt of a new value.
|
||||||
|
match (*ptr).waker {
|
||||||
|
Some(ref w) if w.will_wake(waker) => {}
|
||||||
|
_ => {
|
||||||
|
(*ptr).waker = Some(waker.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !(*ptr).queued {
|
||||||
|
(*ptr).queued = true;
|
||||||
|
tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(TryRecvError::Empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, the receiver has lagged behind the sender by
|
||||||
|
// more than the channel capacity. The receiver will attempt to
|
||||||
|
// catch up by skipping dropped messages and setting the
|
||||||
|
// internal cursor to the **oldest** message stored by the
|
||||||
|
// channel.
|
||||||
|
//
|
||||||
|
// However, finding the oldest position is a bit more
|
||||||
|
// complicated than `tail-position - buffer-size`. When
|
||||||
|
// the channel is closed, the tail position is incremented to
|
||||||
|
// signal a new `None` message, but `None` is not stored in the
|
||||||
|
// channel itself (see issue #2425 for why).
|
||||||
|
//
|
||||||
|
// To account for this, if the channel is closed, the tail
|
||||||
|
// position is decremented by `buffer-size + 1`.
|
||||||
let mut adjust = 0;
|
let mut adjust = 0;
|
||||||
if tail.closed {
|
if tail.closed {
|
||||||
adjust = 1
|
adjust = 1
|
||||||
@ -691,6 +772,7 @@ impl<T> Receiver<T> {
|
|||||||
|
|
||||||
return Err(TryRecvError::Lagged(missed));
|
return Err(TryRecvError::Lagged(missed));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
self.next = self.next.wrapping_add(1);
|
self.next = self.next.wrapping_add(1);
|
||||||
|
|
||||||
@ -746,22 +828,59 @@ where
|
|||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||||
let guard = self.recv_ref()?;
|
let guard = self.recv_ref(None)?;
|
||||||
guard.clone_value().ok_or(TryRecvError::Closed)
|
guard.clone_value().ok_or(TryRecvError::Closed)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)] // TODO: document
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since = "0.2.21", note = "use async fn recv()")]
|
||||||
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
|
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
|
||||||
if let Some(value) = ok_empty(self.try_recv())? {
|
use Poll::{Pending, Ready};
|
||||||
return Poll::Ready(Ok(value));
|
|
||||||
|
// The borrow checker prohibits calling `self.poll_ref` while passing in
|
||||||
|
// a mutable ref to a field (as it should). To work around this,
|
||||||
|
// `waiter` is first *removed* from `self` then `poll_recv` is called.
|
||||||
|
//
|
||||||
|
// However, for safety, we must ensure that `waiter` is **not** dropped.
|
||||||
|
// It could be contained in the intrusive linked list. The `Receiver`
|
||||||
|
// drop implementation handles cleanup.
|
||||||
|
//
|
||||||
|
// The guard pattern is used to ensure that, on return, even due to
|
||||||
|
// panic, the waiter node is replaced on `self`.
|
||||||
|
|
||||||
|
struct Guard<'a, T> {
|
||||||
|
waiter: Option<Pin<Box<UnsafeCell<Waiter>>>>,
|
||||||
|
receiver: &'a mut Receiver<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.register_waker(cx.waker());
|
impl<'a, T> Drop for Guard<'a, T> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.receiver.waiter = self.waiter.take();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(value) = ok_empty(self.try_recv())? {
|
let waiter = self.waiter.take().or_else(|| {
|
||||||
Poll::Ready(Ok(value))
|
Some(Box::pin(UnsafeCell::new(Waiter {
|
||||||
} else {
|
queued: false,
|
||||||
Poll::Pending
|
waker: None,
|
||||||
|
pointers: linked_list::Pointers::new(),
|
||||||
|
_p: PhantomPinned,
|
||||||
|
})))
|
||||||
|
});
|
||||||
|
|
||||||
|
let guard = Guard {
|
||||||
|
waiter,
|
||||||
|
receiver: self,
|
||||||
|
};
|
||||||
|
let res = guard
|
||||||
|
.receiver
|
||||||
|
.recv_ref(Some((&guard.waiter.as_ref().unwrap(), cx.waker())));
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Ok(guard) => Ready(guard.clone_value().ok_or(RecvError::Closed)),
|
||||||
|
Err(TryRecvError::Closed) => Ready(Err(RecvError::Closed)),
|
||||||
|
Err(TryRecvError::Lagged(n)) => Ready(Err(RecvError::Lagged(n))),
|
||||||
|
Err(TryRecvError::Empty) => Pending,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -830,44 +949,14 @@ where
|
|||||||
/// assert_eq!(30, rx.recv().await.unwrap());
|
/// assert_eq!(30, rx.recv().await.unwrap());
|
||||||
/// }
|
/// }
|
||||||
pub async fn recv(&mut self) -> Result<T, RecvError> {
|
pub async fn recv(&mut self) -> Result<T, RecvError> {
|
||||||
use crate::future::poll_fn;
|
let fut = Recv::<_, T>::new(Borrow(self));
|
||||||
|
fut.await
|
||||||
poll_fn(|cx| self.poll_recv(cx)).await
|
|
||||||
}
|
|
||||||
|
|
||||||
fn register_waker(&self, cx: &Waker) {
|
|
||||||
self.wait.waker.register_by_ref(cx);
|
|
||||||
|
|
||||||
if !self.wait.queued.load(SeqCst) {
|
|
||||||
// Set `queued` before queuing.
|
|
||||||
self.wait.queued.store(true, SeqCst);
|
|
||||||
|
|
||||||
let mut curr = self.shared.wait_stack.load(SeqCst);
|
|
||||||
|
|
||||||
// The ref count is decremented in `notify_rx` when all nodes are
|
|
||||||
// removed from the waiter stack.
|
|
||||||
let node = Arc::into_raw(self.wait.clone()) as *mut _;
|
|
||||||
|
|
||||||
loop {
|
|
||||||
// Safety: `queued == false` means the caller has exclusive
|
|
||||||
// access to `self.wait.next`.
|
|
||||||
self.wait.next.with_mut(|ptr| unsafe { *ptr = curr });
|
|
||||||
|
|
||||||
let res = self
|
|
||||||
.shared
|
|
||||||
.wait_stack
|
|
||||||
.compare_exchange(curr, node, SeqCst, SeqCst);
|
|
||||||
|
|
||||||
match res {
|
|
||||||
Ok(_) => return,
|
|
||||||
Err(actual) => curr = actual,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "stream")]
|
#[cfg(feature = "stream")]
|
||||||
|
#[doc(hidden)]
|
||||||
|
#[deprecated(since = "0.2.21", note = "use `into_stream()`")]
|
||||||
impl<T> crate::stream::Stream for Receiver<T>
|
impl<T> crate::stream::Stream for Receiver<T>
|
||||||
where
|
where
|
||||||
T: Clone,
|
T: Clone,
|
||||||
@ -878,6 +967,7 @@ where
|
|||||||
mut self: std::pin::Pin<&mut Self>,
|
mut self: std::pin::Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<T, RecvError>>> {
|
) -> Poll<Option<Result<T, RecvError>>> {
|
||||||
|
#[allow(deprecated)]
|
||||||
self.poll_recv(cx).map(|v| match v {
|
self.poll_recv(cx).map(|v| match v {
|
||||||
Ok(v) => Some(Ok(v)),
|
Ok(v) => Some(Ok(v)),
|
||||||
lag @ Err(RecvError::Lagged(_)) => Some(lag),
|
lag @ Err(RecvError::Lagged(_)) => Some(lag),
|
||||||
@ -890,13 +980,30 @@ impl<T> Drop for Receiver<T> {
|
|||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
let mut tail = self.shared.tail.lock().unwrap();
|
let mut tail = self.shared.tail.lock().unwrap();
|
||||||
|
|
||||||
|
if let Some(waiter) = &self.waiter {
|
||||||
|
// safety: tail lock is held
|
||||||
|
let queued = waiter.with(|ptr| unsafe { (*ptr).queued });
|
||||||
|
|
||||||
|
if queued {
|
||||||
|
// Remove the node
|
||||||
|
//
|
||||||
|
// safety: tail lock is held and the wait node is verified to be in
|
||||||
|
// the list.
|
||||||
|
unsafe {
|
||||||
|
waiter.with_mut(|ptr| {
|
||||||
|
tail.waiters.remove((&mut *ptr).into());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tail.rx_cnt -= 1;
|
tail.rx_cnt -= 1;
|
||||||
let until = tail.pos;
|
let until = tail.pos;
|
||||||
|
|
||||||
drop(tail);
|
drop(tail);
|
||||||
|
|
||||||
while self.next != until {
|
while self.next != until {
|
||||||
match self.recv_ref() {
|
match self.recv_ref(None) {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
// The channel is closed
|
// The channel is closed
|
||||||
Err(TryRecvError::Closed) => break,
|
Err(TryRecvError::Closed) => break,
|
||||||
@ -909,16 +1016,168 @@ impl<T> Drop for Receiver<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Drop for Shared<T> {
|
impl<R, T> Recv<R, T>
|
||||||
fn drop(&mut self) {
|
where
|
||||||
// Clear the wait stack
|
R: AsMut<Receiver<T>>,
|
||||||
let mut curr = self.wait_stack.with_mut(|ptr| *ptr as *const WaitNode);
|
{
|
||||||
|
fn new(receiver: R) -> Recv<R, T> {
|
||||||
while !curr.is_null() {
|
Recv {
|
||||||
let waiter = unsafe { Arc::from_raw(curr) };
|
receiver,
|
||||||
curr = waiter.next.with(|ptr| unsafe { *ptr });
|
waiter: UnsafeCell::new(Waiter {
|
||||||
|
queued: false,
|
||||||
|
waker: None,
|
||||||
|
pointers: linked_list::Pointers::new(),
|
||||||
|
_p: PhantomPinned,
|
||||||
|
}),
|
||||||
|
_p: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A custom `project` implementation is used in place of `pin-project-lite`
|
||||||
|
/// as a custom drop implementation is needed.
|
||||||
|
fn project(self: Pin<&mut Self>) -> (&mut Receiver<T>, &UnsafeCell<Waiter>) {
|
||||||
|
unsafe {
|
||||||
|
// Safety: Receiver is Unpin
|
||||||
|
is_unpin::<&mut Receiver<T>>();
|
||||||
|
|
||||||
|
let me = self.get_unchecked_mut();
|
||||||
|
(me.receiver.as_mut(), &me.waiter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, T> Future for Recv<R, T>
|
||||||
|
where
|
||||||
|
R: AsMut<Receiver<T>>,
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
type Output = Result<T, RecvError>;
|
||||||
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
|
||||||
|
let (receiver, waiter) = self.project();
|
||||||
|
|
||||||
|
let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(TryRecvError::Empty) => return Poll::Pending,
|
||||||
|
Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))),
|
||||||
|
Err(TryRecvError::Closed) => return Poll::Ready(Err(RecvError::Closed)),
|
||||||
|
};
|
||||||
|
|
||||||
|
Poll::Ready(guard.clone_value().ok_or(RecvError::Closed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg_stream! {
|
||||||
|
use futures_core::Stream;
|
||||||
|
|
||||||
|
impl<T: Clone> Receiver<T> {
|
||||||
|
/// Convert the receiver into a `Stream`.
|
||||||
|
///
|
||||||
|
/// The conversion allows using `Receiver` with APIs that require stream
|
||||||
|
/// values.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use tokio::stream::StreamExt;
|
||||||
|
/// use tokio::sync::broadcast;
|
||||||
|
///
|
||||||
|
/// #[tokio::main]
|
||||||
|
/// async fn main() {
|
||||||
|
/// let (tx, rx) = broadcast::channel(128);
|
||||||
|
///
|
||||||
|
/// tokio::spawn(async move {
|
||||||
|
/// for i in 0..10_i32 {
|
||||||
|
/// tx.send(i).unwrap();
|
||||||
|
/// }
|
||||||
|
/// });
|
||||||
|
///
|
||||||
|
/// // Streams must be pinned to iterate.
|
||||||
|
/// tokio::pin! {
|
||||||
|
/// let stream = rx
|
||||||
|
/// .into_stream()
|
||||||
|
/// .filter(Result::is_ok)
|
||||||
|
/// .map(Result::unwrap)
|
||||||
|
/// .filter(|v| v % 2 == 0)
|
||||||
|
/// .map(|v| v + 1);
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// while let Some(i) = stream.next().await {
|
||||||
|
/// println!("{}", i);
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
pub fn into_stream(self) -> impl Stream<Item = Result<T, RecvError>> {
|
||||||
|
Recv::new(Borrow(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, T: Clone> Stream for Recv<R, T>
|
||||||
|
where
|
||||||
|
R: AsMut<Receiver<T>>,
|
||||||
|
T: Clone,
|
||||||
|
{
|
||||||
|
type Item = Result<T, RecvError>;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
let (receiver, waiter) = self.project();
|
||||||
|
|
||||||
|
let guard = match receiver.recv_ref(Some((waiter, cx.waker()))) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(TryRecvError::Empty) => return Poll::Pending,
|
||||||
|
Err(TryRecvError::Lagged(n)) => return Poll::Ready(Some(Err(RecvError::Lagged(n)))),
|
||||||
|
Err(TryRecvError::Closed) => return Poll::Ready(None),
|
||||||
|
};
|
||||||
|
|
||||||
|
Poll::Ready(guard.clone_value().map(Ok))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R, T> Drop for Recv<R, T>
|
||||||
|
where
|
||||||
|
R: AsMut<Receiver<T>>,
|
||||||
|
{
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Acquire the tail lock. This is required for safety before accessing
|
||||||
|
// the waiter node.
|
||||||
|
let mut tail = self.receiver.as_mut().shared.tail.lock().unwrap();
|
||||||
|
|
||||||
|
// safety: tail lock is held
|
||||||
|
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
|
||||||
|
|
||||||
|
if queued {
|
||||||
|
// Remove the node
|
||||||
|
//
|
||||||
|
// safety: tail lock is held and the wait node is verified to be in
|
||||||
|
// the list.
|
||||||
|
unsafe {
|
||||||
|
self.waiter.with_mut(|ptr| {
|
||||||
|
tail.waiters.remove((&mut *ptr).into());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// `Waiter` is forced to be !Unpin.
|
||||||
|
unsafe impl linked_list::Link for Waiter {
|
||||||
|
type Handle = NonNull<Waiter>;
|
||||||
|
type Target = Waiter;
|
||||||
|
|
||||||
|
fn as_raw(handle: &NonNull<Waiter>) -> NonNull<Waiter> {
|
||||||
|
*handle
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn from_raw(ptr: NonNull<Waiter>) -> NonNull<Waiter> {
|
||||||
|
ptr
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn pointers(mut target: NonNull<Waiter>) -> NonNull<linked_list::Pointers<Waiter>> {
|
||||||
|
NonNull::from(&mut target.as_mut().pointers)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> fmt::Debug for Sender<T> {
|
impl<T> fmt::Debug for Sender<T> {
|
||||||
@ -952,15 +1211,6 @@ impl<'a, T> Drop for RecvGuard<'a, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ok_empty<T>(res: Result<T, TryRecvError>) -> Result<Option<T>, RecvError> {
|
|
||||||
match res {
|
|
||||||
Ok(value) => Ok(Some(value)),
|
|
||||||
Err(TryRecvError::Empty) => Ok(None),
|
|
||||||
Err(TryRecvError::Lagged(n)) => Err(RecvError::Lagged(n)),
|
|
||||||
Err(TryRecvError::Closed) => Err(RecvError::Closed),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for RecvError {
|
impl fmt::Display for RecvError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
@ -983,3 +1233,5 @@ impl fmt::Display for TryRecvError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for TryRecvError {}
|
impl std::error::Error for TryRecvError {}
|
||||||
|
|
||||||
|
fn is_unpin<T: Unpin>() {}
|
||||||
|
@ -90,10 +90,13 @@ fn send_two_recv() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn send_recv_stream() {
|
async fn send_recv_into_stream_ready() {
|
||||||
use tokio::stream::StreamExt;
|
use tokio::stream::StreamExt;
|
||||||
|
|
||||||
let (tx, mut rx) = broadcast::channel::<i32>(8);
|
let (tx, rx) = broadcast::channel::<i32>(8);
|
||||||
|
tokio::pin! {
|
||||||
|
let rx = rx.into_stream();
|
||||||
|
}
|
||||||
|
|
||||||
assert_ok!(tx.send(1));
|
assert_ok!(tx.send(1));
|
||||||
assert_ok!(tx.send(2));
|
assert_ok!(tx.send(2));
|
||||||
@ -106,6 +109,26 @@ async fn send_recv_stream() {
|
|||||||
assert_eq!(None, rx.next().await);
|
assert_eq!(None, rx.next().await);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_recv_into_stream_pending() {
|
||||||
|
use tokio::stream::StreamExt;
|
||||||
|
|
||||||
|
let (tx, rx) = broadcast::channel::<i32>(8);
|
||||||
|
|
||||||
|
tokio::pin! {
|
||||||
|
let rx = rx.into_stream();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut recv = task::spawn(rx.next());
|
||||||
|
assert_pending!(recv.poll());
|
||||||
|
|
||||||
|
assert_ok!(tx.send(1));
|
||||||
|
|
||||||
|
assert!(recv.is_woken());
|
||||||
|
let val = assert_ready!(recv.poll());
|
||||||
|
assert_eq!(val, Some(Ok(1)));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn send_recv_bounded() {
|
fn send_recv_bounded() {
|
||||||
let (tx, mut rx) = broadcast::channel(16);
|
let (tx, mut rx) = broadcast::channel(16);
|
||||||
@ -160,6 +183,23 @@ fn send_two_recv_bounded() {
|
|||||||
assert_eq!(val2, "world");
|
assert_eq!(val2, "world");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn change_tasks() {
|
||||||
|
let (tx, mut rx) = broadcast::channel(1);
|
||||||
|
|
||||||
|
let mut recv = Box::pin(rx.recv());
|
||||||
|
|
||||||
|
let mut task1 = task::spawn(&mut recv);
|
||||||
|
assert_pending!(task1.poll());
|
||||||
|
|
||||||
|
let mut task2 = task::spawn(&mut recv);
|
||||||
|
assert_pending!(task2.poll());
|
||||||
|
|
||||||
|
tx.send("hello").unwrap();
|
||||||
|
|
||||||
|
assert!(task2.is_woken());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn send_slow_rx() {
|
fn send_slow_rx() {
|
||||||
let (tx, mut rx1) = broadcast::channel(16);
|
let (tx, mut rx1) = broadcast::channel(16);
|
||||||
@ -451,6 +491,39 @@ fn lagging_receiver_recovers_after_wrap_open() {
|
|||||||
assert_empty!(rx);
|
assert_empty!(rx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_recv_stream_ready_deprecated() {
|
||||||
|
use tokio::stream::StreamExt;
|
||||||
|
|
||||||
|
let (tx, mut rx) = broadcast::channel::<i32>(8);
|
||||||
|
|
||||||
|
assert_ok!(tx.send(1));
|
||||||
|
assert_ok!(tx.send(2));
|
||||||
|
|
||||||
|
assert_eq!(Some(Ok(1)), rx.next().await);
|
||||||
|
assert_eq!(Some(Ok(2)), rx.next().await);
|
||||||
|
|
||||||
|
drop(tx);
|
||||||
|
|
||||||
|
assert_eq!(None, rx.next().await);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn send_recv_stream_pending_deprecated() {
|
||||||
|
use tokio::stream::StreamExt;
|
||||||
|
|
||||||
|
let (tx, mut rx) = broadcast::channel::<i32>(8);
|
||||||
|
|
||||||
|
let mut recv = task::spawn(rx.next());
|
||||||
|
assert_pending!(recv.poll());
|
||||||
|
|
||||||
|
assert_ok!(tx.send(1));
|
||||||
|
|
||||||
|
assert!(recv.is_woken());
|
||||||
|
let val = assert_ready!(recv.poll());
|
||||||
|
assert_eq!(val, Some(Ok(1)));
|
||||||
|
}
|
||||||
|
|
||||||
fn is_closed(err: broadcast::RecvError) -> bool {
|
fn is_closed(err: broadcast::RecvError) -> bool {
|
||||||
match err {
|
match err {
|
||||||
broadcast::RecvError::Closed => true,
|
broadcast::RecvError::Closed => true,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user