util: make UdpFramed take Borrow<UdpSocket> (#3451)

This commit is contained in:
Evan Cameron 2021-04-14 14:16:23 -04:00 committed by GitHub
parent 39706b198c
commit 9eeec039f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 19 deletions

View File

@ -6,9 +6,12 @@ use tokio::{io::ReadBuf, net::UdpSocket};
use bytes::{BufMut, BytesMut};
use futures_core::ready;
use futures_sink::Sink;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{
borrow::Borrow,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
};
use std::{io, mem::MaybeUninit};
/// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
@ -34,8 +37,8 @@ use std::{io, mem::MaybeUninit};
#[must_use = "sinks do nothing unless polled"]
#[cfg_attr(docsrs, doc(all(feature = "codec", feature = "udp")))]
#[derive(Debug)]
pub struct UdpFramed<C> {
socket: UdpSocket,
pub struct UdpFramed<C, T = UdpSocket> {
socket: T,
codec: C,
rd: BytesMut,
wr: BytesMut,
@ -48,7 +51,13 @@ pub struct UdpFramed<C> {
const INITIAL_RD_CAPACITY: usize = 64 * 1024;
const INITIAL_WR_CAPACITY: usize = 8 * 1024;
impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
impl<C, T> Unpin for UdpFramed<C, T> {}
impl<C, T> Stream for UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
C: Decoder,
{
type Item = Result<(C::Item, SocketAddr), C::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@ -79,7 +88,7 @@ impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]);
let mut read = ReadBuf::uninit(buf);
let ptr = read.filled().as_ptr();
let res = ready!(Pin::new(&mut pin.socket).poll_recv_from(cx, &mut read));
let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
assert_eq!(ptr, read.filled().as_ptr());
let addr = res?;
@ -93,7 +102,11 @@ impl<C: Decoder + Unpin> Stream for UdpFramed<C> {
}
}
impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
C: Encoder<I>,
{
type Error = C::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -125,13 +138,13 @@ impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
}
let Self {
ref mut socket,
ref socket,
ref mut out_addr,
ref mut wr,
..
} = *self;
let n = ready!(socket.poll_send_to(cx, &wr, *out_addr))?;
let n = ready!(socket.borrow().poll_send_to(cx, &wr, *out_addr))?;
let wrote_all = n == self.wr.len();
self.wr.clear();
@ -156,11 +169,14 @@ impl<I, C: Encoder<I> + Unpin> Sink<(I, SocketAddr)> for UdpFramed<C> {
}
}
impl<C> UdpFramed<C> {
impl<C, T> UdpFramed<C, T>
where
T: Borrow<UdpSocket>,
{
/// Create a new `UdpFramed` backed by the given socket and codec.
///
/// See struct level documentation for more details.
pub fn new(socket: UdpSocket, codec: C) -> UdpFramed<C> {
pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
Self {
socket,
codec,
@ -180,27 +196,21 @@ impl<C> UdpFramed<C> {
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &UdpSocket {
pub fn get_ref(&self) -> &T {
&self.socket
}
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
/// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
///
/// # Note
///
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut UdpSocket {
pub fn get_mut(&mut self) -> &mut T {
&mut self.socket
}
/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> UdpSocket {
self.socket
}
/// Returns a reference to the underlying codec wrapped by
/// `Framed`.
///
@ -228,4 +238,9 @@ impl<C> UdpFramed<C> {
pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
&mut self.rd
}
/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> T {
self.socket
}
}

View File

@ -10,6 +10,7 @@ use futures::future::try_join;
use futures::future::FutureExt;
use futures::sink::SinkExt;
use std::io;
use std::sync::Arc;
#[cfg_attr(any(target_os = "macos", target_os = "ios"), allow(unused_assignments))]
#[tokio::test]
@ -101,3 +102,31 @@ async fn send_framed_lines_codec() -> std::io::Result<()> {
Ok(())
}
#[tokio::test]
async fn framed_half() -> std::io::Result<()> {
let a_soc = Arc::new(UdpSocket::bind("127.0.0.1:0").await?);
let b_soc = a_soc.clone();
let a_addr = a_soc.local_addr()?;
let b_addr = b_soc.local_addr()?;
let mut a = UdpFramed::new(a_soc, ByteCodec);
let mut b = UdpFramed::new(b_soc, LinesCodec::new());
let msg = b"1\r\n2\r\n3\r\n".to_vec();
a.send((&msg, b_addr)).await?;
let msg = b"4\r\n5\r\n6\r\n".to_vec();
a.send((&msg, b_addr)).await?;
assert_eq!(b.next().await.unwrap().unwrap(), ("1".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("2".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("3".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("4".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("5".to_string(), a_addr));
assert_eq!(b.next().await.unwrap().unwrap(), ("6".to_string(), a_addr));
Ok(())
}