task: add AbortHandle type for cancelling tasks in a JoinSet (#4530)

## Motivation

Before we stabilize the `JoinSet` API, we intend to add a method for
individual tasks in the `JoinSet` to be aborted. Because the
`JoinHandle`s for the tasks spawned on a `JoinSet` are owned by the
`JoinSet`, the user can no longer use them to abort tasks on the
`JoinSet`. Therefore, we need another way to cause a remote abort of a
task on a `JoinSet` without holding its `JoinHandle`.

## Solution

This branch adds a new `AbortHandle` type in `tokio::task`, which
represents the owned permission to remotely cancel a task, but _not_ to
await its output. The `AbortHandle` type holds an additional reference
to the task cell.

A crate-private method is added to `JoinHandle` that returns an
`AbortHandle` for the same task, incrementing its ref count.
`AbortHandle` provides a single method, `AbortHandle::abort(self)`, that
remotely cancels the task. Dropping an `AbortHandle` decrements the
task's ref count but does not cancel it. The `AbortHandle` type is
currently marked as unstable.

The spawning methods on `JoinSet` are modified to return an
`AbortHandle` that can be used to cancel the spawned task.

## Future Work

- Currently, the `AbortHandle` type is _only_ available in the public
API through a `JoinSet`. We could also make the
`JoinHandle::abort_handle` method public, to allow users to use the
`AbortHandle` type in other contexts. I didn't do that in this PR,
because I wanted to make the API addition as minimal as possible, but we
could make this method public later.

- Currently, `AbortHandle` is not `Clone`. We could easily make it
`Clone` by incrementing the task's ref count. Since this adds more trait
impls to the API, we may want to be cautious about this, but I see no
obvious reason we would need to remove a `Clone` implementation if one
was added...

- There's been some discussion of adding a `JoinMap` type that allows
aborting tasks by key, and manages a hash map of keys to `AbortHandle`s,
and removes the tasks from the map when they complete. This would make
aborting by key much easier, since the user wouldn't have to worry about
keeping the state of the map of abort handles and the tasks actually
active on the `JoinSet` in sync. After thinking about it a bit, I
thought this is probably best as a `tokio-util` API --- it can currently
be implemented in `tokio-util` with the APIs added in `tokio` in this
PR.

- I noticed while working on this that `JoinSet::join_one` and
`JoinSet::poll_join_one` return a cancelled `JoinError` when a task is
cancelled. I'm not sure if I love this behavior --- it seems like it
would be nicer to just skip cancelled tasks and continue polling. But,
there are currently tests that expect a cancelled `JoinError` to be
returned for each cancelled task, so I didn't want to change it in
_this_ PR. I think this is worth revisiting before stabilizing the API,
though?

Signed-off-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
Eliza Weisman 2022-02-24 13:08:23 -08:00 committed by GitHub
parent dfac73d580
commit 8e0e56fdf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 328 additions and 37 deletions

View File

@ -0,0 +1,69 @@
use crate::runtime::task::RawTask;
use std::fmt;
use std::panic::{RefUnwindSafe, UnwindSafe};
/// An owned permission to abort a spawned task, without awaiting its completion.
///
/// Unlike a [`JoinHandle`], an `AbortHandle` does *not* represent the
/// permission to await the task's completion, only to terminate it.
///
/// The task may be aborted by calling the [`AbortHandle::abort`] method.
/// Dropping an `AbortHandle` releases the permission to terminate the task
/// --- it does *not* abort the task.
///
/// **Note**: This is an [unstable API][unstable]. The public API of this type
/// may break in 1.x releases. See [the documentation on unstable
/// features][unstable] for details.
///
/// [unstable]: crate#unstable-features
/// [`JoinHandle`]: crate::task::JoinHandle
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
pub struct AbortHandle {
raw: Option<RawTask>,
}
impl AbortHandle {
pub(super) fn new(raw: Option<RawTask>) -> Self {
Self { raw }
}
/// Abort the task associated with the handle.
///
/// Awaiting a cancelled task might complete as usual if the task was
/// already completed at the time it was cancelled, but most likely it
/// will fail with a [cancelled] `JoinError`.
///
/// If the task was already cancelled, such as by [`JoinHandle::abort`],
/// this method will do nothing.
///
/// [cancelled]: method@super::error::JoinError::is_cancelled
// the `AbortHandle` type is only publicly exposed when `tokio_unstable` is
// enabled, but it is still defined for testing purposes.
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
pub fn abort(self) {
if let Some(raw) = self.raw {
raw.remote_abort();
}
}
}
unsafe impl Send for AbortHandle {}
unsafe impl Sync for AbortHandle {}
impl UnwindSafe for AbortHandle {}
impl RefUnwindSafe for AbortHandle {}
impl fmt::Debug for AbortHandle {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AbortHandle").finish()
}
}
impl Drop for AbortHandle {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
raw.drop_abort_handle();
}
}
}

