task: add JoinSet for managing sets of tasks(#4335)

Adds `JoinSet` for managing multiple spawned tasks and joining them
in completion order.

Closes: #3903
This commit is contained in:
Alice Ryhl 2022-02-01 23:17:09 +01:00 committed by GitHub
parent f602410227
commit 1bb4d23162
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1052 additions and 23 deletions

View File

@ -143,7 +143,7 @@ wasm-bindgen-test = "0.3.0"
mio-aio = { version = "0.6.0", features = ["tokio"] }
[target.'cfg(loom)'.dev-dependencies]
loom = { version = "0.5", features = ["futures", "checkpoint"] }
loom = { version = "0.5.2", features = ["futures", "checkpoint"] }
[package.metadata.docs.rs]
all-features = true

View File

@ -1,6 +1,6 @@
use crate::future::poll_fn;
use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::Mutex;
use crate::loom::sync::{Arc, Mutex};
use crate::park::{Park, Unpark};
use crate::runtime::context::EnterGuard;
use crate::runtime::driver::Driver;
@ -16,7 +16,6 @@ use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::sync::atomic::Ordering::{AcqRel, Release};
use std::sync::Arc;
use std::task::Poll::{Pending, Ready};
use std::time::Duration;
@ -481,8 +480,8 @@ impl Schedule for Arc<Shared> {
}
impl Wake for Shared {
fn wake(self: Arc<Self>) {
Wake::wake_by_ref(&self)
fn wake(arc_self: Arc<Self>) {
Wake::wake_by_ref(&arc_self)
}
/// Wake by reference

View File

@ -164,6 +164,13 @@ where
}
}
/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
can_read_output(self.header(), self.trailer(), waker)
}
pub(super) fn drop_join_handle_slow(self) {
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// case the task concurrently completed.

View File

@ -5,7 +5,7 @@ use std::future::Future;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Waker};
cfg_rt! {
/// An owned permission to join on a task (await its termination).
@ -200,6 +200,16 @@ impl<T> JoinHandle<T> {
raw.remote_abort();
}
}
/// Set the waker that is notified when the task completes.
pub(crate) fn set_join_waker(&mut self, waker: &Waker) {
if let Some(raw) = self.raw {
if raw.try_set_join_waker(waker) {
// In this case the task has already completed. We wake the waker immediately.
waker.wake_by_ref();
}
}
}
}
impl<T> Unpin for JoinHandle<T> {}

View File

@ -19,6 +19,11 @@ pub(super) struct Vtable {
/// Reads the task output, if complete.
pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),
/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool,
/// The join handle has been dropped.
pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),
@ -35,6 +40,7 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
poll: poll::<T, S>,
dealloc: dealloc::<T, S>,
try_read_output: try_read_output::<T, S>,
try_set_join_waker: try_set_join_waker::<T, S>,
drop_join_handle_slow: drop_join_handle_slow::<T, S>,
remote_abort: remote_abort::<T, S>,
shutdown: shutdown::<T, S>,
@ -84,6 +90,11 @@ impl RawTask {
(vtable.try_read_output)(self.ptr, dst, waker);
}
pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
let vtable = self.header().vtable;
unsafe { (vtable.try_set_join_waker)(self.ptr, waker) }
}
pub(super) fn drop_join_handle_slow(self) {
let vtable = self.header().vtable;
unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
@ -129,6 +140,11 @@ unsafe fn try_read_output<T: Future, S: Schedule>(
harness.try_read_output(out, waker);
}
unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool {
let harness = Harness::<T, S>::from_raw(ptr);
harness.try_set_join_waker(waker)
}
unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.drop_join_handle_slow()

View File

@ -0,0 +1,82 @@
use crate::runtime::Builder;
use crate::task::JoinSet;
#[test]
fn test_join_set() {
loom::model(|| {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();
let mut set = JoinSet::new();
rt.block_on(async {
assert_eq!(set.len(), 0);
set.spawn(async { () });
assert_eq!(set.len(), 1);
set.spawn(async { () });
assert_eq!(set.len(), 2);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
set.spawn(async { () });
assert_eq!(set.len(), 2);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 0);
set.spawn(async { () });
assert_eq!(set.len(), 1);
});
drop(set);
drop(rt);
});
}
#[test]
fn abort_all_during_completion() {
use std::sync::{
atomic::{AtomicBool, Ordering::SeqCst},
Arc,
};
// These booleans assert that at least one execution had the task complete first, and that at
// least one execution had the task be cancelled before it completed.
let complete_happened = Arc::new(AtomicBool::new(false));
let cancel_happened = Arc::new(AtomicBool::new(false));
{
let complete_happened = complete_happened.clone();
let cancel_happened = cancel_happened.clone();
loom::model(move || {
let rt = Builder::new_multi_thread()
.worker_threads(1)
.build()
.unwrap();
let mut set = JoinSet::new();
rt.block_on(async {
set.spawn(async { () });
set.abort_all();
match set.join_one().await {
Ok(Some(())) => complete_happened.store(true, SeqCst),
Err(err) if err.is_cancelled() => cancel_happened.store(true, SeqCst),
Err(err) => panic!("fail: {}", err),
Ok(None) => {
unreachable!("Aborting the task does not remove it from the JoinSet.")
}
}
assert!(matches!(set.join_one().await, Ok(None)));
});
drop(set);
drop(rt);
});
}
assert!(complete_happened.load(SeqCst));
assert!(cancel_happened.load(SeqCst));
}

