mirror of
				https://github.com/smoltcp-rs/smoltcp.git
				synced 2025-11-04 07:12:46 +00:00 
			
		
		
		
	Rework and test UDP sockets.
Before, errors such as packets not fitting into a buffer would have resulted in panics, and errors such as unbound sockets were simply ignored.
This commit is contained in:
		
							parent
							
								
									5bf64586cd
								
							
						
					
					
						commit
						492fe3e4b1
					
				@ -64,7 +64,7 @@ fn main() {
 | 
			
		||||
        {
 | 
			
		||||
            let socket: &mut UdpSocket = sockets.get_mut(udp_handle).as_socket();
 | 
			
		||||
            if !socket.endpoint().is_specified() {
 | 
			
		||||
                socket.bind(6969)
 | 
			
		||||
                socket.bind(6969).unwrap()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let client = match socket.recv() {
 | 
			
		||||
 | 
			
		||||
@ -353,9 +353,9 @@ impl<'a> TcpSocket<'a> {
 | 
			
		||||
    pub fn listen<T>(&mut self, local_endpoint: T) -> Result<()>
 | 
			
		||||
            where T: Into<IpEndpoint> {
 | 
			
		||||
        let local_endpoint = local_endpoint.into();
 | 
			
		||||
        if local_endpoint.port == 0 { return Err(Error::Unaddressable) }
 | 
			
		||||
 | 
			
		||||
        if self.is_open() { return Err(Error::Illegal) }
 | 
			
		||||
        if local_endpoint.port == 0 { return Err(Error::Unaddressable) }
 | 
			
		||||
 | 
			
		||||
        self.reset();
 | 
			
		||||
        self.listen_address  = local_endpoint.addr;
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
use core::cmp::min;
 | 
			
		||||
use managed::Managed;
 | 
			
		||||
 | 
			
		||||
use {Error, Result};
 | 
			
		||||
@ -15,13 +16,6 @@ pub struct PacketBuffer<'a> {
 | 
			
		||||
    payload:  Managed<'a, [u8]>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> Resettable for PacketBuffer<'a> {
 | 
			
		||||
    fn reset(&mut self) {
 | 
			
		||||
        self.endpoint = Default::default();
 | 
			
		||||
        self.size = 0;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> PacketBuffer<'a> {
 | 
			
		||||
    /// Create a buffered packet.
 | 
			
		||||
    pub fn new<T>(payload: T) -> PacketBuffer<'a>
 | 
			
		||||
@ -40,6 +34,22 @@ impl<'a> PacketBuffer<'a> {
 | 
			
		||||
    fn as_mut<'b>(&'b mut self) -> &'b mut [u8] {
 | 
			
		||||
        &mut self.payload[..self.size]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn resize<'b>(&'b mut self, size: usize) -> Result<&'b mut Self> {
 | 
			
		||||
        if self.payload.len() >= size {
 | 
			
		||||
            self.size = size;
 | 
			
		||||
            Ok(self)
 | 
			
		||||
        } else {
 | 
			
		||||
            Err(Error::Truncated)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> Resettable for PacketBuffer<'a> {
 | 
			
		||||
    fn reset(&mut self) {
 | 
			
		||||
        self.endpoint = Default::default();
 | 
			
		||||
        self.size = 0;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// An UDP packet ring buffer.
 | 
			
		||||
@ -90,8 +100,17 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Bind the socket to the given endpoint.
 | 
			
		||||
    pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) {
 | 
			
		||||
        self.endpoint = endpoint.into()
 | 
			
		||||
    ///
 | 
			
		||||
    /// Returns `Err(Error::Illegal)` if the socket is already bound,
 | 
			
		||||
    /// and `Err(Error::Unaddressable)` if the port is unspecified.
 | 
			
		||||
    pub fn bind<T: Into<IpEndpoint>>(&mut self, endpoint: T) -> Result<()> {
 | 
			
		||||
        let endpoint = endpoint.into();
 | 
			
		||||
        if endpoint.port == 0 { return Err(Error::Unaddressable) }
 | 
			
		||||
 | 
			
		||||
        if self.endpoint.port != 0 { return Err(Error::Illegal) }
 | 
			
		||||
 | 
			
		||||
        self.endpoint = endpoint;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Check whether the transmit buffer is full.
 | 
			
		||||
@ -109,15 +128,18 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
    /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer
 | 
			
		||||
    /// to its payload.
 | 
			
		||||
    ///
 | 
			
		||||
    /// This function returns `Err(Error::Exhausted)` if the size is greater than
 | 
			
		||||
    /// the transmit packet buffer size.
 | 
			
		||||
    /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full,
 | 
			
		||||
    /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer
 | 
			
		||||
    /// size, and `Err(Error::Unaddressable)` if local or remote port, or remote address,
 | 
			
		||||
    /// are unspecified.
 | 
			
		||||
    pub fn send(&mut self, size: usize, endpoint: IpEndpoint) -> Result<&mut [u8]> {
 | 
			
		||||
        let packet_buf = self.tx_buffer.enqueue()?;
 | 
			
		||||
        if self.endpoint.port == 0 { return Err(Error::Unaddressable) }
 | 
			
		||||
        if !endpoint.is_specified() { return Err(Error::Unaddressable) }
 | 
			
		||||
 | 
			
		||||
        let packet_buf = self.tx_buffer.try_enqueue(|buf| buf.resize(size))?;
 | 
			
		||||
        packet_buf.endpoint = endpoint;
 | 
			
		||||
        packet_buf.size = size;
 | 
			
		||||
        net_trace!("[{}]{}:{}: buffer to send {} octets",
 | 
			
		||||
                   self.debug_id, self.endpoint,
 | 
			
		||||
                   packet_buf.endpoint, packet_buf.size);
 | 
			
		||||
                   self.debug_id, self.endpoint, packet_buf.endpoint, size);
 | 
			
		||||
        Ok(&mut packet_buf.as_mut()[..size])
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -125,9 +147,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
    ///
 | 
			
		||||
    /// See also [send](#method.send).
 | 
			
		||||
    pub fn send_slice(&mut self, data: &[u8], endpoint: IpEndpoint) -> Result<usize> {
 | 
			
		||||
        let buffer = self.send(data.len(), endpoint)?;
 | 
			
		||||
        let data = &data[..buffer.len()];
 | 
			
		||||
        buffer.copy_from_slice(data);
 | 
			
		||||
        self.send(data.len(), endpoint)?.copy_from_slice(data);
 | 
			
		||||
        Ok(data.len())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -140,7 +160,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
        net_trace!("[{}]{}:{}: receive {} buffered octets",
 | 
			
		||||
                   self.debug_id, self.endpoint,
 | 
			
		||||
                   packet_buf.endpoint, packet_buf.size);
 | 
			
		||||
        Ok((&packet_buf.as_ref()[..packet_buf.size], packet_buf.endpoint))
 | 
			
		||||
        Ok((&packet_buf.as_ref(), packet_buf.endpoint))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Dequeue a packet received from a remote endpoint, and return the endpoint as well
 | 
			
		||||
@ -149,8 +169,9 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
    /// See also [recv](#method.recv).
 | 
			
		||||
    pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> {
 | 
			
		||||
        let (buffer, endpoint) = self.recv()?;
 | 
			
		||||
        data[..buffer.len()].copy_from_slice(buffer);
 | 
			
		||||
        Ok((buffer.len(), endpoint))
 | 
			
		||||
        let length = min(data.len(), buffer.len());
 | 
			
		||||
        data[..length].copy_from_slice(&buffer[..length]);
 | 
			
		||||
        Ok((length, endpoint))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(crate) fn process(&mut self, _timestamp: u64, ip_repr: &IpRepr,
 | 
			
		||||
@ -160,15 +181,12 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
        let packet = UdpPacket::new_checked(&payload[..ip_repr.payload_len()])?;
 | 
			
		||||
        let repr = UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr())?;
 | 
			
		||||
 | 
			
		||||
        if repr.dst_port != self.endpoint.port { return Err(Error::Rejected) }
 | 
			
		||||
        if !self.endpoint.addr.is_unspecified() {
 | 
			
		||||
            if self.endpoint.addr != ip_repr.dst_addr() { return Err(Error::Rejected) }
 | 
			
		||||
        }
 | 
			
		||||
        let endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
 | 
			
		||||
        if !self.endpoint.accepts(&endpoint) { return Err(Error::Rejected) }
 | 
			
		||||
 | 
			
		||||
        let packet_buf = self.rx_buffer.enqueue()?;
 | 
			
		||||
        packet_buf.endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port };
 | 
			
		||||
        packet_buf.size = repr.payload.len();
 | 
			
		||||
        packet_buf.as_mut()[..repr.payload.len()].copy_from_slice(repr.payload);
 | 
			
		||||
        let packet_buf = self.rx_buffer.try_enqueue(|buf| buf.resize(repr.payload.len()))?;
 | 
			
		||||
        packet_buf.as_mut().copy_from_slice(repr.payload);
 | 
			
		||||
        packet_buf.endpoint = endpoint;
 | 
			
		||||
        net_trace!("[{}]{}:{}: receiving {} octets",
 | 
			
		||||
                   self.debug_id, self.endpoint,
 | 
			
		||||
                   packet_buf.endpoint, packet_buf.size);
 | 
			
		||||
@ -182,6 +200,7 @@ impl<'a, 'b> UdpSocket<'a, 'b> {
 | 
			
		||||
        net_trace!("[{}]{}:{}: sending {} octets",
 | 
			
		||||
                   self.debug_id, self.endpoint,
 | 
			
		||||
                   packet_buf.endpoint, packet_buf.size);
 | 
			
		||||
 | 
			
		||||
        let repr = UdpRepr {
 | 
			
		||||
            src_port: self.endpoint.port,
 | 
			
		||||
            dst_port: packet_buf.endpoint.port,
 | 
			
		||||
@ -207,3 +226,180 @@ impl<'a> IpPayload for UdpRepr<'a> {
 | 
			
		||||
        self.emit(&mut packet, &repr.src_addr(), &repr.dst_addr())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod test {
 | 
			
		||||
    use std::vec::Vec;
 | 
			
		||||
    use wire::{IpAddress, Ipv4Address, IpRepr, Ipv4Repr, UdpRepr};
 | 
			
		||||
    use socket::AsSocket;
 | 
			
		||||
    use super::*;
 | 
			
		||||
 | 
			
		||||
    fn buffer(packets: usize) -> SocketBuffer<'static, 'static> {
 | 
			
		||||
        let mut storage = vec![];
 | 
			
		||||
        for _ in 0..packets {
 | 
			
		||||
            storage.push(PacketBuffer::new(vec![0; 16]))
 | 
			
		||||
        }
 | 
			
		||||
        SocketBuffer::new(storage)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn socket(rx_buffer: SocketBuffer<'static, 'static>,
 | 
			
		||||
              tx_buffer: SocketBuffer<'static, 'static>)
 | 
			
		||||
            -> UdpSocket<'static, 'static> {
 | 
			
		||||
        match UdpSocket::new(rx_buffer, tx_buffer) {
 | 
			
		||||
            Socket::Udp(socket) => socket,
 | 
			
		||||
            _ => unreachable!()
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const LOCAL_IP:    IpAddress  = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 1]));
 | 
			
		||||
    const REMOTE_IP:   IpAddress  = IpAddress::Ipv4(Ipv4Address([10, 0, 0, 2]));
 | 
			
		||||
    const LOCAL_PORT:  u16        = 53;
 | 
			
		||||
    const REMOTE_PORT: u16        = 49500;
 | 
			
		||||
    const LOCAL_END:   IpEndpoint = IpEndpoint { addr: LOCAL_IP,  port: LOCAL_PORT  };
 | 
			
		||||
    const REMOTE_END:  IpEndpoint = IpEndpoint { addr: REMOTE_IP, port: REMOTE_PORT };
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_bind_unaddressable() {
 | 
			
		||||
        let mut socket = socket(buffer(0), buffer(0));
 | 
			
		||||
        assert_eq!(socket.bind(0), Err(Error::Unaddressable));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_bind_twice() {
 | 
			
		||||
        let mut socket = socket(buffer(0), buffer(0));
 | 
			
		||||
        assert_eq!(socket.bind(1), Ok(()));
 | 
			
		||||
        assert_eq!(socket.bind(2), Err(Error::Illegal));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const LOCAL_IP_REPR: IpRepr = IpRepr::Unspecified {
 | 
			
		||||
        src_addr: LOCAL_IP,
 | 
			
		||||
        dst_addr: REMOTE_IP,
 | 
			
		||||
        protocol: IpProtocol::Udp,
 | 
			
		||||
        payload_len: 8 + 6
 | 
			
		||||
    };
 | 
			
		||||
    const LOCAL_UDP_REPR: UdpRepr = UdpRepr {
 | 
			
		||||
        src_port: LOCAL_PORT,
 | 
			
		||||
        dst_port: REMOTE_PORT,
 | 
			
		||||
        payload: b"abcdef"
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_send_unaddressable() {
 | 
			
		||||
        let mut socket = socket(buffer(0), buffer(1));
 | 
			
		||||
        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Err(Error::Unaddressable));
 | 
			
		||||
        socket.bind(LOCAL_PORT);
 | 
			
		||||
        assert_eq!(socket.send_slice(b"abcdef",
 | 
			
		||||
                                     IpEndpoint { addr: IpAddress::Unspecified, ..REMOTE_END }),
 | 
			
		||||
                   Err(Error::Unaddressable));
 | 
			
		||||
        assert_eq!(socket.send_slice(b"abcdef",
 | 
			
		||||
                                     IpEndpoint { port: 0, ..REMOTE_END }),
 | 
			
		||||
                   Err(Error::Unaddressable));
 | 
			
		||||
        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(6));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_send_truncated() {
 | 
			
		||||
        let mut socket = socket(buffer(0), buffer(1));
 | 
			
		||||
        socket.bind(LOCAL_END);
 | 
			
		||||
        assert_eq!(socket.send_slice(&[0; 32][..], REMOTE_END), Err(Error::Truncated));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_send_dispatch() {
 | 
			
		||||
        let limits = DeviceLimits::default();
 | 
			
		||||
 | 
			
		||||
        let mut socket = socket(buffer(0), buffer(1));
 | 
			
		||||
        socket.bind(LOCAL_END);
 | 
			
		||||
 | 
			
		||||
        assert!(socket.can_send());
 | 
			
		||||
        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
 | 
			
		||||
            unreachable!()
 | 
			
		||||
        }), Err(Error::Exhausted) as Result<()>);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(6));
 | 
			
		||||
        assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted));
 | 
			
		||||
        assert!(!socket.can_send());
 | 
			
		||||
 | 
			
		||||
        macro_rules! assert_payload_eq {
 | 
			
		||||
            ($ip_repr:expr, $ip_payload:expr, $expected:expr) => {{
 | 
			
		||||
                let mut buffer = vec![0; $ip_payload.buffer_len()];
 | 
			
		||||
                $ip_payload.emit($ip_repr, &mut buffer);
 | 
			
		||||
                let udp_packet = UdpPacket::new_checked(&buffer).unwrap();
 | 
			
		||||
                let udp_repr = UdpRepr::parse(&udp_packet, &LOCAL_IP, &REMOTE_IP).unwrap();
 | 
			
		||||
                assert_eq!(&udp_repr, $expected)
 | 
			
		||||
            }}
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
 | 
			
		||||
            assert_eq!(ip_repr, &LOCAL_IP_REPR);
 | 
			
		||||
            assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
 | 
			
		||||
            Err(Error::Unaddressable)
 | 
			
		||||
        }), Err(Error::Unaddressable) as Result<()>);
 | 
			
		||||
        /*assert!(!socket.can_send());*/
 | 
			
		||||
 | 
			
		||||
        assert_eq!(socket.dispatch(0, &limits, &mut |ip_repr, ip_payload| {
 | 
			
		||||
            assert_eq!(ip_repr, &LOCAL_IP_REPR);
 | 
			
		||||
            assert_payload_eq!(ip_repr, ip_payload, &LOCAL_UDP_REPR);
 | 
			
		||||
            Ok(())
 | 
			
		||||
        }), /*Ok(())*/ Err(Error::Exhausted));
 | 
			
		||||
        assert!(socket.can_send());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const REMOTE_IP_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr {
 | 
			
		||||
        src_addr: Ipv4Address([10, 0, 0, 2]),
 | 
			
		||||
        dst_addr: Ipv4Address([10, 0, 0, 1]),
 | 
			
		||||
        protocol: IpProtocol::Udp,
 | 
			
		||||
        payload_len: 8 + 6
 | 
			
		||||
    });
 | 
			
		||||
    const REMOTE_UDP_REPR: UdpRepr = UdpRepr {
 | 
			
		||||
        src_port: REMOTE_PORT,
 | 
			
		||||
        dst_port: LOCAL_PORT,
 | 
			
		||||
        payload: b"abcdef"
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_recv_process() {
 | 
			
		||||
        let mut socket = socket(buffer(1), buffer(0));
 | 
			
		||||
        socket.bind(LOCAL_PORT);
 | 
			
		||||
        assert!(!socket.can_recv());
 | 
			
		||||
 | 
			
		||||
        let mut buffer = vec![0; REMOTE_UDP_REPR.buffer_len()];
 | 
			
		||||
        REMOTE_UDP_REPR.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
 | 
			
		||||
 | 
			
		||||
        assert_eq!(socket.recv(), Err(Error::Exhausted));
 | 
			
		||||
        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
 | 
			
		||||
                   Ok(()));
 | 
			
		||||
        assert!(socket.can_recv());
 | 
			
		||||
 | 
			
		||||
        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
 | 
			
		||||
                   Err(Error::Exhausted));
 | 
			
		||||
        assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END)));
 | 
			
		||||
        assert!(!socket.can_recv());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_recv_truncated_slice() {
 | 
			
		||||
        let mut socket = socket(buffer(1), buffer(0));
 | 
			
		||||
        socket.bind(LOCAL_PORT);
 | 
			
		||||
 | 
			
		||||
        let mut buffer = vec![0; REMOTE_UDP_REPR.buffer_len()];
 | 
			
		||||
        REMOTE_UDP_REPR.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
 | 
			
		||||
        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer), Ok(()));
 | 
			
		||||
 | 
			
		||||
        let mut slice = [0; 4];
 | 
			
		||||
        assert_eq!(socket.recv_slice(&mut slice[..]), Ok((4, REMOTE_END)));
 | 
			
		||||
        assert_eq!(&slice, b"abcd");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_recv_truncated_packet() {
 | 
			
		||||
        let mut socket = socket(buffer(1), buffer(0));
 | 
			
		||||
        socket.bind(LOCAL_PORT);
 | 
			
		||||
 | 
			
		||||
        let udp_repr = UdpRepr { payload: &[0; 100][..], ..REMOTE_UDP_REPR };
 | 
			
		||||
        let mut buffer = vec![0; udp_repr.buffer_len()];
 | 
			
		||||
        udp_repr.emit(&mut UdpPacket::new(&mut buffer), &LOCAL_IP, &REMOTE_IP);
 | 
			
		||||
        assert_eq!(socket.process(0, &REMOTE_IP_REPR, &buffer),
 | 
			
		||||
                   Err(Error::Truncated));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -50,29 +50,56 @@ impl<'a, T: 'a> RingBuffer<'a, T> {
 | 
			
		||||
 | 
			
		||||
    /// Enqueue an element into the buffer, and return a pointer to it, or return
 | 
			
		||||
    /// `Err(Error::Exhausted)` if the buffer is full.
 | 
			
		||||
    pub fn enqueue(&mut self) -> Result<&mut T> {
 | 
			
		||||
        if self.full() {
 | 
			
		||||
            Err(Error::Exhausted)
 | 
			
		||||
        } else {
 | 
			
		||||
    pub fn enqueue<'b>(&'b mut self) -> Result<&'b mut T> {
 | 
			
		||||
        if self.full() { return Err(Error::Exhausted) }
 | 
			
		||||
 | 
			
		||||
        let index = self.mask(self.read_at + self.length);
 | 
			
		||||
            let result = &mut self.storage[index];
 | 
			
		||||
        self.length += 1;
 | 
			
		||||
        Ok(&mut self.storage[index])
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Call `f` with a buffer element, and enqueue the element if `f` returns successfully, or
 | 
			
		||||
    /// return `Err(Error::Exhausted)` if the buffer is full.
 | 
			
		||||
    pub fn try_enqueue<'b, R, F>(&'b mut self, f: F) -> Result<R>
 | 
			
		||||
            where F: Fn(&'b mut T) -> Result<R> {
 | 
			
		||||
        if self.full() { return Err(Error::Exhausted) }
 | 
			
		||||
 | 
			
		||||
        let index = self.mask(self.read_at + self.length);
 | 
			
		||||
        match f(&mut self.storage[index]) {
 | 
			
		||||
            Ok(result) => {
 | 
			
		||||
                self.length += 1;
 | 
			
		||||
                Ok(result)
 | 
			
		||||
            }
 | 
			
		||||
            Err(error) => Err(error)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Dequeue an element from the buffer, and return a mutable reference to it, or return
 | 
			
		||||
    /// `Err(Error::Exhausted)` if the buffer is empty.
 | 
			
		||||
    pub fn dequeue(&mut self) -> Result<&mut T> {
 | 
			
		||||
        if self.empty() {
 | 
			
		||||
            Err(Error::Exhausted)
 | 
			
		||||
        } else {
 | 
			
		||||
            self.length -= 1;
 | 
			
		||||
        if self.empty() { return Err(Error::Exhausted) }
 | 
			
		||||
 | 
			
		||||
        let read_at = self.read_at;
 | 
			
		||||
        self.length -= 1;
 | 
			
		||||
        self.read_at = self.incr(self.read_at);
 | 
			
		||||
            let result = &mut self.storage[read_at];
 | 
			
		||||
        Ok(&mut self.storage[read_at])
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Call `f` with a buffer element, and dequeue the element if `f` returns successfully, or
 | 
			
		||||
    /// return `Err(Error::Exhausted)` if the buffer is empty.
 | 
			
		||||
    pub fn try_dequeue<'b, R, F>(&'b mut self, f: F) -> Result<R>
 | 
			
		||||
            where F: Fn(&'b mut T) -> Result<R> {
 | 
			
		||||
        if self.empty() { return Err(Error::Exhausted) }
 | 
			
		||||
 | 
			
		||||
        let next_at = self.incr(self.read_at);
 | 
			
		||||
        match f(&mut self.storage[self.read_at]) {
 | 
			
		||||
            Ok(result) => {
 | 
			
		||||
                self.length -= 1;
 | 
			
		||||
                self.read_at = next_at;
 | 
			
		||||
                Ok(result)
 | 
			
		||||
            }
 | 
			
		||||
            Err(error) => Err(error)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -86,33 +113,69 @@ mod test {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    pub fn test_buffer() {
 | 
			
		||||
        const TEST_BUFFER_SIZE: usize = 5;
 | 
			
		||||
    const SIZE: usize = 5;
 | 
			
		||||
 | 
			
		||||
    fn buffer() -> RingBuffer<'static, usize> {
 | 
			
		||||
        let mut storage = vec![];
 | 
			
		||||
        for i in 0..TEST_BUFFER_SIZE {
 | 
			
		||||
        for i in 0..SIZE {
 | 
			
		||||
            storage.push(i + 10);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let mut ring_buffer = RingBuffer::new(&mut storage[..]);
 | 
			
		||||
        assert!(ring_buffer.empty());
 | 
			
		||||
        assert!(!ring_buffer.full());
 | 
			
		||||
        assert_eq!(ring_buffer.dequeue(), Err(Error::Exhausted));
 | 
			
		||||
        ring_buffer.enqueue().unwrap();
 | 
			
		||||
        assert!(!ring_buffer.empty());
 | 
			
		||||
        assert!(!ring_buffer.full());
 | 
			
		||||
        for i in 1..TEST_BUFFER_SIZE {
 | 
			
		||||
            *ring_buffer.enqueue().unwrap() = i;
 | 
			
		||||
            assert!(!ring_buffer.empty());
 | 
			
		||||
        RingBuffer::new(storage)
 | 
			
		||||
    }
 | 
			
		||||
        assert!(ring_buffer.full());
 | 
			
		||||
        assert_eq!(ring_buffer.enqueue(), Err(Error::Exhausted));
 | 
			
		||||
 | 
			
		||||
        for i in 0..TEST_BUFFER_SIZE {
 | 
			
		||||
            assert_eq!(*ring_buffer.dequeue().unwrap(), i);
 | 
			
		||||
            assert!(!ring_buffer.full());
 | 
			
		||||
    #[test]
 | 
			
		||||
    pub fn test_buffer() {
 | 
			
		||||
        let mut buf = buffer();
 | 
			
		||||
        assert!(buf.empty());
 | 
			
		||||
        assert!(!buf.full());
 | 
			
		||||
        assert_eq!(buf.dequeue(), Err(Error::Exhausted));
 | 
			
		||||
 | 
			
		||||
        buf.enqueue().unwrap();
 | 
			
		||||
        assert!(!buf.empty());
 | 
			
		||||
        assert!(!buf.full());
 | 
			
		||||
 | 
			
		||||
        for i in 1..SIZE {
 | 
			
		||||
            *buf.enqueue().unwrap() = i;
 | 
			
		||||
            assert!(!buf.empty());
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(ring_buffer.dequeue(), Err(Error::Exhausted));
 | 
			
		||||
        assert!(ring_buffer.empty());
 | 
			
		||||
        assert!(buf.full());
 | 
			
		||||
        assert_eq!(buf.enqueue(), Err(Error::Exhausted));
 | 
			
		||||
 | 
			
		||||
        for i in 0..SIZE {
 | 
			
		||||
            assert_eq!(*buf.dequeue().unwrap(), i);
 | 
			
		||||
            assert!(!buf.full());
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(buf.dequeue(), Err(Error::Exhausted));
 | 
			
		||||
        assert!(buf.empty());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    pub fn test_buffer_try() {
 | 
			
		||||
        let mut buf = buffer();
 | 
			
		||||
        assert!(buf.empty());
 | 
			
		||||
        assert!(!buf.full());
 | 
			
		||||
        assert_eq!(buf.try_dequeue(|_| unreachable!()) as Result<()>,
 | 
			
		||||
                   Err(Error::Exhausted));
 | 
			
		||||
 | 
			
		||||
        buf.try_enqueue(|e| Ok(e)).unwrap();
 | 
			
		||||
        assert!(!buf.empty());
 | 
			
		||||
        assert!(!buf.full());
 | 
			
		||||
 | 
			
		||||
        for i in 1..SIZE {
 | 
			
		||||
            buf.try_enqueue(|e| Ok(*e = i)).unwrap();
 | 
			
		||||
            assert!(!buf.empty());
 | 
			
		||||
        }
 | 
			
		||||
        assert!(buf.full());
 | 
			
		||||
        assert_eq!(buf.try_enqueue(|_| unreachable!()) as Result<()>,
 | 
			
		||||
                   Err(Error::Exhausted));
 | 
			
		||||
 | 
			
		||||
        for i in 0..SIZE {
 | 
			
		||||
            assert_eq!(buf.try_dequeue(|e| Ok(*e)).unwrap(), i);
 | 
			
		||||
            assert!(!buf.full());
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(buf.try_dequeue(|_| unreachable!()) as Result<()>,
 | 
			
		||||
                   Err(Error::Exhausted));
 | 
			
		||||
        assert!(buf.empty());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -124,6 +124,14 @@ impl Endpoint {
 | 
			
		||||
    pub fn is_specified(&self) -> bool {
 | 
			
		||||
        !self.addr.is_unspecified() && self.port != 0
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Query whether `self` should accept packets from `other`.
 | 
			
		||||
    ///
 | 
			
		||||
    /// Returns `true` if `other` is a specified endpoint and `self` either
 | 
			
		||||
    /// has an unspecified address, or the addresses are equal.
 | 
			
		||||
    pub fn accepts(&self, other: &Endpoint) -> bool {
 | 
			
		||||
        other.is_specified() && (self.addr == other.addr || self.addr.is_unspecified())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl fmt::Display for Endpoint {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user