diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index a2b407459..a952729b5 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -990,6 +990,27 @@ impl<'a> SemaphorePermit<'a> { self.permits += other.permits; other.permits = 0; } + + /// Splits `n` permits from `self` and returns a new [`SemaphorePermit`] instance that holds `n` permits. + /// + /// If there are insufficient permits and it's not possible to reduce by `n`, returns `None`. + pub fn split(&mut self, n: u32) -> Option { + if n > self.permits { + return None; + } + + self.permits -= n; + + Some(Self { + sem: self.sem, + permits: n, + }) + } + + /// Returns the number of permits held by `self`. + pub fn num_permits(&self) -> u32 { + self.permits + } } impl OwnedSemaphorePermit { @@ -1019,10 +1040,35 @@ impl OwnedSemaphorePermit { other.permits = 0; } + /// Splits `n` permits from `self` and returns a new [`OwnedSemaphorePermit`] instance that holds `n` permits. + /// + /// If there are insufficient permits and it's not possible to reduce by `n`, returns `None`. + /// + /// # Note + /// + /// It will clone the owned `Arc` to construct the new instance. + pub fn split(&mut self, n: u32) -> Option { + if n > self.permits { + return None; + } + + self.permits -= n; + + Some(Self { + sem: self.sem.clone(), + permits: n, + }) + } + /// Returns the [`Semaphore`] from which this permit was acquired. pub fn semaphore(&self) -> &Arc { &self.sem } + + /// Returns the number of permits held by `self`. + pub fn num_permits(&self) -> u32 { + self.permits + } } impl Drop for SemaphorePermit<'_> { diff --git a/tokio/tests/sync_semaphore.rs b/tokio/tests/sync_semaphore.rs index 40a5a0802..ab4b316ce 100644 --- a/tokio/tests/sync_semaphore.rs +++ b/tokio/tests/sync_semaphore.rs @@ -88,6 +88,32 @@ fn merge_unrelated_permits() { p1.merge(p2); } +#[test] +fn split() { + let sem = Semaphore::new(5); + let mut p1 = sem.try_acquire_many(3).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 3); + let mut p2 = p1.split(1).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 2); + assert_eq!(p2.num_permits(), 1); + let p3 = p1.split(0).unwrap(); + assert_eq!(p3.num_permits(), 0); + drop(p1); + assert_eq!(sem.available_permits(), 4); + let p4 = p2.split(1).unwrap(); + assert_eq!(p2.num_permits(), 0); + assert_eq!(p4.num_permits(), 1); + assert!(p2.split(1).is_none()); + drop(p2); + assert_eq!(sem.available_permits(), 4); + drop(p3); + assert_eq!(sem.available_permits(), 4); + drop(p4); + assert_eq!(sem.available_permits(), 5); +} + #[tokio::test] #[cfg(feature = "full")] async fn stress_test() { diff --git a/tokio/tests/sync_semaphore_owned.rs b/tokio/tests/sync_semaphore_owned.rs index d4b12d40e..f9eeee0cf 100644 --- a/tokio/tests/sync_semaphore_owned.rs +++ b/tokio/tests/sync_semaphore_owned.rs @@ -114,6 +114,32 @@ fn merge_unrelated_permits() { p1.merge(p2) } +#[test] +fn split() { + let sem = Arc::new(Semaphore::new(5)); + let mut p1 = sem.clone().try_acquire_many_owned(3).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 3); + let mut p2 = p1.split(1).unwrap(); + assert_eq!(sem.available_permits(), 2); + assert_eq!(p1.num_permits(), 2); + assert_eq!(p2.num_permits(), 1); + let p3 = p1.split(0).unwrap(); + assert_eq!(p3.num_permits(), 0); + drop(p1); + assert_eq!(sem.available_permits(), 4); + let p4 = p2.split(1).unwrap(); + assert_eq!(p2.num_permits(), 0); + assert_eq!(p4.num_permits(), 1); + assert!(p2.split(1).is_none()); + drop(p2); + assert_eq!(sem.available_permits(), 4); + drop(p3); + assert_eq!(sem.available_permits(), 4); + drop(p4); + assert_eq!(sem.available_permits(), 5); +} + #[tokio::test] #[cfg(feature = "full")] async fn stress_test() {