diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs index df0f5e815..5d91b4d9c 100644 --- a/embassy-sync/src/lib.rs +++ b/embassy-sync/src/lib.rs @@ -18,6 +18,7 @@ pub mod once_lock; pub mod pipe; pub mod priority_channel; pub mod pubsub; +pub mod rwlock; pub mod semaphore; pub mod signal; pub mod waitqueue; diff --git a/embassy-sync/src/rwlock.rs b/embassy-sync/src/rwlock.rs new file mode 100644 index 000000000..15ea8468e --- /dev/null +++ b/embassy-sync/src/rwlock.rs @@ -0,0 +1,256 @@ +use core::cell::RefCell; +use core::future::{poll_fn, Future}; +use core::ops::{Deref, DerefMut}; +use core::task::Poll; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex as BlockingMutex; +use crate::waitqueue::MultiWakerRegistration; + +/// Error returned by [`RwLock::try_read`] and [`RwLock::try_write`] +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct TryLockError; + +/// Async read-write lock. +/// +/// The lock is generic over a blocking [`RawMutex`](crate::blocking_mutex::raw::RawMutex). +/// The raw mutex is used to guard access to the internal state. It +/// is held for very short periods only, while locking and unlocking. It is *not* held +/// for the entire time the async RwLock is locked. +/// +/// Which implementation you select depends on the context in which you're using the lock. +/// +/// Use [`CriticalSectionRawMutex`](crate::blocking_mutex::raw::CriticalSectionRawMutex) when data can be shared between threads and interrupts. +/// +/// Use [`NoopRawMutex`](crate::blocking_mutex::raw::NoopRawMutex) when data is only shared between tasks running on the same executor. +/// +/// Use [`ThreadModeRawMutex`](crate::blocking_mutex::raw::ThreadModeRawMutex) when data is shared between tasks running on the same executor but you want a singleton. +/// +pub struct RwLock +where + M: RawMutex, + T: ?Sized, +{ + state: BlockingMutex>, + inner: RefCell, +} + +struct State { + readers: usize, + writer: bool, + writer_waker: MultiWakerRegistration<1>, + reader_wakers: MultiWakerRegistration<8>, +} + +impl State { + fn new() -> Self { + Self { + readers: 0, + writer: false, + writer_waker: MultiWakerRegistration::new(), + reader_wakers: MultiWakerRegistration::new(), + } + } +} + +impl RwLock +where + M: RawMutex, +{ + /// Create a new read-write lock with the given value. + pub const fn new(value: T) -> Self { + Self { + inner: RefCell::new(value), + state: BlockingMutex::new(RefCell::new(State::new())), + } + } +} + +impl RwLock +where + M: RawMutex, + T: ?Sized, +{ + /// Acquire a read lock. + /// + /// This will wait for the lock to be available if it's already locked for writing. + pub fn read(&self) -> impl Future> { + poll_fn(|cx| { + let mut state = self.state.lock(|s| s.borrow_mut()); + if state.writer { + state.reader_wakers.register(cx.waker()); + Poll::Pending + } else { + state.readers += 1; + Poll::Ready(RwLockReadGuard { lock: self }) + } + }) + } + + /// Acquire a write lock. + /// + /// This will wait for the lock to be available if it's already locked for reading or writing. + pub fn write(&self) -> impl Future> { + poll_fn(|cx| { + let mut state = self.state.lock(|s| s.borrow_mut()); + if state.writer || state.readers > 0 { + state.writer_waker.register(cx.waker()); + Poll::Pending + } else { + state.writer = true; + Poll::Ready(RwLockWriteGuard { lock: self }) + } + }) + } + + /// Attempt to immediately acquire a read lock. + /// + /// If the lock is already locked for writing, this will return an error instead of waiting. + pub fn try_read(&self) -> Result, TryLockError> { + let mut state = self.state.lock(|s| s.borrow_mut()); + if state.writer { + Err(TryLockError) + } else { + state.readers += 1; + Ok(RwLockReadGuard { lock: self }) + } + } + + /// Attempt to immediately acquire a write lock. + /// + /// If the lock is already locked for reading or writing, this will return an error instead of waiting. + pub fn try_write(&self) -> Result, TryLockError> { + let mut state = self.state.lock(|s| s.borrow_mut()); + if state.writer || state.readers > 0 { + Err(TryLockError) + } else { + state.writer = true; + Ok(RwLockWriteGuard { lock: self }) + } + } + + /// Consumes this lock, returning the underlying data. + pub fn into_inner(self) -> T + where + T: Sized, + { + self.inner.into_inner() + } + + /// Returns a mutable reference to the underlying data. + /// + /// Since this call borrows the RwLock mutably, no actual locking needs to + /// take place -- the mutable borrow statically guarantees no locks exist. + pub fn get_mut(&mut self) -> &mut T { + self.inner.get_mut() + } +} + +impl From for RwLock +where + M: RawMutex, +{ + fn from(from: T) -> Self { + Self::new(from) + } +} + +impl Default for RwLock +where + M: RawMutex, + T: Default, +{ + fn default() -> Self { + Self::new(Default::default()) + } +} + +/// Async read lock guard. +/// +/// Owning an instance of this type indicates having +/// successfully locked the RwLock for reading, and grants access to the contents. +/// +/// Dropping it unlocks the RwLock. +#[must_use = "if unused the RwLock will immediately unlock"] +pub struct RwLockReadGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + lock: &'a RwLock, +} + +impl<'a, M, T> Drop for RwLockReadGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn drop(&mut self) { + let mut state = self.lock.state.lock(|s| s.borrow_mut()); + state.readers -= 1; + if state.readers == 0 { + state.writer_waker.wake(); + } + } +} + +impl<'a, M, T> Deref for RwLockReadGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + type Target = T; + fn deref(&self) -> &Self::Target { + self.lock.inner.borrow() + } +} + +/// Async write lock guard. +/// +/// Owning an instance of this type indicates having +/// successfully locked the RwLock for writing, and grants access to the contents. +/// +/// Dropping it unlocks the RwLock. +#[must_use = "if unused the RwLock will immediately unlock"] +pub struct RwLockWriteGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + lock: &'a RwLock, +} + +impl<'a, M, T> Drop for RwLockWriteGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn drop(&mut self) { + let mut state = self.lock.state.lock(|s| s.borrow_mut()); + state.writer = false; + state.reader_wakers.wake(); + state.writer_waker.wake(); + } +} + +impl<'a, M, T> Deref for RwLockWriteGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + type Target = T; + fn deref(&self) -> &Self::Target { + self.lock.inner.borrow() + } +} + +impl<'a, M, T> DerefMut for RwLockWriteGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.lock.inner.borrow_mut() + } +} diff --git a/embassy-sync/tests/rwlock.rs b/embassy-sync/tests/rwlock.rs new file mode 100644 index 000000000..301cb0492 --- /dev/null +++ b/embassy-sync/tests/rwlock.rs @@ -0,0 +1,81 @@ +use embassy_sync::blocking_mutex::raw::NoopRawMutex; +use embassy_sync::rwlock::RwLock; +use futures_executor::block_on; + +#[futures_test::test] +async fn test_rwlock_read() { + let lock = RwLock::::new(5); + + { + let read_guard = lock.read().await; + assert_eq!(*read_guard, 5); + } + + { + let read_guard = lock.read().await; + assert_eq!(*read_guard, 5); + } +} + +#[futures_test::test] +async fn test_rwlock_write() { + let lock = RwLock::::new(5); + + { + let mut write_guard = lock.write().await; + *write_guard = 10; + } + + { + let read_guard = lock.read().await; + assert_eq!(*read_guard, 10); + } +} + +#[futures_test::test] +async fn test_rwlock_try_read() { + let lock = RwLock::::new(5); + + { + let read_guard = lock.try_read().unwrap(); + assert_eq!(*read_guard, 5); + } + + { + let read_guard = lock.try_read().unwrap(); + assert_eq!(*read_guard, 5); + } +} + +#[futures_test::test] +async fn test_rwlock_try_write() { + let lock = RwLock::::new(5); + + { + let mut write_guard = lock.try_write().unwrap(); + *write_guard = 10; + } + + { + let read_guard = lock.try_read().unwrap(); + assert_eq!(*read_guard, 10); + } +} + +#[futures_test::test] +async fn test_rwlock_fairness() { + let lock = RwLock::::new(5); + + let read1 = lock.read().await; + let read2 = lock.read().await; + + let write_fut = lock.write(); + futures_util::pin_mut!(write_fut); + + assert!(futures_util::poll!(write_fut.as_mut()).is_pending()); + + drop(read1); + drop(read2); + + assert!(futures_util::poll!(write_fut.as_mut()).is_ready()); +}