sync: implement try_recv for mpsc channels (#4113)

This commit is contained in:
Alice Ryhl 2021-09-18 09:27:16 +02:00 committed by GitHub
parent 8e92f05795
commit ddd33f2b05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 349 additions and 8 deletions

View File

@ -1,6 +1,6 @@
use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TrySendError};
use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
cfg_time! {
use crate::sync::mpsc::error::SendTimeoutError;
@ -187,6 +187,46 @@ impl<T> Receiver<T> {
poll_fn(|cx| self.chan.recv(cx)).await
}
/// Try to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(100);
///
/// tx.send("hello").await.unwrap();
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
///
/// tx.send("hello").await.unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.chan.try_recv()
}
/// Blocking receive to call outside of asynchronous contexts.
///
/// This method returns `None` if the channel has been closed and there are

View File

@ -2,6 +2,9 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::future::AtomicWaker;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::Arc;
use crate::park::thread::CachedParkThread;
use crate::park::Park;
use crate::sync::mpsc::error::TryRecvError;
use crate::sync::mpsc::list;
use crate::sync::notify::Notify;
@ -263,6 +266,51 @@ impl<T, S: Semaphore> Rx<T, S> {
}
})
}
/// Try to receive the next value.
pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
use super::list::TryPopResult;
self.inner.rx_fields.with_mut(|rx_fields_ptr| {
let rx_fields = unsafe { &mut *rx_fields_ptr };
macro_rules! try_recv {
() => {
match rx_fields.list.try_pop(&self.inner.tx) {
TryPopResult::Ok(value) => {
self.inner.semaphore.add_permit();
return Ok(value);
}
TryPopResult::Closed => return Err(TryRecvError::Disconnected),
TryPopResult::Empty => return Err(TryRecvError::Empty),
TryPopResult::Busy => {} // fall through
}
};
}
try_recv!();
// If a previous `poll_recv` call has set a waker, we wake it here.
// This allows us to put our own CachedParkThread waker in the
// AtomicWaker slot instead.
//
// This is not a spurious wakeup to `poll_recv` since we just got a
// Busy from `try_pop`, which only happens if there are messages in
// the queue.
self.inner.rx_waker.wake();
// Park the thread until the problematic send has completed.
let mut park = CachedParkThread::new();
let waker = park.unpark().into_waker();
loop {
self.inner.rx_waker.register_by_ref(&waker);
// It is possible that the problematic send has now completed,
// so we have to check for messages again.
try_recv!();
park.park().expect("park failed");
}
})
}
}
impl<T, S: Semaphore> Drop for Rx<T, S> {

View File

@ -51,6 +51,30 @@ impl<T> From<SendError<T>> for TrySendError<T> {
}
}
// ===== TryRecvError =====
/// Error returned by `try_recv`.
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum TryRecvError {
/// This **channel** is currently empty, but the **Sender**(s) have not yet
/// disconnected, so data may yet become available.
Empty,
/// The **channel**'s sending half has become disconnected, and there will
/// never be any more data received on it.
Disconnected,
}
impl fmt::Display for TryRecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
TryRecvError::Empty => "receiving on an empty channel".fmt(fmt),
TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt),
}
}
}
impl Error for TryRecvError {}
// ===== RecvError =====
/// Error returned by `Receiver`.

View File

