Implement suspending the work queue (#3919)

* Implement suspending the work queue

* Add test case

* Rename variable

* Fix race condition

* Move imports out
This commit is contained in:
Dániel Buga 2025-08-10 14:19:26 +02:00 committed by GitHub
parent 66be1e1001
commit 5b24baf904
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 196 additions and 17 deletions

View File

@ -829,7 +829,7 @@ impl<'d> AesBackend<'d> {
/// ```
pub fn start(&mut self) -> AesWorkQueueDriver<'_, 'd> {
AesWorkQueueDriver {
_inner: WorkQueueDriver::new(self, BLOCKING_AES_VTABLE, &AES_WORK_QUEUE),
inner: WorkQueueDriver::new(self, BLOCKING_AES_VTABLE, &AES_WORK_QUEUE),
}
}
@ -871,7 +871,14 @@ pub enum Error {
///
/// This object must be kept around, otherwise AES operations will never complete.
pub struct AesWorkQueueDriver<'t, 'd> {
_inner: WorkQueueDriver<'t, AesBackend<'d>, AesOperation>,
inner: WorkQueueDriver<'t, AesBackend<'d>, AesOperation>,
}
impl<'t, 'd> AesWorkQueueDriver<'t, 'd> {
/// Finishes processing the current work queue item, then stops the driver.
pub fn stop(self) -> impl Future<Output = ()> {
self.inner.stop()
}
}
/// An AES work queue user.

View File

@ -12,7 +12,7 @@
//! the work item has been processed. Dropping the handle will cancel the work item.
#![cfg_attr(esp32c2, allow(unused))]
use core::{future::poll_fn, marker::PhantomData, ptr::NonNull};
use core::{future::poll_fn, marker::PhantomData, ptr::NonNull, task::Context};
use embassy_sync::waitqueue::WakerRegistration;
@ -70,6 +70,14 @@ struct Inner<T: Sync + Send> {
// The data pointer will be passed to VTable functions, which may be called in any context.
data: NonNull<()>,
vtable: VTable<T>,
// Counts suspend requests. When this reaches 0 again, the all wakers in the queue need to be
// waken to continue processing.
suspend_count: usize,
// The task waiting for the queue to be suspended. There can be multiple tasks, but that's
// practically rare (in this setup, it needs both HMAC and DSA to want to work at the same
// time).
suspend_waker: WakerRegistration,
}
impl<T: Sync + Send> Inner<T> {
@ -112,7 +120,14 @@ impl<T: Sync + Send> Inner<T> {
Poll::Ready(status) => {
unsafe { current.as_mut() }.complete(status);
self.current = None;
self.dequeue_and_post(true)
if self.suspend_count > 0 {
// Queue suspended, stop the driver.
(self.vtable.stop)(self.data);
self.suspend_waker.wake();
false
} else {
self.dequeue_and_post(true)
}
}
Poll::Pending(recall) => recall,
}
@ -260,6 +275,64 @@ impl<T: Sync + Send> Inner<T> {
// Did not find `ptr`.
false
}
/// Increases the suspend counter, preventing new work items from starting to be processed.
///
/// If the current work item finishes processing, the driver is shut down. Call `is_active` to
/// determine when the queue enters suspended state.
fn suspend(&mut self, ctx: Option<&Context<'_>>) {
self.suspend_count += 1;
if let Some(ctx) = ctx {
if self.current.is_some() {
self.suspend_waker.register(ctx.waker());
} else {
ctx.waker().wake_by_ref();
}
}
}
/// Decreases the suspend counter.
///
/// When it reaches 0, this function wakes async tasks that poll the queue. They need to be
/// waken to ensure that their items don't end up stuck. Blocking pollers will eventually end up
/// looping when their turn comes.
fn resume(&mut self) {
self.suspend_count -= 1;
if self.suspend_count == 0 {
self.wake_polling_tasks();
}
}
fn wake_polling_tasks(&mut self) {
if self.data == NonNull::dangling() {
// No VTable means no driver, no need to continue processing.
return;
}
// Walk through the list and wake polling tasks.
let mut current = self.head;
while let Some(mut current_item) = current {
let item = unsafe { current_item.as_mut() };
item.waker.wake();
current = item.next;
}
}
fn is_active(&self) -> bool {
self.current.is_some()
}
unsafe fn configure(&mut self, data: NonNull<()>, vtable: VTable<T>) {
(self.vtable.stop)(self.data);
self.data = data;
self.vtable = vtable;
if self.suspend_count == 0 {
self.wake_polling_tasks();
}
}
}
/// A generic work queue.
@ -278,6 +351,9 @@ impl<T: Sync + Send> WorkQueue<T> {
data: NonNull::dangling(),
vtable: VTable::noop(),
suspend_count: 0,
suspend_waker: WakerRegistration::new(),
}),
}
}
@ -292,12 +368,8 @@ impl<T: Sync + Send> WorkQueue<T> {
/// driver must access the data pointer appropriately (i.e. it must not move !Send data out of
/// it).
pub unsafe fn configure<D: Sync + Send>(&self, data: NonNull<D>, vtable: VTable<T>) {
self.inner.with(|inner| {
(inner.vtable.stop)(inner.data);
inner.data = data.cast();
inner.vtable = vtable;
})
self.inner
.with(|inner| unsafe { inner.configure(data.cast(), vtable) })
}
/// Enqueues a work item.
@ -522,6 +594,40 @@ where
_marker: PhantomData,
}
}
/// Shuts down the driver.
pub fn stop(self) -> impl Future<Output = ()> {
let mut suspended = false;
poll_fn(move |ctx| {
self.queue.inner.with(|inner| {
if !inner.is_active() {
unsafe {
// Safety: the noop VTable functions don't use the pointer at all.
self.queue
.configure(NonNull::<D>::dangling(), VTable::noop())
};
// Make sure the queue doesn't remain suspended when the driver is re-started.
if suspended {
inner.resume();
}
return core::task::Poll::Ready(());
}
// This may kick out other suspend() callers, but that should be okay. They will
// only be able to do work if the queue is !active, for them it doesn't matter if
// the queue is suspended or stopped completely - just that it isn't running. As for
// the possible waker churn, we can use MultiWakerRegistration with a capacity
// suitable for the number of possible suspenders (2-3 unless the work queue ends up
// being used more widely), if this turns out to be a problem.
inner.suspend_waker.register(ctx.waker());
if !suspended {
inner.suspend(Some(ctx));
suspended = true;
}
core::task::Poll::Pending
})
})
}
}
impl<D, T> Drop for WorkQueueDriver<'_, D, T>
@ -530,11 +636,36 @@ where
T: Sync + Send,
{
fn drop(&mut self) {
unsafe {
// Safety: the noop VTable functions don't use the pointer at all.
self.queue
.configure(NonNull::<D>::dangling(), VTable::noop())
};
let wait_for_suspended = self.queue.inner.with(|inner| {
if inner.is_active() {
inner.suspend(None);
true
} else {
unsafe { inner.configure(NonNull::dangling(), VTable::noop()) };
false
}
});
if !wait_for_suspended {
return;
}
loop {
let done = self.queue.inner.with(|inner| {
if inner.is_active() {
return false;
}
unsafe { inner.configure(NonNull::dangling(), VTable::noop()) };
inner.resume();
true
});
if done {
break;
}
}
}
}

