RMT: refactor copies to/from channel RAM (#4126)

* RMT: add check_data_eq to log more details on test failure

if defmt is enabled

* RMT: add minimal RmtWriter

- de-duplicates the copy-to-hardware code
- splits copying data to the buffer and `start_send`
- this also paves the way for supporting other data types (like
  iterators instead of slices)

* RMT: add minimal RmtReader

- de-duplicates the copy-from-hardware code
- this also paves the way for supporting other data types (like
  iterators instead of slices) and wrapping rx (RmtReader as implemented
  here already supports that, but it's unused for now)
This commit is contained in:
Benedikt 2025-09-18 10:48:43 +02:00 committed by GitHub
parent 97b34faa1b
commit 8e58a753b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 305 additions and 85 deletions

View File

@ -226,6 +226,11 @@ use crate::{
time::Rate,
};
mod reader;
use reader::RmtReader;
mod writer;
use writer::RmtWriter;
/// Errors
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
@ -1128,10 +1133,7 @@ where
{
channel: Channel<Blocking, Tx>,
// The position in channel RAM to continue writing at; must be either
// 0 or half the available RAM size if there's further data.
// The position may be invalid if there's no data left.
ram_index: usize,
writer: RmtWriter,
// Remaining data that has not yet been written to channel RAM. May be empty.
remaining_data: &'a [T],
@ -1149,31 +1151,9 @@ where
if status == Some(Event::Threshold) {
raw.reset_tx_threshold_set();
if !self.remaining_data.is_empty() {
// re-fill TX RAM
let memsize = raw.memsize().codes();
let ptr = unsafe { raw.channel_ram_start().add(self.ram_index) };
let count = self.remaining_data.len().min(memsize / 2);
let (chunk, remaining) = self.remaining_data.split_at(count);
for (idx, entry) in chunk.iter().enumerate() {
unsafe {
ptr.add(idx).write_volatile((*entry).into());
}
}
// If count == memsize / 2 codes were written, update ram_index as
// - 0 -> memsize / 2
// - memsize / 2 -> 0
// Otherwise, for count < memsize / 2, the new position is invalid but the new
// slice is empty and we won't use ram_index again.
self.ram_index = memsize / 2 - self.ram_index;
self.remaining_data = remaining;
debug_assert!(
self.ram_index == 0
|| self.ram_index == memsize / 2
|| self.remaining_data.is_empty()
);
}
// `RmtWriter::write()` is safe to call even if `poll_internal` is called repeatedly
// after the data is exhausted since it returns immediately if already done.
self.writer.write(&mut self.remaining_data, raw, false);
}
status
@ -1320,16 +1300,28 @@ impl Channel<Blocking, Tx> {
/// the transaction to complete and get back the channel for further
/// use.
#[cfg_attr(place_rmt_driver_in_ram, ram)]
pub fn transmit<T>(self, data: &[T]) -> Result<SingleShotTxTransaction<'_, T>, Error>
pub fn transmit<T>(self, mut data: &[T]) -> Result<SingleShotTxTransaction<'_, T>, Error>
where
T: Into<PulseCode> + Copy,
{
let index = self.raw.start_send(data, None)?;
let raw = self.raw;
let memsize = raw.memsize();
match data.last() {
None => return Err(Error::InvalidArgument),
Some(&code) if code.into().is_end_marker() => (),
Some(_) => return Err(Error::EndMarkerMissing),
}
let mut writer = RmtWriter::new();
writer.write(&mut data, raw, true);
raw.start_send(None, memsize);
Ok(SingleShotTxTransaction {
channel: self,
// Either, remaining_data is empty, or we filled the entire buffer.
ram_index: 0,
remaining_data: &data[index..],
writer,
remaining_data: data,
})
}
@ -1344,17 +1336,26 @@ impl Channel<Blocking, Tx> {
#[cfg_attr(place_rmt_driver_in_ram, ram)]
pub fn transmit_continuously<T>(
self,
data: &[T],
mut data: &[T],
loopcount: LoopCount,
) -> Result<ContinuousTxTransaction, Error>
where
T: Into<PulseCode> + Copy,
{
if data.len() > self.raw.memsize().codes() {
let raw = self.raw;
let memsize = raw.memsize();
if data.is_empty() {
return Err(Error::InvalidArgument);
} else if data.len() > memsize.codes() {
return Err(Error::Overflow);
}
let _index = self.raw.start_send(data, Some(loopcount))?;
let mut writer = RmtWriter::new();
writer.write(&mut data, raw, true);
self.raw.start_send(Some(loopcount), memsize);
Ok(ContinuousTxTransaction { channel: self })
}
}
@ -1365,6 +1366,9 @@ where
T: From<PulseCode>,
{
channel: Channel<Blocking, Rx>,
reader: RmtReader,
data: &'a mut [T],
}
@ -1377,18 +1381,16 @@ where
let raw = self.channel.raw;
let status = raw.get_rx_status();
if status == Some(Event::End) {
// Do not clear the interrupt flags here: Subsequent calls of wait() must
// be able to observe them if this is currently called via poll()
raw.stop_rx();
raw.update();
let ptr = raw.channel_ram_start();
// SAFETY: RxChannel.receive() verifies that the length of self.data does not
// exceed the channel RAM size.
for (idx, entry) in self.data.iter_mut().enumerate() {
*entry = unsafe { ptr.add(idx).read_volatile() }.into();
}
// `RmtReader::read()` is safe to call even if `poll_internal` is called repeatedly
// after the receiver finished since it returns immediately if already done.
self.reader.read(&mut self.data, raw, true);
}
status
@ -1437,14 +1439,20 @@ impl Channel<Blocking, Rx> {
Self: Sized,
T: From<PulseCode>,
{
if data.len() > self.raw.memsize().codes() {
let raw = self.raw;
let memsize = raw.memsize();
if data.len() > memsize.codes() {
return Err(Error::InvalidDataLength);
}
self.raw.start_receive();
let reader = RmtReader::new();
raw.start_receive();
Ok(RxTransaction {
channel: self,
reader,
data,
})
}
@ -1477,20 +1485,30 @@ impl Channel<Async, Tx> {
/// The length of sequence cannot exceed the size of the allocated RMT
/// RAM.
#[cfg_attr(place_rmt_driver_in_ram, ram)]
pub async fn transmit<T>(&mut self, data: &[T]) -> Result<(), Error>
pub async fn transmit<T>(&mut self, mut data: &[T]) -> Result<(), Error>
where
Self: Sized,
T: Into<PulseCode> + Copy,
{
let raw = self.raw;
let memsize = raw.memsize();
if data.len() > raw.memsize().codes() {
match data.last() {
None => return Err(Error::InvalidArgument),
Some(&code) if code.into().is_end_marker() => (),
Some(_) => return Err(Error::EndMarkerMissing),
}
if data.len() > memsize.codes() {
return Err(Error::InvalidDataLength);
}
let mut writer = RmtWriter::new();
writer.write(&mut data, raw, true);
raw.clear_tx_interrupts();
raw.listen_tx_interrupt(Event::End | Event::Error);
raw.start_send(data, None)?;
raw.start_send(None, memsize);
(RmtTxFuture { raw }).await
}
@ -1522,17 +1540,20 @@ impl Channel<Async, Rx> {
/// The length of sequence cannot exceed the size of the allocated RMT
/// RAM.
#[cfg_attr(place_rmt_driver_in_ram, ram)]
pub async fn receive<T>(&mut self, data: &mut [T]) -> Result<(), Error>
pub async fn receive<T>(&mut self, mut data: &mut [T]) -> Result<(), Error>
where
Self: Sized,
T: From<PulseCode>,
{
let raw = self.raw;
let memsize = raw.memsize();
if data.len() > raw.memsize().codes() {
if data.len() > memsize.codes() {
return Err(Error::InvalidDataLength);
}
let mut reader = RmtReader::new();
raw.clear_rx_interrupts();
raw.listen_rx_interrupt(Event::End | Event::Error);
raw.start_receive();
@ -1544,11 +1565,7 @@ impl Channel<Async, Rx> {
raw.clear_rx_interrupts();
raw.update();
let ptr = raw.channel_ram_start();
let len = data.len();
for (idx, entry) in data.iter_mut().take(len).enumerate() {
*entry = unsafe { ptr.add(idx).read_volatile().into() };
}
reader.read(&mut data, raw, true);
}
result
@ -1595,38 +1612,19 @@ impl DynChannelAccess<Tx> {
OUTPUT_SIGNALS[self.ch_idx as usize]
}
// We could obtain `memsize` via `self.memsize()` here. However, it is already known at all
// call sites, and passing it as argument avoids a volatile read that the compiler wouldn't be
// able to deduplicate.
#[inline]
fn start_send<T>(&self, data: &[T], loopcount: Option<LoopCount>) -> Result<usize, Error>
where
T: Into<PulseCode> + Copy,
{
fn start_send(&self, loopcount: Option<LoopCount>, memsize: MemSize) {
self.clear_tx_interrupts();
if let Some(last) = data.last() {
if loopcount.is_none() && !(*last).into().is_end_marker() {
return Err(Error::EndMarkerMissing);
}
} else {
return Err(Error::InvalidArgument);
}
let ptr = self.channel_ram_start();
let memsize = self.memsize().codes();
for (idx, entry) in data.iter().take(memsize).enumerate() {
unsafe {
ptr.add(idx).write_volatile((*entry).into());
}
}
self.set_tx_threshold((memsize / 2) as u8);
self.set_tx_threshold((memsize.codes() / 2) as u8);
self.set_tx_continuous(loopcount.is_some());
self.set_generate_repeat_interrupt(loopcount);
self.set_tx_wrap_mode(true);
self.update();
self.start_tx();
self.update();
Ok(data.len().min(memsize))
}
#[inline]

95
esp-hal/src/rmt/reader.rs Normal file
View File

@ -0,0 +1,95 @@
use super::{DynChannelAccess, Error, PulseCode, Rx};
#[derive(Debug, PartialEq)]
pub(crate) enum ReaderState {
Active,
Error(Error),
Done,
}
pub(crate) struct RmtReader {
// The position in channel RAM to continue reading from; must be either
// 0 or half the available RAM size if there's further data.
// The position may be invalid if there's no data left.
offset: u16,
pub state: ReaderState,
}
impl RmtReader {
pub(crate) fn new() -> Self {
Self {
offset: 0,
state: ReaderState::Active,
}
}
// Copy from the hardware buffer to `data`, advancing the `data` slice accordingly.
//
// If `final_` is set, read a full buffer length, potentially wrapping around. Otherwise, fetch
// half the buffer's length.
#[cfg_attr(place_rmt_driver_in_ram, ram)]
pub(crate) fn read<T>(&mut self, data: &mut &mut [T], raw: DynChannelAccess<Rx>, final_: bool)
where
T: From<PulseCode>,
{
if self.state != ReaderState::Active {
return;
}
let ram_start = raw.channel_ram_start();
let memsize = raw.memsize().codes();
let max_count = if final_ { memsize } else { memsize / 2 };
let count = data.len().min(max_count);
let mut count0 = count.min(memsize - self.offset as usize);
let mut count1 = count - count0;
// Read in up to 2 chunks to allow wrapping around the buffer end. This is more efficient
// than checking in each iteration of the inner loop whether we reached the buffer end.
let mut ptr = unsafe { ram_start.add(self.offset as usize) };
loop {
for entry in data.iter_mut().take(count0) {
// SAFETY: The iteration `count` is smaller than `max_count` such that incrementing
// the `ptr` `count0` times cannot advance further than `ram_start + memsize`.
unsafe {
*entry = ptr.read_volatile().into();
ptr = ptr.add(1);
}
}
if count1 == 0 {
break;
}
count0 = count1;
count1 = 0;
ptr = ram_start;
}
// Update offset as
//
// | offset | new offset |
// | ----------- + ----------- |
// | 0 | memsize / 2 |
// | memsize / 2 | 0 |
//
// If `count < max_count` or if `final_` is set, the new offset will not correspond to
// where we stopped reading, but the new offset will not be used again since further calls
// will immediately return due to `self.state != Active`.
self.offset = (memsize / 2) as u16 - self.offset;
data.split_off_mut(..count).unwrap();
if count < max_count {
// `data` exhausted
self.state = ReaderState::Error(Error::ReceiverError);
} else if final_ {
// Caller indicated that we're done
self.state = ReaderState::Done;
}
debug_assert!(self.offset == 0 || self.offset as usize == memsize / 2);
}
}

78
esp-hal/src/rmt/writer.rs Normal file
View File

@ -0,0 +1,78 @@
use super::{DynChannelAccess, PulseCode, Tx};
#[derive(PartialEq)]
pub(crate) enum WriterState {
Active,
Done,
}
pub(crate) struct RmtWriter {
// The position in channel RAM to continue writing at; must be either
// 0 or half the available RAM size if there's further data.
// The position may be invalid if there's no data left.
offset: u16,
pub state: WriterState,
}
impl RmtWriter {
pub(crate) fn new() -> Self {
Self {
offset: 0,
state: WriterState::Active,
}
}
// Copy from `data` to the hardware buffer, advancing the `data` slice accordingly.
//
// If `initial` is set, fill the entire buffer. Otherwise, append half the buffer's length from
// `data`.
#[cfg_attr(place_rmt_driver_in_ram, ram)]
pub(crate) fn write<T>(&mut self, data: &mut &[T], raw: DynChannelAccess<Tx>, initial: bool)
where
T: Into<PulseCode> + Copy,
{
if self.state != WriterState::Active {
return;
}
let ram_start = raw.channel_ram_start();
let memsize = raw.memsize().codes();
let max_count = if initial { memsize } else { memsize / 2 };
let count = data.len().min(max_count);
debug_assert!(!initial || self.offset == 0);
let mut ptr = unsafe { ram_start.add(self.offset as usize) };
for entry in data.iter().take(count) {
// SAFETY: The iteration `count` is smaller than `max_count` such that incrementing the
// `ptr` `count` times cannot advance further than `ram_start + memsize`.
unsafe {
ptr.write_volatile((*entry).into());
ptr = ptr.add(1);
}
}
// If the input data was not exhausted, update offset as
//
// | initial | offset | max_count | new offset |
// | ------- + ----------- + ----------- + ----------- |
// | true | 0 | memsize | 0 |
// | false | 0 | memsize / 2 | memsize / 2 |
// | false | memsize / 2 | memsize / 2 | 0 |
//
// Otherwise, the new position is invalid but the new slice is empty and we won't use the
// offset again. In either case, the unsigned subtraction will not underflow.
self.offset = memsize as u16 - max_count as u16 - self.offset;
// The panic can never trigger since count <= data.len()!
data.split_off(..count).unwrap();
if data.is_empty() {
self.state = WriterState::Done;
}
debug_assert!(self.offset == 0 || self.offset as usize == memsize / 2);
}
}

View File

@ -80,6 +80,59 @@ fn generate_tx_data<const TX_LEN: usize>(write_end_marker: bool) -> [PulseCode;
tx_data
}
// When running this with defmt:
// - use `DEFMT_RTT_BUFFER_SIZE=32768 xtask run ...` to avoid truncated output
// - increase embedded_test's default_timeout below to avoid timeouts while printing
// Note that probe-rs reading the buffer might mess up timing-sensitive tests!
fn check_data_eq(tx: &[PulseCode], rx: &[PulseCode], tx_len: usize) {
let mut errors: usize = 0;
for (idx, (&code_tx, &code_rx)) in core::iter::zip(tx, rx).enumerate() {
let _msg = if idx == tx_len - 1 {
// The last pulse code is the stop code, which can't be received.
""
} else if idx == tx_len - 2 {
// The second-to-last pulse-code is the one which exceeds the idle threshold and
// should be received as stop code.
if !(code_rx.level1() == Level::High && code_rx.length1() == 0) {
errors += 1;
"rx code not a stop code!"
} else {
""
}
} else if code_tx != code_rx {
errors += 1;
"rx/tx code mismatch!"
} else {
""
};
#[cfg(feature = "defmt")]
if _msg.len() > 0 {
defmt::error!(
"loopback @ idx {}: {:?} (tx) -> {:?} (rx): {}",
idx,
code_tx,
code_rx,
_msg
);
} else {
defmt::info!(
"loopback @ idx {}: {:?} (tx) -> {:?} (rx)",
idx,
code_tx,
code_rx,
);
}
}
assert_eq!(
errors, 0,
"rx/tx code mismatch at {}/{} indices",
errors, tx_len
);
}
fn do_rmt_loopback_inner<const TX_LEN: usize>(
tx_channel: Channel<Blocking, Tx>,
rx_channel: Channel<Blocking, Rx>,
@ -101,9 +154,7 @@ fn do_rmt_loopback_inner<const TX_LEN: usize>(
tx_transaction.wait().unwrap();
rx_transaction.wait().unwrap();
// the last two pulse-codes are the ones which wait for the timeout so
// they can't be equal
assert_eq!(&tx_data[..TX_LEN - 2], &rcv_data[..TX_LEN - 2]);
check_data_eq(&tx_data, &rcv_data, TX_LEN);
}
// Run a test where some data is sent from one channel and looped back to
@ -151,9 +202,7 @@ async fn do_rmt_loopback_async<const TX_LEN: usize>(tx_memsize: u8, rx_memsize:
tx_res.unwrap();
rx_res.unwrap();
// the last two pulse-codes are the ones which wait for the timeout so
// they can't be equal
assert_eq!(&tx_data[..TX_LEN - 2], &rcv_data[..TX_LEN - 2]);
check_data_eq(&tx_data, &rcv_data, TX_LEN);
}
// Run a test that just sends some data, without trying to recive it.