task: fix missing wakeup when using LocalSet::enter (#6016)

This commit is contained in:
inkyu 2023-10-16 02:18:06 +09:00 committed by GitHub
parent f1e41a4ad4
commit f3ad6cffd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 43 deletions

View File

@ -280,10 +280,43 @@ pin_project! {
tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
ctx: RcCell::new(),
wake_on_schedule: Cell::new(false),
} });
struct LocalData {
ctx: RcCell<Context>,
wake_on_schedule: Cell<bool>,
}
impl LocalData {
/// Should be called except when we call `LocalSet::enter`.
/// Especially when we poll a LocalSet.
#[must_use = "dropping this guard will reset the entered state"]
fn enter(&self, ctx: Rc<Context>) -> LocalDataEnterGuard<'_> {
let ctx = self.ctx.replace(Some(ctx));
let wake_on_schedule = self.wake_on_schedule.replace(false);
LocalDataEnterGuard {
local_data_ref: self,
ctx,
wake_on_schedule,
}
}
}
/// A guard for `LocalData::enter()`
struct LocalDataEnterGuard<'a> {
local_data_ref: &'a LocalData,
ctx: Option<Rc<Context>>,
wake_on_schedule: bool,
}
impl<'a> Drop for LocalDataEnterGuard<'a> {
fn drop(&mut self) {
self.local_data_ref.ctx.set(self.ctx.take());
self.local_data_ref
.wake_on_schedule
.set(self.wake_on_schedule)
}
}
cfg_rt! {
@ -360,13 +393,26 @@ const MAX_TASKS_PER_TICK: usize = 61;
const REMOTE_FIRST_INTERVAL: u8 = 31;
/// Context guard for LocalSet
pub struct LocalEnterGuard(Option<Rc<Context>>);
pub struct LocalEnterGuard {
ctx: Option<Rc<Context>>,
/// Distinguishes whether the context was entered or being polled.
/// When we enter it, the value `wake_on_schedule` is set. In this case
/// `spawn_local` refers the context, whereas it is not being polled now.
wake_on_schedule: bool,
}
impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|LocalData { ctx, .. }| {
ctx.set(self.0.take());
})
CURRENT.with(
|LocalData {
ctx,
wake_on_schedule,
}| {
ctx.set(self.ctx.take());
wake_on_schedule.set(self.wake_on_schedule);
},
)
}
}
@ -408,10 +454,20 @@ impl LocalSet {
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|LocalData { ctx, .. }| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
CURRENT.with(
|LocalData {
ctx,
wake_on_schedule,
..
}| {
let ctx = ctx.replace(Some(self.context.clone()));
let wake_on_schedule = wake_on_schedule.replace(true);
LocalEnterGuard {
ctx,
wake_on_schedule,
}
},
)
}
/// Spawns a `!Send` task onto the local task set.
@ -667,23 +723,8 @@ impl LocalSet {
}
fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.set(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));
let _reset = Reset {
ctx_ref: ctx,
val: old,
};
CURRENT.with(|local_data| {
let _guard = local_data.enter(self.context.clone());
f()
})
}
@ -693,23 +734,8 @@ impl LocalSet {
fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
let mut f = Some(f);
let res = CURRENT.try_with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
}
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
self.ctx_ref.replace(self.val.take());
}
}
let old = ctx.replace(Some(self.context.clone()));
let _reset = Reset {
ctx_ref: ctx,
val: old,
};
let res = CURRENT.try_with(|local_data| {
let _guard = local_data.enter(self.context.clone());
(f.take().unwrap())()
});
@ -967,7 +993,10 @@ impl Shared {
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|localdata| {
match localdata.ctx.get() {
Some(cx) if cx.shared.ptr_eq(self) => unsafe {
// If the current `LocalSet` is being polled, we don't need to wake it.
// When we `enter` it, then the value `wake_on_schedule` is set to be true.
// In this case it is not being polled, so we need to wake it.
Some(cx) if cx.shared.ptr_eq(self) && !localdata.wake_on_schedule.get() => unsafe {
// Safety: if the current `LocalSet` context points to this
// `LocalSet`, then we are on the thread that owns it.
cx.shared.local_state.task_push_back(task);

View File

@ -573,6 +573,27 @@ async fn spawn_wakes_localset() {
}
}
/// Checks that the task wakes up with `enter`.
/// Reproduces <https://github.com/tokio-rs/tokio/issues/5020>.
#[tokio::test]
async fn sleep_with_local_enter_guard() {
let local = LocalSet::new();
let _guard = local.enter();
let (tx, rx) = oneshot::channel();
local
.run_until(async move {
tokio::task::spawn_local(async move {
time::sleep(Duration::ZERO).await;
tx.send(()).expect("failed to send");
});
assert_eq!(rx.await, Ok(()));
})
.await;
}
#[test]
fn store_local_set_in_thread_local_with_runtime() {
use tokio::runtime::Runtime;