sync: add Sender::{try_,}reserve_many (#6205)

This commit is contained in:
Théodore Prévot 2024-01-02 17:34:56 +01:00 committed by GitHub
parent 2d2faf6014
commit 7c606ab44a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 425 additions and 28 deletions

View File

@ -71,7 +71,7 @@ pub struct AcquireError(());
pub(crate) struct Acquire<'a> {
node: Waiter,
semaphore: &'a Semaphore,
num_permits: u32,
num_permits: usize,
queued: bool,
}
@ -262,13 +262,13 @@ impl Semaphore {
self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
}
pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> {
pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
assert!(
num_permits as usize <= Self::MAX_PERMITS,
num_permits <= Self::MAX_PERMITS,
"a semaphore may not have more than MAX_PERMITS permits ({})",
Self::MAX_PERMITS
);
let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
let num_permits = num_permits << Self::PERMIT_SHIFT;
let mut curr = self.permits.load(Acquire);
loop {
// Has the semaphore closed?
@ -293,7 +293,7 @@ impl Semaphore {
}
}
pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> {
pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> {
Acquire::new(self, num_permits)
}
@ -371,7 +371,7 @@ impl Semaphore {
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: u32,
num_permits: usize,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
@ -380,7 +380,7 @@ impl Semaphore {
let needed = if queued {
node.state.load(Acquire) << Self::PERMIT_SHIFT
} else {
(num_permits as usize) << Self::PERMIT_SHIFT
num_permits << Self::PERMIT_SHIFT
};
let mut lock = None;
@ -506,12 +506,12 @@ impl fmt::Debug for Semaphore {
impl Waiter {
fn new(
num_permits: u32,
num_permits: usize,
#[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx,
) -> Self {
Waiter {
waker: UnsafeCell::new(None),
state: AtomicUsize::new(num_permits as usize),
state: AtomicUsize::new(num_permits),
pointers: linked_list::Pointers::new(),
#[cfg(all(tokio_unstable, feature = "tracing"))]
ctx,
@ -591,7 +591,7 @@ impl Future for Acquire<'_> {
}
impl<'a> Acquire<'a> {
fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self {
fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self {
#[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
return Self {
node: Waiter::new(num_permits),
@ -635,14 +635,14 @@ impl<'a> Acquire<'a> {
});
}
fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) {
fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) {
fn is_unpin<T: Unpin>() {}
unsafe {
// Safety: all fields other than `node` are `Unpin`
is_unpin::<&Semaphore>();
is_unpin::<&mut bool>();
is_unpin::<u32>();
is_unpin::<usize>();
let this = self.get_unchecked_mut();
(
@ -673,7 +673,7 @@ impl Drop for Acquire<'_> {
// Safety: we have locked the wait list.
unsafe { waiters.queue.remove(node) };
let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire);
let acquired_permits = self.num_permits - self.node.state.load(Acquire);
if acquired_permits > 0 {
self.semaphore.add_permits_locked(acquired_permits, waiters);
}

View File

@ -68,6 +68,18 @@ pub struct Permit<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
}
/// An [`Iterator`] of [`Permit`] that can be used to hold `n` slots in the channel.
///
/// `PermitIterator` values are returned by [`Sender::reserve_many()`] and [`Sender::try_reserve_many()`]
/// and are used to guarantee channel capacity before generating `n` messages to send.
///
/// [`Sender::reserve_many()`]: Sender::reserve_many
/// [`Sender::try_reserve_many()`]: Sender::try_reserve_many
pub struct PermitIterator<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
n: usize,
}
/// Owned permit to send one value into the channel.
///
/// This is identical to the [`Permit`] type, except that it moves the sender
@ -926,10 +938,74 @@ impl<T> Sender<T> {
/// }
/// ```
pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(Permit { chan: &self.chan })
}
/// Waits for channel capacity. Once capacity to send `n` messages is
/// available, it is reserved for the caller.
///
/// If the channel is full or if there are fewer than `n` permits available, the function waits
/// for the number of unreceived messages to become `n` less than the channel capacity.
/// Capacity to send `n` message is then reserved for the caller.
///
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`try_reserve_many`] except it awaits for the slots to become available.
///
/// If the channel is closed, the function returns a [`SendError`].
///
/// Dropping [`PermitIterator`] without consuming it entirely releases the remaining
/// permits back to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`Permit`]: Permit
/// [`send`]: Permit::send
/// [`try_reserve_many`]: Sender::try_reserve_many
///
/// # Cancel safety
///
/// This channel uses a queue to ensure that calls to `send` and `reserve_many`
/// complete in the order they were requested. Cancelling a call to
/// `reserve_many` makes you lose your place in the queue.
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.reserve_many(2).await.unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
/// }
/// ```
pub async fn reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, SendError<()>> {
self.reserve_inner(n).await?;
Ok(PermitIterator {
chan: &self.chan,
n,
})
}
/// Waits for channel capacity, moving the `Sender` and returning an owned
/// permit. Once capacity to send one message is available, it is reserved
/// for the caller.
@ -1011,16 +1087,19 @@ impl<T> Sender<T> {
/// [`send`]: OwnedPermit::send
/// [`Arc::clone`]: std::sync::Arc::clone
pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(OwnedPermit {
chan: Some(self.chan),
})
}
async fn reserve_inner(&self) -> Result<(), SendError<()>> {
async fn reserve_inner(&self, n: usize) -> Result<(), SendError<()>> {
crate::trace::async_trace_leaf().await;
match self.chan.semaphore().semaphore.acquire(1).await {
if n > self.max_capacity() {
return Err(SendError(()));
}
match self.chan.semaphore().semaphore.acquire(n).await {
Ok(()) => Ok(()),
Err(_) => Err(SendError(())),
}
@ -1079,6 +1158,91 @@ impl<T> Sender<T> {
Ok(Permit { chan: &self.chan })
}
/// Tries to acquire `n` slots in the channel without waiting for the slot to become
/// available.
///
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`reserve_many`] except it does not await for the slots to become available.
///
/// If there are fewer than `n` permits available on the channel, then
/// this function will return a [`TrySendError::Full`]. If the channel is closed
/// this function will return a [`TrySendError::Closed`].
///
/// Dropping [`PermitIterator`] without consuming it entirely releases the remaining
/// permits back to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`send`]: Permit::send
/// [`reserve_many`]: Sender::reserve_many
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.try_reserve_many(2).unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Trying to reserve an additional slot on the `tx` will
/// // fail because there is no capacity.
/// assert!(tx.try_reserve().is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
///
/// // Trying to call try_reserve_many with 0 will return an empty iterator
/// let mut permit = tx.try_reserve_many(0).unwrap();
/// assert!(permit.next().is_none());
///
/// // Trying to call try_reserve_many with a number greater than the channel
/// // capacity will return an error
/// let permit = tx.try_reserve_many(3);
/// assert!(permit.is_err());
///
/// // Trying to call try_reserve_many on a closed channel will return an error
/// drop(rx);
/// let permit = tx.try_reserve_many(1);
/// assert!(permit.is_err());
///
/// let permit = tx.try_reserve_many(0);
/// assert!(permit.is_err());
/// }
/// ```
pub fn try_reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, TrySendError<()>> {
if n > self.max_capacity() {
return Err(TrySendError::Full(()));
}
match self.chan.semaphore().semaphore.try_acquire(n) {
Ok(()) => {}
Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())),
Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())),
}
Ok(PermitIterator {
chan: &self.chan,
n,
})
}
/// Tries to acquire a slot in the channel without waiting for the slot to become
/// available, returning an owned permit.
///
@ -1355,6 +1519,58 @@ impl<T> fmt::Debug for Permit<'_, T> {
}
}
// ===== impl PermitIterator =====
impl<'a, T> Iterator for PermitIterator<'a, T> {
type Item = Permit<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.n == 0 {
return None;
}
self.n -= 1;
Some(Permit { chan: self.chan })
}
fn size_hint(&self) -> (usize, Option<usize>) {
let n = self.n;
(n, Some(n))
}
}
impl<T> ExactSizeIterator for PermitIterator<'_, T> {}
impl<T> std::iter::FusedIterator for PermitIterator<'_, T> {}
impl<T> Drop for PermitIterator<'_, T> {
fn drop(&mut self) {
use chan::Semaphore;
if self.n == 0 {
return;
}
let semaphore = self.chan.semaphore();
// Add the remaining permits back to the semaphore
semaphore.add_permits(self.n);
// If this is the last sender for this channel, wake the receiver so
// that it can be notified that the channel is closed.
if semaphore.is_closed() && semaphore.is_idle() {
self.chan.wake_rx();
}
}
}
impl<T> fmt::Debug for PermitIterator<'_, T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("PermitIterator")
.field("chan", &self.chan)
.field("capacity", &self.n)
.finish()
}
}
// ===== impl Permit =====
impl<T> OwnedPermit<T> {

View File

@ -95,7 +95,9 @@
pub(super) mod block;
mod bounded;
pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender};
pub use self::bounded::{
channel, OwnedPermit, Permit, PermitIterator, Receiver, Sender, WeakSender,
};
mod chan;