View File

@ -210,6 +210,16 @@ impl<T> JoinHandle<T> {
}
}
}
/// Returns a new `AbortHandle` that can be used to remotely abort this task.
#[cfg(any(tokio_unstable, test))]
pub(crate) fn abort_handle(&self) -> super::AbortHandle {
let raw = self.raw.map(|raw| {
raw.ref_inc();
raw
});
super::AbortHandle::new(raw)
}
}
impl<T> Unpin for JoinHandle<T> {}

View File

@ -155,7 +155,14 @@ cfg_rt_multi_thread! {
pub(super) use self::inject::Inject;
}
#[cfg(all(feature = "rt", any(tokio_unstable, test)))]
mod abort;
mod join;
#[cfg(all(feature = "rt", any(tokio_unstable, test)))]
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
pub use self::abort::AbortHandle;
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
pub use self::join::JoinHandle;

View File

@ -27,6 +27,9 @@ pub(super) struct Vtable {
/// The join handle has been dropped.
pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),
/// An abort handle has been dropped.
pub(super) drop_abort_handle: unsafe fn(NonNull<Header>),
/// The task is remotely aborted.
pub(super) remote_abort: unsafe fn(NonNull<Header>),
@ -42,6 +45,7 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
try_read_output: try_read_output::<T, S>,
try_set_join_waker: try_set_join_waker::<T, S>,
drop_join_handle_slow: drop_join_handle_slow::<T, S>,
drop_abort_handle: drop_abort_handle::<T, S>,
remote_abort: remote_abort::<T, S>,
shutdown: shutdown::<T, S>,
}
@ -104,6 +108,11 @@ impl RawTask {
unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
}
pub(super) fn drop_abort_handle(self) {
let vtable = self.header().vtable;
unsafe { (vtable.drop_abort_handle)(self.ptr) }
}
pub(super) fn shutdown(self) {
let vtable = self.header().vtable;
unsafe { (vtable.shutdown)(self.ptr) }
@ -113,6 +122,13 @@ impl RawTask {
let vtable = self.header().vtable;
unsafe { (vtable.remote_abort)(self.ptr) }
}
/// Increment the task's reference count.
///
/// Currently, this is used only when creating an `AbortHandle`.
pub(super) fn ref_inc(self) {
self.header().state.ref_inc();
}
}
impl Clone for RawTask {
@ -154,6 +170,11 @@ unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
harness.drop_join_handle_slow()
}
unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.drop_reference();
}
unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.remote_abort()

View File

