From 8e58a753b6ac54ee6a08ba07538bdd925d73aa11 Mon Sep 17 00:00:00 2001 From: Benedikt Date: Thu, 18 Sep 2025 10:48:43 +0200 Subject: [PATCH] RMT: refactor copies to/from channel RAM (#4126) * RMT: add check_data_eq to log more details on test failure if defmt is enabled * RMT: add minimal RmtWriter - de-duplicates the copy-to-hardware code - splits copying data to the buffer and `start_send` - this also paves the way for supporting other data types (like iterators instead of slices) * RMT: add minimal RmtReader - de-duplicates the copy-from-hardware code - this also paves the way for supporting other data types (like iterators instead of slices) and wrapping rx (RmtReader as implemented here already supports that, but it's unused for now) --- esp-hal/src/rmt.rs | 156 +++++++++++++++++++------------------- esp-hal/src/rmt/reader.rs | 95 +++++++++++++++++++++++ esp-hal/src/rmt/writer.rs | 78 +++++++++++++++++++ hil-test/tests/rmt.rs | 61 +++++++++++++-- 4 files changed, 305 insertions(+), 85 deletions(-) create mode 100644 esp-hal/src/rmt/reader.rs create mode 100644 esp-hal/src/rmt/writer.rs diff --git a/esp-hal/src/rmt.rs b/esp-hal/src/rmt.rs index 1617afed4..116075100 100644 --- a/esp-hal/src/rmt.rs +++ b/esp-hal/src/rmt.rs @@ -226,6 +226,11 @@ use crate::{ time::Rate, }; +mod reader; +use reader::RmtReader; +mod writer; +use writer::RmtWriter; + /// Errors #[derive(Debug, Clone, Copy, PartialEq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -1128,10 +1133,7 @@ where { channel: Channel, - // The position in channel RAM to continue writing at; must be either - // 0 or half the available RAM size if there's further data. - // The position may be invalid if there's no data left. - ram_index: usize, + writer: RmtWriter, // Remaining data that has not yet been written to channel RAM. May be empty. remaining_data: &'a [T], @@ -1149,31 +1151,9 @@ where if status == Some(Event::Threshold) { raw.reset_tx_threshold_set(); - if !self.remaining_data.is_empty() { - // re-fill TX RAM - let memsize = raw.memsize().codes(); - let ptr = unsafe { raw.channel_ram_start().add(self.ram_index) }; - let count = self.remaining_data.len().min(memsize / 2); - let (chunk, remaining) = self.remaining_data.split_at(count); - for (idx, entry) in chunk.iter().enumerate() { - unsafe { - ptr.add(idx).write_volatile((*entry).into()); - } - } - - // If count == memsize / 2 codes were written, update ram_index as - // - 0 -> memsize / 2 - // - memsize / 2 -> 0 - // Otherwise, for count < memsize / 2, the new position is invalid but the new - // slice is empty and we won't use ram_index again. - self.ram_index = memsize / 2 - self.ram_index; - self.remaining_data = remaining; - debug_assert!( - self.ram_index == 0 - || self.ram_index == memsize / 2 - || self.remaining_data.is_empty() - ); - } + // `RmtWriter::write()` is safe to call even if `poll_internal` is called repeatedly + // after the data is exhausted since it returns immediately if already done. + self.writer.write(&mut self.remaining_data, raw, false); } status @@ -1320,16 +1300,28 @@ impl Channel { /// the transaction to complete and get back the channel for further /// use. #[cfg_attr(place_rmt_driver_in_ram, ram)] - pub fn transmit(self, data: &[T]) -> Result, Error> + pub fn transmit(self, mut data: &[T]) -> Result, Error> where T: Into + Copy, { - let index = self.raw.start_send(data, None)?; + let raw = self.raw; + let memsize = raw.memsize(); + + match data.last() { + None => return Err(Error::InvalidArgument), + Some(&code) if code.into().is_end_marker() => (), + Some(_) => return Err(Error::EndMarkerMissing), + } + + let mut writer = RmtWriter::new(); + writer.write(&mut data, raw, true); + + raw.start_send(None, memsize); + Ok(SingleShotTxTransaction { channel: self, - // Either, remaining_data is empty, or we filled the entire buffer. - ram_index: 0, - remaining_data: &data[index..], + writer, + remaining_data: data, }) } @@ -1344,17 +1336,26 @@ impl Channel { #[cfg_attr(place_rmt_driver_in_ram, ram)] pub fn transmit_continuously( self, - data: &[T], + mut data: &[T], loopcount: LoopCount, ) -> Result where T: Into + Copy, { - if data.len() > self.raw.memsize().codes() { + let raw = self.raw; + let memsize = raw.memsize(); + + if data.is_empty() { + return Err(Error::InvalidArgument); + } else if data.len() > memsize.codes() { return Err(Error::Overflow); } - let _index = self.raw.start_send(data, Some(loopcount))?; + let mut writer = RmtWriter::new(); + writer.write(&mut data, raw, true); + + self.raw.start_send(Some(loopcount), memsize); + Ok(ContinuousTxTransaction { channel: self }) } } @@ -1365,6 +1366,9 @@ where T: From, { channel: Channel, + + reader: RmtReader, + data: &'a mut [T], } @@ -1377,18 +1381,16 @@ where let raw = self.channel.raw; let status = raw.get_rx_status(); + if status == Some(Event::End) { // Do not clear the interrupt flags here: Subsequent calls of wait() must // be able to observe them if this is currently called via poll() raw.stop_rx(); raw.update(); - let ptr = raw.channel_ram_start(); - // SAFETY: RxChannel.receive() verifies that the length of self.data does not - // exceed the channel RAM size. - for (idx, entry) in self.data.iter_mut().enumerate() { - *entry = unsafe { ptr.add(idx).read_volatile() }.into(); - } + // `RmtReader::read()` is safe to call even if `poll_internal` is called repeatedly + // after the receiver finished since it returns immediately if already done. + self.reader.read(&mut self.data, raw, true); } status @@ -1437,14 +1439,20 @@ impl Channel { Self: Sized, T: From, { - if data.len() > self.raw.memsize().codes() { + let raw = self.raw; + let memsize = raw.memsize(); + + if data.len() > memsize.codes() { return Err(Error::InvalidDataLength); } - self.raw.start_receive(); + let reader = RmtReader::new(); + + raw.start_receive(); Ok(RxTransaction { channel: self, + reader, data, }) } @@ -1477,20 +1485,30 @@ impl Channel { /// The length of sequence cannot exceed the size of the allocated RMT /// RAM. #[cfg_attr(place_rmt_driver_in_ram, ram)] - pub async fn transmit(&mut self, data: &[T]) -> Result<(), Error> + pub async fn transmit(&mut self, mut data: &[T]) -> Result<(), Error> where Self: Sized, T: Into + Copy, { let raw = self.raw; + let memsize = raw.memsize(); - if data.len() > raw.memsize().codes() { + match data.last() { + None => return Err(Error::InvalidArgument), + Some(&code) if code.into().is_end_marker() => (), + Some(_) => return Err(Error::EndMarkerMissing), + } + + if data.len() > memsize.codes() { return Err(Error::InvalidDataLength); } + let mut writer = RmtWriter::new(); + writer.write(&mut data, raw, true); + raw.clear_tx_interrupts(); raw.listen_tx_interrupt(Event::End | Event::Error); - raw.start_send(data, None)?; + raw.start_send(None, memsize); (RmtTxFuture { raw }).await } @@ -1522,17 +1540,20 @@ impl Channel { /// The length of sequence cannot exceed the size of the allocated RMT /// RAM. #[cfg_attr(place_rmt_driver_in_ram, ram)] - pub async fn receive(&mut self, data: &mut [T]) -> Result<(), Error> + pub async fn receive(&mut self, mut data: &mut [T]) -> Result<(), Error> where Self: Sized, T: From, { let raw = self.raw; + let memsize = raw.memsize(); - if data.len() > raw.memsize().codes() { + if data.len() > memsize.codes() { return Err(Error::InvalidDataLength); } + let mut reader = RmtReader::new(); + raw.clear_rx_interrupts(); raw.listen_rx_interrupt(Event::End | Event::Error); raw.start_receive(); @@ -1544,11 +1565,7 @@ impl Channel { raw.clear_rx_interrupts(); raw.update(); - let ptr = raw.channel_ram_start(); - let len = data.len(); - for (idx, entry) in data.iter_mut().take(len).enumerate() { - *entry = unsafe { ptr.add(idx).read_volatile().into() }; - } + reader.read(&mut data, raw, true); } result @@ -1595,38 +1612,19 @@ impl DynChannelAccess { OUTPUT_SIGNALS[self.ch_idx as usize] } + // We could obtain `memsize` via `self.memsize()` here. However, it is already known at all + // call sites, and passing it as argument avoids a volatile read that the compiler wouldn't be + // able to deduplicate. #[inline] - fn start_send(&self, data: &[T], loopcount: Option) -> Result - where - T: Into + Copy, - { + fn start_send(&self, loopcount: Option, memsize: MemSize) { self.clear_tx_interrupts(); - - if let Some(last) = data.last() { - if loopcount.is_none() && !(*last).into().is_end_marker() { - return Err(Error::EndMarkerMissing); - } - } else { - return Err(Error::InvalidArgument); - } - - let ptr = self.channel_ram_start(); - let memsize = self.memsize().codes(); - for (idx, entry) in data.iter().take(memsize).enumerate() { - unsafe { - ptr.add(idx).write_volatile((*entry).into()); - } - } - - self.set_tx_threshold((memsize / 2) as u8); + self.set_tx_threshold((memsize.codes() / 2) as u8); self.set_tx_continuous(loopcount.is_some()); self.set_generate_repeat_interrupt(loopcount); self.set_tx_wrap_mode(true); self.update(); self.start_tx(); self.update(); - - Ok(data.len().min(memsize)) } #[inline] diff --git a/esp-hal/src/rmt/reader.rs b/esp-hal/src/rmt/reader.rs new file mode 100644 index 000000000..f916c3151 --- /dev/null +++ b/esp-hal/src/rmt/reader.rs @@ -0,0 +1,95 @@ +use super::{DynChannelAccess, Error, PulseCode, Rx}; + +#[derive(Debug, PartialEq)] +pub(crate) enum ReaderState { + Active, + + Error(Error), + + Done, +} + +pub(crate) struct RmtReader { + // The position in channel RAM to continue reading from; must be either + // 0 or half the available RAM size if there's further data. + // The position may be invalid if there's no data left. + offset: u16, + + pub state: ReaderState, +} + +impl RmtReader { + pub(crate) fn new() -> Self { + Self { + offset: 0, + state: ReaderState::Active, + } + } + + // Copy from the hardware buffer to `data`, advancing the `data` slice accordingly. + // + // If `final_` is set, read a full buffer length, potentially wrapping around. Otherwise, fetch + // half the buffer's length. + #[cfg_attr(place_rmt_driver_in_ram, ram)] + pub(crate) fn read(&mut self, data: &mut &mut [T], raw: DynChannelAccess, final_: bool) + where + T: From, + { + if self.state != ReaderState::Active { + return; + } + + let ram_start = raw.channel_ram_start(); + let memsize = raw.memsize().codes(); + + let max_count = if final_ { memsize } else { memsize / 2 }; + let count = data.len().min(max_count); + let mut count0 = count.min(memsize - self.offset as usize); + let mut count1 = count - count0; + + // Read in up to 2 chunks to allow wrapping around the buffer end. This is more efficient + // than checking in each iteration of the inner loop whether we reached the buffer end. + let mut ptr = unsafe { ram_start.add(self.offset as usize) }; + loop { + for entry in data.iter_mut().take(count0) { + // SAFETY: The iteration `count` is smaller than `max_count` such that incrementing + // the `ptr` `count0` times cannot advance further than `ram_start + memsize`. + unsafe { + *entry = ptr.read_volatile().into(); + ptr = ptr.add(1); + } + } + + if count1 == 0 { + break; + } + + count0 = count1; + count1 = 0; + ptr = ram_start; + } + + // Update offset as + // + // | offset | new offset | + // | ----------- + ----------- | + // | 0 | memsize / 2 | + // | memsize / 2 | 0 | + // + // If `count < max_count` or if `final_` is set, the new offset will not correspond to + // where we stopped reading, but the new offset will not be used again since further calls + // will immediately return due to `self.state != Active`. + self.offset = (memsize / 2) as u16 - self.offset; + data.split_off_mut(..count).unwrap(); + + if count < max_count { + // `data` exhausted + self.state = ReaderState::Error(Error::ReceiverError); + } else if final_ { + // Caller indicated that we're done + self.state = ReaderState::Done; + } + + debug_assert!(self.offset == 0 || self.offset as usize == memsize / 2); + } +} diff --git a/esp-hal/src/rmt/writer.rs b/esp-hal/src/rmt/writer.rs new file mode 100644 index 000000000..df6ceaa7d --- /dev/null +++ b/esp-hal/src/rmt/writer.rs @@ -0,0 +1,78 @@ +use super::{DynChannelAccess, PulseCode, Tx}; + +#[derive(PartialEq)] +pub(crate) enum WriterState { + Active, + + Done, +} + +pub(crate) struct RmtWriter { + // The position in channel RAM to continue writing at; must be either + // 0 or half the available RAM size if there's further data. + // The position may be invalid if there's no data left. + offset: u16, + + pub state: WriterState, +} + +impl RmtWriter { + pub(crate) fn new() -> Self { + Self { + offset: 0, + state: WriterState::Active, + } + } + + // Copy from `data` to the hardware buffer, advancing the `data` slice accordingly. + // + // If `initial` is set, fill the entire buffer. Otherwise, append half the buffer's length from + // `data`. + #[cfg_attr(place_rmt_driver_in_ram, ram)] + pub(crate) fn write(&mut self, data: &mut &[T], raw: DynChannelAccess, initial: bool) + where + T: Into + Copy, + { + if self.state != WriterState::Active { + return; + } + + let ram_start = raw.channel_ram_start(); + let memsize = raw.memsize().codes(); + + let max_count = if initial { memsize } else { memsize / 2 }; + let count = data.len().min(max_count); + + debug_assert!(!initial || self.offset == 0); + + let mut ptr = unsafe { ram_start.add(self.offset as usize) }; + for entry in data.iter().take(count) { + // SAFETY: The iteration `count` is smaller than `max_count` such that incrementing the + // `ptr` `count` times cannot advance further than `ram_start + memsize`. + unsafe { + ptr.write_volatile((*entry).into()); + ptr = ptr.add(1); + } + } + + // If the input data was not exhausted, update offset as + // + // | initial | offset | max_count | new offset | + // | ------- + ----------- + ----------- + ----------- | + // | true | 0 | memsize | 0 | + // | false | 0 | memsize / 2 | memsize / 2 | + // | false | memsize / 2 | memsize / 2 | 0 | + // + // Otherwise, the new position is invalid but the new slice is empty and we won't use the + // offset again. In either case, the unsigned subtraction will not underflow. + self.offset = memsize as u16 - max_count as u16 - self.offset; + + // The panic can never trigger since count <= data.len()! + data.split_off(..count).unwrap(); + if data.is_empty() { + self.state = WriterState::Done; + } + + debug_assert!(self.offset == 0 || self.offset as usize == memsize / 2); + } +} diff --git a/hil-test/tests/rmt.rs b/hil-test/tests/rmt.rs index a38c14126..d85b74530 100644 --- a/hil-test/tests/rmt.rs +++ b/hil-test/tests/rmt.rs @@ -80,6 +80,59 @@ fn generate_tx_data(write_end_marker: bool) -> [PulseCode; tx_data } +// When running this with defmt: +// - use `DEFMT_RTT_BUFFER_SIZE=32768 xtask run ...` to avoid truncated output +// - increase embedded_test's default_timeout below to avoid timeouts while printing +// Note that probe-rs reading the buffer might mess up timing-sensitive tests! +fn check_data_eq(tx: &[PulseCode], rx: &[PulseCode], tx_len: usize) { + let mut errors: usize = 0; + + for (idx, (&code_tx, &code_rx)) in core::iter::zip(tx, rx).enumerate() { + let _msg = if idx == tx_len - 1 { + // The last pulse code is the stop code, which can't be received. + "" + } else if idx == tx_len - 2 { + // The second-to-last pulse-code is the one which exceeds the idle threshold and + // should be received as stop code. + if !(code_rx.level1() == Level::High && code_rx.length1() == 0) { + errors += 1; + "rx code not a stop code!" + } else { + "" + } + } else if code_tx != code_rx { + errors += 1; + "rx/tx code mismatch!" + } else { + "" + }; + + #[cfg(feature = "defmt")] + if _msg.len() > 0 { + defmt::error!( + "loopback @ idx {}: {:?} (tx) -> {:?} (rx): {}", + idx, + code_tx, + code_rx, + _msg + ); + } else { + defmt::info!( + "loopback @ idx {}: {:?} (tx) -> {:?} (rx)", + idx, + code_tx, + code_rx, + ); + } + } + + assert_eq!( + errors, 0, + "rx/tx code mismatch at {}/{} indices", + errors, tx_len + ); +} + fn do_rmt_loopback_inner( tx_channel: Channel, rx_channel: Channel, @@ -101,9 +154,7 @@ fn do_rmt_loopback_inner( tx_transaction.wait().unwrap(); rx_transaction.wait().unwrap(); - // the last two pulse-codes are the ones which wait for the timeout so - // they can't be equal - assert_eq!(&tx_data[..TX_LEN - 2], &rcv_data[..TX_LEN - 2]); + check_data_eq(&tx_data, &rcv_data, TX_LEN); } // Run a test where some data is sent from one channel and looped back to @@ -151,9 +202,7 @@ async fn do_rmt_loopback_async(tx_memsize: u8, rx_memsize: tx_res.unwrap(); rx_res.unwrap(); - // the last two pulse-codes are the ones which wait for the timeout so - // they can't be equal - assert_eq!(&tx_data[..TX_LEN - 2], &rcv_data[..TX_LEN - 2]); + check_data_eq(&tx_data, &rcv_data, TX_LEN); } // Run a test that just sends some data, without trying to recive it.