diff --git a/tokio-util/src/task/join_queue.rs b/tokio-util/src/task/join_queue.rs new file mode 100644 index 000000000..744b9cfb8 --- /dev/null +++ b/tokio-util/src/task/join_queue.rs @@ -0,0 +1,349 @@ +use super::AbortOnDropHandle; +use std::{ + collections::VecDeque, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + runtime::Handle, + task::{AbortHandle, Id, JoinError, JoinHandle}, +}; + +/// A FIFO queue for of tasks spawned on a Tokio runtime. +/// +/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO +/// order. That is, if tasks are spawned in the order A, B, C, then +/// awaiting the next completed task will always return A first, then B, +/// then C, regardless of the order in which the tasks actually complete. +/// +/// All of the tasks must have the same return type `T`. +/// +/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are +/// immediately aborted. +#[derive(Debug)] +pub struct JoinQueue(VecDeque>); + +impl JoinQueue { + /// Create a new empty [`JoinQueue`]. + pub const fn new() -> Self { + Self(VecDeque::new()) + } + + /// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks. + pub fn with_capacity(capacity: usize) -> Self { + Self(VecDeque::with_capacity(capacity)) + } + + /// Returns the number of tasks currently in the [`JoinQueue`]. + /// + /// This includes both tasks that are currently running and tasks that have + /// completed but not yet been removed from the queue because outputting of + /// them waits for FIFO order. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns whether the [`JoinQueue`] is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`] + /// that can be used to remotely cancel the task. + /// + /// The provided future will start running in the background immediately + /// when this method is called, even if you don't await anything on this + /// [`JoinQueue`]. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`AbortHandle`]: tokio::task::AbortHandle + #[track_caller] + pub fn spawn(&mut self, task: F) -> AbortHandle + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.push_back(tokio::spawn(task)) + } + + /// Spawn the provided task on the provided runtime and store it in this + /// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. + /// + /// The provided future will start running in the background immediately + /// when this method is called, even if you don't await anything on this + /// [`JoinQueue`]. + /// + /// [`AbortHandle`]: tokio::task::AbortHandle + #[track_caller] + pub fn spawn_on(&mut self, task: F, handle: &Handle) -> AbortHandle + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.push_back(handle.spawn(task)) + } + + /// Spawn the provided task on the current [`LocalSet`] and store it in this + /// [`JoinQueue`], returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. + /// + /// The provided future will start running in the background immediately + /// when this method is called, even if you don't await anything on this + /// [`JoinQueue`]. + /// + /// # Panics + /// + /// This method panics if it is called outside of a `LocalSet`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`AbortHandle`]: tokio::task::AbortHandle + #[track_caller] + pub fn spawn_local(&mut self, task: F) -> AbortHandle + where + F: Future + 'static, + T: 'static, + { + self.push_back(tokio::task::spawn_local(task)) + } + + /// Spawn the blocking code on the blocking threadpool and store + /// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be + /// used to remotely cancel the task. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`AbortHandle`]: tokio::task::AbortHandle + #[track_caller] + pub fn spawn_blocking(&mut self, f: F) -> AbortHandle + where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, + { + self.push_back(tokio::task::spawn_blocking(f)) + } + + /// Spawn the blocking code on the blocking threadpool of the + /// provided runtime and store it in this [`JoinQueue`], returning an + /// [`AbortHandle`] that can be used to remotely cancel the task. + /// + /// [`AbortHandle`]: tokio::task::AbortHandle + #[track_caller] + pub fn spawn_blocking_on(&mut self, f: F, handle: &Handle) -> AbortHandle + where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, + { + self.push_back(handle.spawn_blocking(f)) + } + + fn push_back(&mut self, jh: JoinHandle) -> AbortHandle { + let jh = AbortOnDropHandle::new(jh); + let abort_handle = jh.abort_handle(); + self.0.push_back(jh); + abort_handle + } + + /// Waits until the next task in FIFO order completes and returns its output. + /// + /// Returns `None` if the queue is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!` + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this [`JoinQueue`]. + pub async fn join_next(&mut self) -> Option> { + std::future::poll_fn(|cx| self.poll_join_next(cx)).await + } + + /// Waits until the next task in FIFO order completes and returns its output, + /// along with the [task ID] of the completed task. + /// + /// Returns `None` if the queue is empty. + /// + /// When this method returns an error, then the id of the task that failed can be accessed + /// using the [`JoinError::id`] method. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!` + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this [`JoinQueue`]. + /// + /// [task ID]: tokio::task::Id + /// [`JoinError::id`]: fn@tokio::task::JoinError::id + pub async fn join_next_with_id(&mut self) -> Option> { + std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await + } + + /// Aborts all tasks and waits for them to finish shutting down. + /// + /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in + /// a loop until it returns `None`. + /// + /// This method ignores any panics in the tasks shutting down. When this call returns, the + /// [`JoinQueue`] will be empty. + /// + /// [`abort_all`]: fn@Self::abort_all + /// [`join_next`]: fn@Self::join_next + pub async fn shutdown(&mut self) { + self.abort_all(); + while self.join_next().await.is_some() {} + } + + /// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results. + /// + /// The results will be stored in the order they were spawned, not the order they completed. + /// This is a convenience method that is equivalent to calling [`join_next`] in + /// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call + /// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are + /// cancelled. To handle errors in any other way, manually call [`join_next`] + /// in a loop. + /// + /// # Cancel Safety + /// + /// This method is not cancel safe as it calls `join_next` in a loop. If you need + /// cancel safety, manually call `join_next` in a loop with `Vec` accumulator. + /// + /// [`join_next`]: fn@Self::join_next + /// [`JoinError::id`]: fn@tokio::task::JoinError::id + pub async fn join_all(mut self) -> Vec { + let mut output = Vec::with_capacity(self.len()); + + while let Some(res) = self.join_next().await { + match res { + Ok(t) => output.push(t), + Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()), + Err(err) => panic!("{err}"), + } + } + output + } + + /// Aborts all tasks on this [`JoinQueue`]. + /// + /// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete + /// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty. + pub fn abort_all(&mut self) { + self.0.iter().for_each(|jh| jh.abort()); + } + + /// Removes all tasks from this [`JoinQueue`] without aborting them. + /// + /// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`] + /// is dropped. + pub fn detach_all(&mut self) { + self.0.drain(..).for_each(|jh| drop(jh.detach())); + } + + /// Polls for the next task in [`JoinQueue`] to complete. + /// + /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled + /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to + /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is + /// available right now. + /// * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed. + /// The `value` is the return value that task. + /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been + /// aborted. The `err` is the `JoinError` from the panicked/aborted task. + /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty. + pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + let jh = match self.0.front_mut() { + None => return Poll::Ready(None), + Some(jh) => jh, + }; + if let Poll::Ready(res) = Pin::new(jh).poll(cx) { + // Use `detach` to avoid calling `abort` on a task that has already completed. + // Dropping `AbortOnDropHandle` would abort the task, but since it is finished, + // we only need to drop the `JoinHandle` for cleanup. + drop(self.0.pop_front().unwrap().detach()); + Poll::Ready(Some(res)) + } else { + Poll::Pending + } + } + + /// Polls for the next task in [`JoinQueue`] to complete. + /// + /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled + /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to + /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is + /// available right now. + /// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed. + /// The `value` is the return value that task, and `id` is its [task ID]. + /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been + /// aborted. The `err` is the `JoinError` from the panicked/aborted task. + /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty. + /// + /// [task ID]: tokio::task::Id + pub fn poll_join_next_with_id( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let jh = match self.0.front_mut() { + None => return Poll::Ready(None), + Some(jh) => jh, + }; + if let Poll::Ready(res) = Pin::new(jh).poll(cx) { + // Use `detach` to avoid calling `abort` on a task that has already completed. + // Dropping `AbortOnDropHandle` would abort the task, but since it is finished, + // we only need to drop the `JoinHandle` for cleanup. + let jh = self.0.pop_front().unwrap().detach(); + let id = jh.id(); + drop(jh); + // If the task succeeded, add the task ID to the output. Otherwise, the + // `JoinError` will already have the task's ID. + Poll::Ready(Some(res.map(|output| (id, output)))) + } else { + Poll::Pending + } + } +} + +impl Default for JoinQueue { + fn default() -> Self { + Self::new() + } +} + +/// Collect an iterator of futures into a [`JoinQueue`]. +/// +/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator. +impl std::iter::FromIterator for JoinQueue +where + F: Future + Send + 'static, + T: Send + 'static, +{ + fn from_iter>(iter: I) -> Self { + let mut set = Self::new(); + iter.into_iter().for_each(|task| { + set.spawn(task); + }); + set + } +} diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs index 7635b74f4..3c35ebff4 100644 --- a/tokio-util/src/task/mod.rs +++ b/tokio-util/src/task/mod.rs @@ -13,6 +13,9 @@ cfg_rt! { mod abort_on_drop; pub use abort_on_drop::AbortOnDropHandle; + + mod join_queue; + pub use join_queue::JoinQueue; } #[cfg(feature = "join-map")] diff --git a/tokio-util/tests/task_join_queue.rs b/tokio-util/tests/task_join_queue.rs new file mode 100644 index 000000000..6b23aa2fd --- /dev/null +++ b/tokio-util/tests/task_join_queue.rs @@ -0,0 +1,223 @@ +#![warn(rust_2018_idioms)] + +use tokio::sync::oneshot; +use tokio::task::yield_now; +use tokio::time::Duration; +use tokio_test::{assert_pending, assert_ready, task}; +use tokio_util::task::JoinQueue; + +#[tokio::test] +async fn test_join_queue_no_spurious_wakeups() { + let (tx, rx) = oneshot::channel::<()>(); + let mut join_queue = JoinQueue::new(); + join_queue.spawn(async move { + let _ = rx.await; + 42 + }); + + let mut join_next = task::spawn(join_queue.join_next()); + + assert_pending!(join_next.poll()); + + assert!(!join_next.is_woken()); + + let _ = tx.send(()); + yield_now().await; + + assert!(join_next.is_woken()); + + let output = assert_ready!(join_next.poll()); + assert_eq!(output.unwrap().unwrap(), 42); +} + +#[tokio::test] +async fn test_join_queue_abort_on_drop() { + let mut queue = JoinQueue::new(); + + let mut recvs = Vec::new(); + + for _ in 0..16 { + let (send, recv) = oneshot::channel::<()>(); + recvs.push(recv); + + queue.spawn(async move { + // This task will never complete on its own. + futures::future::pending::<()>().await; + drop(send); + }); + } + + drop(queue); + + for recv in recvs { + // The task is aborted soon and we will receive an error. + assert!(recv.await.is_err()); + } +} + +#[tokio::test] +async fn test_join_queue_alternating() { + let mut queue = JoinQueue::new(); + + assert_eq!(queue.len(), 0); + queue.spawn(async {}); + assert_eq!(queue.len(), 1); + queue.spawn(async {}); + assert_eq!(queue.len(), 2); + + for _ in 0..16 { + let res = queue.join_next().await.unwrap(); + assert!(res.is_ok()); + assert_eq!(queue.len(), 1); + queue.spawn(async {}); + assert_eq!(queue.len(), 2); + } +} + +#[tokio::test(start_paused = true)] +async fn test_join_queue_abort_all() { + let mut queue: JoinQueue<()> = JoinQueue::new(); + + for _ in 0..5 { + queue.spawn(futures::future::pending()); + } + for _ in 0..5 { + queue.spawn(async { + tokio::time::sleep(Duration::from_secs(1)).await; + }); + } + + // The join queue will now have 5 pending tasks and 5 ready tasks. + tokio::time::sleep(Duration::from_secs(2)).await; + + queue.abort_all(); + assert_eq!(queue.len(), 10); + + let mut count = 0; + while let Some(res) = queue.join_next().await { + if count < 5 { + assert!(res.unwrap_err().is_cancelled()); + } else { + assert!(res.is_ok()); + } + count += 1; + } + assert_eq!(count, 10); + assert!(queue.is_empty()); +} + +#[tokio::test] +async fn test_join_queue_join_all() { + let mut queue = JoinQueue::new(); + let mut senders = Vec::new(); + for i in 0..5 { + let (tx, rx) = oneshot::channel::<()>(); + senders.push(tx); + queue.spawn(async move { + let _ = rx.await; + i + }); + } + // Complete all tasks in reverse order + while let Some(tx) = senders.pop() { + let _ = tx.send(()); + } + let results = queue.join_all().await; + assert_eq!(results, vec![0, 1, 2, 3, 4]); +} + +#[tokio::test] +async fn test_join_queue_shutdown() { + let mut queue = JoinQueue::new(); + let mut senders = Vec::new(); + + for _ in 0..5 { + let (tx, rx) = oneshot::channel::<()>(); + senders.push(tx); + queue.spawn(async move { + let _ = rx.await; + }); + } + + queue.shutdown().await; + assert!(queue.is_empty()); + while let Some(tx) = senders.pop() { + assert!(tx.is_closed()); + } +} + +#[tokio::test] +async fn test_join_queue_with_manual_abort() { + let mut queue = JoinQueue::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + let mut senders = Vec::new(); + for i in 0..16 { + let (tx, rx) = oneshot::channel::<()>(); + senders.push(tx); + let abort = queue.spawn(async move { + let _ = rx.await; + i + }); + + if i % 2 != 0 { + // abort odd-numbered tasks. + abort.abort(); + } + } + // Complete all tasks in reverse order + while let Some(tx) = senders.pop() { + let _ = tx.send(()); + } + while let Some(res) = queue.join_next().await { + match res { + Ok(res) => { + assert_eq!(res, num_completed * 2); + num_completed += 1; + } + Err(e) => { + assert!(e.is_cancelled()); + num_canceled += 1; + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[tokio::test] +async fn test_join_queue_join_next_with_id() { + const TASK_NUM: u32 = 1000; + + let (send, recv) = tokio::sync::watch::channel(()); + + let mut set = JoinQueue::new(); + let mut spawned = Vec::with_capacity(TASK_NUM as usize); + + for _ in 0..TASK_NUM { + let mut recv = recv.clone(); + let handle = set.spawn(async move { recv.changed().await.unwrap() }); + + spawned.push(handle.id()); + } + drop(recv); + + send.send_replace(()); + send.closed().await; + + let mut count = 0; + let mut joined = Vec::with_capacity(TASK_NUM as usize); + while let Some(res) = set.join_next_with_id().await { + match res { + Ok((id, ())) => { + count += 1; + joined.push(id); + } + Err(err) => panic!("failed: {err}"), + } + } + + assert_eq!(count, TASK_NUM); + assert_eq!(joined, spawned); +}