@ -78,6 +78,44 @@ fn create_drop2() {
handle.assert_dropped();
}
#[test]
fn drop_abort_handle1() {
let (ad, handle) = AssertDrop::new();
let (notified, join) = unowned(
async {
drop(ad);
unreachable!()
},
NoopSchedule,
);
let abort = join.abort_handle();
drop(join);
handle.assert_not_dropped();
drop(notified);
handle.assert_not_dropped();
drop(abort);
handle.assert_dropped();
}
#[test]
fn drop_abort_handle2() {
let (ad, handle) = AssertDrop::new();
let (notified, join) = unowned(
async {
drop(ad);
unreachable!()
},
NoopSchedule,
);
let abort = join.abort_handle();
drop(notified);
handle.assert_not_dropped();
drop(abort);
handle.assert_not_dropped();
drop(join);
handle.assert_dropped();
}
// Shutting down through Notified works
#[test]
fn create_shutdown1() {

View File

@ -3,6 +3,7 @@ use std::panic;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::runtime::task::AbortHandle;
use crate::runtime::Builder;
use crate::sync::oneshot;
use crate::task::JoinHandle;
@ -56,6 +57,12 @@ enum CombiAbort {
AbortedAfterConsumeOutput = 4,
}
#[derive(Copy, Clone, Debug, PartialEq)]
enum CombiAbortSource {
JoinHandle,
AbortHandle,
}
#[test]
fn test_combinations() {
let mut rt = &[
@ -90,6 +97,13 @@ fn test_combinations() {
CombiAbort::AbortedAfterFinish,
CombiAbort::AbortedAfterConsumeOutput,
];
let ah = [
None,
Some(CombiJoinHandle::DropImmediately),
Some(CombiJoinHandle::DropFirstPoll),
Some(CombiJoinHandle::DropAfterNoConsume),
Some(CombiJoinHandle::DropAfterConsume),
];
for rt in rt.iter().copied() {
for ls in ls.iter().copied() {
@ -98,7 +112,34 @@ fn test_combinations() {
for ji in ji.iter().copied() {
for jh in jh.iter().copied() {
for abort in abort.iter().copied() {
test_combination(rt, ls, task, output, ji, jh, abort);
// abort via join handle --- abort handles
// may be dropped at any point
for ah in ah.iter().copied() {
test_combination(
rt,
ls,
task,
output,
ji,
jh,
ah,
abort,
CombiAbortSource::JoinHandle,
);
}
// if aborting via AbortHandle, it will
// never be dropped.
test_combination(
rt,
ls,
task,
output,
ji,
jh,
None,
abort,
CombiAbortSource::AbortHandle,
);
}
}
}
@ -108,6 +149,7 @@ fn test_combinations() {
}
}
#[allow(clippy::too_many_arguments)]
fn test_combination(
rt: CombiRuntime,
ls: CombiLocalSet,
@ -115,12 +157,24 @@ fn test_combination(
output: CombiOutput,
ji: CombiJoinInterest,
jh: CombiJoinHandle,
ah: Option<CombiJoinHandle>,
abort: CombiAbort,
abort_src: CombiAbortSource,
) {
if (jh as usize) < (abort as usize) {
// drop before abort not possible
return;
match (abort_src, ah) {
(CombiAbortSource::JoinHandle, _) if (jh as usize) < (abort as usize) => {
// join handle dropped prior to abort
return;
}
(CombiAbortSource::AbortHandle, Some(_)) => {
// abort handle dropped, we can't abort through the
// abort handle
return;
}
_ => {}
}
if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) {
// this causes double panic
return;
@ -130,7 +184,7 @@ fn test_combination(
return;
}
println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort);
println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, AbortHandle {:?}, Abort {:?} ({:?})", rt, ls, task, output, ji, jh, ah, abort, abort_src);
// A runtime optionally with a LocalSet
struct Rt {
@ -282,8 +336,24 @@ fn test_combination(
);
}
// If we are either aborting the task via an abort handle, or dropping via
// an abort handle, do that now.
let mut abort_handle = if ah.is_some() || abort_src == CombiAbortSource::AbortHandle {
handle.as_ref().map(JoinHandle::abort_handle)
} else {
None
};
let do_abort = |abort_handle: &mut Option<AbortHandle>,
join_handle: Option<&mut JoinHandle<_>>| {
match abort_src {
CombiAbortSource::AbortHandle => abort_handle.take().unwrap().abort(),
CombiAbortSource::JoinHandle => join_handle.unwrap().abort(),
}
};
if abort == CombiAbort::AbortedImmediately {
handle.as_mut().unwrap().abort();
do_abort(&mut abort_handle, handle.as_mut());
aborted = true;
}
if jh == CombiJoinHandle::DropImmediately {
@ -301,12 +371,15 @@ fn test_combination(
}
if abort == CombiAbort::AbortedFirstPoll {
handle.as_mut().unwrap().abort();
do_abort(&mut abort_handle, handle.as_mut());
aborted = true;
}
if jh == CombiJoinHandle::DropFirstPoll {
drop(handle.take().unwrap());
}
if ah == Some(CombiJoinHandle::DropFirstPoll) {
drop(abort_handle.take().unwrap());
}
// Signal the future that it can return now
let _ = on_complete.send(());
@ -318,23 +391,42 @@ fn test_combination(
if abort == CombiAbort::AbortedAfterFinish {
// Don't set aborted to true here as the task already finished
handle.as_mut().unwrap().abort();
do_abort(&mut abort_handle, handle.as_mut());
}
if jh == CombiJoinHandle::DropAfterNoConsume {
// The runtime will usually have dropped every ref-count at this point,
// in which case dropping the JoinHandle drops the output.
//
// (But it might race and still hold a ref-count)
let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
if ah == Some(CombiJoinHandle::DropAfterNoConsume) {
drop(handle.take().unwrap());
}));
if panic.is_err() {
assert!(
(output == CombiOutput::PanicOnDrop)
&& (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop))
&& !aborted,
"Dropping JoinHandle shouldn't panic here"
);
// The runtime will usually have dropped every ref-count at this point,
// in which case dropping the AbortHandle drops the output.
//
// (But it might race and still hold a ref-count)
let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
drop(abort_handle.take().unwrap());
}));
if panic.is_err() {
assert!(
(output == CombiOutput::PanicOnDrop)
&& (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop))
&& !aborted,
"Dropping AbortHandle shouldn't panic here"
);
}
} else {
// The runtime will usually have dropped every ref-count at this point,
// in which case dropping the JoinHandle drops the output.
//
// (But it might race and still hold a ref-count)
let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| {
drop(handle.take().unwrap());
}));
if panic.is_err() {
assert!(
(output == CombiOutput::PanicOnDrop)
&& (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop))
&& !aborted,
"Dropping JoinHandle shouldn't panic here"
);
}
}
}
@ -362,11 +454,15 @@ fn test_combination(
_ => unreachable!(),
}
let handle = handle.take().unwrap();
let mut handle = handle.take().unwrap();
if abort == CombiAbort::AbortedAfterConsumeOutput {
handle.abort();
do_abort(&mut abort_handle, Some(&mut handle));
}
drop(handle);
if ah == Some(CombiJoinHandle::DropAfterConsume) {
drop(abort_handle.take());
}
}
// The output should have been dropped now. Check whether the output

