net: introduce split on UnixDatagram (#2557)

This commit is contained in:
cssivision 2020-07-24 13:03:47 +08:00 committed by GitHub
parent 7a60a0b362
commit ff7125ec7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 189 additions and 0 deletions

View File

@ -2,12 +2,14 @@ use crate::future::poll_fn;
use crate::io::PollEvented;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use std::io;
use std::net::Shutdown;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{self, SocketAddr};
use std::path::Path;
use std::sync::Arc;
use std::task::{Context, Poll};
cfg_uds! {
@ -201,6 +203,12 @@ impl UnixDatagram {
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.io.get_ref().shutdown(how)
}
/// Split a `UnixDatagram` into a receive half and a send half, which can be used
/// to receice and send the datagram concurrently.
pub fn into_split(self) -> (OwnedRecvHalf, OwnedSendHalf) {
split_owned(self)
}
}
impl TryFrom<UnixDatagram> for mio_uds::UnixDatagram {
@ -240,3 +248,133 @@ impl AsRawFd for UnixDatagram {
self.io.get_ref().as_raw_fd()
}
}
fn split_owned(socket: UnixDatagram) -> (OwnedRecvHalf, OwnedSendHalf) {
let shared = Arc::new(socket);
let send = shared.clone();
let recv = shared;
(
OwnedRecvHalf { inner: recv },
OwnedSendHalf {
inner: send,
shutdown_on_drop: true,
},
)
}
/// The send half after [`split`](UnixDatagram::into_split).
///
/// Use [`send_to`](#method.send_to) or [`send`](#method.send) to send
/// datagrams.
#[derive(Debug)]
pub struct OwnedSendHalf {
inner: Arc<UnixDatagram>,
shutdown_on_drop: bool,
}
/// The recv half after [`split`](UnixDatagram::into_split).
///
/// Use [`recv_from`](#method.recv_from) or [`recv`](#method.recv) to receive
/// datagrams.
#[derive(Debug)]
pub struct OwnedRecvHalf {
inner: Arc<UnixDatagram>,
}
/// Error indicating two halves were not from the same socket, and thus could
/// not be `reunite`d.
#[derive(Debug)]
pub struct ReuniteError(pub OwnedSendHalf, pub OwnedRecvHalf);
impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl Error for ReuniteError {}
fn reunite(s: OwnedSendHalf, r: OwnedRecvHalf) -> Result<UnixDatagram, ReuniteError> {
if Arc::ptr_eq(&s.inner, &r.inner) {
s.forget();
// Only two instances of the `Arc` are ever created, one for the
// receiver and one for the sender, and those `Arc`s are never exposed
// externally. And so when we drop one here, the other one must be the
// only remaining one.
Ok(Arc::try_unwrap(r.inner).expect("unixdatagram: try_unwrap failed in reunite"))
} else {
Err(ReuniteError(s, r))
}
}
impl OwnedRecvHalf {
/// Attempts to put the two "halves" of a `UnixDatagram` back together and
/// recover the original socket. Succeeds only if the two "halves"
/// originated from the same call to `UnixDatagram::split`.
pub fn reunite(self, other: OwnedSendHalf) -> Result<UnixDatagram, ReuniteError> {
reunite(other, self)
}
/// Receives data from the socket.
pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
poll_fn(|cx| self.inner.poll_recv_from_priv(cx, buf)).await
}
/// Receives data from the socket.
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| self.inner.poll_recv_priv(cx, buf)).await
}
}
impl OwnedSendHalf {
/// Attempts to put the two "halves" of a `UnixDatagram` back together and
/// recover the original socket. Succeeds only if the two "halves"
/// originated from the same call to `UnixDatagram::split`.
pub fn reunite(self, other: OwnedRecvHalf) -> Result<UnixDatagram, ReuniteError> {
reunite(self, other)
}
/// Sends data on the socket to the specified address.
pub async fn send_to<P>(&mut self, buf: &[u8], target: P) -> io::Result<usize>
where
P: AsRef<Path> + Unpin,
{
poll_fn(|cx| self.inner.poll_send_to_priv(cx, buf, target.as_ref())).await
}
/// Sends data on the socket to the socket's peer.
pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
poll_fn(|cx| self.inner.poll_send_priv(cx, buf)).await
}
/// Destroy the send half, but don't close the stream until the recvice half
/// is dropped. If the read half has already been dropped, this closes the
/// stream.
pub fn forget(mut self) {
self.shutdown_on_drop = false;
drop(self);
}
}
impl Drop for OwnedSendHalf {
fn drop(&mut self) {
if self.shutdown_on_drop {
let _ = self.inner.shutdown(Shutdown::Both);
}
}
}
impl AsRef<UnixDatagram> for OwnedSendHalf {
fn as_ref(&self) -> &UnixDatagram {
&self.inner
}
}
impl AsRef<UnixDatagram> for OwnedRecvHalf {
fn as_ref(&self) -> &UnixDatagram {
&self.inner
}
}

View File

@ -1,6 +1,7 @@
//! Unix domain socket utility types
pub(crate) mod datagram;
pub use datagram::{OwnedRecvHalf, OwnedSendHalf, ReuniteError};
mod incoming;
pub use incoming::Incoming;

View File

@ -3,6 +3,7 @@
#![cfg(unix)]
use tokio::net::UnixDatagram;
use tokio::try_join;
use std::io;
@ -41,3 +42,52 @@ async fn echo() -> io::Result<()> {
Ok(())
}
#[tokio::test]
async fn split() -> std::io::Result<()> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("split.sock");
let socket = UnixDatagram::bind(path.clone())?;
let (mut r, mut s) = socket.into_split();
let msg = b"hello";
let ((), ()) = try_join! {
async {
s.send_to(msg, path).await?;
io::Result::Ok(())
},
async {
let mut recv_buf = [0u8; 32];
let (len, _) = r.recv_from(&mut recv_buf[..]).await?;
assert_eq!(&recv_buf[..len], msg);
Ok(())
},
}?;
Ok(())
}
#[tokio::test]
async fn reunite() -> std::io::Result<()> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("reunite.sock");
let socket = UnixDatagram::bind(path)?;
let (s, r) = socket.into_split();
assert!(s.reunite(r).is_ok());
Ok(())
}
#[tokio::test]
async fn reunite_error() -> std::io::Result<()> {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("reunit.sock");
let dir = tempfile::tempdir().unwrap();
let path1 = dir.path().join("reunit.sock");
let socket = UnixDatagram::bind(path)?;
let socket1 = UnixDatagram::bind(path1)?;
let (s, _) = socket.into_split();
let (_, r1) = socket1.into_split();
assert!(s.reunite(r1).is_err());
Ok(())
}