task: add JoinMap::keys (#6046)

Co-authored-by: Alice Ryhl <aliceryhl@google.com>
This commit is contained in:
Andrea Stedile 2023-10-15 16:47:41 +02:00 committed by GitHub
parent f9335b8186
commit f1e41a4ad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 68 additions and 1 deletions

View File

@ -5,6 +5,7 @@ use std::collections::hash_map::RandomState;
use std::fmt;
use std::future::Future;
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
use tokio::runtime::Handle;
use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
@ -626,6 +627,19 @@ where
}
}
/// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
///
/// If a task has completed, but its output hasn't yet been consumed by a
/// call to [`join_next`], this method will still return its key.
///
/// [`join_next`]: fn@Self::join_next
pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
JoinMapKeys {
iter: self.tasks_by_key.keys(),
_value: PhantomData,
}
}
/// Returns `true` if this `JoinMap` contains a task for the provided key.
///
/// If the task has completed, but its output hasn't yet been consumed by a
@ -859,3 +873,32 @@ impl<K: PartialEq> PartialEq for Key<K> {
}
impl<K: Eq> Eq for Key<K> {}
/// An iterator over the keys of a [`JoinMap`].
#[derive(Debug, Clone)]
pub struct JoinMapKeys<'a, K, V> {
iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
/// To make it easier to change JoinMap in the future, keep V as a generic
/// parameter.
_value: PhantomData<&'a V>,
}
impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
type Item = &'a K;
fn next(&mut self) -> Option<&'a K> {
self.iter.next().map(|key| &key.key)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
fn len(&self) -> usize {
self.iter.len()
}
}
impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}

View File

@ -9,4 +9,4 @@ pub use spawn_pinned::LocalPoolHandle;
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))]
pub use join_map::JoinMap;
pub use join_map::{JoinMap, JoinMapKeys};

View File

@ -109,6 +109,30 @@ async fn alternating() {
}
}
#[tokio::test]
async fn test_keys() {
use std::collections::HashSet;
let mut map = JoinMap::new();
assert_eq!(map.len(), 0);
map.spawn(1, async {});
assert_eq!(map.len(), 1);
map.spawn(2, async {});
assert_eq!(map.len(), 2);
let keys = map.keys().collect::<HashSet<&u32>>();
assert!(keys.contains(&1));
assert!(keys.contains(&2));
let _ = map.join_next().await.unwrap();
let _ = map.join_next().await.unwrap();
assert_eq!(map.len(), 0);
let keys = map.keys().collect::<HashSet<&u32>>();
assert!(keys.is_empty());
}
#[tokio::test(start_paused = true)]
async fn abort_by_key() {
let mut map = JoinMap::new();