View File

@ -30,12 +30,13 @@ mod unowned_wrapper {
cfg_loom! {
mod loom_basic_scheduler;
mod loom_local;
mod loom_blocking;
mod loom_local;
mod loom_oneshot;
mod loom_pool;
mod loom_queue;
mod loom_shutdown_join;
mod loom_join_set;
}
cfg_not_loom! {

242
tokio/src/task/join_set.rs Normal file
View File

@ -0,0 +1,242 @@
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::runtime::Handle;
use crate::task::{JoinError, JoinHandle, LocalSet};
use crate::util::IdleNotifiedSet;
/// A collection of tasks spawned on a Tokio runtime.
///
/// A `JoinSet` can be used to await the completion of some or all of the tasks
/// in the set. The set is not ordered, and the tasks will be returned in the
/// order they complete.
///
/// All of the tasks must have the same return type `T`.
///
/// When the `JoinSet` is dropped, all tasks in the `JoinSet` are immediately aborted.
///
/// # Examples
///
/// Spawn multiple tasks and wait for them.
///
/// ```
/// use tokio::task::JoinSet;
///
/// #[tokio::main]
/// async fn main() {
/// let mut set = JoinSet::new();
///
/// for i in 0..10 {
/// set.spawn(async move { i });
/// }
///
/// let mut seen = [false; 10];
/// while let Some(res) = set.join_one().await.unwrap() {
/// seen[res] = true;
/// }
///
/// for i in 0..10 {
/// assert!(seen[i]);
/// }
/// }
/// ```
pub struct JoinSet<T> {
inner: IdleNotifiedSet<JoinHandle<T>>,
}
impl<T> JoinSet<T> {
/// Create a new `JoinSet`.
pub fn new() -> Self {
Self {
inner: IdleNotifiedSet::new(),
}
}
/// Returns the number of tasks currently in the `JoinSet`.
pub fn len(&self) -> usize {
self.inner.len()
}
/// Returns whether the `JoinSet` is empty.
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<T: 'static> JoinSet<T> {
/// Spawn the provided task on the `JoinSet`.
///
/// # Panics
///
/// This method panics if called outside of a Tokio runtime.
pub fn spawn<F>(&mut self, task: F)
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.insert(crate::spawn(task));
}
/// Spawn the provided task on the provided runtime and store it in this `JoinSet`.
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.insert(handle.spawn(task));
}
/// Spawn the provided task on the current [`LocalSet`] and store it in this `JoinSet`.
///
/// # Panics
///
/// This method panics if it is called outside of a `LocalSet`.
///
/// [`LocalSet`]: crate::task::LocalSet
pub fn spawn_local<F>(&mut self, task: F)
where
F: Future<Output = T>,
F: 'static,
{
self.insert(crate::task::spawn_local(task));
}
/// Spawn the provided task on the provided [`LocalSet`] and store it in this `JoinSet`.
///
/// [`LocalSet`]: crate::task::LocalSet
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet)
where
F: Future<Output = T>,
F: 'static,
{
self.insert(local_set.spawn_local(task));
}
fn insert(&mut self, jh: JoinHandle<T>) {
let mut entry = self.inner.insert_idle(jh);
// Set the waker that is notified when the task completes.
entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker()));
}
/// Waits until one of the tasks in the set completes and returns its output.
///
/// Returns `None` if the set is empty.
///
/// # Cancel Safety
///
/// This method is cancel safe. If `join_one` is used as the event in a `tokio::select!`
/// statement and some other branch completes first, it is guaranteed that no tasks were
/// removed from this `JoinSet`.
pub async fn join_one(&mut self) -> Result<Option<T>, JoinError> {
crate::future::poll_fn(|cx| self.poll_join_one(cx)).await
}
/// Aborts all tasks and waits for them to finish shutting down.
///
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_one`] in
/// a loop until it returns `Ok(None)`.
///
/// This method ignores any panics in the tasks shutting down. When this call returns, the
/// `JoinSet` will be empty.
///
/// [`abort_all`]: fn@Self::abort_all
/// [`join_one`]: fn@Self::join_one
pub async fn shutdown(&mut self) {
self.abort_all();
while self.join_one().await.transpose().is_some() {}
}
/// Aborts all tasks on this `JoinSet`.
///
/// This does not remove the tasks from the `JoinSet`. To wait for the tasks to complete
/// cancellation, you should call `join_one` in a loop until the `JoinSet` is empty.
pub fn abort_all(&mut self) {
self.inner.for_each(|jh| jh.abort());
}
/// Removes all tasks from this `JoinSet` without aborting them.
///
/// The tasks removed by this call will continue to run in the background even if the `JoinSet`
/// is dropped.
pub fn detach_all(&mut self) {
self.inner.drain(drop);
}
/// Polls for one of the tasks in the set to complete.
///
/// If this returns `Poll::Ready(Ok(Some(_)))` or `Poll::Ready(Err(_))`, then the task that
/// completed is removed from the set.
///
/// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
/// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to
/// `poll_join_one`, only the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
///
/// # Returns
///
/// This function returns:
///
/// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is
/// available right now.
/// * `Poll::Ready(Ok(Some(value)))` if one of the tasks in this `JoinSet` has completed. The
/// `value` is the return value of one of the tasks that completed.
/// * `Poll::Ready(Err(err))` if one of the tasks in this `JoinSet` has panicked or been
/// aborted.
/// * `Poll::Ready(Ok(None))` if the `JoinSet` is empty.
///
/// Note that this method may return `Poll::Pending` even if one of the tasks has completed.
/// This can happen if the [coop budget] is reached.
///
/// [coop budget]: crate::task#cooperative-scheduling
fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<T>, JoinError>> {
// The call to `pop_notified` moves the entry to the `idle` list. It is moved back to
// the `notified` list if the waker is notified in the `poll` call below.
let mut entry = match self.inner.pop_notified(cx.waker()) {
Some(entry) => entry,
None => {
if self.is_empty() {
return Poll::Ready(Ok(None));
} else {
// The waker was set by `pop_notified`.
return Poll::Pending;
}
}
};
let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx));
if let Poll::Ready(res) = res {
entry.remove();
Poll::Ready(Some(res).transpose())
} else {
// A JoinHandle generally won't emit a wakeup without being ready unless
// the coop limit has been reached. We yield to the executor in this
// case.
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
impl<T> Drop for JoinSet<T> {
fn drop(&mut self) {
self.inner.drain(|join_handle| join_handle.abort());
}
}
impl<T> fmt::Debug for JoinSet<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinSet").field("len", &self.len()).finish()
}
}
impl<T> Default for JoinSet<T> {
fn default() -> Self {
Self::new()
}
}

View File

@ -300,6 +300,9 @@ cfg_rt! {
mod unconstrained;
pub use unconstrained::{unconstrained, Unconstrained};
mod join_set;
pub use join_set::JoinSet;
cfg_trace! {
mod builder;
pub use builder::Builder;

View File

@ -0,0 +1,463 @@
//! This module defines an `IdleNotifiedSet`, which is a collection of elements.
//! Each element is intended to correspond to a task, and the collection will
//! keep track of which tasks have had their waker notified, and which have not.
//!
//! Each entry in the set holds some user-specified value. The value's type is
//! specified using the `T` parameter. It will usually be a `JoinHandle` or
//! similar.
use std::marker::PhantomPinned;
use std::mem::ManuallyDrop;
use std::ptr::NonNull;
use std::task::{Context, Waker};
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::{Arc, Mutex};
use crate::util::linked_list::{self, Link};
use crate::util::{waker_ref, Wake};
type LinkedList<T> =
linked_list::LinkedList<ListEntry<T>, <ListEntry<T> as linked_list::Link>::Target>;
/// This is the main handle to the collection.
pub(crate) struct IdleNotifiedSet<T> {
lists: Arc<Lists<T>>,
length: usize,
}
/// A handle to an entry that is guaranteed to be stored in the idle or notified
/// list of its `IdleNotifiedSet`. This value borrows the `IdleNotifiedSet`
/// mutably to prevent the entry from being moved to the `Neither` list, which
/// only the `IdleNotifiedSet` may do.
///
/// The main consequence of being stored in one of the lists is that the `value`
/// field has not yet been consumed.
///
/// Note: This entry can be moved from the idle to the notified list while this
/// object exists by waking its waker.
pub(crate) struct EntryInOneOfTheLists<'a, T> {
entry: Arc<ListEntry<T>>,
set: &'a mut IdleNotifiedSet<T>,
}
type Lists<T> = Mutex<ListsInner<T>>;
/// The linked lists hold strong references to the ListEntry items, and the
/// ListEntry items also hold a strong reference back to the Lists object, but
/// the destructor of the `IdleNotifiedSet` will clear the two lists, so once
/// that object is destroyed, no ref-cycles will remain.
struct ListsInner<T> {
notified: LinkedList<T>,
idle: LinkedList<T>,
/// Whenever an element in the `notified` list is woken, this waker will be
/// notified and consumed, if it exists.
waker: Option<Waker>,
}
/// Which of the two lists in the shared Lists object is this entry stored in?
///
/// If the value is `Idle`, then an entry's waker may move it to the notified
/// list. Otherwise, only the `IdleNotifiedSet` may move it.
///
/// If the value is `Neither`, then it is still possible that the entry is in
/// some third external list (this happens in `drain`).
#[derive(Copy, Clone, Eq, PartialEq)]
enum List {
Notified,
Idle,
Neither,
}
/// An entry in the list.
///
/// # Safety
///
/// The `my_list` field must only be accessed while holding the mutex in
/// `parent`. It is an invariant that the value of `my_list` corresponds to
/// which linked list in the `parent` holds this entry. Once this field takes
/// the value `Neither`, then it may never be modified again.
///
/// If the value of `my_list` is `Notified` or `Idle`, then the `pointers` field
/// must only be accessed while holding the mutex. If the value of `my_list` is
/// `Neither`, then the `pointers` field may be accessed by the
/// `IdleNotifiedSet` (this happens inside `drain`).
///
/// The `value` field is owned by the `IdleNotifiedSet` and may only be accessed
/// by the `IdleNotifiedSet`. The operation that sets the value of `my_list` to
/// `Neither` assumes ownership of the `value`, and it must either drop it or
/// move it out from this entry to prevent it from getting leaked. (Since the
/// two linked lists are emptied in the destructor of `IdleNotifiedSet`, the
/// value should not be leaked.)
///
/// This type is `#[repr(C)]` because its `linked_list::Link` implementation
/// requires that `pointers` is the first field.
#[repr(C)]
struct ListEntry<T> {
/// The linked list pointers of the list this entry is in.
pointers: linked_list::Pointers<ListEntry<T>>,
/// Pointer to the shared `Lists` struct.
parent: Arc<Lists<T>>,
/// The value stored in this entry.
value: UnsafeCell<ManuallyDrop<T>>,
/// Used to remember which list this entry is in.
my_list: UnsafeCell<List>,
/// Required by the `linked_list::Pointers` field.
_pin: PhantomPinned,
}
// With mutable access to the `IdleNotifiedSet`, you can get mutable access to
// the values.
unsafe impl<T: Send> Send for IdleNotifiedSet<T> {}
// With the current API we strictly speaking don't even need `T: Sync`, but we
// require it anyway to support adding &self APIs that access the values in the
// future.
unsafe impl<T: Sync> Sync for IdleNotifiedSet<T> {}
// These impls control when it is safe to create a Waker. Since the waker does
// not allow access to the value in any way (including its destructor), it is
// not necessary for `T` to be Send or Sync.
unsafe impl<T> Send for ListEntry<T> {}
unsafe impl<T> Sync for ListEntry<T> {}
impl<T> IdleNotifiedSet<T> {
/// Create a new IdleNotifiedSet.
pub(crate) fn new() -> Self {
let lists = Mutex::new(ListsInner {
notified: LinkedList::new(),
idle: LinkedList::new(),
waker: None,
});
IdleNotifiedSet {
lists: Arc::new(lists),
length: 0,
}
}
pub(crate) fn len(&self) -> usize {
self.length
}
pub(crate) fn is_empty(&self) -> bool {
self.length == 0
}
/// Insert the given value into the `idle` list.
pub(crate) fn insert_idle(&mut self, value: T) -> EntryInOneOfTheLists<'_, T> {
self.length += 1;
let entry = Arc::new(ListEntry {
parent: self.lists.clone(),
value: UnsafeCell::new(ManuallyDrop::new(value)),
my_list: UnsafeCell::new(List::Idle),
pointers: linked_list::Pointers::new(),
_pin: PhantomPinned,
});
{
let mut lock = self.lists.lock();
lock.idle.push_front(entry.clone());
}
// Safety: We just put the entry in the idle list, so it is in one of the lists.
EntryInOneOfTheLists { entry, set: self }
}
/// Pop an entry from the notified list to poll it. The entry is moved to
/// the idle list atomically.
pub(crate) fn pop_notified(&mut self, waker: &Waker) -> Option<EntryInOneOfTheLists<'_, T>> {
// We don't decrement the length because this call moves the entry to
// the idle list rather than removing it.
if self.length == 0 {
// Fast path.
return None;
}
let mut lock = self.lists.lock();
let should_update_waker = match lock.waker.as_mut() {
Some(cur_waker) => !waker.will_wake(cur_waker),
None => true,
};
if should_update_waker {
lock.waker = Some(waker.clone());
}
// Pop the entry, returning None if empty.
let entry = lock.notified.pop_back()?;
lock.idle.push_front(entry.clone());
// Safety: We are holding the lock.
entry.my_list.with_mut(|ptr| unsafe {
*ptr = List::Idle;
});
drop(lock);
// Safety: We just put the entry in the idle list, so it is in one of the lists.
Some(EntryInOneOfTheLists { entry, set: self })
}
/// Call a function on every element in this list.
pub(crate) fn for_each<F: FnMut(&mut T)>(&mut self, mut func: F) {
fn get_ptrs<T>(list: &mut LinkedList<T>, ptrs: &mut Vec<*mut T>) {
let mut node = list.last();
while let Some(entry) = node {
ptrs.push(entry.value.with_mut(|ptr| {
let ptr: *mut ManuallyDrop<T> = ptr;
let ptr: *mut T = ptr.cast();
ptr
}));
let prev = entry.pointers.get_prev();
node = prev.map(|prev| unsafe { &*prev.as_ptr() });
}
}
// Atomically get a raw pointer to the value of every entry.
//
// Since this only locks the mutex once, it is not possible for a value
// to get moved from the idle list to the notified list during the
// operation, which would otherwise result in some value being listed
// twice.
let mut ptrs = Vec::with_capacity(self.len());
{
let mut lock = self.lists.lock();
get_ptrs(&mut lock.idle, &mut ptrs);
get_ptrs(&mut lock.notified, &mut ptrs);
}
debug_assert_eq!(ptrs.len(), ptrs.capacity());
for ptr in ptrs {
// Safety: When we grabbed the pointers, the entries were in one of
// the two lists. This means that their value was valid at the time,
// and it must still be valid because we are the IdleNotifiedSet,
// and only we can remove an entry from the two lists. (It's
// possible that an entry is moved from one list to the other during
// this loop, but that is ok.)
func(unsafe { &mut *ptr });
}
}
/// Remove all entries in both lists, applying some function to each element.
///
/// The closure is called on all elements even if it panics. Having it panic
/// twice is a double-panic, and will abort the application.
pub(crate) fn drain<F: FnMut(T)>(&mut self, func: F) {
if self.length == 0 {
// Fast path.
return;
}
self.length = 0;
// The LinkedList is not cleared on panic, so we use a bomb to clear it.
//
// This value has the invariant that any entry in its `all_entries` list
// has `my_list` set to `Neither` and that the value has not yet been
// dropped.
struct AllEntries<T, F: FnMut(T)> {
all_entries: LinkedList<T>,
func: F,
}
impl<T, F: FnMut(T)> AllEntries<T, F> {
fn pop_next(&mut self) -> bool {
if let Some(entry) = self.all_entries.pop_back() {
// Safety: We just took this value from the list, so we can
// destroy the value in the entry.
entry
.value
.with_mut(|ptr| unsafe { (self.func)(ManuallyDrop::take(&mut *ptr)) });
true
} else {
false
}
}
}
impl<T, F: FnMut(T)> Drop for AllEntries<T, F> {
fn drop(&mut self) {
while self.pop_next() {}
}
}
let mut all_entries = AllEntries {
all_entries: LinkedList::new(),
func,
};
// Atomically move all entries to the new linked list in the AllEntries
// object.
{
let mut lock = self.lists.lock();
unsafe {
// Safety: We are holding the lock and `all_entries` is a new
// LinkedList.
move_to_new_list(&mut lock.idle, &mut all_entries.all_entries);
move_to_new_list(&mut lock.notified, &mut all_entries.all_entries);
}
}
// Keep destroying entries in the list until it is empty.
//
// If the closure panics, then the destructor of the `AllEntries` bomb
// ensures that we keep running the destructor on the remaining values.
// A second panic will abort the program.
while all_entries.pop_next() {}
}
}
/// # Safety
///
/// The mutex for the entries must be held, and the target list must be such
/// that setting `my_list` to `Neither` is ok.
unsafe fn move_to_new_list<T>(from: &mut LinkedList<T>, to: &mut LinkedList<T>) {
while let Some(entry) = from.pop_back() {
entry.my_list.with_mut(|ptr| {
*ptr = List::Neither;
});
to.push_front(entry);
}
}
impl<'a, T> EntryInOneOfTheLists<'a, T> {
/// Remove this entry from the list it is in, returning the value associated
/// with the entry.
///
/// This consumes the value, since it is no longer guaranteed to be in a
/// list.
pub(crate) fn remove(self) -> T {
self.set.length -= 1;
{
let mut lock = self.set.lists.lock();
// Safety: We are holding the lock so there is no race, and we will
// remove the entry afterwards to uphold invariants.
let old_my_list = self.entry.my_list.with_mut(|ptr| unsafe {
let old_my_list = *ptr;
*ptr = List::Neither;
old_my_list
});
let list = match old_my_list {
List::Idle => &mut lock.idle,
List::Notified => &mut lock.notified,
// An entry in one of the lists is in one of the lists.
List::Neither => unreachable!(),
};
unsafe {
// Safety: We just checked that the entry is in this particular
// list.
list.remove(ListEntry::as_raw(&self.entry)).unwrap();
}
}
// By setting `my_list` to `Neither`, we have taken ownership of the
// value. We return it to the caller.
//
// Safety: We have a mutable reference to the `IdleNotifiedSet` that
// owns this entry, so we can use its permission to access the value.
self.entry
.value
.with_mut(|ptr| unsafe { ManuallyDrop::take(&mut *ptr) })
}
/// Access the value in this entry together with a context for its waker.
pub(crate) fn with_value_and_context<F, U>(&mut self, func: F) -> U
where
F: FnOnce(&mut T, &mut Context<'_>) -> U,
T: 'static,
{
let waker = waker_ref(&self.entry);
let mut context = Context::from_waker(&waker);
// Safety: We have a mutable reference to the `IdleNotifiedSet` that
// owns this entry, so we can use its permission to access the value.
self.entry
.value
.with_mut(|ptr| unsafe { func(&mut *ptr, &mut context) })
}
}
impl<T> Drop for IdleNotifiedSet<T> {
fn drop(&mut self) {
// Clear both lists.
self.drain(drop);
#[cfg(debug_assertions)]
if !std::thread::panicking() {
let lock = self.lists.lock();
assert!(lock.idle.is_empty());
assert!(lock.notified.is_empty());
}
}
}
impl<T: 'static> Wake for ListEntry<T> {
fn wake_by_ref(me: &Arc<Self>) {
let mut lock = me.parent.lock();
// Safety: We are holding the lock and we will update the lists to
// maintain invariants.
let old_my_list = me.my_list.with_mut(|ptr| unsafe {
let old_my_list = *ptr;
if old_my_list == List::Idle {
*ptr = List::Notified;
}
old_my_list
});
if old_my_list == List::Idle {
// We move ourself to the notified list.
let me = unsafe {
// Safety: We just checked that we are in this particular list.
lock.idle.remove(NonNull::from(&**me)).unwrap()
};
lock.notified.push_front(me);
if let Some(waker) = lock.waker.take() {
drop(lock);
waker.wake();
}
}
}
fn wake(me: Arc<Self>) {
Self::wake_by_ref(&me)
}
}
/// # Safety
///
/// `ListEntry` is forced to be !Unpin.
unsafe impl<T> linked_list::Link for ListEntry<T> {
type Handle = Arc<ListEntry<T>>;
type Target = ListEntry<T>;
fn as_raw(handle: &Self::Handle) -> NonNull<ListEntry<T>> {
let ptr: *const ListEntry<T> = Arc::as_ptr(handle);
// Safety: We can't get a null pointer from `Arc::as_ptr`.
unsafe { NonNull::new_unchecked(ptr as *mut ListEntry<T>) }
}
unsafe fn from_raw(ptr: NonNull<ListEntry<T>>) -> Arc<ListEntry<T>> {
Arc::from_raw(ptr.as_ptr())
}
unsafe fn pointers(
target: NonNull<ListEntry<T>>,
) -> NonNull<linked_list::Pointers<ListEntry<T>>> {
// Safety: The pointers struct is the first field and ListEntry is
// `#[repr(C)]` so this cast is safe.
//
// We do this rather than doing a field access since `std::ptr::addr_of`
// is too new for our MSRV.
target.cast()
}
}

