diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index d48ac8083..0104bcc4c 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -47,6 +47,7 @@ futures-util = { version = "0.3.0", optional = true } log = "0.4" pin-project-lite = "0.2.0" slab = { version = "0.4.1", optional = true } # Backs `DelayQueue` +async-stream = "0.3.0" [dev-dependencies] tokio = { version = "1.0.0", features = ["full"] } @@ -55,8 +56,6 @@ tokio-test = { version = "0.4.0" } futures = "0.3.0" futures-test = "0.3.5" -async-stream = "0.3.0" - [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/tokio-util/src/sync/mod.rs b/tokio-util/src/sync/mod.rs index 159f12dbe..7a0637d32 100644 --- a/tokio-util/src/sync/mod.rs +++ b/tokio-util/src/sync/mod.rs @@ -4,3 +4,6 @@ mod cancellation_token; pub use cancellation_token::{CancellationToken, WaitForCancellationFuture}; mod intrusive_double_linked_list; + +mod poll_semaphore; +pub use poll_semaphore::PollSemaphore; diff --git a/tokio-util/src/sync/poll_semaphore.rs b/tokio-util/src/sync/poll_semaphore.rs new file mode 100644 index 000000000..6519bc663 --- /dev/null +++ b/tokio-util/src/sync/poll_semaphore.rs @@ -0,0 +1,85 @@ +use futures_core::Stream; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method. +/// +/// [`Semaphore`]: tokio::sync::Semaphore +pub struct PollSemaphore { + semaphore: Arc, + inner: Pin + Send + Sync>>, +} + +impl PollSemaphore { + /// Create a new `PollSemaphore`. + pub fn new(semaphore: Arc) -> Self { + Self { + semaphore: semaphore.clone(), + inner: Box::pin(async_stream::stream! { + loop { + match semaphore.clone().acquire_owned().await { + Ok(permit) => yield permit, + Err(_closed) => break, + } + } + }), + } + } + + /// Closes the semaphore. + pub fn close(&self) { + self.semaphore.close() + } + + /// Obtain a clone of the inner semaphore. + pub fn clone_inner(&self) -> Arc { + self.semaphore.clone() + } + + /// Get back the inner semaphore. + pub fn into_inner(self) -> Arc { + self.semaphore + } + + /// Poll to acquire a permit from the semaphore. + /// + /// This can return the following values: + /// + /// - `Poll::Pending` if a permit is not currently available. + /// - `Poll::Ready(Some(permit))` if a permit was acquired. + /// - `Poll::Ready(None)` if the semaphore has been closed. + /// + /// When this method returns `Poll::Pending`, the current task is scheduled + /// to receive a wakeup when a permit becomes available, or when the + /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.as_mut().poll_next(cx) + } +} + +impl Stream for PollSemaphore { + type Item = OwnedSemaphorePermit; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::into_inner(self).poll_acquire(cx) + } +} + +impl Clone for PollSemaphore { + fn clone(&self) -> PollSemaphore { + PollSemaphore::new(self.clone_inner()) + } +} + +impl fmt::Debug for PollSemaphore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PollSemaphore") + .field("semaphore", &self.semaphore) + .finish() + } +}