From 88e775dcf0ad33e1faeed08ffd8482f467417df8 Mon Sep 17 00:00:00 2001 From: Yin Guanhao Date: Tue, 9 Jul 2019 05:47:31 +0800 Subject: [PATCH] udp: UdpSocket split support (#1226) --- tokio-udp/src/lib.rs | 1 + tokio-udp/src/recv.rs | 6 +- tokio-udp/src/recv_from.rs | 6 +- tokio-udp/src/send.rs | 6 +- tokio-udp/src/send_to.rs | 6 +- tokio-udp/src/socket.rs | 54 ++++++++++++++ tokio-udp/src/split.rs | 145 +++++++++++++++++++++++++++++++++++++ tokio-udp/tests/udp.rs | 34 +++++++++ tokio/src/net.rs | 2 +- 9 files changed, 247 insertions(+), 13 deletions(-) create mode 100644 tokio-udp/src/split.rs diff --git a/tokio-udp/src/lib.rs b/tokio-udp/src/lib.rs index 2d33ede14..2fe90a8fa 100644 --- a/tokio-udp/src/lib.rs +++ b/tokio-udp/src/lib.rs @@ -27,6 +27,7 @@ mod recv_from; mod send; mod send_to; mod socket; +pub mod split; // pub use self::frame::UdpFramed; pub use self::recv::Recv; diff --git a/tokio-udp/src/recv.rs b/tokio-udp/src/recv.rs index 7e425823c..1c9075436 100644 --- a/tokio-udp/src/recv.rs +++ b/tokio-udp/src/recv.rs @@ -10,12 +10,12 @@ use std::task::{Context, Poll}; #[must_use = "futures do nothing unless polled"] #[derive(Debug)] pub struct Recv<'a, 'b> { - socket: &'a mut UdpSocket, + socket: &'a UdpSocket, buf: &'b mut [u8], } impl<'a, 'b> Recv<'a, 'b> { - pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b mut [u8]) -> Self { + pub(super) fn new(socket: &'a UdpSocket, buf: &'b mut [u8]) -> Self { Self { socket, buf } } } @@ -25,6 +25,6 @@ impl<'a, 'b> Future for Recv<'a, 'b> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Recv { socket, buf } = self.get_mut(); - Pin::new(&mut **socket).poll_recv(cx, buf) + socket.poll_recv_priv(cx, buf) } } diff --git a/tokio-udp/src/recv_from.rs b/tokio-udp/src/recv_from.rs index 4d42fd102..f4d934265 100644 --- a/tokio-udp/src/recv_from.rs +++ b/tokio-udp/src/recv_from.rs @@ -11,12 +11,12 @@ use std::task::{Context, Poll}; #[must_use = "futures do nothing unless polled"] #[derive(Debug)] pub struct RecvFrom<'a, 'b> { - socket: &'a mut UdpSocket, + socket: &'a UdpSocket, buf: &'b mut [u8], } impl<'a, 'b> RecvFrom<'a, 'b> { - pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b mut [u8]) -> Self { + pub(super) fn new(socket: &'a UdpSocket, buf: &'b mut [u8]) -> Self { Self { socket, buf } } } @@ -26,6 +26,6 @@ impl<'a, 'b> Future for RecvFrom<'a, 'b> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let RecvFrom { socket, buf } = self.get_mut(); - Pin::new(&mut **socket).poll_recv_from(cx, buf) + socket.poll_recv_from_priv(cx, buf) } } diff --git a/tokio-udp/src/send.rs b/tokio-udp/src/send.rs index 77557cd2c..6debe4e3e 100644 --- a/tokio-udp/src/send.rs +++ b/tokio-udp/src/send.rs @@ -10,12 +10,12 @@ use std::task::{Context, Poll}; #[must_use = "futures do nothing unless polled"] #[derive(Debug)] pub struct Send<'a, 'b> { - socket: &'a mut UdpSocket, + socket: &'a UdpSocket, buf: &'b [u8], } impl<'a, 'b> Send<'a, 'b> { - pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b [u8]) -> Self { + pub(super) fn new(socket: &'a UdpSocket, buf: &'b [u8]) -> Self { Self { socket, buf } } } @@ -25,6 +25,6 @@ impl<'a, 'b> Future for Send<'a, 'b> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Send { socket, buf } = self.get_mut(); - Pin::new(&mut **socket).poll_send(cx, buf) + socket.poll_send_priv(cx, buf) } } diff --git a/tokio-udp/src/send_to.rs b/tokio-udp/src/send_to.rs index 96d79f569..fa1ab7637 100644 --- a/tokio-udp/src/send_to.rs +++ b/tokio-udp/src/send_to.rs @@ -11,13 +11,13 @@ use std::task::{Context, Poll}; #[must_use = "futures do nothing unless polled"] #[derive(Debug)] pub struct SendTo<'a, 'b> { - socket: &'a mut UdpSocket, + socket: &'a UdpSocket, buf: &'b [u8], target: &'b SocketAddr, } impl<'a, 'b> SendTo<'a, 'b> { - pub(super) fn new(socket: &'a mut UdpSocket, buf: &'b [u8], target: &'b SocketAddr) -> Self { + pub(super) fn new(socket: &'a UdpSocket, buf: &'b [u8], target: &'b SocketAddr) -> Self { Self { socket, buf, @@ -35,6 +35,6 @@ impl<'a, 'b> Future for SendTo<'a, 'b> { buf, target, } = self.get_mut(); - Pin::new(&mut **socket).poll_send_to(cx, buf, target) + socket.poll_send_to_priv(cx, buf, target) } } diff --git a/tokio-udp/src/socket.rs b/tokio-udp/src/socket.rs index 2a9cc4eff..55648a429 100644 --- a/tokio-udp/src/socket.rs +++ b/tokio-udp/src/socket.rs @@ -1,3 +1,4 @@ +use super::split::{split, UdpSocketRecvHalf, UdpSocketSendHalf}; use super::{Recv, RecvFrom, Send, SendTo}; use mio; use std::convert::TryFrom; @@ -42,6 +43,16 @@ impl UdpSocket { Ok(UdpSocket { io }) } + /// Split the `UdpSocket` into a receive half and a send half. The two parts + /// can be used to receive and send datagrams concurrently, even from two + /// different tasks. + /// + /// See the module level documenation of [`split`](super::split) for more + /// details. + pub fn split(self) -> (UdpSocketRecvHalf, UdpSocketSendHalf) { + split(self) + } + /// Returns the local address that this socket is bound to. pub fn local_addr(&self) -> io::Result { self.io.get_ref().local_addr() @@ -83,6 +94,24 @@ impl UdpSocket { self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], + ) -> Poll> { + self.poll_send_priv(cx, buf) + } + + // Poll IO functions that takes `&self` are provided for the split API. + // + // They are not public because (taken from the doc of `PollEvented`): + // + // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the + // caller must ensure that there are at most two tasks that use a + // `PollEvented` instance concurrently. One for reading and one for writing. + // While violating this requirement is "safe" from a Rust memory model point + // of view, it will result in unexpected behavior in the form of lost + // notifications and tasks hanging. + pub(crate) fn poll_send_priv( + &self, + cx: &mut Context<'_>, + buf: &[u8], ) -> Poll> { ready!(self.io.poll_write_ready(cx))?; @@ -134,6 +163,14 @@ impl UdpSocket { self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], + ) -> Poll> { + self.poll_recv_priv(cx, buf) + } + + pub(crate) fn poll_recv_priv( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], ) -> Poll> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; @@ -173,6 +210,15 @@ impl UdpSocket { cx: &mut Context<'_>, buf: &[u8], target: &SocketAddr, + ) -> Poll> { + self.poll_send_to_priv(cx, buf, target) + } + + pub(crate) fn poll_send_to_priv( + &self, + cx: &mut Context<'_>, + buf: &[u8], + target: &SocketAddr, ) -> Poll> { ready!(self.io.poll_write_ready(cx))?; @@ -201,6 +247,14 @@ impl UdpSocket { self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], + ) -> Poll> { + self.poll_recv_from_priv(cx, buf) + } + + pub(crate) fn poll_recv_from_priv( + &self, + cx: &mut Context<'_>, + buf: &mut [u8], ) -> Poll> { ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; diff --git a/tokio-udp/src/split.rs b/tokio-udp/src/split.rs new file mode 100644 index 000000000..fc2c6001a --- /dev/null +++ b/tokio-udp/src/split.rs @@ -0,0 +1,145 @@ +//! [`UdpSocket`](../struct.UdpSocket.html) split support. +//! +//! The [`split`](../struct.UdpSocket.html#method.split) method splits a +//! `UdpSocket` into a receive half and a send half, which can be used to +//! receive and send datagrams concurrently, even from two different tasks. +//! +//! The halves provide access to the underlying socket, implementing +//! `AsRef`. This allows you to call `UdpSocket` methods that takes +//! `&self`, e.g., to get local address, to get and set socket options, to join +//! or leave multicast groups, etc. +//! +//! The halves can be reunited to the original socket with their `reunite` +//! methods. + +use super::{Recv, RecvFrom, Send, SendTo, UdpSocket}; +use std::error::Error; +use std::fmt; +use std::net::SocketAddr; +use std::sync::Arc; + +/// The send half after [`split`](super::UdpSocket::split). +/// +/// Use [`send_to`](#method.send_to) or [`send`](#method.send) to send +/// datagrams. +#[derive(Debug)] +pub struct UdpSocketSendHalf(Arc); + +/// The recv half after [`split`](super::UdpSocket::split). +/// +/// Use [`recv_from`](#method.recv_from) or [`recv`](#method.recv) to receive +/// datagrams. +#[derive(Debug)] +pub struct UdpSocketRecvHalf(Arc); + +pub(crate) fn split(socket: UdpSocket) -> (UdpSocketRecvHalf, UdpSocketSendHalf) { + let shared = Arc::new(socket); + let send = shared.clone(); + let recv = shared; + (UdpSocketRecvHalf(recv), UdpSocketSendHalf(send)) +} + +/// Error indicating two halves were not from the same socket, and thus could +/// not be `reunite`d. +#[derive(Debug)] +pub struct ReuniteError(pub UdpSocketSendHalf, pub UdpSocketRecvHalf); + +impl fmt::Display for ReuniteError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "tried to reunite halves that are not from the same socket" + ) + } +} + +impl Error for ReuniteError {} + +fn reunite(s: UdpSocketSendHalf, r: UdpSocketRecvHalf) -> Result { + if Arc::ptr_eq(&s.0, &r.0) { + drop(r); + // Only two instances of the `Arc` are ever created, one for the + // receiver and one for the sender, and those `Arc`s are never exposed + // externally. And so when we drop one here, the other one must be the + // only remaining one. + Ok(Arc::try_unwrap(s.0).expect("tokio_udp: try_unwrap failed in reunite")) + } else { + Err(ReuniteError(s, r)) + } +} + +impl UdpSocketRecvHalf { + /// Attempts to put the two "halves" of a `UdpSocket` back together and + /// recover the original socket. Succeeds only if the two "halves" + /// originated from the same call to `UdpSocket::split`. + pub fn reunite(self, other: UdpSocketSendHalf) -> Result { + reunite(other, self) + } + + /// Returns a future that receives a single datagram on the socket. On success, + /// the future resolves to the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient size + /// to hold the message bytes. If a message is too long to fit in the supplied + /// buffer, excess bytes may be discarded. + pub fn recv_from<'a, 'b>(&'a mut self, buf: &'b mut [u8]) -> RecvFrom<'a, 'b> { + RecvFrom::new(&self.0, buf) + } + + /// Returns a future that receives a single datagram message on the socket from + /// the remote address to which it is connected. On success, the future will resolve + /// to the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient size to + /// hold the message bytes. If a message is too long to fit in the supplied buffer, + /// excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will fail if the socket is not connected. + /// + /// [`connect`]: super::UdpSocket::connect + pub fn recv<'a, 'b>(&'a mut self, buf: &'b mut [u8]) -> Recv<'a, 'b> { + Recv::new(&self.0, buf) + } +} + +impl UdpSocketSendHalf { + /// Attempts to put the two "halves" of a `UdpSocket` back together and + /// recover the original socket. Succeeds only if the two "halves" + /// originated from the same call to `UdpSocket::split`. + pub fn reunite(self, other: UdpSocketRecvHalf) -> Result { + reunite(self, other) + } + + /// Returns a future that sends data on the socket to the given address. + /// On success, the future will resolve to the number of bytes written. + /// + /// The future will resolve to an error if the IP version of the socket does + /// not match that of `target`. + pub fn send_to<'a, 'b>(&'a mut self, buf: &'b [u8], target: &'b SocketAddr) -> SendTo<'a, 'b> { + SendTo::new(&self.0, buf, target) + } + + /// Returns a future that sends data on the socket to the remote address to which it is connected. + /// On success, the future will resolve to the number of bytes written. + /// + /// The [`connect`] method will connect this socket to a remote address. The future + /// will resolve to an error if the socket is not connected. + /// + /// [`connect`]: super::UdpSocket::connect + pub fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> Send<'a, 'b> { + Send::new(&self.0, buf) + } +} + +impl AsRef for UdpSocketSendHalf { + fn as_ref(&self) -> &UdpSocket { + &self.0 + } +} + +impl AsRef for UdpSocketRecvHalf { + fn as_ref(&self) -> &UdpSocket { + &self.0 + } +} diff --git a/tokio-udp/tests/udp.rs b/tokio-udp/tests/udp.rs index 1739c488d..bc140daa4 100644 --- a/tokio-udp/tests/udp.rs +++ b/tokio-udp/tests/udp.rs @@ -38,6 +38,40 @@ async fn send_to_recv_from() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn split() -> std::io::Result<()> { + let socket = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?; + let (mut r, mut s) = socket.split(); + + let msg = b"hello"; + let addr = s.as_ref().local_addr()?; + tokio::spawn(async move { + s.send_to(msg, &addr).await.unwrap(); + }); + let mut recv_buf = [0u8; 32]; + let (len, _) = r.recv_from(&mut recv_buf[..]).await?; + assert_eq!(&recv_buf[..len], msg); + Ok(()) +} + +#[tokio::test] +async fn reunite() -> std::io::Result<()> { + let socket = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?; + let (s, r) = socket.split(); + assert!(s.reunite(r).is_ok()); + Ok(()) +} + +#[tokio::test] +async fn reunite_error() -> std::io::Result<()> { + let socket = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?; + let socket1 = UdpSocket::bind(&"127.0.0.1:0".parse().unwrap())?; + let (s, _) = socket.split(); + let (_, r1) = socket1.split(); + assert!(s.reunite(r1).is_err()); + Ok(()) +} + // pub struct ByteCodec; // impl Decoder for ByteCodec { diff --git a/tokio/src/net.rs b/tokio/src/net.rs index 3b37b5371..599f85518 100644 --- a/tokio/src/net.rs +++ b/tokio/src/net.rs @@ -58,7 +58,7 @@ pub mod udp { //! [`Send`]: struct.Send.html //! [`RecvFrom`]: struct.RecvFrom.html //! [`SendTo`]: struct.SendTo.html - pub use tokio_udp::{Recv, RecvFrom, Send, SendTo, UdpSocket}; + pub use tokio_udp::{split, Recv, RecvFrom, Send, SendTo, UdpSocket}; } #[cfg(feature = "udp")] pub use self::udp::UdpSocket;