View File

@ -772,7 +772,7 @@ impl<T: ?Sized> RwLock<T> {
/// ```
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
let acquire_fut = async {
self.s.acquire(self.mr).await.unwrap_or_else(|_| {
self.s.acquire(self.mr as usize).await.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
unreachable!()
@ -907,7 +907,7 @@ impl<T: ?Sized> RwLock<T> {
let resource_span = self.resource_span.clone();
let acquire_fut = async {
self.s.acquire(self.mr).await.unwrap_or_else(|_| {
self.s.acquire(self.mr as usize).await.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
unreachable!()
@ -971,7 +971,7 @@ impl<T: ?Sized> RwLock<T> {
/// }
/// ```
pub fn try_write(&self) -> Result<RwLockWriteGuard<'_, T>, TryLockError> {
match self.s.try_acquire(self.mr) {
match self.s.try_acquire(self.mr as usize) {
Ok(permit) => permit,
Err(TryAcquireError::NoPermits) => return Err(TryLockError(())),
Err(TryAcquireError::Closed) => unreachable!(),
@ -1029,7 +1029,7 @@ impl<T: ?Sized> RwLock<T> {
/// }
/// ```
pub fn try_write_owned(self: Arc<Self>) -> Result<OwnedRwLockWriteGuard<T>, TryLockError> {
match self.s.try_acquire(self.mr) {
match self.s.try_acquire(self.mr as usize) {
Ok(permit) => permit,
Err(TryAcquireError::NoPermits) => return Err(TryLockError(())),
Err(TryAcquireError::Closed) => unreachable!(),

View File

@ -565,7 +565,7 @@ impl Semaphore {
pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
trace::async_op(
|| self.ll_sem.acquire(n),
|| self.ll_sem.acquire(n as usize),
self.resource_span.clone(),
"Semaphore::acquire_many",
"poll",
@ -574,7 +574,7 @@ impl Semaphore {
.await?;
#[cfg(not(all(tokio_unstable, feature = "tracing")))]
self.ll_sem.acquire(n).await?;
self.ll_sem.acquire(n as usize).await?;
Ok(SemaphorePermit {
sem: self,
@ -646,7 +646,7 @@ impl Semaphore {
/// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits
/// [`SemaphorePermit`]: crate::sync::SemaphorePermit
pub fn try_acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, TryAcquireError> {
match self.ll_sem.try_acquire(n) {
match self.ll_sem.try_acquire(n as usize) {
Ok(()) => Ok(SemaphorePermit {
sem: self,
permits: n,
@ -764,14 +764,14 @@ impl Semaphore {
) -> Result<OwnedSemaphorePermit, AcquireError> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let inner = trace::async_op(
|| self.ll_sem.acquire(n),
|| self.ll_sem.acquire(n as usize),
self.resource_span.clone(),
"Semaphore::acquire_many_owned",
"poll",
true,
);
#[cfg(not(all(tokio_unstable, feature = "tracing")))]
let inner = self.ll_sem.acquire(n);
let inner = self.ll_sem.acquire(n as usize);
inner.await?;
Ok(OwnedSemaphorePermit {
@ -855,7 +855,7 @@ impl Semaphore {
self: Arc<Self>,
n: u32,
) -> Result<OwnedSemaphorePermit, TryAcquireError> {
match self.ll_sem.try_acquire(n) {
match self.ll_sem.try_acquire(n as usize) {
Ok(()) => Ok(OwnedSemaphorePermit {
sem: self,
permits: n,

View File

@ -522,6 +522,79 @@ async fn try_send_fail_with_try_recv() {
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
}
#[maybe_tokio_test]
async fn reserve_many_above_cap() {
const MAX_PERMITS: usize = tokio::sync::Semaphore::MAX_PERMITS;
let (tx, _rx) = mpsc::channel::<()>(1);
assert_err!(tx.reserve_many(2).await);
assert_err!(tx.reserve_many(MAX_PERMITS + 1).await);
assert_err!(tx.reserve_many(usize::MAX).await);
}
#[test]
fn try_reserve_many_zero() {
let (tx, rx) = mpsc::channel::<()>(1);
// Succeeds when not closed.
assert!(assert_ok!(tx.try_reserve_many(0)).next().is_none());
// Even when channel is full.
tx.try_send(()).unwrap();
assert!(assert_ok!(tx.try_reserve_many(0)).next().is_none());
drop(rx);
// Closed error when closed.
assert_eq!(
assert_err!(tx.try_reserve_many(0)),
TrySendError::Closed(())
);
}
#[maybe_tokio_test]
async fn reserve_many_zero() {
let (tx, rx) = mpsc::channel::<()>(1);
// Succeeds when not closed.
assert!(assert_ok!(tx.reserve_many(0).await).next().is_none());
// Even when channel is full.
tx.send(()).await.unwrap();
assert!(assert_ok!(tx.reserve_many(0).await).next().is_none());
drop(rx);
// Closed error when closed.
assert_err!(tx.reserve_many(0).await);
}
#[maybe_tokio_test]
async fn try_reserve_many_edge_cases() {
const MAX_PERMITS: usize = tokio::sync::Semaphore::MAX_PERMITS;
let (tx, rx) = mpsc::channel::<()>(1);
let mut permit = assert_ok!(tx.try_reserve_many(0));
assert!(permit.next().is_none());
let permit = tx.try_reserve_many(MAX_PERMITS + 1);
match assert_err!(permit) {
TrySendError::Full(..) => {}
_ => panic!(),
}
let permit = tx.try_reserve_many(usize::MAX);
match assert_err!(permit) {
TrySendError::Full(..) => {}
_ => panic!(),
}
// Dropping the receiver should close the channel
drop(rx);
assert_err!(tx.reserve_many(0).await);
}
#[maybe_tokio_test]
async fn try_reserve_fails() {
let (tx, mut rx) = mpsc::channel(1);
@ -545,6 +618,87 @@ async fn try_reserve_fails() {
let _permit = tx.try_reserve().unwrap();
}
#[maybe_tokio_test]
async fn reserve_many_and_send() {
let (tx, mut rx) = mpsc::channel(100);
for i in 0..100 {
for permit in assert_ok!(tx.reserve_many(i).await) {
permit.send("foo");
assert_eq!(rx.recv().await, Some("foo"));
}
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
}
#[maybe_tokio_test]
async fn try_reserve_many_and_send() {
let (tx, mut rx) = mpsc::channel(100);
for i in 0..100 {
for permit in assert_ok!(tx.try_reserve_many(i)) {
permit.send("foo");
assert_eq!(rx.recv().await, Some("foo"));
}
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
}
}
#[maybe_tokio_test]
async fn reserve_many_on_closed_channel() {
let (tx, rx) = mpsc::channel::<()>(100);
drop(rx);
assert_err!(tx.reserve_many(10).await);
}
#[maybe_tokio_test]
async fn try_reserve_many_on_closed_channel() {
let (tx, rx) = mpsc::channel::<usize>(100);
drop(rx);
match assert_err!(tx.try_reserve_many(10)) {
TrySendError::Closed(()) => {}
_ => panic!(),
};
}
#[maybe_tokio_test]
async fn try_reserve_many_full() {
// Reserve n capacity and send k messages
for n in 1..100 {
for k in 0..n {
let (tx, mut rx) = mpsc::channel::<usize>(n);
let permits = assert_ok!(tx.try_reserve_many(n));
assert_eq!(permits.len(), n);
assert_eq!(tx.capacity(), 0);
match assert_err!(tx.try_reserve_many(1)) {
TrySendError::Full(..) => {}
_ => panic!(),
};
for permit in permits.take(k) {
permit.send(0);
}
// We only used k permits on the n reserved
assert_eq!(tx.capacity(), n - k);
// We can reserve more permits
assert_ok!(tx.try_reserve_many(1));
// But not more than the current capacity
match assert_err!(tx.try_reserve_many(n - k + 1)) {
TrySendError::Full(..) => {}
_ => panic!(),
};
for _i in 0..k {
assert_eq!(rx.recv().await, Some(0));
}
// Now that we've received everything, capacity should be back to n
assert_eq!(tx.capacity(), n);
}
}
}
#[tokio::test]
#[cfg(feature = "full")]
async fn drop_permit_releases_permit() {
@ -564,6 +718,30 @@ async fn drop_permit_releases_permit() {
assert_ready_ok!(reserve2.poll());
}
#[maybe_tokio_test]
async fn drop_permit_iterator_releases_permits() {
// poll_ready reserves capacity, ensure that the capacity is released if tx
// is dropped w/o sending a value.
for n in 1..100 {
let (tx1, _rx) = mpsc::channel::<i32>(n);
let tx2 = tx1.clone();
let permits = assert_ok!(tx1.reserve_many(n).await);
let mut reserve2 = tokio_test::task::spawn(tx2.reserve_many(n));
assert_pending!(reserve2.poll());
drop(permits);
assert!(reserve2.is_woken());
let permits = assert_ready_ok!(reserve2.poll());
drop(permits);
assert_eq!(tx1.capacity(), n);
}
}
#[maybe_tokio_test]
async fn dropping_rx_closes_channel() {
let (tx, rx) = mpsc::channel(100);
@ -573,6 +751,7 @@ async fn dropping_rx_closes_channel() {
drop(rx);
assert_err!(tx.reserve().await);
assert_err!(tx.reserve_many(10).await);
assert_eq!(1, Arc::strong_count(&msg));
}