From 096ce02ac4c935c4df485a79107fb26b5897065d Mon Sep 17 00:00:00 2001 From: Egor Karavaev Date: Sun, 10 Sep 2017 23:18:12 +0300 Subject: [PATCH] Implement a SocketRef smart pointer to detect state changes. --- examples/client.rs | 7 ++--- examples/loopback.rs | 7 ++--- examples/ping.rs | 5 ++- examples/server.rs | 12 +++---- src/iface/ethernet.rs | 23 ++++++-------- src/socket/mod.rs | 68 +++++++++++++++++++--------------------- src/socket/ref_.rs | 73 +++++++++++++++++++++++++++++++++++++++++++ src/socket/set.rs | 39 ++++++++++------------- 8 files changed, 146 insertions(+), 88 deletions(-) create mode 100644 src/socket/ref_.rs diff --git a/examples/client.rs b/examples/client.rs index 4f88bd37..8d518cb4 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -12,8 +12,7 @@ use std::os::unix::io::AsRawFd; use smoltcp::phy::wait as phy_wait; use smoltcp::wire::{EthernetAddress, Ipv4Address, IpAddress, IpCidr}; use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface}; -use smoltcp::socket::{AsSocket, SocketSet}; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer}; +use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; fn main() { utils::setup_logging(""); @@ -50,14 +49,14 @@ fn main() { let tcp_handle = sockets.add(tcp_socket); { - let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket(); + let mut socket = sockets.get::(tcp_handle); socket.connect((address, port), 49500).unwrap(); } let mut tcp_active = false; loop { { - let socket: &mut TcpSocket = sockets.get_mut(tcp_handle).as_socket(); + let mut socket = sockets.get::(tcp_handle); if socket.is_active() && !tcp_active { debug!("connected"); } else if !socket.is_active() && tcp_active { diff --git a/examples/loopback.rs b/examples/loopback.rs index 92139e07..6d378649 100644 --- a/examples/loopback.rs +++ b/examples/loopback.rs @@ -19,8 +19,7 @@ use core::str; use smoltcp::phy::Loopback; use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface}; -use smoltcp::socket::{AsSocket, SocketSet}; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer}; +use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; #[cfg(not(feature = "std"))] mod mock { @@ -124,7 +123,7 @@ fn main() { let mut done = false; while !done && clock.elapsed() < 10_000 { { - let socket: &mut TcpSocket = socket_set.get_mut(server_handle).as_socket(); + let mut socket = socket_set.get::(server_handle); if !socket.is_active() && !socket.is_listening() { if !did_listen { debug!("listening"); @@ -141,7 +140,7 @@ fn main() { } { - let socket: &mut TcpSocket = socket_set.get_mut(client_handle).as_socket(); + let mut socket = socket_set.get::(client_handle); if !socket.is_open() { if !did_connect { debug!("connecting"); diff --git a/examples/ping.rs b/examples/ping.rs index 2f26fe89..c0266cee 100644 --- a/examples/ping.rs +++ b/examples/ping.rs @@ -16,8 +16,7 @@ use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress, IpCidr, Ipv4Address, Ipv4Packet, Ipv4Repr, Icmpv4Repr, Icmpv4Packet}; use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface}; -use smoltcp::socket::{AsSocket, SocketSet}; -use smoltcp::socket::{RawSocket, RawSocketBuffer, RawPacketBuffer}; +use smoltcp::socket::{SocketSet, RawSocket, RawSocketBuffer, RawPacketBuffer}; use std::collections::HashMap; use byteorder::{ByteOrder, NetworkEndian}; @@ -75,7 +74,7 @@ fn main() { loop { { - let socket: &mut RawSocket = sockets.get_mut(raw_handle).as_socket(); + let mut socket = sockets.get::(raw_handle); let timestamp = Instant::now().duration_since(startup_time); let timestamp_us = (timestamp.as_secs() * 1000000) + diff --git a/examples/server.rs b/examples/server.rs index 26d5acdf..12bf7312 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -13,7 +13,7 @@ use std::os::unix::io::AsRawFd; use smoltcp::phy::wait as phy_wait; use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; use smoltcp::iface::{ArpCache, SliceArpCache, EthernetInterface}; -use smoltcp::socket::{AsSocket, SocketSet}; +use smoltcp::socket::SocketSet; use smoltcp::socket::{UdpSocket, UdpSocketBuffer, UdpPacketBuffer}; use smoltcp::socket::{TcpSocket, TcpSocketBuffer}; @@ -70,7 +70,7 @@ fn main() { loop { // udp:6969: respond "hello" { - let socket: &mut UdpSocket = sockets.get_mut(udp_handle).as_socket(); + let mut socket = sockets.get::(udp_handle); if !socket.is_open() { socket.bind(6969).unwrap() } @@ -93,7 +93,7 @@ fn main() { // tcp:6969: respond "hello" { - let socket: &mut TcpSocket = sockets.get_mut(tcp1_handle).as_socket(); + let mut socket = sockets.get::(tcp1_handle); if !socket.is_open() { socket.listen(6969).unwrap(); } @@ -108,7 +108,7 @@ fn main() { // tcp:6970: echo with reverse { - let socket: &mut TcpSocket = sockets.get_mut(tcp2_handle).as_socket(); + let mut socket = sockets.get::(tcp2_handle); if !socket.is_open() { socket.listen(6970).unwrap() } @@ -145,7 +145,7 @@ fn main() { // tcp:6971: sinkhole { - let socket: &mut TcpSocket = sockets.get_mut(tcp3_handle).as_socket(); + let mut socket = sockets.get::(tcp3_handle); if !socket.is_open() { socket.listen(6971).unwrap(); socket.set_keep_alive(Some(1000)); @@ -165,7 +165,7 @@ fn main() { // tcp:6972: fountain { - let socket: &mut TcpSocket = sockets.get_mut(tcp4_handle).as_socket(); + let mut socket = sockets.get::(tcp4_handle); if !socket.is_open() { socket.listen(6972).unwrap() } diff --git a/src/iface/ethernet.rs b/src/iface/ethernet.rs index fb22fe88..97681acd 100644 --- a/src/iface/ethernet.rs +++ b/src/iface/ethernet.rs @@ -13,7 +13,7 @@ use wire::{Ipv4Packet, Ipv4Repr}; use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable}; #[cfg(feature = "socket-udp")] use wire::{UdpPacket, UdpRepr}; #[cfg(feature = "socket-tcp")] use wire::{TcpPacket, TcpRepr, TcpControl}; -use socket::{Socket, SocketSet, AsSocket}; +use socket::{Socket, SocketSet, AnySocket}; #[cfg(feature = "socket-raw")] use socket::RawSocket; #[cfg(feature = "socket-udp")] use socket::UdpSocket; #[cfg(feature = "socket-tcp")] use socket::TcpSocket; @@ -195,29 +195,29 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { let mut caps = self.device.capabilities(); caps.max_transmission_unit -= EthernetFrame::<&[u8]>::header_len(); - for socket in sockets.iter_mut() { + for mut socket in sockets.iter_mut() { let mut device_result = Ok(()); let socket_result = - match socket { + match *socket { #[cfg(feature = "socket-raw")] - &mut Socket::Raw(ref mut socket) => + Socket::Raw(ref mut socket) => socket.dispatch(|response| { device_result = self.dispatch(timestamp, Packet::Raw(response)); device_result }, &caps.checksum), #[cfg(feature = "socket-udp")] - &mut Socket::Udp(ref mut socket) => + Socket::Udp(ref mut socket) => socket.dispatch(|response| { device_result = self.dispatch(timestamp, Packet::Udp(response)); device_result }), #[cfg(feature = "socket-tcp")] - &mut Socket::Tcp(ref mut socket) => + Socket::Tcp(ref mut socket) => socket.dispatch(timestamp, &caps, |response| { device_result = self.dispatch(timestamp, Packet::Tcp(response)); device_result }), - &mut Socket::__Nonexhaustive(_) => unreachable!() + Socket::__Nonexhaustive(_) => unreachable!() }; match (device_result, socket_result) { (Err(Error::Unaddressable), _) => break, // no one to transmit to @@ -323,8 +323,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { // Pass every IP packet to all raw sockets we have registered. #[cfg(feature = "socket-raw")] - for raw_socket in sockets.iter_mut().filter_map( - >::try_as_socket) { + for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) { if !raw_socket.accepts(&ip_repr) { continue } match raw_socket.process(&ip_repr, ip_payload, &checksum_caps) { @@ -415,8 +414,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { let checksum_caps = self.device.capabilities().checksum; let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?; - for udp_socket in sockets.iter_mut().filter_map( - >::try_as_socket) { + for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { if !udp_socket.accepts(&ip_repr, &udp_repr) { continue } match udp_socket.process(&ip_repr, &udp_repr) { @@ -458,8 +456,7 @@ impl<'a, 'b, 'c, DeviceT: Device + 'a> Interface<'a, 'b, 'c, DeviceT> { let checksum_caps = self.device.capabilities().checksum; let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &checksum_caps)?; - for tcp_socket in sockets.iter_mut().filter_map( - >::try_as_socket) { + for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) { if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue } match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) { diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 011294a8..7665f3bf 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -17,6 +17,7 @@ use wire::IpRepr; #[cfg(feature = "socket-udp")] mod udp; #[cfg(feature = "socket-tcp")] mod tcp; mod set; +mod ref_; #[cfg(feature = "socket-raw")] pub use self::raw::{PacketBuffer as RawPacketBuffer, @@ -36,19 +37,19 @@ pub use self::tcp::{SocketBuffer as TcpSocketBuffer, pub use self::set::{Set as SocketSet, Item as SocketSetItem, Handle as SocketHandle}; pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut}; +pub use self::ref_::Ref as SocketRef; +pub(crate) use self::ref_::Session as SocketSession; + /// A network socket. /// /// This enumeration abstracts the various types of sockets based on the IP protocol. -/// To downcast a `Socket` value down to a concrete socket, use -/// the [AsSocket](trait.AsSocket.html) trait, and call e.g. `socket.as_socket::>()`. +/// To downcast a `Socket` value to a concrete socket, use the [AnySocket] trait, +/// e.g. to get `UdpSocket`, call `UdpSocket::downcast(socket)`. /// -/// The `process` and `dispatch` functions are fundamentally asymmetric and thus differ in -/// their use of the [trait PacketRepr](trait.PacketRepr.html). When `process` is called, -/// the packet length is already known and no allocation is required; on the other hand, -/// `process` would have to downcast a `&PacketRepr` to e.g. an `&UdpRepr` through `Any`, -/// which is rather inelegant. Conversely, when `dispatch` is called, the packet length is -/// not yet known and the packet storage has to be allocated; but the `&PacketRepr` is sufficient -/// since the lower layers treat the packet as an opaque octet sequence. +/// It is usually more convenient to use [SocketSet::get] instead. +/// +/// [AnySocket]: trait.AnySocket.html +/// [SocketSet::get]: struct.SocketSet.html#method.get #[derive(Debug)] pub enum Socket<'a, 'b: 'a> { #[cfg(feature = "socket-raw")] @@ -90,40 +91,37 @@ impl<'a, 'b> Socket<'a, 'b> { } } -/// A conversion trait for network sockets. -/// -/// This trait is used to concisely downcast [Socket](trait.Socket.html) values to their -/// concrete types. -pub trait AsSocket { - fn as_socket(&mut self) -> &mut T; - fn try_as_socket(&mut self) -> Option<&mut T>; +impl<'a, 'b> SocketSession for Socket<'a, 'b> { + fn finish(&mut self) { + dispatch_socket!(self, |socket [mut]| socket.finish()) + } } -macro_rules! as_socket { - ($socket:ty, $variant:ident) => { - impl<'a, 'b> AsSocket<$socket> for Socket<'a, 'b> { - fn as_socket(&mut self) -> &mut $socket { - match self { - &mut Socket::$variant(ref mut socket) => socket, - _ => panic!(concat!(".as_socket::<", - stringify!($socket), - "> called on wrong socket type")) - } - } +/// A conversion trait for network sockets. +pub trait AnySocket<'a, 'b>: SocketSession + Sized { + fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a, 'b>>) -> + Option>; +} - fn try_as_socket(&mut self) -> Option<&mut $socket> { - match self { - &mut Socket::$variant(ref mut socket) => Some(socket), - _ => None, - } +macro_rules! from_socket { + ($socket:ty, $variant:ident) => { + impl<'a, 'b> AnySocket<'a, 'b> for $socket { + fn downcast<'c>(ref_: SocketRef<'c, Socket<'a, 'b>>) -> + Option> { + SocketRef::map(ref_, |socket| { + match *socket { + Socket::$variant(ref mut socket) => Some(socket), + _ => None, + } + }) } } } } #[cfg(feature = "socket-raw")] -as_socket!(RawSocket<'a, 'b>, Raw); +from_socket!(RawSocket<'a, 'b>, Raw); #[cfg(feature = "socket-udp")] -as_socket!(UdpSocket<'a, 'b>, Udp); +from_socket!(UdpSocket<'a, 'b>, Udp); #[cfg(feature = "socket-tcp")] -as_socket!(TcpSocket<'a>, Tcp); +from_socket!(TcpSocket<'a>, Tcp); diff --git a/src/socket/ref_.rs b/src/socket/ref_.rs new file mode 100644 index 00000000..45b1f073 --- /dev/null +++ b/src/socket/ref_.rs @@ -0,0 +1,73 @@ +use core::ops::{Deref, DerefMut}; + +#[cfg(feature = "socket-raw")] +use socket::RawSocket; +#[cfg(feature = "socket-udp")] +use socket::UdpSocket; +#[cfg(feature = "socket-tcp")] +use socket::TcpSocket; + +/// A trait for tracking a socket usage session. +/// +/// Allows implementation of custom drop logic that runs only if the socket was changed +/// in specific ways. For example, drop logic for UDP would check if the local endpoint +/// has changed, and if yes, notify the socket set. +#[doc(hidden)] +pub trait Session { + fn finish(&mut self) {} +} + +#[cfg(feature = "socket-raw")] +impl<'a, 'b> Session for RawSocket<'a, 'b> {} +#[cfg(feature = "socket-udp")] +impl<'a, 'b> Session for UdpSocket<'a, 'b> {} +#[cfg(feature = "socket-tcp")] +impl<'a> Session for TcpSocket<'a> {} + +/// A smart pointer to a socket. +/// +/// Allows the network stack to efficiently determine if the socket state was changed in any way. +pub struct Ref<'a, T: Session + 'a> { + socket: &'a mut T, + consumed: bool, +} + +impl<'a, T: Session> Ref<'a, T> { + pub(crate) fn new(socket: &'a mut T) -> Self { + Ref { socket, consumed: false } + } +} + +impl<'a, T: Session + 'a> Ref<'a, T> { + pub(crate) fn map(mut ref_: Self, f: F) -> Option> + where U: Session + 'a, F: FnOnce(&'a mut T) -> Option<&'a mut U> { + if let Some(socket) = f(ref_.socket) { + ref_.consumed = true; + Some(Ref::new(socket)) + } else { + None + } + } +} + +impl<'a, T: Session> Deref for Ref<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.socket + } +} + +impl<'a, T: Session> DerefMut for Ref<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.socket + } +} + +impl<'a, T: Session> Drop for Ref<'a, T> { + fn drop(&mut self) { + if !self.consumed { + Session::finish(self.socket); + } + } +} diff --git a/src/socket/set.rs b/src/socket/set.rs index 1b488d80..78ec018c 100644 --- a/src/socket/set.rs +++ b/src/socket/set.rs @@ -1,7 +1,7 @@ use core::{fmt, slice}; use managed::ManagedSlice; -use super::Socket; +use super::{Socket, SocketRef, AnySocket}; #[cfg(feature = "socket-tcp")] use super::TcpState; /// An item of a socket set. @@ -28,7 +28,7 @@ impl fmt::Display for Handle { } } -/// An extensible set of sockets, with stable numeric identifiers. +/// An extensible set of sockets. /// /// The lifetimes `'b` and `'c` are used when storing a `Socket<'b, 'c>`. #[derive(Debug)] @@ -79,26 +79,19 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> { } } - /// Get a socket from the set by its handle. - /// - /// # Panics - /// This function may panic if the handle does not belong to this socket set. - pub fn get(&self, handle: Handle) -> &Socket<'b, 'c> { - &self.sockets[handle.0] - .as_ref() - .expect("handle does not refer to a valid socket") - .socket - } - /// Get a socket from the set by its handle, as mutable. /// /// # Panics - /// This function may panic if the handle does not belong to this socket set. - pub fn get_mut(&mut self, handle: Handle) -> &mut Socket<'b, 'c> { - &mut self.sockets[handle.0] - .as_mut() - .expect("handle does not refer to a valid socket") - .socket + /// This function may panic if the handle does not belong to this socket set + /// or the socket has the wrong type. + pub fn get>(&mut self, handle: Handle) -> SocketRef { + match self.sockets[handle.0].as_mut() { + Some(item) => { + T::downcast(SocketRef::new(&mut item.socket)) + .expect("handle refers to a socket of a wrong type") + } + None => panic!("handle does not refer to a valid socket") + } } /// Remove a socket from the set, without changing its state. @@ -175,7 +168,7 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> { Iter { lower: self.sockets.iter() } } - /// Iterate every socket in this set, as mutable. + /// Iterate every socket in this set, as SocketRef. pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'b, 'c> { IterMut { lower: self.sockets.iter_mut() } } @@ -207,16 +200,16 @@ impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for Iter<'a, 'b, 'c> { /// This struct is created by the [iter_mut](struct.SocketSet.html#method.iter_mut) /// on [socket sets](struct.SocketSet.html). pub struct IterMut<'a, 'b: 'a, 'c: 'a + 'b> { - lower: slice::IterMut<'a, Option>> + lower: slice::IterMut<'a, Option>>, } impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for IterMut<'a, 'b, 'c> { - type Item = &'a mut Socket<'b, 'c>; + type Item = SocketRef<'a, Socket<'b, 'c>>; fn next(&mut self) -> Option { while let Some(item_opt) = self.lower.next() { if let Some(item) = item_opt.as_mut() { - return Some(&mut item.socket) + return Some(SocketRef::new(&mut item.socket)) } } None