Merge pull request #3655 from bugadani/header-executor

Fix racy access of TaskHeader::executor
This commit is contained in:
Dario Nieuwenhuis 2024-12-16 16:52:21 +00:00 committed by GitHub
commit d3f0294fb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 12 deletions

View File

@ -28,6 +28,7 @@ use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::ptr::NonNull;
use core::sync::atomic::{AtomicPtr, Ordering};
use core::task::{Context, Poll};
use self::run_queue::{RunQueue, RunQueueItem};
@ -40,7 +41,7 @@ use super::SpawnToken;
pub(crate) struct TaskHeader {
pub(crate) state: State,
pub(crate) run_queue_item: RunQueueItem,
pub(crate) executor: SyncUnsafeCell<Option<&'static SyncExecutor>>,
pub(crate) executor: AtomicPtr<SyncExecutor>,
poll_fn: SyncUnsafeCell<Option<unsafe fn(TaskRef)>>,
/// Integrated timer queue storage. This field should not be accessed outside of the timer queue.
@ -86,7 +87,8 @@ impl TaskRef {
/// Returns a reference to the executor that the task is currently running on.
pub unsafe fn executor(self) -> Option<&'static Executor> {
self.header().executor.get().map(|e| Executor::wrap(e))
let executor = self.header().executor.load(Ordering::Relaxed);
executor.as_ref().map(|e| Executor::wrap(e))
}
/// Returns a reference to the timer queue item.
@ -153,7 +155,7 @@ impl<F: Future + 'static> TaskStorage<F> {
raw: TaskHeader {
state: State::new(),
run_queue_item: RunQueueItem::new(),
executor: SyncUnsafeCell::new(None),
executor: AtomicPtr::new(core::ptr::null_mut()),
// Note: this is lazily initialized so that a static `TaskStorage` will go in `.bss`
poll_fn: SyncUnsafeCell::new(None),
@ -396,7 +398,9 @@ impl SyncExecutor {
}
pub(super) unsafe fn spawn(&'static self, task: TaskRef) {
task.header().executor.set(Some(self));
task.header()
.executor
.store((self as *const Self).cast_mut(), Ordering::Relaxed);
#[cfg(feature = "trace")]
trace::task_new(self, &task);
@ -549,7 +553,7 @@ pub fn wake_task(task: TaskRef) {
header.state.run_enqueue(|l| {
// We have just marked the task as scheduled, so enqueue it.
unsafe {
let executor = header.executor.get().unwrap_unchecked();
let executor = header.executor.load(Ordering::Relaxed).as_ref().unwrap_unchecked();
executor.enqueue(task, l);
}
});
@ -563,7 +567,7 @@ pub fn wake_task_no_pend(task: TaskRef) {
header.state.run_enqueue(|l| {
// We have just marked the task as scheduled, so enqueue it.
unsafe {
let executor = header.executor.get().unwrap_unchecked();
let executor = header.executor.load(Ordering::Relaxed).as_ref().unwrap_unchecked();
executor.run_queue.enqueue(task, l);
}
});

View File

@ -2,13 +2,14 @@ use core::sync::atomic::{AtomicU32, Ordering};
use super::timer_queue::TimerEnqueueOperation;
#[derive(Clone, Copy)]
pub(crate) struct Token(());
/// Creates a token and passes it to the closure.
///
/// This is a no-op replacement for `CriticalSection::with` because we don't need any locking.
pub(crate) fn locked(f: impl FnOnce(Token)) {
f(Token(()));
pub(crate) fn locked<R>(f: impl FnOnce(Token) -> R) -> R {
f(Token(()))
}
/// Task is spawned (has a future)

View File

@ -3,13 +3,14 @@ use core::sync::atomic::{compiler_fence, AtomicBool, AtomicU32, Ordering};
use super::timer_queue::TimerEnqueueOperation;
#[derive(Clone, Copy)]
pub(crate) struct Token(());
/// Creates a token and passes it to the closure.
///
/// This is a no-op replacement for `CriticalSection::with` because we don't need any locking.
pub(crate) fn locked(f: impl FnOnce(Token)) {
f(Token(()));
pub(crate) fn locked<R>(f: impl FnOnce(Token) -> R) -> R {
f(Token(()))
}
// Must be kept in sync with the layout of `State`!

View File

@ -1,6 +1,7 @@
use core::future::poll_fn;
use core::marker::PhantomData;
use core::mem;
use core::sync::atomic::Ordering;
use core::task::Poll;
use super::raw;
@ -92,7 +93,13 @@ impl Spawner {
pub async fn for_current_executor() -> Self {
poll_fn(|cx| {
let task = raw::task_from_waker(cx.waker());
let executor = unsafe { task.header().executor.get().unwrap_unchecked() };
let executor = unsafe {
task.header()
.executor
.load(Ordering::Relaxed)
.as_ref()
.unwrap_unchecked()
};
let executor = unsafe { raw::Executor::wrap(executor) };
Poll::Ready(Self::new(executor))
})
@ -164,7 +171,13 @@ impl SendSpawner {
pub async fn for_current_executor() -> Self {
poll_fn(|cx| {
let task = raw::task_from_waker(cx.waker());
let executor = unsafe { task.header().executor.get().unwrap_unchecked() };
let executor = unsafe {
task.header()
.executor
.load(Ordering::Relaxed)
.as_ref()
.unwrap_unchecked()
};
Poll::Ready(Self::new(executor))
})
.await