Fail RMT one-shot transactions if end-marker is missing (#2463)

* Fail RMT one-shot transactions if end-marker is missing

* CHANGELOG.md

* Add test

* Fix

* Fix

* RMT: use u32, turn PulseCode into a convenience trait

* Clippy

* Adapt test
This commit is contained in:
Björn Quentin 2024-11-13 12:29:36 +01:00 committed by GitHub
parent 8cbc249e2e
commit 7da4444a7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 176 additions and 152 deletions

View File

@ -62,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `slave::Spi` constructors no longer take pins (#2485) - `slave::Spi` constructors no longer take pins (#2485)
- The `I2c` master driver has been moved from `esp_hal::i2c` to `esp_hal::i2c::master`. (#2476) - The `I2c` master driver has been moved from `esp_hal::i2c` to `esp_hal::i2c::master`. (#2476)
- `I2c` SCL timeout is now defined in bus clock cycles. (#2477) - `I2c` SCL timeout is now defined in bus clock cycles. (#2477)
- Trying to send a single-shot RMT transmission will result in an error now, `RMT` deals with `u32` now, `PulseCode` is a convenience trait now (#2463)
### Fixed ### Fixed

View File

@ -368,3 +368,31 @@ If you were using an 16-bit bus, you don't need to change anything, `set_byte_or
If you were sharing the bus between an 8-bit and 16-bit device, you will have to call the corresponding method when If you were sharing the bus between an 8-bit and 16-bit device, you will have to call the corresponding method when
you switch between devices. Be sure to read the documentation of the new methods. you switch between devices. Be sure to read the documentation of the new methods.
## `rmt::Channel::transmit` now returns `Result`, `PulseCode` is now `u32`
When trying to send a one-shot transmission will fail if it doesn't end with an end-marker.
```diff
- let mut data = [PulseCode {
- level1: true,
- length1: 200,
- level2: false,
- length2: 50,
- }; 20];
-
- data[data.len() - 2] = PulseCode {
- level1: true,
- length1: 3000,
- level2: false,
- length2: 500,
- };
- data[data.len() - 1] = PulseCode::default();
+ let mut data = [PulseCode::new(true, 200, false, 50); 20];
+ data[data.len() - 2] = PulseCode::new(true, 3000, false, 500);
+ data[data.len() - 1] = PulseCode::empty();
- let transaction = channel.transmit(&data);
+ let transaction = channel.transmit(&data).unwrap();
```

View File

@ -112,61 +112,62 @@ pub enum Error {
InvalidArgument, InvalidArgument,
/// An error occurred during transmission /// An error occurred during transmission
TransmissionError, TransmissionError,
/// No transmission end marker found
EndMarkerMissing,
} }
/// Convenience representation of a pulse code entry. /// Convenience trait to work with pulse codes.
/// pub trait PulseCode: crate::private::Sealed {
/// Allows for the assignment of two levels and their lengths /// Create a new instance
#[derive(Clone, Copy, Debug, Default, PartialEq)] fn new(level1: bool, length1: u16, level2: bool, length2: u16) -> Self;
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PulseCode { /// Create a new empty instance
fn empty() -> Self;
/// Set all levels and lengths to 0
fn reset(&mut self);
/// Logical output level in the first pulse code interval /// Logical output level in the first pulse code interval
pub level1: bool, fn level1(&self) -> bool;
/// Length of the first pulse code interval (in clock cycles) /// Length of the first pulse code interval (in clock cycles)
pub length1: u16, fn length1(&self) -> u16;
/// Logical output level in the second pulse code interval /// Logical output level in the second pulse code interval
pub level2: bool, fn level2(&self) -> bool;
/// Length of the second pulse code interval (in clock cycles) /// Length of the second pulse code interval (in clock cycles)
pub length2: u16, fn length2(&self) -> u16;
} }
impl From<u32> for PulseCode { impl PulseCode for u32 {
fn from(value: u32) -> Self { fn new(level1: bool, length1: u16, level2: bool, length2: u16) -> Self {
Self { (((level1 as u32) << 15) | length1 as u32 & 0b111_1111_1111_1111)
level1: value & (1 << 15) != 0, | (((level2 as u32) << 15) | length2 as u32 & 0b111_1111_1111_1111) << 16
length1: (value & 0b111_1111_1111_1111) as u16,
level2: value & (1 << 31) != 0,
length2: ((value >> 16) & 0b111_1111_1111_1111) as u16,
}
} }
}
/// Convert a pulse code structure into a u32 value that can be written fn empty() -> Self {
/// into the data registers 0
impl From<PulseCode> for u32 { }
#[inline(always)]
fn from(p: PulseCode) -> u32 {
// The length1 value resides in bits [14:0]
let mut entry: u32 = p.length1 as u32;
// If level1 is high, set bit 15, otherwise clear it fn reset(&mut self) {
if p.level1 { *self = 0
entry |= 1 << 15; }
} else {
entry &= !(1 << 15);
}
// If level2 is high, set bit 31, otherwise clear it fn level1(&self) -> bool {
if p.level2 { self & (1 << 15) != 0
entry |= 1 << 31; }
} else {
entry &= !(1 << 31);
}
// The length2 value resides in bits [30:16] fn length1(&self) -> u16 {
entry |= (p.length2 as u32) << 16; (self & 0b111_1111_1111_1111) as u16
}
entry fn level2(&self) -> bool {
self & (1 << 31) != 0
}
fn length2(&self) -> u16 {
((self >> 16) & 0b111_1111_1111_1111) as u16
} }
} }
@ -423,16 +424,16 @@ where
} }
/// An in-progress transaction for a single shot TX transaction. /// An in-progress transaction for a single shot TX transaction.
pub struct SingleShotTxTransaction<'a, C, T: Into<u32> + Copy> pub struct SingleShotTxTransaction<'a, C>
where where
C: TxChannel, C: TxChannel,
{ {
channel: C, channel: C,
index: usize, index: usize,
data: &'a [T], data: &'a [u32],
} }
impl<C, T: Into<u32> + Copy> SingleShotTxTransaction<'_, C, T> impl<C> SingleShotTxTransaction<'_, C>
where where
C: TxChannel, C: TxChannel,
{ {
@ -466,7 +467,7 @@ where
.enumerate() .enumerate()
{ {
unsafe { unsafe {
ptr.add(idx).write_volatile((*entry).into()); ptr.add(idx).write_volatile(*entry);
} }
} }
@ -982,26 +983,23 @@ pub trait TxChannel: TxChannelInternal<Blocking> {
/// This returns a [`SingleShotTxTransaction`] which can be used to wait for /// This returns a [`SingleShotTxTransaction`] which can be used to wait for
/// the transaction to complete and get back the channel for further /// the transaction to complete and get back the channel for further
/// use. /// use.
fn transmit<T: Into<u32> + Copy>(self, data: &[T]) -> SingleShotTxTransaction<'_, Self, T> fn transmit(self, data: &[u32]) -> Result<SingleShotTxTransaction<'_, Self>, Error>
where where
Self: Sized, Self: Sized,
{ {
let index = Self::send_raw(data, false, 0); let index = Self::send_raw(data, false, 0)?;
SingleShotTxTransaction { Ok(SingleShotTxTransaction {
channel: self, channel: self,
index, index,
data, data,
} })
} }
/// Start transmitting the given pulse code continuously. /// Start transmitting the given pulse code continuously.
/// This returns a [`ContinuousTxTransaction`] which can be used to stop the /// This returns a [`ContinuousTxTransaction`] which can be used to stop the
/// ongoing transmission and get back the channel for further use. /// ongoing transmission and get back the channel for further use.
/// The length of sequence cannot exceed the size of the allocated RMT RAM. /// The length of sequence cannot exceed the size of the allocated RMT RAM.
fn transmit_continuously<T: Into<u32> + Copy>( fn transmit_continuously(self, data: &[u32]) -> Result<ContinuousTxTransaction<Self>, Error>
self,
data: &[T],
) -> Result<ContinuousTxTransaction<Self>, Error>
where where
Self: Sized, Self: Sized,
{ {
@ -1011,10 +1009,10 @@ pub trait TxChannel: TxChannelInternal<Blocking> {
/// Like [`Self::transmit_continuously`] but also sets a loop count. /// Like [`Self::transmit_continuously`] but also sets a loop count.
/// [`ContinuousTxTransaction`] can be used to check if the loop count is /// [`ContinuousTxTransaction`] can be used to check if the loop count is
/// reached. /// reached.
fn transmit_continuously_with_loopcount<T: Into<u32> + Copy>( fn transmit_continuously_with_loopcount(
self, self,
loopcount: u16, loopcount: u16,
data: &[T], data: &[u32],
) -> Result<ContinuousTxTransaction<Self>, Error> ) -> Result<ContinuousTxTransaction<Self>, Error>
where where
Self: Sized, Self: Sized,
@ -1023,21 +1021,21 @@ pub trait TxChannel: TxChannelInternal<Blocking> {
return Err(Error::Overflow); return Err(Error::Overflow);
} }
let _index = Self::send_raw(data, true, loopcount); let _index = Self::send_raw(data, true, loopcount)?;
Ok(ContinuousTxTransaction { channel: self }) Ok(ContinuousTxTransaction { channel: self })
} }
} }
/// RX transaction instance /// RX transaction instance
pub struct RxTransaction<'a, C, T: From<u32> + Copy> pub struct RxTransaction<'a, C>
where where
C: RxChannel, C: RxChannel,
{ {
channel: C, channel: C,
data: &'a mut [T], data: &'a mut [u32],
} }
impl<C, T: From<u32> + Copy> RxTransaction<'_, C, T> impl<C> RxTransaction<'_, C>
where where
C: RxChannel, C: RxChannel,
{ {
@ -1062,7 +1060,7 @@ where
as *mut u32; as *mut u32;
let len = self.data.len(); let len = self.data.len();
for (idx, entry) in self.data.iter_mut().take(len).enumerate() { for (idx, entry) in self.data.iter_mut().take(len).enumerate() {
*entry = unsafe { ptr.add(idx).read_volatile().into() }; *entry = unsafe { ptr.add(idx).read_volatile() };
} }
Ok(self.channel) Ok(self.channel)
@ -1075,10 +1073,7 @@ pub trait RxChannel: RxChannelInternal<Blocking> {
/// This returns a [RxTransaction] which can be used to wait for receive to /// This returns a [RxTransaction] which can be used to wait for receive to
/// complete and get back the channel for further use. /// complete and get back the channel for further use.
/// The length of the received data cannot exceed the allocated RMT RAM. /// The length of the received data cannot exceed the allocated RMT RAM.
fn receive<T: From<u32> + Copy>( fn receive(self, data: &mut [u32]) -> Result<RxTransaction<'_, Self>, Error>
self,
data: &mut [T],
) -> Result<RxTransaction<'_, Self, T>, Error>
where where
Self: Sized, Self: Sized,
{ {
@ -1143,7 +1138,7 @@ pub trait TxChannelAsync: TxChannelInternal<Async> {
/// Start transmitting the given pulse code sequence. /// Start transmitting the given pulse code sequence.
/// The length of sequence cannot exceed the size of the allocated RMT /// The length of sequence cannot exceed the size of the allocated RMT
/// RAM. /// RAM.
async fn transmit<'a, T: Into<u32> + Copy>(&mut self, data: &'a [T]) -> Result<(), Error> async fn transmit<'a>(&mut self, data: &'a [u32]) -> Result<(), Error>
where where
Self: Sized, Self: Sized,
{ {
@ -1154,7 +1149,7 @@ pub trait TxChannelAsync: TxChannelInternal<Async> {
Self::clear_interrupts(); Self::clear_interrupts();
Self::listen_interrupt(Event::End); Self::listen_interrupt(Event::End);
Self::listen_interrupt(Event::Error); Self::listen_interrupt(Event::Error);
Self::send_raw(data, false, 0); Self::send_raw(data, false, 0)?;
RmtTxFuture::new(self).await; RmtTxFuture::new(self).await;
@ -1402,9 +1397,17 @@ where
fn is_loopcount_interrupt_set() -> bool; fn is_loopcount_interrupt_set() -> bool;
fn send_raw<T: Into<u32> + Copy>(data: &[T], continuous: bool, repeat: u16) -> usize { fn send_raw(data: &[u32], continuous: bool, repeat: u16) -> Result<usize, Error> {
Self::clear_interrupts(); Self::clear_interrupts();
if let Some(last) = data.last() {
if !continuous && last.length2() != 0 && last.length1() != 0 {
return Err(Error::EndMarkerMissing);
}
} else {
return Err(Error::InvalidArgument);
}
let ptr = (constants::RMT_RAM_START let ptr = (constants::RMT_RAM_START
+ Self::CHANNEL as usize * constants::RMT_CHANNEL_RAM_SIZE * 4) + Self::CHANNEL as usize * constants::RMT_CHANNEL_RAM_SIZE * 4)
as *mut u32; as *mut u32;
@ -1414,7 +1417,7 @@ where
.enumerate() .enumerate()
{ {
unsafe { unsafe {
ptr.add(idx).write_volatile((*entry).into()); ptr.add(idx).write_volatile(*entry);
} }
} }
@ -1428,9 +1431,9 @@ where
Self::update(); Self::update();
if data.len() >= constants::RMT_CHANNEL_RAM_SIZE { if data.len() >= constants::RMT_CHANNEL_RAM_SIZE {
constants::RMT_CHANNEL_RAM_SIZE Ok(constants::RMT_CHANNEL_RAM_SIZE)
} else { } else {
data.len() Ok(data.len())
} }
} }

View File

@ -73,46 +73,41 @@ async fn main(spawner: Spawner) {
.spawn(signal_task(Output::new(peripherals.GPIO5, Level::Low))) .spawn(signal_task(Output::new(peripherals.GPIO5, Level::Low)))
.unwrap(); .unwrap();
let mut data = [PulseCode { let mut data: [u32; 48] = [PulseCode::empty(); 48];
level1: true,
length1: 1,
level2: false,
length2: 1,
}; 48];
loop { loop {
println!("receive"); println!("receive");
channel.receive(&mut data).await.unwrap(); channel.receive(&mut data).await.unwrap();
let mut total = 0usize; let mut total = 0usize;
for entry in &data[..data.len()] { for entry in &data[..data.len()] {
if entry.length1 == 0 { if entry.length1() == 0 {
break; break;
} }
total += entry.length1 as usize; total += entry.length1() as usize;
if entry.length2 == 0 { if entry.length2() == 0 {
break; break;
} }
total += entry.length2 as usize; total += entry.length2() as usize;
} }
for entry in &data[..data.len()] { for entry in &data[..data.len()] {
if entry.length1 == 0 { if entry.length1() == 0 {
break; break;
} }
let count = WIDTH / (total / entry.length1 as usize); let count = WIDTH / (total / entry.length1() as usize);
let c = if entry.level1 { '-' } else { '_' }; let c = if entry.level1() { '-' } else { '_' };
for _ in 0..count + 1 { for _ in 0..count + 1 {
print!("{}", c); print!("{}", c);
} }
if entry.length2 == 0 { if entry.length2() == 0 {
break; break;
} }
let count = WIDTH / (total / entry.length2 as usize); let count = WIDTH / (total / entry.length2() as usize);
let c = if entry.level2 { '-' } else { '_' }; let c = if entry.level2() { '-' } else { '_' };
for _ in 0..count + 1 { for _ in 0..count + 1 {
print!("{}", c); print!("{}", c);
} }

View File

@ -50,20 +50,10 @@ async fn main(_spawner: Spawner) {
) )
.unwrap(); .unwrap();
let mut data = [PulseCode { let mut data = [PulseCode::new(true, 200, false, 50); 20];
level1: true,
length1: 200,
level2: false,
length2: 50,
}; 20];
data[data.len() - 2] = PulseCode { data[data.len() - 2] = PulseCode::new(true, 3000, false, 500);
level1: true, data[data.len() - 1] = PulseCode::empty();
length1: 3000,
level2: false,
length2: 500,
};
data[data.len() - 1] = PulseCode::default();
loop { loop {
println!("transmit"); println!("transmit");

View File

@ -56,17 +56,11 @@ fn main() -> ! {
let delay = Delay::new(); let delay = Delay::new();
let mut data = [PulseCode { let mut data: [u32; 48] = [PulseCode::empty(); 48];
level1: true,
length1: 1,
level2: false,
length2: 1,
}; 48];
loop { loop {
for x in data.iter_mut() { for x in data.iter_mut() {
x.length1 = 0; x.reset()
x.length2 = 0;
} }
let transaction = channel.receive(&mut data).unwrap(); let transaction = channel.receive(&mut data).unwrap();
@ -84,34 +78,34 @@ fn main() -> ! {
channel = channel_res; channel = channel_res;
let mut total = 0usize; let mut total = 0usize;
for entry in &data[..data.len()] { for entry in &data[..data.len()] {
if entry.length1 == 0 { if entry.length1() == 0 {
break; break;
} }
total += entry.length1 as usize; total += entry.length1() as usize;
if entry.length2 == 0 { if entry.length2() == 0 {
break; break;
} }
total += entry.length2 as usize; total += entry.length2() as usize;
} }
for entry in &data[..data.len()] { for entry in &data[..data.len()] {
if entry.length1 == 0 { if entry.length1() == 0 {
break; break;
} }
let count = WIDTH / (total / entry.length1 as usize); let count = WIDTH / (total / entry.length1() as usize);
let c = if entry.level1 { '-' } else { '_' }; let c = if entry.level1() { '-' } else { '_' };
for _ in 0..count + 1 { for _ in 0..count + 1 {
print!("{}", c); print!("{}", c);
} }
if entry.length2 == 0 { if entry.length2() == 0 {
break; break;
} }
let count = WIDTH / (total / entry.length2 as usize); let count = WIDTH / (total / entry.length2() as usize);
let c = if entry.level2 { '-' } else { '_' }; let c = if entry.level2() { '-' } else { '_' };
for _ in 0..count + 1 { for _ in 0..count + 1 {
print!("{}", c); print!("{}", c);
} }

View File

@ -43,23 +43,12 @@ fn main() -> ! {
let delay = Delay::new(); let delay = Delay::new();
let mut data = [PulseCode { let mut data = [PulseCode::new(true, 200, false, 50); 20];
level1: true, data[data.len() - 2] = PulseCode::new(true, 3000, false, 500);
length1: 200, data[data.len() - 1] = PulseCode::empty();
level2: false,
length2: 50,
}; 20];
data[data.len() - 2] = PulseCode {
level1: true,
length1: 3000,
level2: false,
length2: 500,
};
data[data.len() - 1] = PulseCode::default();
loop { loop {
let transaction = channel.transmit(&data); let transaction = channel.transmit(&data).unwrap();
channel = transaction.wait().unwrap(); channel = transaction.wait().unwrap();
delay.delay_millis(500); delay.delay_millis(500);
} }

View File

@ -14,6 +14,8 @@ use hil_test as _;
#[cfg(test)] #[cfg(test)]
#[embedded_test::tests] #[embedded_test::tests]
mod tests { mod tests {
use esp_hal::rmt::Error;
use super::*; use super::*;
#[init] #[init]
@ -76,30 +78,15 @@ mod tests {
} }
} }
let mut tx_data = [PulseCode { let mut tx_data = [PulseCode::new(true, 200, false, 50); 20];
level1: true,
length1: 200,
level2: false,
length2: 50,
}; 20];
tx_data[tx_data.len() - 2] = PulseCode { tx_data[tx_data.len() - 2] = PulseCode::new(true, 3000, false, 500);
level1: true, tx_data[tx_data.len() - 1] = PulseCode::empty();
length1: 3000,
level2: false,
length2: 500,
};
tx_data[tx_data.len() - 1] = PulseCode::default();
let mut rcv_data = [PulseCode { let mut rcv_data: [u32; 20] = [PulseCode::empty(); 20];
level1: false,
length1: 0,
level2: false,
length2: 0,
}; 20];
let rx_transaction = rx_channel.receive(&mut rcv_data).unwrap(); let rx_transaction = rx_channel.receive(&mut rcv_data).unwrap();
let tx_transaction = tx_channel.transmit(&tx_data); let tx_transaction = tx_channel.transmit(&tx_data).unwrap();
rx_transaction.wait().unwrap(); rx_transaction.wait().unwrap();
tx_transaction.wait().unwrap(); tx_transaction.wait().unwrap();
@ -108,4 +95,41 @@ mod tests {
// they can't be equal // they can't be equal
assert_eq!(&tx_data[..18], &rcv_data[..18]); assert_eq!(&tx_data[..18], &rcv_data[..18]);
} }
#[test]
#[timeout(1)]
fn rmt_single_shot_fails_without_end_marker() {
let peripherals = esp_hal::init(esp_hal::Config::default());
let io = Io::new(peripherals.GPIO, peripherals.IO_MUX);
cfg_if::cfg_if! {
if #[cfg(feature = "esp32h2")] {
let freq = 32.MHz();
} else {
let freq = 80.MHz();
}
};
let rmt = Rmt::new(peripherals.RMT, freq).unwrap();
let (_, tx) = hil_test::common_test_pins!(io);
let tx_config = TxChannelConfig {
clk_divider: 255,
..TxChannelConfig::default()
};
let tx_channel = {
use esp_hal::rmt::TxChannelCreator;
rmt.channel0.configure(tx, tx_config).unwrap()
};
let tx_data = [PulseCode::new(true, 200, false, 50); 20];
let tx_transaction = tx_channel.transmit(&tx_data);
assert!(tx_transaction.is_err());
assert!(matches!(tx_transaction, Err(Error::EndMarkerMissing)));
}
} }