Prevent task from respawning while in the timer queue

This commit is contained in:
Dániel Buga 2024-12-09 08:43:57 +01:00
parent d45ea43892
commit ec96395d08
No known key found for this signature in database
7 changed files with 181 additions and 9 deletions

View File

@ -50,7 +50,7 @@ pub(crate) struct TaskHeader {
}
/// This is essentially a `&'static TaskStorage<F>` where the type of the future has been erased.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq)]
pub struct TaskRef {
ptr: NonNull<TaskHeader>,
}
@ -72,6 +72,16 @@ impl TaskRef {
}
}
/// # Safety
///
/// The result of this function must only be compared
/// for equality, or stored, but not used.
pub const unsafe fn dangling() -> Self {
Self {
ptr: NonNull::dangling(),
}
}
pub(crate) fn header(self) -> &'static TaskHeader {
unsafe { self.ptr.as_ref() }
}
@ -88,6 +98,30 @@ impl TaskRef {
&self.header().timer_queue_item
}
/// Mark the task as timer-queued. Return whether it was newly queued (i.e. not queued before)
///
/// Entering this state prevents the task from being respawned while in a timer queue.
///
/// Safety:
///
/// This functions should only be called by the timer queue implementation, before
/// enqueueing the timer item.
#[cfg(feature = "integrated-timers")]
pub unsafe fn timer_enqueue(&self) -> timer_queue::TimerEnqueueOperation {
self.header().state.timer_enqueue()
}
/// Unmark the task as timer-queued.
///
/// Safety:
///
/// This functions should only be called by the timer queue implementation, after the task has
/// been removed from the timer queue.
#[cfg(feature = "integrated-timers")]
pub unsafe fn timer_dequeue(&self) {
self.header().state.timer_dequeue()
}
/// The returned pointer is valid for the entire TaskStorage.
pub(crate) fn as_ptr(self) -> *const TaskHeader {
self.ptr.as_ptr()

View File

@ -1,9 +1,15 @@
use core::sync::atomic::{AtomicU32, Ordering};
#[cfg(feature = "integrated-timers")]
use super::timer_queue::TimerEnqueueOperation;
/// Task is spawned (has a future)
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
/// Task is in the executor run queue
pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1;
/// Task is in the executor timer queue
#[cfg(feature = "integrated-timers")]
pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2;
pub(crate) struct State {
state: AtomicU32,
@ -52,4 +58,34 @@ impl State {
let state = self.state.fetch_and(!STATE_RUN_QUEUED, Ordering::AcqRel);
state & STATE_SPAWNED != 0
}
/// Mark the task as timer-queued. Return whether it can be enqueued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_enqueue(&self) -> TimerEnqueueOperation {
if self
.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
// If not started, ignore it
if state & STATE_SPAWNED == 0 {
None
} else {
// Mark it as enqueued
Some(state | STATE_TIMER_QUEUED)
}
})
.is_ok()
{
TimerEnqueueOperation::Enqueue
} else {
TimerEnqueueOperation::Ignore
}
}
/// Unmark the task as timer-queued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_dequeue(&self) {
self.state.fetch_and(!STATE_TIMER_QUEUED, Ordering::Relaxed);
}
}

View File

@ -1,9 +1,14 @@
use core::arch::asm;
use core::sync::atomic::{compiler_fence, AtomicBool, AtomicU32, Ordering};
#[cfg(feature = "integrated-timers")]
use super::timer_queue::TimerEnqueueOperation;
// Must be kept in sync with the layout of `State`!
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 8;
#[cfg(feature = "integrated-timers")]
pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 16;
#[repr(C, align(4))]
pub(crate) struct State {
@ -11,8 +16,9 @@ pub(crate) struct State {
spawned: AtomicBool,
/// Task is in the executor run queue
run_queued: AtomicBool,
/// Task is in the executor timer queue
timer_queued: AtomicBool,
pad: AtomicBool,
pad2: AtomicBool,
}
impl State {
@ -20,8 +26,8 @@ impl State {
Self {
spawned: AtomicBool::new(false),
run_queued: AtomicBool::new(false),
timer_queued: AtomicBool::new(false),
pad: AtomicBool::new(false),
pad2: AtomicBool::new(false),
}
}
@ -85,4 +91,34 @@ impl State {
self.run_queued.store(false, Ordering::Relaxed);
r
}
/// Mark the task as timer-queued. Return whether it can be enqueued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_enqueue(&self) -> TimerEnqueueOperation {
if self
.as_u32()
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
// If not started, ignore it
if state & STATE_SPAWNED == 0 {
None
} else {
// Mark it as enqueued
Some(state | STATE_TIMER_QUEUED)
}
})
.is_ok()
{
TimerEnqueueOperation::Enqueue
} else {
TimerEnqueueOperation::Ignore
}
}
/// Unmark the task as timer-queued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_dequeue(&self) {
self.timer_queued.store(false, Ordering::Relaxed);
}
}

