From 7b352654653466063fd149a766e2deb14bcfc1ad Mon Sep 17 00:00:00 2001 From: skkeye Date: Mon, 10 Feb 2025 00:51:59 -0500 Subject: [PATCH] feat: Feature match udp sockets fix: fixed compile proto-ipv4/v6 edge cases in the ping module --- embassy-net/src/icmp.rs | 245 +++++++++++++++++++++++++++++++--------- 1 file changed, 192 insertions(+), 53 deletions(-) diff --git a/embassy-net/src/icmp.rs b/embassy-net/src/icmp.rs index 68f6fd536..22c31a589 100644 --- a/embassy-net/src/icmp.rs +++ b/embassy-net/src/icmp.rs @@ -1,8 +1,8 @@ //! ICMP sockets. -use core::future::poll_fn; +use core::future::{poll_fn, Future}; use core::mem; -use core::task::Poll; +use core::task::{Context, Poll}; use smoltcp::iface::{Interface, SocketHandle}; pub use smoltcp::phy::ChecksumCapabilities; @@ -36,6 +36,8 @@ pub enum SendError { NoRoute, /// Socket not bound to an outgoing port. SocketNotBound, + /// There is not enough transmit buffer capacity to ever send this packet. + PacketTooLarge, } /// Error returned by [`IcmpSocket::recv_from`]. @@ -109,25 +111,61 @@ impl<'a> IcmpSocket<'a> { }) } - /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, - /// and return the amount of octets copied as well as the `IpAddress` + /// Wait until the socket becomes readable. /// - /// **Note**: when the size of the provided buffer is smaller than the size of the payload, - /// the packet is dropped and a `RecvError::Truncated` error is returned. - pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpAddress), RecvError> { - poll_fn(move |cx| { - self.with_mut(|s, _| match s.recv_slice(buf) { - Ok(x) => Poll::Ready(Ok(x)), - // No data ready - Err(icmp::RecvError::Exhausted) => { - //s.register_recv_waker(cx.waker()); - cx.waker().wake_by_ref(); - Poll::Pending - } - Err(icmp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)), - }) + /// A socket is readable when a packet has been received, or when there are queued packets in + /// the buffer. + pub fn wait_recv_ready(&self) -> impl Future + '_ { + poll_fn(move |cx| self.poll_recv_ready(cx)) + } + + /// Wait until a datagram can be read. + /// + /// When no datagram is readable, this method will return `Poll::Pending` and + /// register the current task to be notified when a datagram is received. + /// + /// When a datagram is received, this method will return `Poll::Ready`. + pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<()> { + self.with_mut(|s, _| { + if s.can_recv() { + Poll::Ready(()) + } else { + // socket buffer is empty wait until at least one byte has arrived + s.register_recv_waker(cx.waker()); + Poll::Pending + } + }) + } + + /// Receive a datagram. + /// + /// This method will wait until a datagram is received. + /// + /// Returns the number of bytes received and the remote endpoint. + pub fn recv_from<'s>( + &'s self, + buf: &'s mut [u8], + ) -> impl Future> + 's { + poll_fn(|cx| self.poll_recv_from(buf, cx)) + } + + /// Receive a datagram. + /// + /// When no datagram is available, this method will return `Poll::Pending` and + /// register the current task to be notified when a datagram is received. + /// + /// When a datagram is received, this method will return `Poll::Ready` with the + /// number of bytes received and the remote endpoint. + pub fn poll_recv_from(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll> { + self.with_mut(|s, _| match s.recv_slice(buf) { + Ok((n, meta)) => Poll::Ready(Ok((n, meta))), + // No data ready + Err(icmp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)), + Err(icmp::RecvError::Exhausted) => { + s.register_recv_waker(cx.waker()); + Poll::Pending + } }) - .await } /// Dequeue a packet received from a remote endpoint and calls the provided function with the @@ -136,7 +174,7 @@ impl<'a> IcmpSocket<'a> { /// /// **Note**: when the size of the provided buffer is smaller than the size of the payload, /// the packet is dropped and a `RecvError::Truncated` error is returned. - pub async fn recv_with(&self, f: F) -> Result + pub async fn recv_from_with(&self, f: F) -> Result where F: FnOnce((&[u8], IpAddress)) -> R, { @@ -154,16 +192,106 @@ impl<'a> IcmpSocket<'a> { .await } - /// Enqueue a packet to be sent to a given remote address, and fill it from a slice. + /// Wait until the socket becomes writable. + /// + /// A socket becomes writable when there is space in the buffer, from initial memory or after + /// dispatching datagrams on a full buffer. + pub fn wait_send_ready(&self) -> impl Future + '_ { + poll_fn(|cx| self.poll_send_ready(cx)) + } + + /// Wait until a datagram can be sent. + /// + /// When no datagram can be sent (i.e. the buffer is full), this method will return + /// `Poll::Pending` and register the current task to be notified when + /// space is freed in the buffer after a datagram has been dispatched. + /// + /// When a datagram can be sent, this method will return `Poll::Ready`. + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<()> { + self.with_mut(|s, _| { + if s.can_send() { + Poll::Ready(()) + } else { + // socket buffer is full wait until a datagram has been dispatched + s.register_send_waker(cx.waker()); + Poll::Pending + } + }) + } + + /// Send a datagram to the specified remote endpoint. + /// + /// This method will wait until the datagram has been sent. + /// + /// If the socket's send buffer is too small to fit `buf`, this method will return `Err(SendError::PacketTooLarge)` + /// + /// When the remote endpoint is not reachable, this method will return `Err(SendError::NoRoute)` pub async fn send_to(&self, buf: &[u8], remote_endpoint: T) -> Result<(), SendError> where T: Into, { + let remote_endpoint: IpAddress = remote_endpoint.into(); + poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await + } + + /// Send a datagram to the specified remote endpoint. + /// + /// When the datagram has been sent, this method will return `Poll::Ready(Ok())`. + /// + /// When the socket's send buffer is full, this method will return `Poll::Pending` + /// and register the current task to be notified when the buffer has space available. + /// + /// If the socket's send buffer is too small to fit `buf`, this method will return `Poll::Ready(Err(SendError::PacketTooLarge))` + /// + /// When the remote endpoint is not reachable, this method will return `Poll::Ready(Err(Error::NoRoute))`. + pub fn poll_send_to(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll> + where + T: Into, + { + // Don't need to wake waker in `with_mut` if the buffer will never fit the icmp tx_buffer. + let send_capacity_too_small = self.with(|s, _| s.payload_send_capacity() < buf.len()); + if send_capacity_too_small { + return Poll::Ready(Err(SendError::PacketTooLarge)); + } + + self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint.into()) { + // Entire datagram has been sent + Ok(()) => Poll::Ready(Ok(())), + Err(icmp::SendError::BufferFull) => { + s.register_send_waker(cx.waker()); + Poll::Pending + } + Err(icmp::SendError::Unaddressable) => { + // If no sender/outgoing port is specified, there is not really "no route" + if s.is_open() { + Poll::Ready(Err(SendError::NoRoute)) + } else { + Poll::Ready(Err(SendError::SocketNotBound)) + } + } + }) + } + + /// Enqueue a packet to be sent to a given remote address with a zero-copy function. + /// + /// This method will wait until the buffer can fit the requested size before + /// calling the function to fill its contents. + pub async fn send_to_with(&mut self, size: usize, remote_endpoint: T, f: F) -> Result + where + T: Into, + F: FnOnce(&mut [u8]) -> R, + { + // Don't need to wake waker in `with_mut` if the buffer will never fit the icmp tx_buffer. + let send_capacity_too_small = self.with(|s, _| s.payload_send_capacity() < size); + if send_capacity_too_small { + return Err(SendError::PacketTooLarge); + } + + let mut f = Some(f); let remote_endpoint = remote_endpoint.into(); poll_fn(move |cx| { - self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) { - // Entire datagram has been sent - Ok(()) => Poll::Ready(Ok(())), + self.with_mut(|s, _| match s.send(size, remote_endpoint) { + Ok(buf) => Poll::Ready(Ok({ unwrap!(f.take())(buf) })), Err(icmp::SendError::BufferFull) => { s.register_send_waker(cx.waker()); Poll::Pending @@ -174,28 +302,20 @@ impl<'a> IcmpSocket<'a> { .await } - /// Enqueue a packet to be sent to a given remote address with a zero-copy function. + /// Flush the socket. /// - /// This method will wait until the buffer can fit the requested size before - /// calling the function to fill its contents. - pub async fn send_to_with(&self, size: usize, remote_endpoint: T, f: F) -> Result - where - T: Into, - F: FnOnce(&mut [u8]) -> R, - { - let mut f = Some(f); - let remote_endpoint = remote_endpoint.into(); - poll_fn(move |cx| { - self.with_mut(|s, _| match s.send(size, remote_endpoint) { - Ok(buf) => Poll::Ready(Ok(unwrap!(f.take())(buf))), - Err(icmp::SendError::BufferFull) => { + /// This method will wait until the socket is flushed. + pub fn flush(&mut self) -> impl Future + '_ { + poll_fn(|cx| { + self.with_mut(|s, _| { + if s.send_queue() == 0 { + Poll::Ready(()) + } else { s.register_send_waker(cx.waker()); Poll::Pending } - Err(icmp::SendError::Unaddressable) => Poll::Ready(Err(SendError::NoRoute)), }) }) - .await } /// Check whether the socket is open. @@ -280,9 +400,15 @@ pub mod ping { //! }; //! ``` - use core::net::{IpAddr, Ipv6Addr}; + use core::net::IpAddr; + #[cfg(feature = "proto-ipv6")] + use core::net::Ipv6Addr; use embassy_time::{Duration, Instant, Timer, WithTimeout}; + #[cfg(feature = "proto-ipv6")] + use smoltcp::wire::IpAddress; + #[cfg(feature = "proto-ipv6")] + use smoltcp::wire::Ipv6Address; use super::*; @@ -392,11 +518,11 @@ pub mod ping { // make a single ping // - shorts out errors // - select the ip version - let ping_duration = match params.target().unwrap() { + let ping_duration = match params.target.unwrap() { #[cfg(feature = "proto-ipv4")] - IpAddr::V4(_) => self.single_ping_v4(params, seq_no).await?, + IpAddress::Ipv4(_) => self.single_ping_v4(params, seq_no).await?, #[cfg(feature = "proto-ipv6")] - IpAddr::V6(_) => self.single_ping_v6(params, seq_no).await?, + IpAddress::Ipv6(_) => self.single_ping_v6(params, seq_no).await?, }; // safely add up the durations of each ping @@ -478,7 +604,7 @@ pub mod ping { // Helper function to recieve and return the correct echo reply when it finds it async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> { - while match socket.recv_with(|(buf, _)| filter_pong(buf, seq_no)).await { + while match socket.recv_from_with(|(buf, _)| filter_pong(buf, seq_no)).await { Ok(b) => !b, Err(e) => return Err(PingError::SocketRecvError(e)), } {} @@ -548,7 +674,7 @@ pub mod ping { // Helper function to recieve and return the correct echo reply when it finds it async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> { - while match socket.recv_with(|(buf, _)| filter_pong(buf, seq_no)).await { + while match socket.recv_from_with(|(buf, _)| filter_pong(buf, seq_no)).await { Ok(b) => !b, Err(e) => return Err(PingError::SocketRecvError(e)), } {} @@ -581,9 +707,9 @@ pub mod ping { /// * `timeout` - The timeout duration before returning a [`PingError::DestinationHostUnreachable`] error. /// * `rate_limit` - The minimum time per echo request. pub struct PingParams<'a> { - target: Option, + target: Option, #[cfg(feature = "proto-ipv6")] - source: Option, + source: Option, payload: &'a [u8], hop_limit: Option, count: u16, @@ -610,7 +736,7 @@ pub mod ping { /// Creates a new instance of [`PingParams`] with the specified target IP address. pub fn new>(target: T) -> Self { Self { - target: Some(target.into()), + target: Some(PingParams::ip_addr_to_smoltcp(target)), #[cfg(feature = "proto-ipv6")] source: None, payload: b"embassy-net", @@ -621,21 +747,34 @@ pub mod ping { } } + fn ip_addr_to_smoltcp>(ip_addr: T) -> IpAddress { + match ip_addr.into() { + #[cfg(feature = "proto-ipv4")] + IpAddr::V4(v4) => IpAddress::Ipv4(v4), + #[cfg(not(feature = "proto-ipv4"))] + IpAddr::V4(_) => unreachable!(), + #[cfg(feature = "proto-ipv6")] + IpAddr::V6(v6) => IpAddress::Ipv6(v6), + #[cfg(not(feature = "proto-ipv6"))] + IpAddr::V6(_) => unreachable!(), + } + } + /// Sets the target IP address for the ping. pub fn set_target>(&mut self, target: T) -> &mut Self { - self.target = Some(target.into()); + self.target = Some(PingParams::ip_addr_to_smoltcp(target)); self } /// Retrieves the target IP address for the ping. pub fn target(&self) -> Option { - self.target + self.target.map(|t| t.into()) } /// Sets the source IP address for the ping (IPv6 only). #[cfg(feature = "proto-ipv6")] - pub fn set_source(&mut self, source: Ipv6Addr) -> &mut Self { - self.source = Some(source); + pub fn set_source>(&mut self, source: T) -> &mut Self { + self.source = Some(source.into()); self }