diff --git a/embassy-usb-driver/Cargo.toml b/embassy-usb-driver/Cargo.toml index 41493f00d..edb6551b0 100644 --- a/embassy-usb-driver/Cargo.toml +++ b/embassy-usb-driver/Cargo.toml @@ -20,3 +20,4 @@ features = ["defmt"] [dependencies] defmt = { version = "0.3", optional = true } +embedded-io-async = "0.6.1" diff --git a/embassy-usb-driver/src/lib.rs b/embassy-usb-driver/src/lib.rs index 3b705c8c4..d204e4d85 100644 --- a/embassy-usb-driver/src/lib.rs +++ b/embassy-usb-driver/src/lib.rs @@ -395,3 +395,12 @@ pub enum EndpointError { /// The endpoint is disabled. Disabled, } + +impl embedded_io_async::Error for EndpointError { + fn kind(&self) -> embedded_io_async::ErrorKind { + match self { + Self::BufferOverflow => embedded_io_async::ErrorKind::OutOfMemory, + Self::Disabled => embedded_io_async::ErrorKind::NotConnected, + } + } +} diff --git a/embassy-usb/Cargo.toml b/embassy-usb/Cargo.toml index 771190c89..4950fbe2a 100644 --- a/embassy-usb/Cargo.toml +++ b/embassy-usb/Cargo.toml @@ -52,6 +52,7 @@ embassy-sync = { version = "0.6.2", path = "../embassy-sync" } embassy-net-driver-channel = { version = "0.3.0", path = "../embassy-net-driver-channel" } defmt = { version = "0.3", optional = true } +embedded-io-async = "0.6.1" log = { version = "0.4.14", optional = true } heapless = "0.8" diff --git a/embassy-usb/src/class/cdc_acm.rs b/embassy-usb/src/class/cdc_acm.rs index ea9d9fb7b..732a433f8 100644 --- a/embassy-usb/src/class/cdc_acm.rs +++ b/embassy-usb/src/class/cdc_acm.rs @@ -410,6 +410,18 @@ impl<'d, D: Driver<'d>> Sender<'d, D> { } } +impl<'d, D: Driver<'d>> embedded_io_async::ErrorType for Sender<'d, D> { + type Error = EndpointError; +} + +impl<'d, D: Driver<'d>> embedded_io_async::Write for Sender<'d, D> { + async fn write(&mut self, buf: &[u8]) -> Result { + let len = core::cmp::min(buf.len(), self.max_packet_size() as usize); + self.write_packet(&buf[..len]).await?; + Ok(len) + } +} + /// CDC ACM class packet receiver. /// /// You can obtain a `Receiver` with [`CdcAcmClass::split`] @@ -451,6 +463,93 @@ impl<'d, D: Driver<'d>> Receiver<'d, D> { pub async fn wait_connection(&mut self) { self.read_ep.wait_enabled().await; } + + /// Turn the `Receiver` into a [`BufferedReceiver`]. + /// + /// The supplied buffer must be large enough to hold max_packet_size bytes. + pub fn into_buffered(self, buf: &'d mut [u8]) -> BufferedReceiver<'d, D> { + BufferedReceiver { + receiver: self, + buffer: buf, + start: 0, + end: 0, + } + } +} + +/// CDC ACM class buffered receiver. +/// +/// It is a requirement of the [`embedded_io_async::Read`] trait that arbitrarily small lengths of +/// data can be read from the stream. The [`Receiver`] can only read full packets at a time. The +/// `BufferedReceiver` instead buffers a single packet if the caller does not read all of the data, +/// so that the remaining data can be returned in subsequent calls. +/// +/// If you have no requirement to use the [`embedded_io_async::Read`] trait or to read a data length +/// less than the packet length, then it is more efficient to use the [`Receiver`] directly. +/// +/// You can obtain a `BufferedReceiver` with [`Receiver::into_buffered`]. +/// +/// [`embedded_io_async::Read`]: https://docs.rs/embedded-io-async/latest/embedded_io_async/trait.Read.html +pub struct BufferedReceiver<'d, D: Driver<'d>> { + receiver: Receiver<'d, D>, + buffer: &'d mut [u8], + start: usize, + end: usize, +} + +impl<'d, D: Driver<'d>> BufferedReceiver<'d, D> { + fn read_from_buffer(&mut self, buf: &mut [u8]) -> usize { + let available = &self.buffer[self.start..self.end]; + let len = core::cmp::min(available.len(), buf.len()); + buf[..len].copy_from_slice(&self.buffer[..len]); + self.start += len; + len + } + + /// Gets the current line coding. The line coding contains information that's mainly relevant + /// for USB to UART serial port emulators, and can be ignored if not relevant. + pub fn line_coding(&self) -> LineCoding { + self.receiver.line_coding() + } + + /// Gets the DTR (data terminal ready) state + pub fn dtr(&self) -> bool { + self.receiver.dtr() + } + + /// Gets the RTS (request to send) state + pub fn rts(&self) -> bool { + self.receiver.rts() + } + + /// Waits for the USB host to enable this interface + pub async fn wait_connection(&mut self) { + self.receiver.wait_connection().await; + } +} + +impl<'d, D: Driver<'d>> embedded_io_async::ErrorType for BufferedReceiver<'d, D> { + type Error = EndpointError; +} + +impl<'d, D: Driver<'d>> embedded_io_async::Read for BufferedReceiver<'d, D> { + async fn read(&mut self, buf: &mut [u8]) -> Result { + // If there is a buffered packet, return data from that first + if self.start != self.end { + return Ok(self.read_from_buffer(buf)); + } + + // If the caller's buffer is large enough to contain an entire packet, read directly into + // that instead of buffering the packet internally. + if buf.len() > self.receiver.max_packet_size() as usize { + return self.receiver.read_packet(buf).await; + } + + // Otherwise read a packet into the internal buffer, and return some of it to the caller + self.start = 0; + self.end = self.receiver.read_packet(&mut self.buffer).await?; + return Ok(self.read_from_buffer(buf)); + } } /// Number of stop bits for LineCoding