From cc70a211ad4ce71388c99e8af7480f3ddddbf602 Mon Sep 17 00:00:00 2001 From: Havish Maka <93401162+hmaka@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:06:52 -0400 Subject: [PATCH] task: add `join_all` method to `JoinSet` (#6784) Adds join_all method to JoinSet. join_all consumes JoinSet and awaits the completion of all tasks on it, returning the results of the tasks in a vec. An error or panic in the task will cause join_all to panic, canceling all other tasks. Fixes: #6664 --- tokio/src/task/join_set.rs | 75 +++++++++++++++++++++++++++++++++++- tokio/tests/task_join_set.rs | 40 +++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index 612e53add..a9cd8f52d 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -4,10 +4,10 @@ //! of spawned tasks and allows asynchronously awaiting the output of those //! tasks as they complete. See the documentation for the [`JoinSet`] type for //! details. -use std::fmt; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::{fmt, panic}; use crate::runtime::Handle; #[cfg(tokio_unstable)] @@ -374,6 +374,79 @@ impl JoinSet { while self.join_next().await.is_some() {} } + /// Awaits the completion of all tasks in this `JoinSet`, returning a vector of their results. + /// + /// The results will be stored in the order they completed not the order they were spawned. + /// This is a convenience method that is equivalent to calling [`join_next`] in + /// a loop. If any tasks on the `JoinSet` fail with an [`JoinError`], then this call + /// to `join_all` will panic and all remaining tasks on the `JoinSet` are + /// cancelled. To handle errors in any other way, manually call [`join_next`] + /// in a loop. + /// + /// # Examples + /// + /// Spawn multiple tasks and `join_all` them. + /// + /// ``` + /// use tokio::task::JoinSet; + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut set = JoinSet::new(); + /// + /// for i in 0..3 { + /// set.spawn(async move { + /// tokio::time::sleep(Duration::from_secs(3 - i)).await; + /// i + /// }); + /// } + /// + /// let output = set.join_all().await; + /// assert_eq!(output, vec![2, 1, 0]); + /// } + /// ``` + /// + /// Equivalent implementation of `join_all`, using [`join_next`] and loop. + /// + /// ``` + /// use tokio::task::JoinSet; + /// use std::panic; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut set = JoinSet::new(); + /// + /// for i in 0..3 { + /// set.spawn(async move {i}); + /// } + /// + /// let mut output = Vec::new(); + /// while let Some(res) = set.join_next().await{ + /// match res { + /// Ok(t) => output.push(t), + /// Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()), + /// Err(err) => panic!("{err}"), + /// } + /// } + /// assert_eq!(output.len(),3); + /// } + /// ``` + /// [`join_next`]: fn@Self::join_next + /// [`JoinError::id`]: fn@crate::task::JoinError::id + pub async fn join_all(mut self) -> Vec { + let mut output = Vec::with_capacity(self.len()); + + while let Some(res) = self.join_next().await { + match res { + Ok(t) => output.push(t), + Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()), + Err(err) => panic!("{err}"), + } + } + output + } + /// Aborts all tasks on this `JoinSet`. /// /// This does not remove the tasks from the `JoinSet`. To wait for the tasks to complete diff --git a/tokio/tests/task_join_set.rs b/tokio/tests/task_join_set.rs index e87135337..da0652627 100644 --- a/tokio/tests/task_join_set.rs +++ b/tokio/tests/task_join_set.rs @@ -156,6 +156,46 @@ fn runtime_gone() { .is_cancelled()); } +#[tokio::test] +async fn join_all() { + let mut set: JoinSet = JoinSet::new(); + + for _ in 0..5 { + set.spawn(async { 1 }); + } + let res: Vec = set.join_all().await; + + assert_eq!(res.len(), 5); + for itm in res.into_iter() { + assert_eq!(itm, 1) + } +} + +#[cfg(panic = "unwind")] +#[tokio::test(start_paused = true)] +async fn task_panics() { + let mut set: JoinSet<()> = JoinSet::new(); + + let (tx, mut rx) = oneshot::channel(); + assert_eq!(set.len(), 0); + + set.spawn(async move { + tokio::time::sleep(Duration::from_secs(2)).await; + tx.send(()).unwrap(); + }); + assert_eq!(set.len(), 1); + + set.spawn(async { + tokio::time::sleep(Duration::from_secs(1)).await; + panic!(); + }); + assert_eq!(set.len(), 2); + + let panic = tokio::spawn(set.join_all()).await.unwrap_err(); + assert!(rx.try_recv().is_err()); + assert!(panic.is_panic()); +} + #[tokio::test(start_paused = true)] async fn abort_all() { let mut set: JoinSet<()> = JoinSet::new();