View File

@ -219,6 +219,7 @@ impl<L: Link> fmt::Debug for LinkedList<L, L::Target> {
#[cfg(any(
feature = "fs",
feature = "rt",
all(unix, feature = "process"),
feature = "signal",
feature = "sync",
@ -296,7 +297,7 @@ impl<T> Pointers<T> {
}
}
fn get_prev(&self) -> Option<NonNull<T>> {
pub(crate) fn get_prev(&self) -> Option<NonNull<T>> {
// SAFETY: prev is the first field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
@ -304,7 +305,7 @@ impl<T> Pointers<T> {
ptr::read(prev)
}
}
fn get_next(&self) -> Option<NonNull<T>> {
pub(crate) fn get_next(&self) -> Option<NonNull<T>> {
// SAFETY: next is the second field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();

View File

@ -44,6 +44,9 @@ pub(crate) mod linked_list;
mod rand;
cfg_rt! {
mod idle_notified_set;
pub(crate) use idle_notified_set::IdleNotifiedSet;
mod wake;
pub(crate) use wake::WakerRef;
pub(crate) use wake::{waker_ref, Wake};

View File

@ -1,13 +1,14 @@
use crate::loom::sync::Arc;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::Deref;
use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable, Waker};
/// Simplified waking interface based on Arcs.
pub(crate) trait Wake: Send + Sync {
pub(crate) trait Wake: Send + Sync + Sized + 'static {
/// Wake by value.
fn wake(self: Arc<Self>);
fn wake(arc_self: Arc<Self>);
/// Wake by reference.
fn wake_by_ref(arc_self: &Arc<Self>);

View File

@ -435,24 +435,33 @@ async_assert_fn!(tokio::sync::watch::Sender<NN>::closed(_): !Send & !Sync & !Unp
async_assert_fn!(tokio::sync::watch::Sender<YN>::closed(_): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::sync::watch::Sender<YY>::closed(_): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSync<()>): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSend<()>): Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFuture<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSync<()>): Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSend<()>): Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::JoinSet<Cell<u32>>::join_one(_): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::JoinSet<Cell<u32>>::shutdown(_): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::JoinSet<Rc<u32>>::join_one(_): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::JoinSet<Rc<u32>>::shutdown(_): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::JoinSet<u32>::join_one(_): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::JoinSet<u32>::shutdown(_): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFuture<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSync<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSend<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSend<()>): Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSync<()>): Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFuture<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSend<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSync<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFuture<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSend<()>): Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSync<()>): Send & Sync & !Unpin);
async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync & !Unpin);
async_assert_fn!(tokio::task::unconstrained(BoxFuture<()>): !Send & !Sync & Unpin);
async_assert_fn!(tokio::task::unconstrained(BoxFutureSend<()>): Send & !Sync & Unpin);
async_assert_fn!(tokio::task::unconstrained(BoxFutureSync<()>): Send & Sync & Unpin);
assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin);
assert_value!(tokio::task::JoinHandle<YY>: Send & Sync & Unpin);
assert_value!(tokio::task::JoinHandle<YN>: Send & Sync & Unpin);
assert_value!(tokio::task::JoinHandle<NN>: !Send & !Sync & Unpin);
assert_value!(tokio::task::JoinError: Send & Sync & Unpin);
assert_value!(tokio::task::JoinHandle<NN>: !Send & !Sync & Unpin);
assert_value!(tokio::task::JoinHandle<YN>: Send & Sync & Unpin);
assert_value!(tokio::task::JoinHandle<YY>: Send & Sync & Unpin);
assert_value!(tokio::task::JoinSet<YY>: Send & Sync & Unpin);
assert_value!(tokio::task::JoinSet<YN>: Send & Sync & Unpin);
assert_value!(tokio::task::JoinSet<NN>: !Send & !Sync & Unpin);
assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin);
assert_value!(tokio::runtime::Builder: Send & Sync & Unpin);
assert_value!(tokio::runtime::EnterGuard<'_>: Send & Sync & Unpin);

View File

@ -0,0 +1,192 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use tokio::sync::oneshot;
use tokio::task::JoinSet;
use tokio::time::Duration;
use futures::future::FutureExt;
fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.build()
.unwrap()
}
#[tokio::test(start_paused = true)]
async fn test_with_sleep() {
let mut set = JoinSet::new();
for i in 0..10 {
set.spawn(async move { i });
assert_eq!(set.len(), 1 + i);
}
set.detach_all();
assert_eq!(set.len(), 0);
assert!(matches!(set.join_one().await, Ok(None)));
for i in 0..10 {
set.spawn(async move {
tokio::time::sleep(Duration::from_secs(i as u64)).await;
i
});
assert_eq!(set.len(), 1 + i);
}
let mut seen = [false; 10];
while let Some(res) = set.join_one().await.unwrap() {
seen[res] = true;
}
for was_seen in &seen {
assert!(was_seen);
}
assert!(matches!(set.join_one().await, Ok(None)));
// Do it again.
for i in 0..10 {
set.spawn(async move {
tokio::time::sleep(Duration::from_secs(i as u64)).await;
i
});
}
let mut seen = [false; 10];
while let Some(res) = set.join_one().await.unwrap() {
seen[res] = true;
}
for was_seen in &seen {
assert!(was_seen);
}
assert!(matches!(set.join_one().await, Ok(None)));
}
#[tokio::test]
async fn test_abort_on_drop() {
let mut set = JoinSet::new();
let mut recvs = Vec::new();
for _ in 0..16 {
let (send, recv) = oneshot::channel::<()>();
recvs.push(recv);
set.spawn(async {
// This task will never complete on its own.
futures::future::pending::<()>().await;
drop(send);
});
}
drop(set);
for recv in recvs {
// The task is aborted soon and we will receive an error.
assert!(recv.await.is_err());
}
}
#[tokio::test]
async fn alternating() {
let mut set = JoinSet::new();
assert_eq!(set.len(), 0);
set.spawn(async {});
assert_eq!(set.len(), 1);
set.spawn(async {});
assert_eq!(set.len(), 2);
for _ in 0..16 {
let () = set.join_one().await.unwrap().unwrap();
assert_eq!(set.len(), 1);
set.spawn(async {});
assert_eq!(set.len(), 2);
}
}
#[test]
fn runtime_gone() {
let mut set = JoinSet::new();
{
let rt = rt();
set.spawn_on(async { 1 }, rt.handle());
drop(rt);
}
assert!(rt().block_on(set.join_one()).unwrap_err().is_cancelled());
}
// This ensures that `join_one` works correctly when the coop budget is
// exhausted.
#[tokio::test(flavor = "current_thread")]
async fn join_set_coop() {
// Large enough to trigger coop.
const TASK_NUM: u32 = 1000;
static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0);
let mut set = JoinSet::new();
for _ in 0..TASK_NUM {
set.spawn(async {
SEM.add_permits(1);
});
}
// Wait for all tasks to complete.
//
// Since this is a `current_thread` runtime, there's no race condition
// between the last permit being added and the task completing.
let _ = SEM.acquire_many(TASK_NUM).await.unwrap();
let mut count = 0;
let mut coop_count = 0;
loop {
match set.join_one().now_or_never() {
Some(Ok(Some(()))) => {}
Some(Err(err)) => panic!("failed: {}", err),
None => {
coop_count += 1;
tokio::task::yield_now().await;
continue;
}
Some(Ok(None)) => break,
}
count += 1;
}
assert!(coop_count >= 1);
assert_eq!(count, TASK_NUM);
}
#[tokio::test(start_paused = true)]
async fn abort_all() {
let mut set: JoinSet<()> = JoinSet::new();
for _ in 0..5 {
set.spawn(futures::future::pending());
}
for _ in 0..5 {
set.spawn(async {
tokio::time::sleep(Duration::from_secs(1)).await;
});
}
// The join set will now have 5 pending tasks and 5 ready tasks.
tokio::time::sleep(Duration::from_secs(2)).await;
set.abort_all();
assert_eq!(set.len(), 10);
let mut count = 0;
while let Some(res) = set.join_one().await.transpose() {
if let Err(err) = res {
assert!(err.is_cancelled());
}
count += 1;
}
assert_eq!(count, 10);
assert_eq!(set.len(), 0);
}