task: add LocalSet::enter (#4736) (#4765)

This commit is contained in:
gftea 2022-07-13 18:10:09 +02:00 committed by GitHub
parent 8e20cfb9ef
commit 14fca343d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 29 deletions

View File

@ -10,6 +10,7 @@ use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;
use pin_project_lite::pin_project;
@ -215,7 +216,7 @@ cfg_rt! {
tick: Cell<u8>,
/// State available from thread-local.
context: Context,
context: Rc<Context>,
/// This type should not be Send.
_not_send: PhantomData<*const ()>,
@ -260,7 +261,7 @@ pin_project! {
}
}
scoped_thread_local!(static CURRENT: Context);
thread_local!(static CURRENT: Cell<Option<Rc<Context>>> = Cell::new(None));
cfg_rt! {
/// Spawns a `!Send` future on the local task set.
@ -310,10 +311,12 @@ cfg_rt! {
F::Output: 'static
{
CURRENT.with(|maybe_cx| {
let cx = maybe_cx
.expect("`spawn_local` called from outside of a `task::LocalSet`");
let ctx = clone_rc(maybe_cx);
match ctx {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}
cx.spawn(future, name)
})
}
}
@ -327,12 +330,29 @@ const MAX_TASKS_PER_TICK: usize = 61;
/// How often it check the remote queue first.
const REMOTE_FIRST_INTERVAL: u8 = 31;
/// Context guard for LocalSet
pub struct LocalEnterGuard(Option<Rc<Context>>);
impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|ctx| {
ctx.replace(self.0.take());
})
}
}
impl fmt::Debug for LocalEnterGuard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LocalEnterGuard").finish()
}
}
impl LocalSet {
/// Returns a new local task set.
pub fn new() -> LocalSet {
LocalSet {
tick: Cell::new(0),
context: Context {
context: Rc::new(Context {
owned: LocalOwnedTasks::new(),
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
shared: Arc::new(Shared {
@ -342,11 +362,24 @@ impl LocalSet {
unhandled_panic: crate::runtime::UnhandledPanic::Ignore,
}),
unhandled_panic: Cell::new(false),
},
}),
_not_send: PhantomData,
}
}
/// Enters the context of this `LocalSet`.
///
/// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
/// context you are inside.
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|ctx| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
}
/// Spawns a `!Send` task onto the local task set.
///
/// This task is guaranteed to be run on the current thread.
@ -579,7 +612,25 @@ impl LocalSet {
}
fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.set(&self.context, f)
CURRENT.with(|ctx| {
struct Reset<'a> {
ctx_ref: &'a Cell<Option<Rc<Context>>>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));
let _reset = Reset {
ctx_ref: ctx,
val: old,
};
f()
})
}
}
@ -645,8 +696,9 @@ cfg_unstable! {
/// [`JoinHandle`]: struct@crate::task::JoinHandle
pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self {
// TODO: This should be set as a builder
Arc::get_mut(&mut self.context.shared)
.expect("TODO: we shouldn't panic")
Rc::get_mut(&mut self.context)
.and_then(|ctx| Arc::get_mut(&mut ctx.shared))
.expect("Unhandled Panic behavior modified after starting LocalSet")
.unhandled_panic = behavior;
self
}
@ -769,23 +821,33 @@ impl<T: Future> Future for RunUntil<'_, T> {
}
}
fn clone_rc<T>(rc: &Cell<Option<Rc<T>>>) -> Option<Rc<T>> {
let value = rc.take();
let cloned = value.clone();
rc.set(value);
cloned
}
impl Shared {
/// Schedule the provided task on the scheduler.
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| match maybe_cx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
}
_ => {
// First check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();
CURRENT.with(|maybe_cx| {
let ctx = clone_rc(maybe_cx);
match ctx {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
}
_ => {
// First check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();
if let Some(queue) = lock.as_mut() {
queue.push_back(task);
drop(lock);
self.waker.wake();
if let Some(queue) = lock.as_mut() {
queue.push_back(task);
drop(lock);
self.waker.wake();
}
}
}
});
@ -799,9 +861,14 @@ impl Shared {
impl task::Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
CURRENT.with(|maybe_cx| {
let cx = maybe_cx.expect("scheduler context missing");
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
let ctx = clone_rc(maybe_cx);
match ctx {
None => panic!("scheduler context missing"),
Some(cx) => {
assert!(cx.shared.ptr_eq(self));
cx.owned.remove(task)
}
}
})
}
@ -821,13 +888,15 @@ impl task::Schedule for Arc<Shared> {
// This hook is only called from within the runtime, so
// `CURRENT` should match with `&self`, i.e. there is no
// opportunity for a nested scheduler to be called.
CURRENT.with(|maybe_cx| match maybe_cx {
CURRENT.with(|maybe_cx| {
let ctx = clone_rc(maybe_cx);
match ctx {
Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
cx.unhandled_panic.set(true);
cx.owned.close_and_shutdown_all();
}
_ => unreachable!("runtime core not set in CURRENT thread-local"),
})
}})
}
}
}

View File

@ -299,7 +299,7 @@ cfg_rt! {
}
mod local;
pub use local::{spawn_local, LocalSet};
pub use local::{spawn_local, LocalSet, LocalEnterGuard};
mod task_local;
pub use task_local::LocalKey;

View File

@ -135,6 +135,21 @@ async fn local_threadpool_timer() {
})
.await;
}
#[test]
fn enter_guard_spawn() {
let local = LocalSet::new();
let _guard = local.enter();
// Run the local task set.
let join = task::spawn_local(async { true });
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
local.block_on(&rt, async move {
assert!(join.await.unwrap());
});
}
#[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
#[test]