task: add join_all method to JoinSet (#6784)

Adds join_all method to JoinSet. join_all consumes JoinSet and awaits
the completion of all tasks on it, returning the results of the tasks in
a vec. An error or panic in the task will cause join_all to panic,
canceling all other tasks.

Fixes: #6664
This commit is contained in:
Havish Maka 2024-08-26 12:06:52 -04:00 committed by GitHub
parent 1ac8dff213
commit cc70a211ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 114 additions and 1 deletions

View File

@ -4,10 +4,10 @@
//! of spawned tasks and allows asynchronously awaiting the output of those
//! tasks as they complete. See the documentation for the [`JoinSet`] type for
//! details.
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, panic};
use crate::runtime::Handle;
#[cfg(tokio_unstable)]
@ -374,6 +374,79 @@ impl<T: 'static> JoinSet<T> {
while self.join_next().await.is_some() {}
}
/// Awaits the completion of all tasks in this `JoinSet`, returning a vector of their results.
///
/// The results will be stored in the order they completed not the order they were spawned.
/// This is a convenience method that is equivalent to calling [`join_next`] in
/// a loop. If any tasks on the `JoinSet` fail with an [`JoinError`], then this call
/// to `join_all` will panic and all remaining tasks on the `JoinSet` are
/// cancelled. To handle errors in any other way, manually call [`join_next`]
/// in a loop.
///
/// # Examples
///
/// Spawn multiple tasks and `join_all` them.
///
/// ```
/// use tokio::task::JoinSet;
/// use std::time::Duration;
///
/// #[tokio::main]
/// async fn main() {
/// let mut set = JoinSet::new();
///
/// for i in 0..3 {
/// set.spawn(async move {
/// tokio::time::sleep(Duration::from_secs(3 - i)).await;
/// i
/// });
/// }
///
/// let output = set.join_all().await;
/// assert_eq!(output, vec![2, 1, 0]);
/// }
/// ```
///
/// Equivalent implementation of `join_all`, using [`join_next`] and loop.
///
/// ```
/// use tokio::task::JoinSet;
/// use std::panic;
///
/// #[tokio::main]
/// async fn main() {
/// let mut set = JoinSet::new();
///
/// for i in 0..3 {
/// set.spawn(async move {i});
/// }
///
/// let mut output = Vec::new();
/// while let Some(res) = set.join_next().await{
/// match res {
/// Ok(t) => output.push(t),
/// Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
/// Err(err) => panic!("{err}"),
/// }
/// }
/// assert_eq!(output.len(),3);
/// }
/// ```
/// [`join_next`]: fn@Self::join_next
/// [`JoinError::id`]: fn@crate::task::JoinError::id
pub async fn join_all(mut self) -> Vec<T> {
let mut output = Vec::with_capacity(self.len());
while let Some(res) = self.join_next().await {
match res {
Ok(t) => output.push(t),
Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
Err(err) => panic!("{err}"),
}
}
output
}
/// Aborts all tasks on this `JoinSet`.
///
/// This does not remove the tasks from the `JoinSet`. To wait for the tasks to complete

View File

@ -156,6 +156,46 @@ fn runtime_gone() {
.is_cancelled());
}
#[tokio::test]
async fn join_all() {
let mut set: JoinSet<i32> = JoinSet::new();
for _ in 0..5 {
set.spawn(async { 1 });
}
let res: Vec<i32> = set.join_all().await;
assert_eq!(res.len(), 5);
for itm in res.into_iter() {
assert_eq!(itm, 1)
}
}
#[cfg(panic = "unwind")]
#[tokio::test(start_paused = true)]
async fn task_panics() {
let mut set: JoinSet<()> = JoinSet::new();
let (tx, mut rx) = oneshot::channel();
assert_eq!(set.len(), 0);
set.spawn(async move {
tokio::time::sleep(Duration::from_secs(2)).await;
tx.send(()).unwrap();
});
assert_eq!(set.len(), 1);
set.spawn(async {
tokio::time::sleep(Duration::from_secs(1)).await;
panic!();
});
assert_eq!(set.len(), 2);
let panic = tokio::spawn(set.join_all()).await.unwrap_err();
assert!(rx.try_recv().is_err());
assert!(panic.is_panic());
}
#[tokio::test(start_paused = true)]
async fn abort_all() {
let mut set: JoinSet<()> = JoinSet::new();