From d3cfa6fccdbcc25498bda24102fca28467eddae0 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sun, 10 Jan 2021 11:51:01 -0800 Subject: [PATCH] fix(mock): re-enable mock feature --- sqlx-core/src/lib.rs | 6 +- sqlx-core/src/mock.rs | 167 +++++++++++++++++++++++ sqlx-core/src/runtime.rs | 6 - sqlx-core/src/runtime/mock.rs | 189 --------------------------- sqlx-mysql/src/connection/connect.rs | 2 +- sqlx-mysql/src/mock.rs | 19 ++- 6 files changed, 180 insertions(+), 209 deletions(-) create mode 100644 sqlx-core/src/mock.rs delete mode 100644 sqlx-core/src/runtime/mock.rs diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index bac0b99e..39c951a4 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -30,6 +30,10 @@ mod runtime; #[doc(hidden)] pub mod io; +#[doc(hidden)] +#[cfg(feature = "_mock")] +pub mod mock; + #[cfg(feature = "blocking")] pub mod blocking; @@ -47,8 +51,6 @@ pub use pool::Pool; pub use runtime::Actix; #[cfg(feature = "async-std")] pub use runtime::AsyncStd; -// #[cfg(feature = "_mock")] -// pub use mock::Mock; #[cfg(feature = "tokio")] pub use runtime::Tokio; pub use runtime::{Async, DefaultRuntime, Runtime}; diff --git a/sqlx-core/src/mock.rs b/sqlx-core/src/mock.rs new file mode 100644 index 00000000..b6862292 --- /dev/null +++ b/sqlx-core/src/mock.rs @@ -0,0 +1,167 @@ +use std::collections::HashMap; +use std::io; +#[cfg(feature = "async")] +use std::pin::Pin; +use std::sync::atomic::AtomicU16; +use std::sync::atomic::Ordering; + +use bytes::BytesMut; +use conquer_once::Lazy; +use crossbeam::channel; +use parking_lot::RwLock; + +#[cfg(feature = "blocking")] +use crate::blocking; +use crate::{io::Stream, Runtime}; + +#[derive(Debug)] +pub struct Mock; + +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct MockStream { + port: u16, + rbuf: BytesMut, + read: channel::Receiver>, + write: channel::Sender>, +} + +static MOCK_STREAM_PORT: AtomicU16 = AtomicU16::new(0); + +static MOCK_STREAMS: Lazy>> = Lazy::new(RwLock::default); + +impl Runtime for Mock { + type TcpStream = MockStream; +} + +#[cfg(feature = "async")] +impl crate::Async for Mock { + fn connect_tcp_async( + _host: &str, + port: u16, + ) -> futures_util::future::BoxFuture<'_, io::Result> { + Box::pin(futures_util::future::ready(Self::get_stream(port))) + } +} + +#[cfg(feature = "blocking")] +impl crate::blocking::Runtime for Mock { + fn connect_tcp(_host: &str, port: u16) -> io::Result { + Self::get_stream(port) + } +} + +impl Mock { + #[must_use] + pub fn stream() -> MockStream { + let port = MOCK_STREAM_PORT.fetch_add(1, Ordering::SeqCst) + 1; + + let (write_l, write_r) = channel::unbounded(); + let (read_r, read_l) = channel::unbounded(); + + let stream_l = MockStream { port, read: read_l, write: write_l, rbuf: BytesMut::new() }; + let stream_r = MockStream { port, read: write_r, write: read_r, rbuf: BytesMut::new() }; + + MOCK_STREAMS.write().insert(port, stream_l); + + stream_r + } + + fn get_stream(port: u16) -> io::Result { + match MOCK_STREAMS.write().remove(&port) { + Some(stream) => Ok(stream), + None => Err(io::ErrorKind::ConnectionRefused.into()), + } + } +} + +impl MockStream { + #[must_use] + pub const fn port(&self) -> u16 { + self.port + } +} + +impl<'s> Stream<'s, Mock> for MockStream { + #[cfg(feature = "async")] + type ReadFuture = Pin> + 's + Send>>; + + #[cfg(feature = "async")] + type WriteFuture = Pin> + 's + Send>>; + + #[cfg(feature = "async")] + fn read_async(&'s mut self, mut buf: &'s mut [u8]) -> Self::ReadFuture { + Box::pin(async move { + use io::Write; + + loop { + if !self.rbuf.is_empty() { + // write as much data from our read buffer as we can + let written = buf.write(&self.rbuf)?; + + // remove the bytes that we were able to write + let _ = self.rbuf.split_to(written); + + // return how many bytes we wrote + return Ok(written); + } + + // no bytes in the buffer, ask the channel for more + let message = if let Ok(message) = self.read.try_recv() { + message + } else { + // no data, return pending (and immediately wake again to run try_recv again) + futures_util::pending!(); + continue; + }; + + self.rbuf.extend_from_slice(&message); + // loop around and now send out this message + } + }) + } + + #[cfg(feature = "async")] + fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { + // send it all, right away + let _ = self.write.send(buf.to_vec()); + + // that was easy + Box::pin(futures_util::future::ok(buf.len())) + } +} + +#[cfg(feature = "blocking")] +impl<'s> blocking::io::Stream<'s, Mock> for MockStream { + fn read(&'s mut self, mut buf: &'s mut [u8]) -> io::Result { + use io::Write; + + loop { + if !self.rbuf.is_empty() { + // write as much data from our read buffer as we can + let written = buf.write(&self.rbuf)?; + + // remove the bytes that we were able to write + let _ = self.rbuf.split_to(written); + + // return how many bytes we wrote + return Ok(written); + } + + // no bytes in the buffer, ask the channel for more + #[allow(clippy::map_err_ignore)] + let message = self.read.recv().map_err(|_err| io::ErrorKind::ConnectionAborted)?; + + self.rbuf.extend_from_slice(&message); + // loop around and now send out this message + } + } + + fn write(&'s mut self, buf: &'s [u8]) -> io::Result { + // send it all, right away + let _ = self.write.send(buf.to_vec()); + + // that was easy + Ok(buf.len()) + } +} diff --git a/sqlx-core/src/runtime.rs b/sqlx-core/src/runtime.rs index 9f7f4925..e36cc9c7 100644 --- a/sqlx-core/src/runtime.rs +++ b/sqlx-core/src/runtime.rs @@ -1,9 +1,5 @@ use crate::io::Stream; -// #[cfg(feature = "_mock")] -// #[doc(hidden)] -// pub mod mock; - #[cfg(feature = "async-std")] #[path = "runtime/async_std.rs"] mod async_std_; @@ -20,8 +16,6 @@ mod tokio_; pub use actix_::Actix; #[cfg(feature = "async-std")] pub use async_std_::AsyncStd; -// #[cfg(feature = "_mock")] -// pub use mock::Mock; #[cfg(feature = "tokio")] pub use tokio_::Tokio; diff --git a/sqlx-core/src/runtime/mock.rs b/sqlx-core/src/runtime/mock.rs deleted file mode 100644 index 22af2de3..00000000 --- a/sqlx-core/src/runtime/mock.rs +++ /dev/null @@ -1,189 +0,0 @@ -use std::collections::HashMap; -#[cfg(feature = "async")] -use std::pin::Pin; -use std::sync::atomic::AtomicU16; -use std::sync::atomic::Ordering; -#[cfg(feature = "async")] -use std::task::{Context, Poll}; - -use bytes::BytesMut; -use conquer_once::Lazy; -use crossbeam::channel; -use parking_lot::RwLock; - -use crate::Runtime; - -#[derive(Debug)] -#[doc(hidden)] -pub struct Mock; - -#[derive(Debug)] -#[doc(hidden)] -#[allow(clippy::module_name_repetitions)] -pub struct MockStream { - port: u16, - rbuf: BytesMut, - read: channel::Receiver>, - write: channel::Sender>, -} - -static MOCK_STREAM_PORT: AtomicU16 = AtomicU16::new(0); - -static MOCK_STREAMS: Lazy>> = Lazy::new(RwLock::default); - -impl Runtime for Mock { - type TcpStream = MockStream; -} - -#[cfg(feature = "async")] -impl crate::AsyncRuntime for Mock { - fn connect_tcp( - _host: &str, - port: u16, - ) -> futures_util::future::BoxFuture<'_, std::io::Result> { - Box::pin(match MOCK_STREAMS.write().remove(&port) { - Some(stream) => futures_util::future::ok(stream), - None => futures_util::future::err(std::io::ErrorKind::ConnectionRefused.into()), - }) - } -} - -#[cfg(feature = "blocking")] -impl crate::blocking::Runtime for Mock { - fn connect_tcp(_host: &str, port: u16) -> std::io::Result { - match MOCK_STREAMS.write().remove(&port) { - Some(stream) => Ok(stream), - None => Err(std::io::ErrorKind::ConnectionRefused.into()), - } - } -} - -impl Mock { - #[must_use] - pub fn stream() -> MockStream { - let port = MOCK_STREAM_PORT.fetch_add(1, Ordering::SeqCst) + 1; - - let (write_l, write_r) = channel::unbounded(); - let (read_r, read_l) = channel::unbounded(); - - let stream_l = MockStream { port, read: read_l, write: write_l, rbuf: BytesMut::new() }; - let stream_r = MockStream { port, read: write_r, write: read_r, rbuf: BytesMut::new() }; - - MOCK_STREAMS.write().insert(port, stream_l); - - stream_r - } -} - -impl MockStream { - #[must_use] - pub const fn port(&self) -> u16 { - self.port - } -} - -#[cfg(feature = "blocking")] -impl std::io::Read for MockStream { - fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result { - use std::io::Write; - - loop { - if !self.rbuf.is_empty() { - // write as much data from our read buffer as we can - let written = buf.write(&self.rbuf)?; - - // remove the bytes that we were able to write - let _ = self.rbuf.split_to(written); - - // return how many bytes we wrote - return Ok(written); - } - - // no bytes in the buffer, ask the channel for more - #[allow(clippy::map_err_ignore)] - let message = self.read.recv().map_err(|_err| std::io::ErrorKind::ConnectionAborted)?; - - self.rbuf.extend_from_slice(&message); - // loop around and now send out this message - } - } -} - -#[cfg(feature = "blocking")] -impl std::io::Write for MockStream { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - // send it all, right away - let _ = self.write.send(buf.to_vec()); - - // that was easy - Ok(buf.len()) - } - - fn flush(&mut self) -> std::io::Result<()> { - // no implementation needed - // flush is inherent - Ok(()) - } -} - -#[cfg(feature = "async")] -impl futures_io::AsyncRead for MockStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: &mut [u8], - ) -> Poll> { - use std::io::Write; - - loop { - if !self.rbuf.is_empty() { - // write as much data from our read buffer as we can - let written = buf.write(&self.rbuf)?; - - // remove the bytes that we were able to write - let _ = self.rbuf.split_to(written); - - // return how many bytes we wrote - return Poll::Ready(Ok(written)); - } - - // no bytes in the buffer, ask the channel for more - let message = if let Ok(message) = self.read.try_recv() { - message - } else { - // no data, return pending (and immediately wake again to run try_recv again) - cx.waker().wake_by_ref(); - return Poll::Pending; - }; - - self.rbuf.extend_from_slice(&message); - // loop around and now send out this message - } - } -} - -#[cfg(feature = "async")] -impl futures_io::AsyncWrite for MockStream { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - // send it all, right away - let _ = self.write.send(buf.to_vec()); - - // that was easy - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - // no implementation needed - // flush is inherent - Poll::Ready(Ok(())) - } - - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - // nothing happens, ha - Poll::Ready(Ok(())) - } -} diff --git a/sqlx-mysql/src/connection/connect.rs b/sqlx-mysql/src/connection/connect.rs index 04d953e4..eafe5df6 100644 --- a/sqlx-mysql/src/connection/connect.rs +++ b/sqlx-mysql/src/connection/connect.rs @@ -143,7 +143,7 @@ where #[cfg(all(test, feature = "async"))] mod tests { use futures_executor::block_on; - use sqlx_core::{ConnectOptions, Mock}; + use sqlx_core::{mock::Mock, ConnectOptions}; use crate::mock::MySqlMockStreamExt; use crate::MySqlConnectOptions; diff --git a/sqlx-mysql/src/mock.rs b/sqlx-mysql/src/mock.rs index 31ec05a5..daec728f 100644 --- a/sqlx-mysql/src/mock.rs +++ b/sqlx-mysql/src/mock.rs @@ -1,5 +1,6 @@ use std::io; +use sqlx_core::io::Stream; use sqlx_core::mock::MockStream; pub(crate) trait MySqlMockStreamExt { @@ -27,12 +28,12 @@ impl MySqlMockStreamExt for MockStream { seq: u8, packet: &'x [u8], ) -> futures_util::future::BoxFuture<'x, io::Result<()>> { - use futures_util::AsyncWriteExt; - Box::pin(async move { - self.write_all(&packet.len().to_le_bytes()[..3]).await?; - self.write_all(&[seq]).await?; - self.write_all(packet).await + self.write_async(&packet.len().to_le_bytes()[..3]).await?; + self.write_async(&[seq]).await?; + self.write_async(packet).await?; + + Ok(()) }) } @@ -41,11 +42,9 @@ impl MySqlMockStreamExt for MockStream { &mut self, n: usize, ) -> futures_util::future::BoxFuture<'_, io::Result>> { - use futures_util::AsyncReadExt; - Box::pin(async move { let mut buf = vec![0; n]; - let read = self.read(&mut buf).await?; + let read = self.read_async(&mut buf).await?; buf.truncate(read); Ok(buf) @@ -54,11 +53,9 @@ impl MySqlMockStreamExt for MockStream { #[cfg(feature = "async")] fn read_all_async(&mut self) -> futures_util::future::BoxFuture<'_, io::Result>> { - use futures_util::AsyncReadExt; - Box::pin(async move { let mut buf = vec![0; 1024]; - let read = self.read(&mut buf).await?; + let read = self.read_async(&mut buf).await?; buf.truncate(read); Ok(buf)