diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 246a217b..2e7ec8dc 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -42,7 +42,7 @@ either = "1.6.1" actix-rt = { version = "2.0.0-beta.2", optional = true } _async-std = { version = "1.8", optional = true, package = "async-std" } futures-util = { version = "0.3", optional = true, features = ["io"] } -_tokio = { version = "1.0", optional = true, package = "tokio", features = ["net"] } +_tokio = { version = "1.0", optional = true, package = "tokio", features = ["net", "io-util"] } async-compat = { version = "*", git = "https://github.com/taiki-e/async-compat", branch = "tokio1", optional = true } futures-io = { version = "0.3", optional = true } futures-core = { version = "0.3", optional = true } diff --git a/sqlx-core/src/blocking/runtime.rs b/sqlx-core/src/blocking/runtime.rs index e0bf6fe9..2f6cb8c2 100644 --- a/sqlx-core/src/blocking/runtime.rs +++ b/sqlx-core/src/blocking/runtime.rs @@ -1,5 +1,5 @@ use std::io::{self, Read, Write}; -use std::net::TcpStream; +use std::net::{Shutdown, TcpStream}; #[cfg(unix)] use std::os::unix::net::UnixStream; #[cfg(unix)] @@ -65,6 +65,10 @@ impl<'s> IoStream<'s, Blocking> for TcpStream { #[cfg(feature = "async")] type WriteFuture = BoxFuture<'s, io::Result>; + #[doc(hidden)] + #[cfg(feature = "async")] + type ShutdownFuture = BoxFuture<'s, io::Result<()>>; + #[inline] #[doc(hidden)] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result { @@ -80,6 +84,12 @@ impl<'s> IoStream<'s, Blocking> for TcpStream { Ok(size) } + #[inline] + #[doc(hidden)] + fn shutdown(&'s mut self) -> io::Result<()> { + TcpStream::shutdown(self, Shutdown::Both) + } + #[doc(hidden)] #[cfg(feature = "async")] fn read_async(&'s mut self, _buf: &'s mut [u8]) -> Self::ReadFuture { @@ -93,6 +103,13 @@ impl<'s> IoStream<'s, Blocking> for TcpStream { // UNREACHABLE: where Self: Async unreachable!() } + + #[doc(hidden)] + #[cfg(feature = "async")] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + // UNREACHABLE: where Self: Async + unreachable!() + } } // 's: stream @@ -106,6 +123,10 @@ impl<'s> IoStream<'s, Blocking> for UnixStream { #[cfg(feature = "async")] type WriteFuture = BoxFuture<'s, io::Result>; + #[doc(hidden)] + #[cfg(feature = "async")] + type ShutdownFuture = BoxFuture<'s, io::Result<()>>; + #[inline] #[doc(hidden)] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result { @@ -121,6 +142,12 @@ impl<'s> IoStream<'s, Blocking> for UnixStream { Ok(size) } + #[inline] + #[doc(hidden)] + fn shutdown(&'s mut self) -> io::Result<()> { + UnixStream::shutdown(self, Shutdown::Both) + } + #[doc(hidden)] #[cfg(feature = "async")] #[allow(unused_variables)] @@ -136,4 +163,11 @@ impl<'s> IoStream<'s, Blocking> for UnixStream { // UNREACHABLE: where Self: Async unreachable!() } + + #[doc(hidden)] + #[cfg(feature = "async")] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + // UNREACHABLE: where Self: Async + unreachable!() + } } diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index f0043d1f..554adbed 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; use bytes::{Bytes, BytesMut}; @@ -155,7 +156,11 @@ macro_rules! flush { } #[cfg(feature = "async")] -impl Stream<'s, Rt>> BufStream { +impl BufStream +where + Rt: crate::Async, + S: for<'s> Stream<'s, Rt>, +{ pub async fn flush_async(&mut self) -> crate::Result<()> { flush!(self) } @@ -166,7 +171,11 @@ impl Stream<'s, Rt>> BufStream { } #[cfg(feature = "blocking")] -impl Stream<'s, Rt>> BufStream { +impl BufStream +where + Rt: crate::blocking::Runtime, + S: for<'s> Stream<'s, Rt>, +{ pub fn flush(&mut self) -> crate::Result<()> { flush!(@blocking self) } @@ -175,3 +184,17 @@ impl Stream<'s, Rt>> BufStream { read!(@blocking self, offset, n) } } + +impl Deref for BufStream { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl DerefMut for BufStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} diff --git a/sqlx-core/src/io/stream.rs b/sqlx-core/src/io/stream.rs index 878663c7..e54df1c4 100644 --- a/sqlx-core/src/io/stream.rs +++ b/sqlx-core/src/io/stream.rs @@ -15,6 +15,9 @@ where #[cfg(feature = "async")] type WriteFuture: 's + Future> + Send; + #[cfg(feature = "async")] + type ShutdownFuture: 's + Future> + Send; + #[cfg(feature = "async")] #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture @@ -27,6 +30,12 @@ where where Rt: crate::Async; + #[cfg(feature = "async")] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture + where + Rt: crate::Async; + #[cfg(feature = "blocking")] #[doc(hidden)] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result @@ -38,6 +47,12 @@ where fn write(&'s mut self, buf: &'s [u8]) -> io::Result where Rt: crate::blocking::Runtime; + + #[cfg(feature = "blocking")] + #[doc(hidden)] + fn shutdown(&'s mut self) -> io::Result<()> + where + Rt: crate::blocking::Runtime; } #[cfg(not(any( @@ -56,6 +71,9 @@ where #[cfg(feature = "async")] type WriteFuture = futures_util::future::BoxFuture<'s, io::Result>; + #[cfg(feature = "async")] + type ShutdownFuture = futures_util::future::BoxFuture<'s, io::Result<()>>; + #[cfg(feature = "async")] #[doc(hidden)] #[allow(unused_variables)] @@ -71,4 +89,14 @@ where // UNREACHABLE: where Self: Async unreachable!() } + + #[cfg(feature = "async")] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture + where + Rt: crate::Async, + { + // UNREACHABLE: where Self: Async + unreachable!() + } } diff --git a/sqlx-core/src/mock.rs b/sqlx-core/src/mock.rs index 2bbd1b19..84937208 100644 --- a/sqlx-core/src/mock.rs +++ b/sqlx-core/src/mock.rs @@ -107,6 +107,9 @@ impl<'s> IoStream<'s, Mock> for MockStream { #[cfg(feature = "async")] type WriteFuture = BoxFuture<'s, io::Result>; + #[cfg(feature = "async")] + type ShutdownFuture = future::Ready>; + #[cfg(feature = "async")] fn read_async(&'s mut self, mut buf: &'s mut [u8]) -> Self::ReadFuture { Box::pin(async move { @@ -148,6 +151,11 @@ impl<'s> IoStream<'s, Mock> for MockStream { Box::pin(future::ok(buf.len())) } + #[cfg(feature = "async")] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + future::ok(()) + } + #[cfg(feature = "blocking")] fn read(&'s mut self, mut buf: &'s mut [u8]) -> io::Result { use io::Write; @@ -181,4 +189,9 @@ impl<'s> IoStream<'s, Mock> for MockStream { // that was easy Ok(buf.len()) } + + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + Ok(()) + } } diff --git a/sqlx-core/src/net/stream.rs b/sqlx-core/src/net/stream.rs index b39ab2f7..d0ce6f07 100644 --- a/sqlx-core/src/net/stream.rs +++ b/sqlx-core/src/net/stream.rs @@ -41,6 +41,7 @@ where )), } } + #[cfg(feature = "blocking")] pub fn connect(address: Either<&(String, u16), &PathBuf>) -> io::Result where @@ -80,7 +81,13 @@ where >::WriteFuture, >; - #[inline] + #[doc(hidden)] + #[cfg(feature = "async")] + type ShutdownFuture = future::Either< + >::ShutdownFuture, + >::ShutdownFuture, + >; + #[doc(hidden)] #[cfg(feature = "async")] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture @@ -93,7 +100,6 @@ where } } - #[inline] #[doc(hidden)] #[cfg(feature = "async")] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture @@ -106,7 +112,18 @@ where } } - #[inline] + #[doc(hidden)] + #[cfg(feature = "async")] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture + where + Rt: crate::Async, + { + match self { + Self::Tcp(stream) => stream.shutdown_async().left_future(), + Self::Unix(stream) => stream.shutdown_async().right_future(), + } + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result @@ -119,7 +136,6 @@ where } } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, buf: &'s [u8]) -> io::Result @@ -131,6 +147,18 @@ where Self::Unix(stream) => stream.write(buf), } } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> + where + Rt: crate::blocking::Runtime, + { + match self { + Self::Tcp(stream) => stream.shutdown(), + Self::Unix(stream) => stream.shutdown(), + } + } } #[cfg(not(unix))] @@ -146,7 +174,10 @@ where #[cfg(feature = "async")] type WriteFuture = >::WriteFuture; - #[inline] + #[doc(hidden)] + #[cfg(feature = "async")] + type ShutdownFuture = >::ShutdownFuture; + #[doc(hidden)] #[cfg(feature = "async")] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture @@ -158,7 +189,6 @@ where } } - #[inline] #[doc(hidden)] #[cfg(feature = "async")] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture @@ -170,7 +200,18 @@ where } } - #[inline] + #[doc(hidden)] + #[cfg(feature = "async")] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture + where + Rt: crate::Async, + { + match self { + Self::Tcp(stream) => stream.shutdown_async().left_future(), + Self::Unix(stream) => stream.shutdown_async().right_future(), + } + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result @@ -182,7 +223,6 @@ where } } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, buf: &'s [u8]) -> io::Result @@ -193,4 +233,15 @@ where Self::Tcp(stream) => stream.write(buf), } } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result + where + Rt: crate::blocking::Runtime, + { + match self { + Self::Tcp(stream) => stream.shutdown(buf), + } + } } diff --git a/sqlx-core/src/runtime/actix.rs b/sqlx-core/src/runtime/actix.rs index 71bbd9c5..d46ca8d4 100644 --- a/sqlx-core/src/runtime/actix.rs +++ b/sqlx-core/src/runtime/actix.rs @@ -67,19 +67,24 @@ impl<'s> Stream<'s, Actix> for Compat { #[doc(hidden)] type WriteFuture = Write<'s, Self>; - #[inline] + #[doc(hidden)] + type ShutdownFuture = BoxFuture<'s, io::Result<()>>; + #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture { AsyncReadExt::read(self, buf) } - #[inline] #[doc(hidden)] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { AsyncWriteExt::write(self, buf) } - #[inline] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + _tokio::io::AsyncWriteExt::shutdown(self.get_mut()).boxed() + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, _buf: &'s mut [u8]) -> io::Result { @@ -87,13 +92,19 @@ impl<'s> Stream<'s, Actix> for Compat { unreachable!() } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, _buf: &'s [u8]) -> io::Result { // UNREACHABLE: where Self: blocking::Runtime unreachable!() } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + // UNREACHABLE: where Self: blocking::Runtime + unreachable!() + } } // 's: stream @@ -105,19 +116,24 @@ impl<'s> Stream<'s, Actix> for Compat { #[doc(hidden)] type WriteFuture = Write<'s, Self>; - #[inline] + #[doc(hidden)] + type ShutdownFuture = BoxFuture<'s, io::Result<()>>; + #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture { AsyncReadExt::read(self, buf) } - #[inline] #[doc(hidden)] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { AsyncWriteExt::write(self, buf) } - #[inline] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + _tokio::io::AsyncWriteExt::shutdown(self.get_mut()).boxed() + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, _buf: &'s mut [u8]) -> io::Result { @@ -125,11 +141,17 @@ impl<'s> Stream<'s, Actix> for Compat { unreachable!() } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, _buf: &'s [u8]) -> io::Result { // UNREACHABLE: where Self: blocking::Runtime unreachable!() } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + // UNREACHABLE: where Self: blocking::Runtime + unreachable!() + } } diff --git a/sqlx-core/src/runtime/async_std.rs b/sqlx-core/src/runtime/async_std.rs index 0c22346e..c0698034 100644 --- a/sqlx-core/src/runtime/async_std.rs +++ b/sqlx-core/src/runtime/async_std.rs @@ -7,12 +7,14 @@ use _async_std::net::TcpStream; use _async_std::os::unix::net::UnixStream; #[cfg(feature = "blocking")] use _async_std::task; +use futures_util::future::{self, BoxFuture}; use futures_util::io::{Read, Write}; -use futures_util::{future::BoxFuture, AsyncReadExt, AsyncWriteExt, FutureExt}; +use futures_util::{AsyncReadExt, AsyncWriteExt, FutureExt}; #[cfg(feature = "blocking")] use crate::blocking; use crate::{io::Stream, Async, Runtime}; +use std::net::Shutdown; /// Provides [`Runtime`] for [**async-std**](https://async.rs). Supports both blocking /// and non-blocking operation. @@ -71,31 +73,41 @@ impl<'s> Stream<'s, AsyncStd> for TcpStream { #[doc(hidden)] type WriteFuture = Write<'s, Self>; - #[inline] + #[doc(hidden)] + type ShutdownFuture = future::Ready>; + #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture { AsyncReadExt::read(self, buf) } - #[inline] #[doc(hidden)] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { AsyncWriteExt::write(self, buf) } - #[inline] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + future::ready(TcpStream::shutdown(self, Shutdown::Both)) + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result { task::block_on(self.read_async(buf)) } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, buf: &'s [u8]) -> io::Result { task::block_on(self.write_async(buf)) } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + task::block_on(self.shutdown_async()) + } } // 's: stream @@ -107,29 +119,39 @@ impl<'s> Stream<'s, AsyncStd> for UnixStream { #[doc(hidden)] type WriteFuture = Write<'s, Self>; - #[inline] + #[doc(hidden)] + type ShutdownFuture = future::Ready>; + #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture { AsyncReadExt::read(self, buf) } - #[inline] #[doc(hidden)] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { AsyncWriteExt::write(self, buf) } - #[inline] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + future::ready(UnixStream::shutdown(self, Shutdown::Both)) + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, buf: &'s mut [u8]) -> io::Result { task::block_on(self.read_async(buf)) } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, buf: &'s [u8]) -> io::Result { task::block_on(self.write_async(buf)) } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + task::block_on(self.shutdown_async()) + } } diff --git a/sqlx-core/src/runtime/tokio.rs b/sqlx-core/src/runtime/tokio.rs index 1368b4e6..eb607b78 100644 --- a/sqlx-core/src/runtime/tokio.rs +++ b/sqlx-core/src/runtime/tokio.rs @@ -65,19 +65,24 @@ impl<'s> Stream<'s, Tokio> for Compat { #[doc(hidden)] type WriteFuture = Write<'s, Self>; - #[inline] + #[doc(hidden)] + type ShutdownFuture = BoxFuture<'s, io::Result<()>>; + #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture { AsyncReadExt::read(self, buf) } - #[inline] #[doc(hidden)] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { AsyncWriteExt::write(self, buf) } - #[inline] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + _tokio::io::AsyncWriteExt::shutdown(self.get_mut()).boxed() + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, _buf: &'s mut [u8]) -> io::Result { @@ -85,13 +90,19 @@ impl<'s> Stream<'s, Tokio> for Compat { unreachable!() } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, _buf: &'s [u8]) -> io::Result { // UNREACHABLE: where Self: blocking::Runtime unreachable!() } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + // UNREACHABLE: where Self: blocking::Runtime + unreachable!() + } } // 's: stream @@ -103,19 +114,24 @@ impl<'s> Stream<'s, Tokio> for Compat { #[doc(hidden)] type WriteFuture = Write<'s, Self>; - #[inline] + #[doc(hidden)] + type ShutdownFuture = BoxFuture<'s, io::Result<()>>; + #[doc(hidden)] fn read_async(&'s mut self, buf: &'s mut [u8]) -> Self::ReadFuture { AsyncReadExt::read(self, buf) } - #[inline] #[doc(hidden)] fn write_async(&'s mut self, buf: &'s [u8]) -> Self::WriteFuture { AsyncWriteExt::write(self, buf) } - #[inline] + #[doc(hidden)] + fn shutdown_async(&'s mut self) -> Self::ShutdownFuture { + _tokio::io::AsyncWriteExt::shutdown(self.get_mut()).boxed() + } + #[doc(hidden)] #[cfg(feature = "blocking")] fn read(&'s mut self, _buf: &'s mut [u8]) -> io::Result { @@ -123,11 +139,17 @@ impl<'s> Stream<'s, Tokio> for Compat { unreachable!() } - #[inline] #[doc(hidden)] #[cfg(feature = "blocking")] fn write(&'s mut self, _buf: &'s [u8]) -> io::Result { // UNREACHABLE: where Self: blocking::Runtime unreachable!() } + + #[doc(hidden)] + #[cfg(feature = "blocking")] + fn shutdown(&'s mut self) -> io::Result<()> { + // UNREACHABLE: where Self: blocking::Runtime + unreachable!() + } } diff --git a/sqlx-mysql/src/connection/close.rs b/sqlx-mysql/src/connection/close.rs index 4c351e5f..f4b877c1 100644 --- a/sqlx-mysql/src/connection/close.rs +++ b/sqlx-mysql/src/connection/close.rs @@ -1,4 +1,4 @@ -use sqlx_core::{Result, Runtime}; +use sqlx_core::{io::Stream, Result, Runtime}; use crate::protocol::Quit; @@ -13,6 +13,7 @@ where { self.write_packet(&Quit)?; self.stream.flush_async().await?; + self.stream.shutdown_async().await?; Ok(()) } @@ -24,6 +25,7 @@ where { self.write_packet(&Quit)?; self.stream.flush()?; + self.stream.shutdown()?; Ok(()) }