View File

@ -2,10 +2,16 @@ use core::cell::Cell;
use critical_section::Mutex;
#[cfg(feature = "integrated-timers")]
use super::timer_queue::TimerEnqueueOperation;
/// Task is spawned (has a future)
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
/// Task is in the executor run queue
pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1;
/// Task is in the executor timer queue
#[cfg(feature = "integrated-timers")]
pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2;
pub(crate) struct State {
state: Mutex<Cell<u32>>,
@ -69,4 +75,27 @@ impl State {
ok
})
}
/// Mark the task as timer-queued. Return whether it can be enqueued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_enqueue(&self) -> TimerEnqueueOperation {
self.update(|s| {
// FIXME: we need to split SPAWNED into two phases, to prevent enqueueing a task that is
// just being spawned, because its executor pointer may still be changing.
if *s & STATE_SPAWNED == STATE_SPAWNED {
*s |= STATE_TIMER_QUEUED;
TimerEnqueueOperation::Enqueue
} else {
TimerEnqueueOperation::Ignore
}
})
}
/// Unmark the task as timer-queued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_dequeue(&self) {
self.update(|s| *s &= !STATE_TIMER_QUEUED);
}
}

View File

@ -7,6 +7,9 @@ use super::TaskRef;
/// An item in the timer queue.
pub struct TimerQueueItem {
/// The next item in the queue.
///
/// If this field contains `Some`, the item is in the queue. The last item in the queue has a
/// value of `Some(dangling_pointer)`
pub next: Cell<Option<TaskRef>>,
/// The time at which this item expires.
@ -19,7 +22,17 @@ impl TimerQueueItem {
pub(crate) const fn new() -> Self {
Self {
next: Cell::new(None),
expires_at: Cell::new(0),
expires_at: Cell::new(u64::MAX),
}
}
}
/// The operation to perform after `timer_enqueue` is called.
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TimerEnqueueOperation {
/// Enqueue the task.
Enqueue,
/// Update the task's expiration time.
Ignore,
}

View File

@ -73,6 +73,20 @@ extern "Rust" {
/// Schedule the given waker to be woken at `at`.
pub fn schedule_wake(at: u64, waker: &Waker) {
#[cfg(feature = "integrated-timers")]
{
use embassy_executor::raw::task_from_waker;
use embassy_executor::raw::timer_queue::TimerEnqueueOperation;
// The very first thing we must do, before we even access the timer queue, is to
// mark the task a TIMER_QUEUED. This ensures that the task that is being scheduled
// can not be respawn while we are accessing the timer queue.
let task = task_from_waker(waker);
if unsafe { task.timer_enqueue() } == TimerEnqueueOperation::Ignore {
// We are not allowed to enqueue the task in the timer queue. This is because the
// task is not spawned, and so it makes no sense to schedule it.
return;
}
}
unsafe { _embassy_time_schedule_wake(at, waker) }
}

View File

@ -24,16 +24,21 @@ impl TimerQueue {
if item.next.get().is_none() {
// If not in the queue, add it and update.
let prev = self.head.replace(Some(p));
item.next.set(prev);
item.next.set(if prev.is_none() {
Some(unsafe { TaskRef::dangling() })
} else {
prev
});
item.expires_at.set(at);
true
} else if at <= item.expires_at.get() {
// If expiration is sooner than previously set, update.
item.expires_at.set(at);
true
} else {
// Task does not need to be updated.
return false;
false
}
item.expires_at.set(at);
true
}
/// Dequeues expired timers and returns the next alarm time.
@ -64,6 +69,10 @@ impl TimerQueue {
fn retain(&self, mut f: impl FnMut(TaskRef) -> bool) {
let mut prev = &self.head;
while let Some(p) = prev.get() {
if unsafe { p == TaskRef::dangling() } {
// prev was the last item, stop
break;
}
let item = p.timer_queue_item();
if f(p) {
// Skip to next
@ -72,6 +81,7 @@ impl TimerQueue {
// Remove it
prev.set(item.next.get());
item.next.set(None);
unsafe { p.timer_dequeue() };
}
}
}