feat: Feature match udp sockets

fix: fixed compile proto-ipv4/v6 edge cases in the ping module
This commit is contained in:
skkeye 2025-02-10 00:51:59 -05:00 committed by Ulf Lilleengen
parent 7d2ffa76e5
commit 7b35265465

View File

@ -1,8 +1,8 @@
//! ICMP sockets.
use core::future::poll_fn;
use core::future::{poll_fn, Future};
use core::mem;
use core::task::Poll;
use core::task::{Context, Poll};
use smoltcp::iface::{Interface, SocketHandle};
pub use smoltcp::phy::ChecksumCapabilities;
@ -36,6 +36,8 @@ pub enum SendError {
NoRoute,
/// Socket not bound to an outgoing port.
SocketNotBound,
/// There is not enough transmit buffer capacity to ever send this packet.
PacketTooLarge,
}
/// Error returned by [`IcmpSocket::recv_from`].
@ -109,25 +111,61 @@ impl<'a> IcmpSocket<'a> {
})
}
/// Dequeue a packet received from a remote endpoint, copy the payload into the given slice,
/// and return the amount of octets copied as well as the `IpAddress`
/// Wait until the socket becomes readable.
///
/// **Note**: when the size of the provided buffer is smaller than the size of the payload,
/// the packet is dropped and a `RecvError::Truncated` error is returned.
pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, IpAddress), RecvError> {
poll_fn(move |cx| {
self.with_mut(|s, _| match s.recv_slice(buf) {
Ok(x) => Poll::Ready(Ok(x)),
// No data ready
Err(icmp::RecvError::Exhausted) => {
//s.register_recv_waker(cx.waker());
cx.waker().wake_by_ref();
/// A socket is readable when a packet has been received, or when there are queued packets in
/// the buffer.
pub fn wait_recv_ready(&self) -> impl Future<Output = ()> + '_ {
poll_fn(move |cx| self.poll_recv_ready(cx))
}
/// Wait until a datagram can be read.
///
/// When no datagram is readable, this method will return `Poll::Pending` and
/// register the current task to be notified when a datagram is received.
///
/// When a datagram is received, this method will return `Poll::Ready`.
pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
self.with_mut(|s, _| {
if s.can_recv() {
Poll::Ready(())
} else {
// socket buffer is empty wait until at least one byte has arrived
s.register_recv_waker(cx.waker());
Poll::Pending
}
})
}
/// Receive a datagram.
///
/// This method will wait until a datagram is received.
///
/// Returns the number of bytes received and the remote endpoint.
pub fn recv_from<'s>(
&'s self,
buf: &'s mut [u8],
) -> impl Future<Output = Result<(usize, IpAddress), RecvError>> + 's {
poll_fn(|cx| self.poll_recv_from(buf, cx))
}
/// Receive a datagram.
///
/// When no datagram is available, this method will return `Poll::Pending` and
/// register the current task to be notified when a datagram is received.
///
/// When a datagram is received, this method will return `Poll::Ready` with the
/// number of bytes received and the remote endpoint.
pub fn poll_recv_from(&self, buf: &mut [u8], cx: &mut Context<'_>) -> Poll<Result<(usize, IpAddress), RecvError>> {
self.with_mut(|s, _| match s.recv_slice(buf) {
Ok((n, meta)) => Poll::Ready(Ok((n, meta))),
// No data ready
Err(icmp::RecvError::Truncated) => Poll::Ready(Err(RecvError::Truncated)),
Err(icmp::RecvError::Exhausted) => {
s.register_recv_waker(cx.waker());
Poll::Pending
}
})
})
.await
}
/// Dequeue a packet received from a remote endpoint and calls the provided function with the
@ -136,7 +174,7 @@ impl<'a> IcmpSocket<'a> {
///
/// **Note**: when the size of the provided buffer is smaller than the size of the payload,
/// the packet is dropped and a `RecvError::Truncated` error is returned.
pub async fn recv_with<F, R>(&self, f: F) -> Result<R, RecvError>
pub async fn recv_from_with<F, R>(&self, f: F) -> Result<R, RecvError>
where
F: FnOnce((&[u8], IpAddress)) -> R,
{
@ -154,48 +192,130 @@ impl<'a> IcmpSocket<'a> {
.await
}
/// Enqueue a packet to be sent to a given remote address, and fill it from a slice.
/// Wait until the socket becomes writable.
///
/// A socket becomes writable when there is space in the buffer, from initial memory or after
/// dispatching datagrams on a full buffer.
pub fn wait_send_ready(&self) -> impl Future<Output = ()> + '_ {
poll_fn(|cx| self.poll_send_ready(cx))
}
/// Wait until a datagram can be sent.
///
/// When no datagram can be sent (i.e. the buffer is full), this method will return
/// `Poll::Pending` and register the current task to be notified when
/// space is freed in the buffer after a datagram has been dispatched.
///
/// When a datagram can be sent, this method will return `Poll::Ready`.
pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
self.with_mut(|s, _| {
if s.can_send() {
Poll::Ready(())
} else {
// socket buffer is full wait until a datagram has been dispatched
s.register_send_waker(cx.waker());
Poll::Pending
}
})
}
/// Send a datagram to the specified remote endpoint.
///
/// This method will wait until the datagram has been sent.
///
/// If the socket's send buffer is too small to fit `buf`, this method will return `Err(SendError::PacketTooLarge)`
///
/// When the remote endpoint is not reachable, this method will return `Err(SendError::NoRoute)`
pub async fn send_to<T>(&self, buf: &[u8], remote_endpoint: T) -> Result<(), SendError>
where
T: Into<IpAddress>,
{
let remote_endpoint = remote_endpoint.into();
poll_fn(move |cx| {
self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint) {
let remote_endpoint: IpAddress = remote_endpoint.into();
poll_fn(move |cx| self.poll_send_to(buf, remote_endpoint, cx)).await
}
/// Send a datagram to the specified remote endpoint.
///
/// When the datagram has been sent, this method will return `Poll::Ready(Ok())`.
///
/// When the socket's send buffer is full, this method will return `Poll::Pending`
/// and register the current task to be notified when the buffer has space available.
///
/// If the socket's send buffer is too small to fit `buf`, this method will return `Poll::Ready(Err(SendError::PacketTooLarge))`
///
/// When the remote endpoint is not reachable, this method will return `Poll::Ready(Err(Error::NoRoute))`.
pub fn poll_send_to<T>(&self, buf: &[u8], remote_endpoint: T, cx: &mut Context<'_>) -> Poll<Result<(), SendError>>
where
T: Into<IpAddress>,
{
// Don't need to wake waker in `with_mut` if the buffer will never fit the icmp tx_buffer.
let send_capacity_too_small = self.with(|s, _| s.payload_send_capacity() < buf.len());
if send_capacity_too_small {
return Poll::Ready(Err(SendError::PacketTooLarge));
}
self.with_mut(|s, _| match s.send_slice(buf, remote_endpoint.into()) {
// Entire datagram has been sent
Ok(()) => Poll::Ready(Ok(())),
Err(icmp::SendError::BufferFull) => {
s.register_send_waker(cx.waker());
Poll::Pending
}
Err(icmp::SendError::Unaddressable) => {
// If no sender/outgoing port is specified, there is not really "no route"
if s.is_open() {
Poll::Ready(Err(SendError::NoRoute))
} else {
Poll::Ready(Err(SendError::SocketNotBound))
}
}
})
}
/// Enqueue a packet to be sent to a given remote address with a zero-copy function.
///
/// This method will wait until the buffer can fit the requested size before
/// calling the function to fill its contents.
pub async fn send_to_with<T, F, R>(&mut self, size: usize, remote_endpoint: T, f: F) -> Result<R, SendError>
where
T: Into<IpAddress>,
F: FnOnce(&mut [u8]) -> R,
{
// Don't need to wake waker in `with_mut` if the buffer will never fit the icmp tx_buffer.
let send_capacity_too_small = self.with(|s, _| s.payload_send_capacity() < size);
if send_capacity_too_small {
return Err(SendError::PacketTooLarge);
}
let mut f = Some(f);
let remote_endpoint = remote_endpoint.into();
poll_fn(move |cx| {
self.with_mut(|s, _| match s.send(size, remote_endpoint) {
Ok(buf) => Poll::Ready(Ok({ unwrap!(f.take())(buf) })),
Err(icmp::SendError::BufferFull) => {
s.register_send_waker(cx.waker());
Poll::Pending
}
Err(icmp::SendError::Unaddressable) => Poll::Ready(Err(SendError::NoRoute)),
})
})
.await
}
/// Enqueue a packet to be sent to a given remote address with a zero-copy function.
/// Flush the socket.
///
/// This method will wait until the buffer can fit the requested size before
/// calling the function to fill its contents.
pub async fn send_to_with<T, F, R>(&self, size: usize, remote_endpoint: T, f: F) -> Result<R, SendError>
where
T: Into<IpAddress>,
F: FnOnce(&mut [u8]) -> R,
{
let mut f = Some(f);
let remote_endpoint = remote_endpoint.into();
poll_fn(move |cx| {
self.with_mut(|s, _| match s.send(size, remote_endpoint) {
Ok(buf) => Poll::Ready(Ok(unwrap!(f.take())(buf))),
Err(icmp::SendError::BufferFull) => {
/// This method will wait until the socket is flushed.
pub fn flush(&mut self) -> impl Future<Output = ()> + '_ {
poll_fn(|cx| {
self.with_mut(|s, _| {
if s.send_queue() == 0 {
Poll::Ready(())
} else {
s.register_send_waker(cx.waker());
Poll::Pending
}
Err(icmp::SendError::Unaddressable) => Poll::Ready(Err(SendError::NoRoute)),
})
})
.await
}
/// Check whether the socket is open.
@ -280,9 +400,15 @@ pub mod ping {
//! };
//! ```
use core::net::{IpAddr, Ipv6Addr};
use core::net::IpAddr;
#[cfg(feature = "proto-ipv6")]
use core::net::Ipv6Addr;
use embassy_time::{Duration, Instant, Timer, WithTimeout};
#[cfg(feature = "proto-ipv6")]
use smoltcp::wire::IpAddress;
#[cfg(feature = "proto-ipv6")]
use smoltcp::wire::Ipv6Address;
use super::*;
@ -392,11 +518,11 @@ pub mod ping {
// make a single ping
// - shorts out errors
// - select the ip version
let ping_duration = match params.target().unwrap() {
let ping_duration = match params.target.unwrap() {
#[cfg(feature = "proto-ipv4")]
IpAddr::V4(_) => self.single_ping_v4(params, seq_no).await?,
IpAddress::Ipv4(_) => self.single_ping_v4(params, seq_no).await?,
#[cfg(feature = "proto-ipv6")]
IpAddr::V6(_) => self.single_ping_v6(params, seq_no).await?,
IpAddress::Ipv6(_) => self.single_ping_v6(params, seq_no).await?,
};
// safely add up the durations of each ping
@ -478,7 +604,7 @@ pub mod ping {
// Helper function to recieve and return the correct echo reply when it finds it
async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> {
while match socket.recv_with(|(buf, _)| filter_pong(buf, seq_no)).await {
while match socket.recv_from_with(|(buf, _)| filter_pong(buf, seq_no)).await {
Ok(b) => !b,
Err(e) => return Err(PingError::SocketRecvError(e)),
} {}
@ -548,7 +674,7 @@ pub mod ping {
// Helper function to recieve and return the correct echo reply when it finds it
async fn recv_pong(socket: &IcmpSocket<'_>, seq_no: u16) -> Result<(), PingError> {
while match socket.recv_with(|(buf, _)| filter_pong(buf, seq_no)).await {
while match socket.recv_from_with(|(buf, _)| filter_pong(buf, seq_no)).await {
Ok(b) => !b,
Err(e) => return Err(PingError::SocketRecvError(e)),
} {}
@ -581,9 +707,9 @@ pub mod ping {
/// * `timeout` - The timeout duration before returning a [`PingError::DestinationHostUnreachable`] error.
/// * `rate_limit` - The minimum time per echo request.
pub struct PingParams<'a> {
target: Option<IpAddr>,
target: Option<IpAddress>,
#[cfg(feature = "proto-ipv6")]
source: Option<Ipv6Addr>,
source: Option<Ipv6Address>,
payload: &'a [u8],
hop_limit: Option<u8>,
count: u16,
@ -610,7 +736,7 @@ pub mod ping {
/// Creates a new instance of [`PingParams`] with the specified target IP address.
pub fn new<T: Into<IpAddr>>(target: T) -> Self {
Self {
target: Some(target.into()),
target: Some(PingParams::ip_addr_to_smoltcp(target)),
#[cfg(feature = "proto-ipv6")]
source: None,
payload: b"embassy-net",
@ -621,21 +747,34 @@ pub mod ping {
}
}
fn ip_addr_to_smoltcp<T: Into<IpAddr>>(ip_addr: T) -> IpAddress {
match ip_addr.into() {
#[cfg(feature = "proto-ipv4")]
IpAddr::V4(v4) => IpAddress::Ipv4(v4),
#[cfg(not(feature = "proto-ipv4"))]
IpAddr::V4(_) => unreachable!(),
#[cfg(feature = "proto-ipv6")]
IpAddr::V6(v6) => IpAddress::Ipv6(v6),
#[cfg(not(feature = "proto-ipv6"))]
IpAddr::V6(_) => unreachable!(),
}
}
/// Sets the target IP address for the ping.
pub fn set_target<T: Into<IpAddr>>(&mut self, target: T) -> &mut Self {
self.target = Some(target.into());
self.target = Some(PingParams::ip_addr_to_smoltcp(target));
self
}
/// Retrieves the target IP address for the ping.
pub fn target(&self) -> Option<IpAddr> {
self.target
self.target.map(|t| t.into())
}
/// Sets the source IP address for the ping (IPv6 only).
#[cfg(feature = "proto-ipv6")]
pub fn set_source(&mut self, source: Ipv6Addr) -> &mut Self {
self.source = Some(source);
pub fn set_source<T: Into<Ipv6Address>>(&mut self, source: T) -> &mut Self {
self.source = Some(source.into());
self
}