mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-28 12:10:37 +00:00
sync: implement try_recv for mpsc channels (#4113)
This commit is contained in:
parent
8e92f05795
commit
ddd33f2b05
@ -1,6 +1,6 @@
|
||||
use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
|
||||
use crate::sync::mpsc::chan;
|
||||
use crate::sync::mpsc::error::{SendError, TrySendError};
|
||||
use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
|
||||
|
||||
cfg_time! {
|
||||
use crate::sync::mpsc::error::SendTimeoutError;
|
||||
@ -187,6 +187,46 @@ impl<T> Receiver<T> {
|
||||
poll_fn(|cx| self.chan.recv(cx)).await
|
||||
}
|
||||
|
||||
/// Try to receive the next value for this receiver.
|
||||
///
|
||||
/// This method returns the [`Empty`] error if the channel is currently
|
||||
/// empty, but there are still outstanding [senders] or [permits].
|
||||
///
|
||||
/// This method returns the [`Disconnected`] error if the channel is
|
||||
/// currently empty, and there are no outstanding [senders] or [permits].
|
||||
///
|
||||
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
|
||||
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
|
||||
/// [senders]: crate::sync::mpsc::Sender
|
||||
/// [permits]: crate::sync::mpsc::Permit
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use tokio::sync::mpsc;
|
||||
/// use tokio::sync::mpsc::error::TryRecvError;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, mut rx) = mpsc::channel(100);
|
||||
///
|
||||
/// tx.send("hello").await.unwrap();
|
||||
///
|
||||
/// assert_eq!(Ok("hello"), rx.try_recv());
|
||||
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
|
||||
///
|
||||
/// tx.send("hello").await.unwrap();
|
||||
/// // Drop the last sender, closing the channel.
|
||||
/// drop(tx);
|
||||
///
|
||||
/// assert_eq!(Ok("hello"), rx.try_recv());
|
||||
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
|
||||
/// }
|
||||
/// ```
|
||||
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||
self.chan.try_recv()
|
||||
}
|
||||
|
||||
/// Blocking receive to call outside of asynchronous contexts.
|
||||
///
|
||||
/// This method returns `None` if the channel has been closed and there are
|
||||
|
@ -2,6 +2,9 @@ use crate::loom::cell::UnsafeCell;
|
||||
use crate::loom::future::AtomicWaker;
|
||||
use crate::loom::sync::atomic::AtomicUsize;
|
||||
use crate::loom::sync::Arc;
|
||||
use crate::park::thread::CachedParkThread;
|
||||
use crate::park::Park;
|
||||
use crate::sync::mpsc::error::TryRecvError;
|
||||
use crate::sync::mpsc::list;
|
||||
use crate::sync::notify::Notify;
|
||||
|
||||
@ -263,6 +266,51 @@ impl<T, S: Semaphore> Rx<T, S> {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to receive the next value.
|
||||
pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||
use super::list::TryPopResult;
|
||||
|
||||
self.inner.rx_fields.with_mut(|rx_fields_ptr| {
|
||||
let rx_fields = unsafe { &mut *rx_fields_ptr };
|
||||
|
||||
macro_rules! try_recv {
|
||||
() => {
|
||||
match rx_fields.list.try_pop(&self.inner.tx) {
|
||||
TryPopResult::Ok(value) => {
|
||||
self.inner.semaphore.add_permit();
|
||||
return Ok(value);
|
||||
}
|
||||
TryPopResult::Closed => return Err(TryRecvError::Disconnected),
|
||||
TryPopResult::Empty => return Err(TryRecvError::Empty),
|
||||
TryPopResult::Busy => {} // fall through
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
try_recv!();
|
||||
|
||||
// If a previous `poll_recv` call has set a waker, we wake it here.
|
||||
// This allows us to put our own CachedParkThread waker in the
|
||||
// AtomicWaker slot instead.
|
||||
//
|
||||
// This is not a spurious wakeup to `poll_recv` since we just got a
|
||||
// Busy from `try_pop`, which only happens if there are messages in
|
||||
// the queue.
|
||||
self.inner.rx_waker.wake();
|
||||
|
||||
// Park the thread until the problematic send has completed.
|
||||
let mut park = CachedParkThread::new();
|
||||
let waker = park.unpark().into_waker();
|
||||
loop {
|
||||
self.inner.rx_waker.register_by_ref(&waker);
|
||||
// It is possible that the problematic send has now completed,
|
||||
// so we have to check for messages again.
|
||||
try_recv!();
|
||||
park.park().expect("park failed");
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S: Semaphore> Drop for Rx<T, S> {
|
||||
|
@ -51,6 +51,30 @@ impl<T> From<SendError<T>> for TrySendError<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// ===== TryRecvError =====
|
||||
|
||||
/// Error returned by `try_recv`.
|
||||
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
|
||||
pub enum TryRecvError {
|
||||
/// This **channel** is currently empty, but the **Sender**(s) have not yet
|
||||
/// disconnected, so data may yet become available.
|
||||
Empty,
|
||||
/// The **channel**'s sending half has become disconnected, and there will
|
||||
/// never be any more data received on it.
|
||||
Disconnected,
|
||||
}
|
||||
|
||||
impl fmt::Display for TryRecvError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
TryRecvError::Empty => "receiving on an empty channel".fmt(fmt),
|
||||
TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for TryRecvError {}
|
||||
|
||||
// ===== RecvError =====
|
||||
|
||||
/// Error returned by `Receiver`.
|
||||
|
@ -13,23 +13,35 @@ pub(crate) struct Tx<T> {
|
||||
/// Tail in the `Block` mpmc list.
|
||||
block_tail: AtomicPtr<Block<T>>,
|
||||
|
||||
/// Position to push the next message. This reference a block and offset
|
||||
/// Position to push the next message. This references a block and offset
|
||||
/// into the block.
|
||||
tail_position: AtomicUsize,
|
||||
}
|
||||
|
||||
/// List queue receive handle
|
||||
pub(crate) struct Rx<T> {
|
||||
/// Pointer to the block being processed
|
||||
/// Pointer to the block being processed.
|
||||
head: NonNull<Block<T>>,
|
||||
|
||||
/// Next slot index to process
|
||||
/// Next slot index to process.
|
||||
index: usize,
|
||||
|
||||
/// Pointer to the next block pending release
|
||||
/// Pointer to the next block pending release.
|
||||
free_head: NonNull<Block<T>>,
|
||||
}
|
||||
|
||||
/// Return value of `Rx::try_pop`.
|
||||
pub(crate) enum TryPopResult<T> {
|
||||
/// Successfully popped a value.
|
||||
Ok(T),
|
||||
/// The channel is empty.
|
||||
Empty,
|
||||
/// The channel is empty and closed.
|
||||
Closed,
|
||||
/// The channel is not empty, but the first value is being written.
|
||||
Busy,
|
||||
}
|
||||
|
||||
pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) {
|
||||
// Create the initial block shared between the tx and rx halves.
|
||||
let initial_block = Box::new(Block::new(0));
|
||||
@ -218,7 +230,7 @@ impl<T> fmt::Debug for Tx<T> {
|
||||
}
|
||||
|
||||
impl<T> Rx<T> {
|
||||
/// Pops the next value off the queue
|
||||
/// Pops the next value off the queue.
|
||||
pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> {
|
||||
// Advance `head`, if needed
|
||||
if !self.try_advancing_head() {
|
||||
@ -240,6 +252,26 @@ impl<T> Rx<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Pops the next value off the queue, detecting whether the block
|
||||
/// is busy or empty on failure.
|
||||
///
|
||||
/// This function exists because `Rx::pop` can return `None` even if the
|
||||
/// channel's queue contains a message that has been completely written.
|
||||
/// This can happen if the fully delivered message is behind another message
|
||||
/// that is in the middle of being written to the block, since the channel
|
||||
/// can't return the messages out of order.
|
||||
pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> {
|
||||
let tail_position = tx.tail_position.load(Acquire);
|
||||
let result = self.pop(tx);
|
||||
|
||||
match result {
|
||||
Some(block::Read::Value(t)) => TryPopResult::Ok(t),
|
||||
Some(block::Read::Closed) => TryPopResult::Closed,
|
||||
None if tail_position == self.index => TryPopResult::Empty,
|
||||
None => TryPopResult::Busy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries advancing the block pointer to the block referenced by `self.index`.
|
||||
///
|
||||
/// Returns `true` if successful, `false` if there is no next block to load.
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::loom::sync::atomic::AtomicUsize;
|
||||
use crate::sync::mpsc::chan;
|
||||
use crate::sync::mpsc::error::SendError;
|
||||
use crate::sync::mpsc::error::{SendError, TryRecvError};
|
||||
|
||||
use std::fmt;
|
||||
use std::task::{Context, Poll};
|
||||
@ -129,6 +129,46 @@ impl<T> UnboundedReceiver<T> {
|
||||
poll_fn(|cx| self.poll_recv(cx)).await
|
||||
}
|
||||
|
||||
/// Try to receive the next value for this receiver.
|
||||
///
|
||||
/// This method returns the [`Empty`] error if the channel is currently
|
||||
/// empty, but there are still outstanding [senders] or [permits].
|
||||
///
|
||||
/// This method returns the [`Disconnected`] error if the channel is
|
||||
/// currently empty, and there are no outstanding [senders] or [permits].
|
||||
///
|
||||
/// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty
|
||||
/// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected
|
||||
/// [senders]: crate::sync::mpsc::Sender
|
||||
/// [permits]: crate::sync::mpsc::Permit
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use tokio::sync::mpsc;
|
||||
/// use tokio::sync::mpsc::error::TryRecvError;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
///
|
||||
/// tx.send("hello").unwrap();
|
||||
///
|
||||
/// assert_eq!(Ok("hello"), rx.try_recv());
|
||||
/// assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
|
||||
///
|
||||
/// tx.send("hello").unwrap();
|
||||
/// // Drop the last sender, closing the channel.
|
||||
/// drop(tx);
|
||||
///
|
||||
/// assert_eq!(Ok("hello"), rx.try_recv());
|
||||
/// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
|
||||
/// }
|
||||
/// ```
|
||||
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
|
||||
self.chan.try_recv()
|
||||
}
|
||||
|
||||
/// Blocking receive to call outside of asynchronous contexts.
|
||||
///
|
||||
/// # Panics
|
||||
|
@ -132,3 +132,59 @@ fn dropping_unbounded_tx() {
|
||||
assert!(v.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_recv() {
|
||||
loom::model(|| {
|
||||
use crate::sync::{mpsc, Semaphore};
|
||||
use loom::sync::{Arc, Mutex};
|
||||
|
||||
const PERMITS: usize = 2;
|
||||
const TASKS: usize = 2;
|
||||
const CYCLES: usize = 1;
|
||||
|
||||
struct Context {
|
||||
sem: Arc<Semaphore>,
|
||||
tx: mpsc::Sender<()>,
|
||||
rx: Mutex<mpsc::Receiver<()>>,
|
||||
}
|
||||
|
||||
fn run(ctx: &Context) {
|
||||
block_on(async {
|
||||
let permit = ctx.sem.acquire().await;
|
||||
assert_ok!(ctx.rx.lock().unwrap().try_recv());
|
||||
crate::task::yield_now().await;
|
||||
assert_ok!(ctx.tx.clone().try_send(()));
|
||||
drop(permit);
|
||||
});
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel(PERMITS);
|
||||
let sem = Arc::new(Semaphore::new(PERMITS));
|
||||
let ctx = Arc::new(Context {
|
||||
sem,
|
||||
tx,
|
||||
rx: Mutex::new(rx),
|
||||
});
|
||||
|
||||
for _ in 0..PERMITS {
|
||||
assert_ok!(ctx.tx.clone().try_send(()));
|
||||
}
|
||||
|
||||
let mut ths = Vec::new();
|
||||
|
||||
for _ in 0..TASKS {
|
||||
let ctx = ctx.clone();
|
||||
|
||||
ths.push(thread::spawn(move || {
|
||||
run(&ctx);
|
||||
}));
|
||||
}
|
||||
|
||||
run(&ctx);
|
||||
|
||||
for th in ths {
|
||||
th.join().unwrap();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
use std::thread;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
|
||||
use tokio_test::task;
|
||||
use tokio_test::{
|
||||
assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok,
|
||||
@ -327,6 +327,27 @@ async fn try_send_fail() {
|
||||
assert!(rx.recv().await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn try_send_fail_with_try_recv() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
tx.try_send("hello").unwrap();
|
||||
|
||||
// This should fail
|
||||
match assert_err!(tx.try_send("fail")) {
|
||||
TrySendError::Full(..) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
|
||||
assert_eq!(rx.try_recv(), Ok("hello"));
|
||||
|
||||
assert_ok!(tx.try_send("goodbye"));
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(rx.try_recv(), Ok("goodbye"));
|
||||
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn try_reserve_fails() {
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
@ -494,3 +515,83 @@ async fn permit_available_not_acquired_close() {
|
||||
drop(permit2);
|
||||
assert!(rx.recv().await.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_recv_bounded() {
|
||||
let (tx, mut rx) = mpsc::channel(5);
|
||||
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
assert!(tx.try_send("hello").is_err());
|
||||
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
|
||||
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
assert!(tx.try_send("hello").is_err());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
|
||||
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
tx.try_send("hello").unwrap();
|
||||
drop(tx);
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Ok("hello"), rx.try_recv());
|
||||
assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_recv_unbounded() {
|
||||
for num in 0..100 {
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
|
||||
for i in 0..num {
|
||||
tx.send(i).unwrap();
|
||||
}
|
||||
|
||||
for i in 0..num {
|
||||
assert_eq!(rx.try_recv(), Ok(i));
|
||||
}
|
||||
|
||||
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
|
||||
drop(tx);
|
||||
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_recv_close_while_empty_bounded() {
|
||||
let (tx, mut rx) = mpsc::channel::<()>(5);
|
||||
|
||||
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
|
||||
drop(tx);
|
||||
assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_recv_close_while_empty_unbounded() {
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<()>();
|
||||
|
||||
assert_eq!(Err(TryRecvError::Empty), rx.try_recv());
|
||||
drop(tx);
|
||||
assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user