View File

@ -4,7 +4,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use crate::runtime::Handle;
use crate::task::{JoinError, JoinHandle, LocalSet};
use crate::task::{AbortHandle, JoinError, JoinHandle, LocalSet};
use crate::util::IdleNotifiedSet;
/// A collection of tasks spawned on a Tokio runtime.
@ -73,61 +73,76 @@ impl<T> JoinSet<T> {
}
impl<T: 'static> JoinSet<T> {
/// Spawn the provided task on the `JoinSet`.
/// Spawn the provided task on the `JoinSet`, returning an [`AbortHandle`]
/// that can be used to remotely cancel the task.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
pub fn spawn<F>(&mut self, task: F)
///
/// [`AbortHandle`]: crate::task::AbortHandle
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.insert(crate::spawn(task));
self.insert(crate::spawn(task))
}
/// Spawn the provided task on the provided runtime and store it in this `JoinSet`.
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
/// Spawn the provided task on the provided runtime and store it in this
/// `JoinSet` returning an [`AbortHandle`] that can be used to remotely
/// cancel the task.
///
/// [`AbortHandle`]: crate::task::AbortHandle
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.insert(handle.spawn(task));
self.insert(handle.spawn(task))
}
/// Spawn the provided task on the current [`LocalSet`] and store it in this `JoinSet`.
/// Spawn the provided task on the current [`LocalSet`] and store it in this
/// `JoinSet`, returning an [`AbortHandle`] that can be used to remotely
/// cancel the task.
///
/// # Panics
///
/// This method panics if it is called outside of a `LocalSet`.
///
/// [`LocalSet`]: crate::task::LocalSet
pub fn spawn_local<F>(&mut self, task: F)
/// [`AbortHandle`]: crate::task::AbortHandle
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.insert(crate::task::spawn_local(task));
self.insert(crate::task::spawn_local(task))
}
/// Spawn the provided task on the provided [`LocalSet`] and store it in this `JoinSet`.
/// Spawn the provided task on the provided [`LocalSet`] and store it in
/// this `JoinSet`, returning an [`AbortHandle`] that can be used to
/// remotely cancel the task.
///
/// [`LocalSet`]: crate::task::LocalSet
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet)
/// [`AbortHandle`]: crate::task::AbortHandle
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.insert(local_set.spawn_local(task));
self.insert(local_set.spawn_local(task))
}
fn insert(&mut self, jh: JoinHandle<T>) {
fn insert(&mut self, jh: JoinHandle<T>) -> AbortHandle {
let abort = jh.abort_handle();
let mut entry = self.inner.insert_idle(jh);
// Set the waker that is notified when the task completes.
entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker()));
abort
}
/// Waits until one of the tasks in the set completes and returns its output.

View File

@ -303,6 +303,7 @@ cfg_rt! {
cfg_unstable! {
mod join_set;
pub use join_set::JoinSet;
pub use crate::runtime::task::AbortHandle;
}
cfg_trace! {

View File

@ -106,6 +106,40 @@ async fn alternating() {
}
}
#[tokio::test(start_paused = true)]
async fn abort_tasks() {
let mut set = JoinSet::new();
let mut num_canceled = 0;
let mut num_completed = 0;
for i in 0..16 {
let abort = set.spawn(async move {
tokio::time::sleep(Duration::from_secs(i as u64)).await;
i
});
if i % 2 != 0 {
// abort odd-numbered tasks.
abort.abort();
}
}
loop {
match set.join_one().await {
Ok(Some(res)) => {
num_completed += 1;
assert_eq!(res % 2, 0);
}
Err(e) => {
assert!(e.is_cancelled());
num_canceled += 1;
}
Ok(None) => break,
}
}
assert_eq!(num_canceled, 8);
assert_eq!(num_completed, 8);
}
#[test]
fn runtime_gone() {
let mut set = JoinSet::new();