Allow raw socket to receive all protocols and versions

This commit is contained in:
KingCol13 2025-06-24 19:37:42 +01:00
parent e2b75e37d7
commit a6cd31e090
3 changed files with 44 additions and 32 deletions

View File

@ -66,8 +66,8 @@ fn main() {
// Will not send IGMP
let raw_tx_buffer = raw::PacketBuffer::new(vec![], vec![]);
let raw_socket = raw::Socket::new(
IpVersion::Ipv4,
IpProtocol::Igmp,
Some(IpVersion::Ipv4),
Some(IpProtocol::Igmp),
raw_rx_buffer,
raw_tx_buffer,
);

View File

@ -852,7 +852,12 @@ fn test_raw_socket_no_reply(#[case] medium: Medium) {
vec![raw::PacketMetadata::EMPTY; packets],
vec![0; 48 * packets],
);
let raw_socket = raw::Socket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer);
let raw_socket = raw::Socket::new(
Some(IpVersion::Ipv4),
Some(IpProtocol::Udp),
rx_buffer,
tx_buffer,
);
sockets.add(raw_socket);
let src_addr = Ipv4Address::new(127, 0, 0, 2);
@ -948,8 +953,8 @@ fn test_raw_socket_with_udp_socket(#[case] medium: Medium) {
vec![0; 48 * packets],
);
let raw_socket = raw::Socket::new(
IpVersion::Ipv4,
IpProtocol::Udp,
Some(IpVersion::Ipv4),
Some(IpProtocol::Udp),
raw_rx_buffer,
raw_tx_buffer,
);

View File

@ -80,12 +80,12 @@ pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>;
/// A raw IP socket.
///
/// A raw socket is bound to a specific IP protocol, and owns
/// A raw socket may be bound to a specific IP protocol, and owns
/// transmit and receive packet buffers.
#[derive(Debug)]
pub struct Socket<'a> {
ip_version: IpVersion,
ip_protocol: IpProtocol,
ip_version: Option<IpVersion>,
ip_protocol: Option<IpProtocol>,
rx_buffer: PacketBuffer<'a>,
tx_buffer: PacketBuffer<'a>,
#[cfg(feature = "async")]
@ -98,8 +98,8 @@ impl<'a> Socket<'a> {
/// Create a raw IP socket bound to the given IP version and datagram protocol,
/// with the given buffers.
pub fn new(
ip_version: IpVersion,
ip_protocol: IpProtocol,
ip_version: Option<IpVersion>,
ip_protocol: Option<IpProtocol>,
rx_buffer: PacketBuffer<'a>,
tx_buffer: PacketBuffer<'a>,
) -> Socket<'a> {
@ -152,13 +152,13 @@ impl<'a> Socket<'a> {
/// Return the IP version the socket is bound to.
#[inline]
pub fn ip_version(&self) -> IpVersion {
pub fn ip_version(&self) -> Option<IpVersion> {
self.ip_version
}
/// Return the IP protocol the socket is bound to.
#[inline]
pub fn ip_protocol(&self) -> IpProtocol {
pub fn ip_protocol(&self) -> Option<IpProtocol> {
self.ip_protocol
}
@ -216,7 +216,7 @@ impl<'a> Socket<'a> {
.map_err(|_| SendError::BufferFull)?;
net_trace!(
"raw:{}:{}: buffer to send {} octets",
"raw:{:?}:{:?}: buffer to send {} octets",
self.ip_version,
self.ip_protocol,
packet_buf.len()
@ -238,7 +238,7 @@ impl<'a> Socket<'a> {
.map_err(|_| SendError::BufferFull)?;
net_trace!(
"raw:{}:{}: buffer to send {} octets",
"raw:{:?}:{:?}: buffer to send {} octets",
self.ip_version,
self.ip_protocol,
size
@ -265,7 +265,7 @@ impl<'a> Socket<'a> {
let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?;
net_trace!(
"raw:{}:{}: receive {} buffered octets",
"raw:{:?}:{:?}: receive {} buffered octets",
self.ip_version,
self.ip_protocol,
packet_buf.len()
@ -299,7 +299,7 @@ impl<'a> Socket<'a> {
let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?;
net_trace!(
"raw:{}:{}: receive {} buffered octets",
"raw:{:?}:{:?}: receive {} buffered octets",
self.ip_version,
self.ip_protocol,
packet_buf.len()
@ -338,10 +338,17 @@ impl<'a> Socket<'a> {
}
pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool {
if ip_repr.version() != self.ip_version {
if self
.ip_version
.is_some_and(|version| version != ip_repr.version())
{
return false;
}
if ip_repr.next_header() != self.ip_protocol {
if self
.ip_protocol
.is_some_and(|next_header| next_header != ip_repr.next_header())
{
return false;
}
@ -355,7 +362,7 @@ impl<'a> Socket<'a> {
let total_len = header_len + payload.len();
net_trace!(
"raw:{}:{}: receiving {} octets",
"raw:{:?}:{:?}: receiving {} octets",
self.ip_version,
self.ip_protocol,
total_len
@ -367,7 +374,7 @@ impl<'a> Socket<'a> {
buf[header_len..].copy_from_slice(payload);
}
Err(_) => net_trace!(
"raw:{}:{}: buffer full, dropped incoming packet",
"raw:{:?}:{:?}: buffer full, dropped incoming packet",
self.ip_version,
self.ip_protocol
),
@ -395,7 +402,7 @@ impl<'a> Socket<'a> {
return Ok(());
}
};
if packet.next_header() != ip_protocol {
if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) {
net_trace!("raw: sent packet with wrong ip protocol, dropping.");
return Ok(());
}
@ -415,7 +422,7 @@ impl<'a> Socket<'a> {
return Ok(());
}
};
net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol);
emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload()))
}
#[cfg(feature = "proto-ipv6")]
@ -427,7 +434,7 @@ impl<'a> Socket<'a> {
return Ok(());
}
};
if packet.next_header() != ip_protocol {
if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) {
net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping.");
return Ok(());
}
@ -440,7 +447,7 @@ impl<'a> Socket<'a> {
}
};
net_trace!("raw:{}:{}: sending", ip_version, ip_protocol);
net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol);
emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload()))
}
Err(_) => {
@ -495,8 +502,8 @@ mod test {
tx_buffer: PacketBuffer<'static>,
) -> Socket<'static> {
Socket::new(
IpVersion::Ipv4,
IpProtocol::Unknown(IP_PROTO),
Some(IpVersion::Ipv4),
Some(IpProtocol::Unknown(IP_PROTO)),
rx_buffer,
tx_buffer,
)
@ -526,8 +533,8 @@ mod test {
tx_buffer: PacketBuffer<'static>,
) -> Socket<'static> {
Socket::new(
IpVersion::Ipv6,
IpProtocol::Unknown(IP_PROTO),
Some(IpVersion::Ipv6),
Some(IpProtocol::Unknown(IP_PROTO)),
rx_buffer,
tx_buffer,
)
@ -827,8 +834,8 @@ mod test {
#[cfg(feature = "proto-ipv4")]
{
let socket = Socket::new(
IpVersion::Ipv4,
IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1),
Some(IpVersion::Ipv4),
Some(IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1)),
buffer(1),
buffer(1),
);
@ -839,8 +846,8 @@ mod test {
#[cfg(feature = "proto-ipv6")]
{
let socket = Socket::new(
IpVersion::Ipv6,
IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1),
Some(IpVersion::Ipv6),
Some(IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1)),
buffer(1),
buffer(1),
);