@ -13,23 +13,35 @@ pub(crate) struct Tx<T> {
/// Tail in the `Block` mpmc list.
block_tail: AtomicPtr<Block<T>>,
/// Position to push the next message. This reference a block and offset
/// Position to push the next message. This references a block and offset
/// into the block.
tail_position: AtomicUsize,
}
/// List queue receive handle
pub(crate) struct Rx<T> {
/// Pointer to the block being processed
/// Pointer to the block being processed.
head: NonNull<Block<T>>,
/// Next slot index to process
/// Next slot index to process.
index: usize,
/// Pointer to the next block pending release
/// Pointer to the next block pending release.
free_head: NonNull<Block<T>>,
}
/// Return value of `Rx::try_pop`.
pub(crate) enum TryPopResult<T> {
/// Successfully popped a value.
Ok(T),
/// The channel is empty.
Empty,
/// The channel is empty and closed.
Closed,
/// The channel is not empty, but the first value is being written.
Busy,
}
pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) {
// Create the initial block shared between the tx and rx halves.
let initial_block = Box::new(Block::new(0));
@ -218,7 +230,7 @@ impl<T> fmt::Debug for Tx<T> {
}
impl<T> Rx<T> {
/// Pops the next value off the queue
/// Pops the next value off the queue.
pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> {
// Advance `head`, if needed
if !self.try_advancing_head() {
@ -240,6 +252,26 @@ impl<T> Rx<T> {
}
}
/// Pops the next value off the queue, detecting whether the block
/// is busy or empty on failure.
///
/// This function exists because `Rx::pop` can return `None` even if the
/// channel's queue contains a message that has been completely written.
/// This can happen if the fully delivered message is behind another message
/// that is in the middle of being written to the block, since the channel
/// can't return the messages out of order.
pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> {
let tail_position = tx.tail_position.load(Acquire);
let result = self.pop(tx);
match result {
Some(block::Read::Value(t)) => TryPopResult::Ok(t),
Some(block::Read::Closed) => TryPopResult::Closed,
None if tail_position == self.index => TryPopResult::Empty,
None => TryPopResult::Busy,
}
}
/// Tries advancing the block pointer to the block referenced by `self.index`.
///
/// Returns `true` if successful, `false` if there is no next block to load.

View File

@ -1,6 +1,6 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::SendError;
use crate::sync::mpsc::error::{SendError, TryRecvError};
use std::fmt;
use std::task::{Context, Poll};
@ -129,6 +129,46 @@ impl<T> UnboundedReceiver<T> {
poll_fn(|cx| self.poll_recv(cx)).await
}
/// Try to receive the next value for this receiver.
///
/// This method returns the [`Empty`] error if the channel is currently
/// empty, but there are still outstanding [senders] or [permits].
///
/// This method returns the [`Disconnected`] error if the channel is
/// currently empty, and there are no outstanding [senders] or [permits].
///
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
/// [senders]: crate::sync::mpsc::Sender
/// [permits]: crate::sync::mpsc::Permit
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
/// use tokio::sync::mpsc::error::TryRecvError;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::unbounded_channel();
///
/// tx.send("hello").unwrap();
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
///
/// tx.send("hello").unwrap();
/// // Drop the last sender, closing the channel.
/// drop(tx);
///
/// assert_eq!(Ok("hello"), rx.try_recv());
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
/// }
/// ```
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.chan.try_recv()
}
/// Blocking receive to call outside of asynchronous contexts.
///
/// # Panics

View File

@ -132,3 +132,59 @@ fn dropping_unbounded_tx() {
assert!(v.is_none());
});
}
#[test]
fn try_recv() {
loom::model(|| {
use crate::sync::{mpsc, Semaphore};
use loom::sync::{Arc, Mutex};
const PERMITS: usize = 2;
const TASKS: usize = 2;
const CYCLES: usize = 1;
struct Context {
sem: Arc<Semaphore>,
tx: mpsc::Sender<()>,
rx: Mutex<mpsc::Receiver<()>>,
}
fn run(ctx: &Context) {
block_on(async {
let permit = ctx.sem.acquire().await;
assert_ok!(ctx.rx.lock().unwrap().try_recv());
crate::task::yield_now().await;
assert_ok!(ctx.tx.clone().try_send(()));
drop(permit);
});
}
let (tx, rx) = mpsc::channel(PERMITS);
let sem = Arc::new(Semaphore::new(PERMITS));
let ctx = Arc::new(Context {
sem,
tx,
rx: Mutex::new(rx),
});
for _ in 0..PERMITS {
assert_ok!(ctx.tx.clone().try_send(()));
}
let mut ths = Vec::new();
for _ in 0..TASKS {
let ctx = ctx.clone();
ths.push(thread::spawn(move || {
run(&ctx);
}));
}
run(&ctx);
for th in ths {
th.join().unwrap();
}
});
}

View File

@ -5,7 +5,7 @@
use std::thread;
use tokio::runtime::Runtime;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
use tokio_test::task;
use tokio_test::{
assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok,
@ -327,6 +327,27 @@ async fn try_send_fail() {
assert!(rx.recv().await.is_none());
}
#[tokio::test]
async fn try_send_fail_with_try_recv() {
let (tx, mut rx) = mpsc::channel(1);
tx.try_send("hello").unwrap();
// This should fail
match assert_err!(tx.try_send("fail")) {
TrySendError::Full(..) => {}
_ => panic!(),
}
assert_eq!(rx.try_recv(), Ok("hello"));
assert_ok!(tx.try_send("goodbye"));
drop(tx);
assert_eq!(rx.try_recv(), Ok("goodbye"));
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
#[tokio::test]
async fn try_reserve_fails() {
let (tx, mut rx) = mpsc::channel(1);
@ -494,3 +515,83 @@ async fn permit_available_not_acquired_close() {
drop(permit2);
assert!(rx.recv().await.is_none());
}
#[test]
fn try_recv_bounded() {
let (tx, mut rx) = mpsc::channel(5);
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
assert!(tx.try_send("hello").is_err());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
assert_eq!(Ok("hello"), rx.try_recv());
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
assert!(tx.try_send("hello").is_err());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
tx.try_send("hello").unwrap();
drop(tx);
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Ok("hello"), rx.try_recv());
assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
}
#[test]
fn try_recv_unbounded() {
for num in 0..100 {
let (tx, mut rx) = mpsc::unbounded_channel();
for i in 0..num {
tx.send(i).unwrap();
}
for i in 0..num {
assert_eq!(rx.try_recv(), Ok(i));
}
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
}
#[test]
fn try_recv_close_while_empty_bounded() {
let (tx, mut rx) = mpsc::channel::<()>(5);
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
drop(tx);
assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
}
#[test]
fn try_recv_close_while_empty_unbounded() {
let (tx, mut rx) = mpsc::unbounded_channel::<()>();
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
drop(tx);
assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
}