View File

@ -6,6 +6,8 @@
#![no_std]
#![no_main]
use embassy_executor::Spawner;
use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
use esp_hal::{
Config,
aes::{
@ -19,7 +21,7 @@ use esp_hal::{
},
clock::CpuClock,
};
use hil_test as _;
use hil_test::mk_static;
const KEY: &[u8] = b"SUp4SeCp@sSw0rd";
const KEY_128: [u8; 16] = pad_to::<16>(KEY);
@ -305,7 +307,7 @@ fn run_cipher_tests(buffer: &mut [u8]) {
}
#[cfg(test)]
#[embedded_test::tests(default_timeout = 3)]
#[embedded_test::tests(default_timeout = 3, executor = hil_test::Executor::new())]
mod tests {
use super::*;
@ -452,6 +454,45 @@ mod tests {
hil_test::assert_eq!(output, CIPHERTEXT_ECB_128);
}
#[test]
async fn test_aes_work_queue_work_posted_before_queue_started_async() {
#[embassy_executor::task]
async fn aes_task(signal: &'static Signal<CriticalSectionRawMutex, ()>) {
let mut output = [0; PLAINTEXT_BUF_SIZE];
let mut plaintext = [0; PLAINTEXT_BUF_SIZE];
fill_with_plaintext(&mut plaintext);
let mut ecb_encrypt = AesContext::new(Ecb, Operation::Encrypt, KEY_128);
let mut handle = ecb_encrypt.process(&plaintext, &mut output).unwrap();
// Backend can start now
signal.signal(());
handle.wait().await;
core::mem::drop(handle);
hil_test::assert_eq!(output, CIPHERTEXT_ECB_128);
// Test can end now
signal.signal(());
}
let p = esp_hal::init(Config::default().with_cpu_clock(CpuClock::max()));
let signal = mk_static!(Signal<CriticalSectionRawMutex, ()>, Signal::new());
// Start task before we'd start the AES operation
let spawner = Spawner::for_current_executor().await;
spawner.must_spawn(aes_task(signal));
signal.wait().await;
let mut aes = AesBackend::new(p.AES);
let _backend = aes.start();
signal.wait().await;
}
#[test]
fn test_aes_work_queue_in_place() {
let p = esp_hal::init(Config::default().with_cpu_clock(CpuClock::max()));