mirror of
https://github.com/esp-rs/esp-hal.git
synced 2025-10-02 06:40:47 +00:00
Implement suspending the work queue (#3919)
* Implement suspending the work queue * Add test case * Rename variable * Fix race condition * Move imports out
This commit is contained in:
parent
66be1e1001
commit
5b24baf904
@ -829,7 +829,7 @@ impl<'d> AesBackend<'d> {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn start(&mut self) -> AesWorkQueueDriver<'_, 'd> {
|
pub fn start(&mut self) -> AesWorkQueueDriver<'_, 'd> {
|
||||||
AesWorkQueueDriver {
|
AesWorkQueueDriver {
|
||||||
_inner: WorkQueueDriver::new(self, BLOCKING_AES_VTABLE, &AES_WORK_QUEUE),
|
inner: WorkQueueDriver::new(self, BLOCKING_AES_VTABLE, &AES_WORK_QUEUE),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -871,7 +871,14 @@ pub enum Error {
|
|||||||
///
|
///
|
||||||
/// This object must be kept around, otherwise AES operations will never complete.
|
/// This object must be kept around, otherwise AES operations will never complete.
|
||||||
pub struct AesWorkQueueDriver<'t, 'd> {
|
pub struct AesWorkQueueDriver<'t, 'd> {
|
||||||
_inner: WorkQueueDriver<'t, AesBackend<'d>, AesOperation>,
|
inner: WorkQueueDriver<'t, AesBackend<'d>, AesOperation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'t, 'd> AesWorkQueueDriver<'t, 'd> {
|
||||||
|
/// Finishes processing the current work queue item, then stops the driver.
|
||||||
|
pub fn stop(self) -> impl Future<Output = ()> {
|
||||||
|
self.inner.stop()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An AES work queue user.
|
/// An AES work queue user.
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
//! the work item has been processed. Dropping the handle will cancel the work item.
|
//! the work item has been processed. Dropping the handle will cancel the work item.
|
||||||
#![cfg_attr(esp32c2, allow(unused))]
|
#![cfg_attr(esp32c2, allow(unused))]
|
||||||
|
|
||||||
use core::{future::poll_fn, marker::PhantomData, ptr::NonNull};
|
use core::{future::poll_fn, marker::PhantomData, ptr::NonNull, task::Context};
|
||||||
|
|
||||||
use embassy_sync::waitqueue::WakerRegistration;
|
use embassy_sync::waitqueue::WakerRegistration;
|
||||||
|
|
||||||
@ -70,6 +70,14 @@ struct Inner<T: Sync + Send> {
|
|||||||
// The data pointer will be passed to VTable functions, which may be called in any context.
|
// The data pointer will be passed to VTable functions, which may be called in any context.
|
||||||
data: NonNull<()>,
|
data: NonNull<()>,
|
||||||
vtable: VTable<T>,
|
vtable: VTable<T>,
|
||||||
|
|
||||||
|
// Counts suspend requests. When this reaches 0 again, the all wakers in the queue need to be
|
||||||
|
// waken to continue processing.
|
||||||
|
suspend_count: usize,
|
||||||
|
// The task waiting for the queue to be suspended. There can be multiple tasks, but that's
|
||||||
|
// practically rare (in this setup, it needs both HMAC and DSA to want to work at the same
|
||||||
|
// time).
|
||||||
|
suspend_waker: WakerRegistration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Sync + Send> Inner<T> {
|
impl<T: Sync + Send> Inner<T> {
|
||||||
@ -112,8 +120,15 @@ impl<T: Sync + Send> Inner<T> {
|
|||||||
Poll::Ready(status) => {
|
Poll::Ready(status) => {
|
||||||
unsafe { current.as_mut() }.complete(status);
|
unsafe { current.as_mut() }.complete(status);
|
||||||
self.current = None;
|
self.current = None;
|
||||||
|
if self.suspend_count > 0 {
|
||||||
|
// Queue suspended, stop the driver.
|
||||||
|
(self.vtable.stop)(self.data);
|
||||||
|
self.suspend_waker.wake();
|
||||||
|
false
|
||||||
|
} else {
|
||||||
self.dequeue_and_post(true)
|
self.dequeue_and_post(true)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Poll::Pending(recall) => recall,
|
Poll::Pending(recall) => recall,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -260,6 +275,64 @@ impl<T: Sync + Send> Inner<T> {
|
|||||||
// Did not find `ptr`.
|
// Did not find `ptr`.
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Increases the suspend counter, preventing new work items from starting to be processed.
|
||||||
|
///
|
||||||
|
/// If the current work item finishes processing, the driver is shut down. Call `is_active` to
|
||||||
|
/// determine when the queue enters suspended state.
|
||||||
|
fn suspend(&mut self, ctx: Option<&Context<'_>>) {
|
||||||
|
self.suspend_count += 1;
|
||||||
|
if let Some(ctx) = ctx {
|
||||||
|
if self.current.is_some() {
|
||||||
|
self.suspend_waker.register(ctx.waker());
|
||||||
|
} else {
|
||||||
|
ctx.waker().wake_by_ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decreases the suspend counter.
|
||||||
|
///
|
||||||
|
/// When it reaches 0, this function wakes async tasks that poll the queue. They need to be
|
||||||
|
/// waken to ensure that their items don't end up stuck. Blocking pollers will eventually end up
|
||||||
|
/// looping when their turn comes.
|
||||||
|
fn resume(&mut self) {
|
||||||
|
self.suspend_count -= 1;
|
||||||
|
if self.suspend_count == 0 {
|
||||||
|
self.wake_polling_tasks();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wake_polling_tasks(&mut self) {
|
||||||
|
if self.data == NonNull::dangling() {
|
||||||
|
// No VTable means no driver, no need to continue processing.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Walk through the list and wake polling tasks.
|
||||||
|
let mut current = self.head;
|
||||||
|
while let Some(mut current_item) = current {
|
||||||
|
let item = unsafe { current_item.as_mut() };
|
||||||
|
|
||||||
|
item.waker.wake();
|
||||||
|
|
||||||
|
current = item.next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_active(&self) -> bool {
|
||||||
|
self.current.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn configure(&mut self, data: NonNull<()>, vtable: VTable<T>) {
|
||||||
|
(self.vtable.stop)(self.data);
|
||||||
|
|
||||||
|
self.data = data;
|
||||||
|
self.vtable = vtable;
|
||||||
|
|
||||||
|
if self.suspend_count == 0 {
|
||||||
|
self.wake_polling_tasks();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A generic work queue.
|
/// A generic work queue.
|
||||||
@ -278,6 +351,9 @@ impl<T: Sync + Send> WorkQueue<T> {
|
|||||||
|
|
||||||
data: NonNull::dangling(),
|
data: NonNull::dangling(),
|
||||||
vtable: VTable::noop(),
|
vtable: VTable::noop(),
|
||||||
|
|
||||||
|
suspend_count: 0,
|
||||||
|
suspend_waker: WakerRegistration::new(),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -292,12 +368,8 @@ impl<T: Sync + Send> WorkQueue<T> {
|
|||||||
/// driver must access the data pointer appropriately (i.e. it must not move !Send data out of
|
/// driver must access the data pointer appropriately (i.e. it must not move !Send data out of
|
||||||
/// it).
|
/// it).
|
||||||
pub unsafe fn configure<D: Sync + Send>(&self, data: NonNull<D>, vtable: VTable<T>) {
|
pub unsafe fn configure<D: Sync + Send>(&self, data: NonNull<D>, vtable: VTable<T>) {
|
||||||
self.inner.with(|inner| {
|
self.inner
|
||||||
(inner.vtable.stop)(inner.data);
|
.with(|inner| unsafe { inner.configure(data.cast(), vtable) })
|
||||||
|
|
||||||
inner.data = data.cast();
|
|
||||||
inner.vtable = vtable;
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enqueues a work item.
|
/// Enqueues a work item.
|
||||||
@ -522,6 +594,40 @@ where
|
|||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Shuts down the driver.
|
||||||
|
pub fn stop(self) -> impl Future<Output = ()> {
|
||||||
|
let mut suspended = false;
|
||||||
|
poll_fn(move |ctx| {
|
||||||
|
self.queue.inner.with(|inner| {
|
||||||
|
if !inner.is_active() {
|
||||||
|
unsafe {
|
||||||
|
// Safety: the noop VTable functions don't use the pointer at all.
|
||||||
|
self.queue
|
||||||
|
.configure(NonNull::<D>::dangling(), VTable::noop())
|
||||||
|
};
|
||||||
|
// Make sure the queue doesn't remain suspended when the driver is re-started.
|
||||||
|
if suspended {
|
||||||
|
inner.resume();
|
||||||
|
}
|
||||||
|
return core::task::Poll::Ready(());
|
||||||
|
}
|
||||||
|
// This may kick out other suspend() callers, but that should be okay. They will
|
||||||
|
// only be able to do work if the queue is !active, for them it doesn't matter if
|
||||||
|
// the queue is suspended or stopped completely - just that it isn't running. As for
|
||||||
|
// the possible waker churn, we can use MultiWakerRegistration with a capacity
|
||||||
|
// suitable for the number of possible suspenders (2-3 unless the work queue ends up
|
||||||
|
// being used more widely), if this turns out to be a problem.
|
||||||
|
inner.suspend_waker.register(ctx.waker());
|
||||||
|
if !suspended {
|
||||||
|
inner.suspend(Some(ctx));
|
||||||
|
suspended = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
core::task::Poll::Pending
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, T> Drop for WorkQueueDriver<'_, D, T>
|
impl<D, T> Drop for WorkQueueDriver<'_, D, T>
|
||||||
@ -530,11 +636,36 @@ where
|
|||||||
T: Sync + Send,
|
T: Sync + Send,
|
||||||
{
|
{
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
unsafe {
|
let wait_for_suspended = self.queue.inner.with(|inner| {
|
||||||
// Safety: the noop VTable functions don't use the pointer at all.
|
if inner.is_active() {
|
||||||
self.queue
|
inner.suspend(None);
|
||||||
.configure(NonNull::<D>::dangling(), VTable::noop())
|
true
|
||||||
};
|
} else {
|
||||||
|
unsafe { inner.configure(NonNull::dangling(), VTable::noop()) };
|
||||||
|
false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if !wait_for_suspended {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let done = self.queue.inner.with(|inner| {
|
||||||
|
if inner.is_active() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe { inner.configure(NonNull::dangling(), VTable::noop()) };
|
||||||
|
|
||||||
|
inner.resume();
|
||||||
|
|
||||||
|
true
|
||||||
|
});
|
||||||
|
if done {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@
|
|||||||
#![no_std]
|
#![no_std]
|
||||||
#![no_main]
|
#![no_main]
|
||||||
|
|
||||||
|
use embassy_executor::Spawner;
|
||||||
|
use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
|
||||||
use esp_hal::{
|
use esp_hal::{
|
||||||
Config,
|
Config,
|
||||||
aes::{
|
aes::{
|
||||||
@ -19,7 +21,7 @@ use esp_hal::{
|
|||||||
},
|
},
|
||||||
clock::CpuClock,
|
clock::CpuClock,
|
||||||
};
|
};
|
||||||
use hil_test as _;
|
use hil_test::mk_static;
|
||||||
|
|
||||||
const KEY: &[u8] = b"SUp4SeCp@sSw0rd";
|
const KEY: &[u8] = b"SUp4SeCp@sSw0rd";
|
||||||
const KEY_128: [u8; 16] = pad_to::<16>(KEY);
|
const KEY_128: [u8; 16] = pad_to::<16>(KEY);
|
||||||
@ -305,7 +307,7 @@ fn run_cipher_tests(buffer: &mut [u8]) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[embedded_test::tests(default_timeout = 3)]
|
#[embedded_test::tests(default_timeout = 3, executor = hil_test::Executor::new())]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@ -452,6 +454,45 @@ mod tests {
|
|||||||
hil_test::assert_eq!(output, CIPHERTEXT_ECB_128);
|
hil_test::assert_eq!(output, CIPHERTEXT_ECB_128);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
async fn test_aes_work_queue_work_posted_before_queue_started_async() {
|
||||||
|
#[embassy_executor::task]
|
||||||
|
async fn aes_task(signal: &'static Signal<CriticalSectionRawMutex, ()>) {
|
||||||
|
let mut output = [0; PLAINTEXT_BUF_SIZE];
|
||||||
|
let mut plaintext = [0; PLAINTEXT_BUF_SIZE];
|
||||||
|
fill_with_plaintext(&mut plaintext);
|
||||||
|
|
||||||
|
let mut ecb_encrypt = AesContext::new(Ecb, Operation::Encrypt, KEY_128);
|
||||||
|
let mut handle = ecb_encrypt.process(&plaintext, &mut output).unwrap();
|
||||||
|
|
||||||
|
// Backend can start now
|
||||||
|
signal.signal(());
|
||||||
|
|
||||||
|
handle.wait().await;
|
||||||
|
core::mem::drop(handle);
|
||||||
|
|
||||||
|
hil_test::assert_eq!(output, CIPHERTEXT_ECB_128);
|
||||||
|
|
||||||
|
// Test can end now
|
||||||
|
signal.signal(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let p = esp_hal::init(Config::default().with_cpu_clock(CpuClock::max()));
|
||||||
|
|
||||||
|
let signal = mk_static!(Signal<CriticalSectionRawMutex, ()>, Signal::new());
|
||||||
|
|
||||||
|
// Start task before we'd start the AES operation
|
||||||
|
let spawner = Spawner::for_current_executor().await;
|
||||||
|
spawner.must_spawn(aes_task(signal));
|
||||||
|
|
||||||
|
signal.wait().await;
|
||||||
|
|
||||||
|
let mut aes = AesBackend::new(p.AES);
|
||||||
|
let _backend = aes.start();
|
||||||
|
|
||||||
|
signal.wait().await;
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_aes_work_queue_in_place() {
|
fn test_aes_work_queue_in_place() {
|
||||||
let p = esp_hal::init(Config::default().with_cpu_clock(CpuClock::max()));
|
let p = esp_hal::init(Config::default().with_cpu_clock(CpuClock::max()));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user