From f5658d6833cb140296a0b6f25b7eb6d16f06c520 Mon Sep 17 00:00:00 2001 From: Corey Schuhen Date: Thu, 29 May 2025 20:48:40 +1000 Subject: [PATCH] Put State inside a critical section mutex of RefCell. This removes the unsound code that was giving out mut&. to State --- embassy-stm32/src/can/fdcan.rs | 254 +++++++++++++++------------------ 1 file changed, 112 insertions(+), 142 deletions(-) diff --git a/embassy-stm32/src/can/fdcan.rs b/embassy-stm32/src/can/fdcan.rs index a7815f79b..4495280c4 100644 --- a/embassy-stm32/src/can/fdcan.rs +++ b/embassy-stm32/src/can/fdcan.rs @@ -52,36 +52,39 @@ impl interrupt::typelevel::Handler for IT0Interrup regs.ir().write(|w| w.set_tefn(true)); } - match &T::state().tx_mode { - TxMode::NonBuffered(waker) => waker.wake(), - TxMode::ClassicBuffered(buf) => { - if !T::registers().tx_queue_is_full() { - match buf.tx_receiver.try_receive() { - Ok(frame) => { - _ = T::registers().write(&frame); + T::info().state.lock(|s| { + let state = s.borrow_mut(); + match &state.tx_mode { + TxMode::NonBuffered(waker) => waker.wake(), + TxMode::ClassicBuffered(buf) => { + if !T::registers().tx_queue_is_full() { + match buf.tx_receiver.try_receive() { + Ok(frame) => { + _ = T::registers().write(&frame); + } + Err(_) => {} + } + } + } + TxMode::FdBuffered(buf) => { + if !T::registers().tx_queue_is_full() { + match buf.tx_receiver.try_receive() { + Ok(frame) => { + _ = T::registers().write(&frame); + } + Err(_) => {} } - Err(_) => {} } } } - TxMode::FdBuffered(buf) => { - if !T::registers().tx_queue_is_full() { - match buf.tx_receiver.try_receive() { - Ok(frame) => { - _ = T::registers().write(&frame); - } - Err(_) => {} - } - } - } - } - if ir.rfn(0) { - T::state().rx_mode.on_interrupt::(0); - } - if ir.rfn(1) { - T::state().rx_mode.on_interrupt::(1); - } + if ir.rfn(0) { + state.rx_mode.on_interrupt::(0, state.ns_per_timer_tick); + } + if ir.rfn(1) { + state.rx_mode.on_interrupt::(1, state.ns_per_timer_tick); + } + }); if ir.bo() { regs.ir().write(|w| w.set_bo(true)); @@ -165,7 +168,6 @@ pub struct CanConfigurator<'d> { _phantom: PhantomData<&'d ()>, config: crate::can::fd::config::FdCanConfig, info: &'static Info, - state: &'static State, /// Reference to internals. properties: Properties, periph_clock: crate::time::Hertz, @@ -188,9 +190,10 @@ impl<'d> CanConfigurator<'d> { rcc::enable_and_reset::(); let info = T::info(); - let state = unsafe { T::mut_state() }; - state.tx_pin_port = Some(tx.pin_port()); - state.rx_pin_port = Some(rx.pin_port()); + T::info().state.lock(|s| { + s.borrow_mut().tx_pin_port = Some(tx.pin_port()); + s.borrow_mut().rx_pin_port = Some(rx.pin_port()); + }); (info.internal_operation)(InternalOperation::NotifySenderCreated); (info.internal_operation)(InternalOperation::NotifyReceiverCreated); @@ -209,7 +212,6 @@ impl<'d> CanConfigurator<'d> { _phantom: PhantomData, config, info, - state, properties: Properties::new(T::info()), periph_clock: T::frequency(), } @@ -261,12 +263,8 @@ impl<'d> CanConfigurator<'d> { /// Start in mode. pub fn start(self, mode: OperatingMode) -> Can<'d> { let ns_per_timer_tick = calc_ns_per_timer_tick(self.info, self.periph_clock, self.config.frame_transmit); - critical_section::with(|_| { - let state = self.state as *const State; - unsafe { - let mut_state = state as *mut State; - (*mut_state).ns_per_timer_tick = ns_per_timer_tick; - } + self.info.state.lock(|s| { + s.borrow_mut().ns_per_timer_tick = ns_per_timer_tick; }); self.info.regs.into_mode(self.config, mode); (self.info.internal_operation)(InternalOperation::NotifySenderCreated); @@ -275,7 +273,6 @@ impl<'d> CanConfigurator<'d> { _phantom: PhantomData, config: self.config, info: self.info, - state: self.state, _mode: mode, properties: Properties::new(self.info), } @@ -309,7 +306,6 @@ pub struct Can<'d> { _phantom: PhantomData<&'d ()>, config: crate::can::fd::config::FdCanConfig, info: &'static Info, - state: &'static State, _mode: OperatingMode, properties: Properties, } @@ -323,7 +319,9 @@ impl<'d> Can<'d> { /// Flush one of the TX mailboxes. pub async fn flush(&self, idx: usize) { poll_fn(|cx| { - self.state.tx_mode.register(cx.waker()); + self.info.state.lock(|s| { + s.borrow_mut().tx_mode.register(cx.waker()); + }); if idx > 3 { panic!("Bad mailbox"); @@ -343,12 +341,12 @@ impl<'d> Can<'d> { /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. pub async fn write(&mut self, frame: &Frame) -> Option { - self.state.tx_mode.write(self.info, frame).await + TxMode::write(self.info, frame).await } /// Returns the next received message frame pub async fn read(&mut self) -> Result { - self.state.rx_mode.read_classic(self.info, self.state).await + RxMode::read_classic(self.info).await } /// Queues the message to be sent but exerts backpressure. If a lower-priority @@ -356,12 +354,12 @@ impl<'d> Can<'d> { /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. pub async fn write_fd(&mut self, frame: &FdFrame) -> Option { - self.state.tx_mode.write_fd(self.info, frame).await + TxMode::write_fd(self.info, frame).await } /// Returns the next received message frame pub async fn read_fd(&mut self) -> Result { - self.state.rx_mode.read_fd(self.info, self.state).await + RxMode::read_fd(self.info).await } /// Split instance into separate portions: Tx(write), Rx(read), common properties @@ -372,14 +370,12 @@ impl<'d> Can<'d> { CanTx { _phantom: PhantomData, info: self.info, - state: self.state, config: self.config, _mode: self._mode, }, CanRx { _phantom: PhantomData, info: self.info, - state: self.state, _mode: self._mode, }, Properties { @@ -395,7 +391,6 @@ impl<'d> Can<'d> { _phantom: PhantomData, config: tx.config, info: tx.info, - state: tx.state, _mode: rx._mode, properties: Properties::new(tx.info), } @@ -407,7 +402,7 @@ impl<'d> Can<'d> { tx_buf: &'static mut TxBuf, rxb: &'static mut RxBuf, ) -> BufferedCan<'d, TX_BUF_SIZE, RX_BUF_SIZE> { - BufferedCan::new(self.info, self.state, self._mode, tx_buf, rxb) + BufferedCan::new(self.info, self._mode, tx_buf, rxb) } /// Return a buffered instance of driver with CAN FD support. User must supply Buffers @@ -416,7 +411,7 @@ impl<'d> Can<'d> { tx_buf: &'static mut TxFdBuf, rxb: &'static mut RxFdBuf, ) -> BufferedCanFd<'d, TX_BUF_SIZE, RX_BUF_SIZE> { - BufferedCanFd::new(self.info, self.state, self._mode, tx_buf, rxb) + BufferedCanFd::new(self.info, self._mode, tx_buf, rxb) } } @@ -437,7 +432,6 @@ pub type TxBuf = Channel { _phantom: PhantomData<&'d ()>, info: &'static Info, - state: &'static State, _mode: OperatingMode, tx_buf: &'static TxBuf, rx_buf: &'static RxBuf, @@ -447,7 +441,6 @@ pub struct BufferedCan<'d, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> { impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> BufferedCan<'d, TX_BUF_SIZE, RX_BUF_SIZE> { fn new( info: &'static Info, - state: &'static State, _mode: OperatingMode, tx_buf: &'static TxBuf, rx_buf: &'static RxBuf, @@ -457,7 +450,6 @@ impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> BufferedCan<'d, BufferedCan { _phantom: PhantomData, info, - state, _mode, tx_buf, rx_buf, @@ -473,19 +465,15 @@ impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> BufferedCan<'d, fn setup(self) -> Self { // We don't want interrupts being processed while we change modes. - critical_section::with(|_| { + self.info.state.lock(|s| { let rx_inner = super::common::ClassicBufferedRxInner { rx_sender: self.rx_buf.sender().into(), }; let tx_inner = super::common::ClassicBufferedTxInner { tx_receiver: self.tx_buf.receiver().into(), }; - let state = self.state as *const State; - unsafe { - let mut_state = state as *mut State; - (*mut_state).rx_mode = RxMode::ClassicBuffered(rx_inner); - (*mut_state).tx_mode = TxMode::ClassicBuffered(tx_inner); - } + s.borrow_mut().rx_mode = RxMode::ClassicBuffered(rx_inner); + s.borrow_mut().tx_mode = TxMode::ClassicBuffered(tx_inner); }); self } @@ -545,7 +533,6 @@ pub type BufferedFdCanReceiver = super::common::BufferedReceiver<'static, FdEnve pub struct BufferedCanFd<'d, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> { _phantom: PhantomData<&'d ()>, info: &'static Info, - state: &'static State, _mode: OperatingMode, tx_buf: &'static TxFdBuf, rx_buf: &'static RxFdBuf, @@ -555,7 +542,6 @@ pub struct BufferedCanFd<'d, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> BufferedCanFd<'d, TX_BUF_SIZE, RX_BUF_SIZE> { fn new( info: &'static Info, - state: &'static State, _mode: OperatingMode, tx_buf: &'static TxFdBuf, rx_buf: &'static RxFdBuf, @@ -565,7 +551,6 @@ impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> BufferedCanFd<' BufferedCanFd { _phantom: PhantomData, info, - state, _mode, tx_buf, rx_buf, @@ -581,19 +566,15 @@ impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> BufferedCanFd<' fn setup(self) -> Self { // We don't want interrupts being processed while we change modes. - critical_section::with(|_| { + self.info.state.lock(|s| { let rx_inner = super::common::FdBufferedRxInner { rx_sender: self.rx_buf.sender().into(), }; let tx_inner = super::common::FdBufferedTxInner { tx_receiver: self.tx_buf.receiver().into(), }; - let state = self.state as *const State; - unsafe { - let mut_state = state as *mut State; - (*mut_state).rx_mode = RxMode::FdBuffered(rx_inner); - (*mut_state).tx_mode = TxMode::FdBuffered(tx_inner); - } + s.borrow_mut().rx_mode = RxMode::FdBuffered(rx_inner); + s.borrow_mut().tx_mode = TxMode::FdBuffered(tx_inner); }); self } @@ -641,19 +622,18 @@ impl<'c, 'd, const TX_BUF_SIZE: usize, const RX_BUF_SIZE: usize> Drop for Buffer pub struct CanRx<'d> { _phantom: PhantomData<&'d ()>, info: &'static Info, - state: &'static State, _mode: OperatingMode, } impl<'d> CanRx<'d> { /// Returns the next received message frame pub async fn read(&mut self) -> Result { - self.state.rx_mode.read_classic(&self.info, &self.state).await + RxMode::read_classic(&self.info).await } /// Returns the next received message frame pub async fn read_fd(&mut self) -> Result { - self.state.rx_mode.read_fd(&self.info, &self.state).await + RxMode::read_fd(&self.info).await } } @@ -667,7 +647,6 @@ impl<'d> Drop for CanRx<'d> { pub struct CanTx<'d> { _phantom: PhantomData<&'d ()>, info: &'static Info, - state: &'static State, config: crate::can::fd::config::FdCanConfig, _mode: OperatingMode, } @@ -678,7 +657,7 @@ impl<'c, 'd> CanTx<'d> { /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. pub async fn write(&mut self, frame: &Frame) -> Option { - self.state.tx_mode.write(self.info, frame).await + TxMode::write(self.info, frame).await } /// Queues the message to be sent but exerts backpressure. If a lower-priority @@ -686,7 +665,7 @@ impl<'c, 'd> CanTx<'d> { /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. pub async fn write_fd(&mut self, frame: &FdFrame) -> Option { - self.state.tx_mode.write_fd(self.info, frame).await + TxMode::write_fd(self.info, frame).await } } @@ -712,19 +691,19 @@ impl RxMode { } } - fn on_interrupt(&self, fifonr: usize) { + fn on_interrupt(&self, fifonr: usize, ns_per_timer_tick: u64) { T::registers().regs.ir().write(|w| w.set_rfn(fifonr, true)); match self { RxMode::NonBuffered(waker) => { waker.wake(); } RxMode::ClassicBuffered(buf) => { - if let Some(result) = self.try_read::() { + if let Some(result) = self.try_read::(ns_per_timer_tick) { let _ = buf.rx_sender.try_send(result); } } RxMode::FdBuffered(buf) => { - if let Some(result) = self.try_read_fd::() { + if let Some(result) = self.try_read_fd::(ns_per_timer_tick) { let _ = buf.rx_sender.try_send(result); } } @@ -732,12 +711,12 @@ impl RxMode { } //async fn read_classic(&self) -> Result { - fn try_read(&self) -> Option> { + fn try_read(&self, ns_per_timer_tick: u64) -> Option> { if let Some((frame, ts)) = T::registers().read(0) { - let ts = T::calc_timestamp(T::state().ns_per_timer_tick, ts); + let ts = T::calc_timestamp(ns_per_timer_tick, ts); Some(Ok(Envelope { ts, frame })) } else if let Some((frame, ts)) = T::registers().read(1) { - let ts = T::calc_timestamp(T::state().ns_per_timer_tick, ts); + let ts = T::calc_timestamp(ns_per_timer_tick, ts); Some(Ok(Envelope { ts, frame })) } else if let Some(err) = T::registers().curr_error() { // TODO: this is probably wrong @@ -747,12 +726,12 @@ impl RxMode { } } - fn try_read_fd(&self) -> Option> { + fn try_read_fd(&self, ns_per_timer_tick: u64) -> Option> { if let Some((frame, ts)) = T::registers().read(0) { - let ts = T::calc_timestamp(T::state().ns_per_timer_tick, ts); + let ts = T::calc_timestamp(ns_per_timer_tick, ts); Some(Ok(FdEnvelope { ts, frame })) } else if let Some((frame, ts)) = T::registers().read(1) { - let ts = T::calc_timestamp(T::state().ns_per_timer_tick, ts); + let ts = T::calc_timestamp(ns_per_timer_tick, ts); Some(Ok(FdEnvelope { ts, frame })) } else if let Some(err) = T::registers().curr_error() { // TODO: this is probably wrong @@ -762,16 +741,12 @@ impl RxMode { } } - fn read( - &self, - info: &'static Info, - state: &'static State, - ) -> Option> { + fn read(info: &'static Info, ns_per_timer_tick: u64) -> Option> { if let Some((msg, ts)) = info.regs.read(0) { - let ts = info.calc_timestamp(state.ns_per_timer_tick, ts); + let ts = info.calc_timestamp(ns_per_timer_tick, ts); Some(Ok((msg, ts))) } else if let Some((msg, ts)) = info.regs.read(1) { - let ts = info.calc_timestamp(state.ns_per_timer_tick, ts); + let ts = info.calc_timestamp(ns_per_timer_tick, ts); Some(Ok((msg, ts))) } else if let Some(err) = info.regs.curr_error() { // TODO: this is probably wrong @@ -781,16 +756,15 @@ impl RxMode { } } - async fn read_async( - &self, - info: &'static Info, - state: &'static State, - ) -> Result<(F, Timestamp), BusError> { - //let _ = self.read::(info, state); + async fn read_async(info: &'static Info) -> Result<(F, Timestamp), BusError> { poll_fn(move |cx| { - state.err_waker.register(cx.waker()); - self.register(cx.waker()); - match self.read::<_>(info, state) { + let ns_per_timer_tick = info.state.lock(|s| { + let state = s.borrow_mut(); + state.err_waker.register(cx.waker()); + state.rx_mode.register(cx.waker()); + state.ns_per_timer_tick + }); + match RxMode::read::<_>(info, ns_per_timer_tick) { Some(result) => Poll::Ready(result), None => Poll::Pending, } @@ -798,15 +772,15 @@ impl RxMode { .await } - async fn read_classic(&self, info: &'static Info, state: &'static State) -> Result { - match self.read_async::<_>(info, state).await { + async fn read_classic(info: &'static Info) -> Result { + match RxMode::read_async::<_>(info).await { Ok((frame, ts)) => Ok(Envelope { ts, frame }), Err(e) => Err(e), } } - async fn read_fd(&self, info: &'static Info, state: &'static State) -> Result { - match self.read_async::<_>(info, state).await { + async fn read_fd(info: &'static Info) -> Result { + match RxMode::read_async::<_>(info).await { Ok((frame, ts)) => Ok(FdEnvelope { ts, frame }), Err(e) => Err(e), } @@ -835,9 +809,11 @@ impl TxMode { /// frame is dropped from the mailbox, it is returned. If no lower-priority frames /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. - async fn write_generic(&self, info: &'static Info, frame: &F) -> Option { + async fn write_generic(info: &'static Info, frame: &F) -> Option { poll_fn(|cx| { - self.register(cx.waker()); + info.state.lock(|s| { + s.borrow_mut().tx_mode.register(cx.waker()); + }); if let Ok(dropped) = info.regs.write(frame) { return Poll::Ready(dropped); @@ -854,16 +830,16 @@ impl TxMode { /// frame is dropped from the mailbox, it is returned. If no lower-priority frames /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. - async fn write(&self, info: &'static Info, frame: &Frame) -> Option { - self.write_generic::<_>(info, frame).await + async fn write(info: &'static Info, frame: &Frame) -> Option { + TxMode::write_generic::<_>(info, frame).await } /// Queues the message to be sent but exerts backpressure. If a lower-priority /// frame is dropped from the mailbox, it is returned. If no lower-priority frames /// can be replaced, this call asynchronously waits for a frame to be successfully /// transmitted, then tries again. - async fn write_fd(&self, info: &'static Info, frame: &FdFrame) -> Option { - self.write_generic::<_>(info, frame).await + async fn write_fd(info: &'static Info, frame: &FdFrame) -> Option { + TxMode::write_generic::<_>(info, frame).await } } @@ -961,12 +937,14 @@ impl State { } } +type SharedState = embassy_sync::blocking_mutex::Mutex>; struct Info { regs: Registers, interrupt0: crate::interrupt::Interrupt, _interrupt1: crate::interrupt::Interrupt, tx_waker: fn(), internal_operation: fn(InternalOperation), + state: SharedState, } impl Info { @@ -993,8 +971,6 @@ trait SealedInstance { fn info() -> &'static Info; fn registers() -> crate::can::fd::peripheral::Registers; - fn state() -> &'static State; - unsafe fn mut_state() -> &'static mut State; fn internal_operation(val: InternalOperation); fn calc_timestamp(ns_per_timer_tick: u64, ts_val: u16) -> Timestamp; } @@ -1019,32 +995,30 @@ macro_rules! impl_fdcan { const MSG_RAM_OFFSET: usize = $msg_ram_offset; fn internal_operation(val: InternalOperation) { - critical_section::with(|_| { - //let state = self.state as *const State; - unsafe { - //let mut_state = state as *mut State; - let mut_state = peripherals::$inst::mut_state(); - match val { - InternalOperation::NotifySenderCreated => { - mut_state.sender_instance_count += 1; - } - InternalOperation::NotifySenderDestroyed => { - mut_state.sender_instance_count -= 1; - if ( 0 == mut_state.sender_instance_count) { - (*mut_state).tx_mode = TxMode::NonBuffered(embassy_sync::waitqueue::AtomicWaker::new()); - } - } - InternalOperation::NotifyReceiverCreated => { - mut_state.receiver_instance_count += 1; - } - InternalOperation::NotifyReceiverDestroyed => { - mut_state.receiver_instance_count -= 1; - if ( 0 == mut_state.receiver_instance_count) { - (*mut_state).rx_mode = RxMode::NonBuffered(embassy_sync::waitqueue::AtomicWaker::new()); - } + peripherals::$inst::info().state.lock(|s| { + let mut mut_state = s.borrow_mut(); + match val { + InternalOperation::NotifySenderCreated => { + mut_state.sender_instance_count += 1; + } + InternalOperation::NotifySenderDestroyed => { + mut_state.sender_instance_count -= 1; + if ( 0 == mut_state.sender_instance_count) { + (*mut_state).tx_mode = TxMode::NonBuffered(embassy_sync::waitqueue::AtomicWaker::new()); } } - if mut_state.sender_instance_count == 0 && mut_state.receiver_instance_count == 0 { + InternalOperation::NotifyReceiverCreated => { + mut_state.receiver_instance_count += 1; + } + InternalOperation::NotifyReceiverDestroyed => { + mut_state.receiver_instance_count -= 1; + if ( 0 == mut_state.receiver_instance_count) { + (*mut_state).rx_mode = RxMode::NonBuffered(embassy_sync::waitqueue::AtomicWaker::new()); + } + } + } + if mut_state.sender_instance_count == 0 && mut_state.receiver_instance_count == 0 { + unsafe { let tx_pin = crate::gpio::AnyPin::steal(mut_state.tx_pin_port.unwrap()); tx_pin.set_as_disconnected(); let rx_pin = crate::gpio::AnyPin::steal(mut_state.rx_pin_port.unwrap()); @@ -1054,26 +1028,22 @@ macro_rules! impl_fdcan { } }); } + fn info() -> &'static Info { + static INFO: Info = Info { regs: Registers{regs: crate::pac::$inst, msgram: crate::pac::$msg_ram_inst, msg_ram_offset: $msg_ram_offset}, interrupt0: crate::_generated::peripheral_interrupts::$inst::IT0::IRQ, _interrupt1: crate::_generated::peripheral_interrupts::$inst::IT1::IRQ, tx_waker: crate::_generated::peripheral_interrupts::$inst::IT0::pend, internal_operation: peripherals::$inst::internal_operation, + state: embassy_sync::blocking_mutex::Mutex::new(core::cell::RefCell::new(State::new())), }; &INFO } fn registers() -> Registers { Registers{regs: crate::pac::$inst, msgram: crate::pac::$msg_ram_inst, msg_ram_offset: Self::MSG_RAM_OFFSET} } - unsafe fn mut_state() -> &'static mut State { - static mut STATE: State = State::new(); - &mut *core::ptr::addr_of_mut!(STATE) - } - fn state() -> &'static State { - unsafe { peripherals::$inst::mut_state() } - } #[cfg(feature = "time")] fn calc_timestamp(ns_per_timer_tick: u64, ts_val: u16) -> Timestamp {