Upgrade async runtime dependencies

Co-authored-by: Josh Toft <joshtoft@gmail.com>
Co-authored-by: Philip A Reimer <antreimer@gmail.com>
This commit is contained in:
Jonas Platte
2021-01-14 13:24:25 +01:00
committed by Ryan Leckey
parent de4a7decfb
commit c5d43db312
25 changed files with 176 additions and 323 deletions

View File

@@ -43,9 +43,9 @@ runtime-async-std-rustls = [ "sqlx-rt/runtime-async-std-rustls", "_tls-rustls",
runtime-tokio-rustls = [ "sqlx-rt/runtime-tokio-rustls", "_tls-rustls", "_rt-tokio" ]
# for conditional compilation
_rt-actix = []
_rt-actix = [ "tokio-stream" ]
_rt-async-std = []
_rt-tokio = []
_rt-tokio = [ "tokio-stream" ]
_tls-native-tls = []
_tls-rustls = [ "rustls", "webpki", "webpki-roots" ]
@@ -61,7 +61,7 @@ bigdecimal_ = { version = "0.2.0", optional = true, package = "bigdecimal" }
rust_decimal = { version = "1.8.1", optional = true }
bit-vec = { version = "0.6.2", optional = true }
bitflags = { version = "1.2.1", default-features = false }
bytes = "0.5.0"
bytes = "1.0.0"
byteorder = { version = "1.3.4", default-features = false, features = [ "std" ] }
chrono = { version = "0.4.11", default-features = false, features = [ "clock" ], optional = true }
crc = { version = "1.8.1", optional = true }
@@ -91,7 +91,7 @@ parking_lot = "0.11.0"
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
regex = { version = "1.3.9", optional = true }
rsa = { version = "0.3.0", optional = true }
rustls = { version = "0.18.0", features = [ "dangerous_configuration" ], optional = true }
rustls = { version = "0.19.0", features = [ "dangerous_configuration" ], optional = true }
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
sha-1 = { version = "0.9.0", default-features = false, optional = true }
@@ -99,6 +99,7 @@ sha2 = { version = "0.9.0", default-features = false, optional = true }
sqlformat = "0.1.0"
thiserror = "1.0.19"
time = { version = "0.2.16", optional = true }
tokio-stream = { version = "0.1.2", features = ["fs"], optional = true }
smallvec = "1.4.0"
url = { version = "2.1.1", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }

View File

@@ -21,8 +21,8 @@ pub trait BufExt: Buf {
impl BufExt for Bytes {
fn get_bytes_nul(&mut self) -> Result<Bytes, Error> {
let nul = memchr(b'\0', self.bytes())
.ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?;
let nul =
memchr(b'\0', &self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?;
let v = self.slice(0..nul);

View File

@@ -14,9 +14,13 @@ pub trait MigrationSource<'s>: Debug {
impl<'s> MigrationSource<'s> for &'s Path {
fn resolve(self) -> BoxFuture<'s, Result<Vec<Migration>, BoxDynError>> {
Box::pin(async move {
#[allow(unused_mut)]
let mut s = fs::read_dir(self.canonicalize()?).await?;
let mut migrations = Vec::new();
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
let mut s = tokio_stream::wrappers::ReadDirStream::new(s);
while let Some(entry) = s.try_next().await? {
if !entry.metadata().await?.is_file() {
// not a file; ignore

View File

@@ -7,9 +7,8 @@ use crate::mssql::statement::MssqlStatementMetadata;
use crate::mssql::{Mssql, MssqlConnectOptions};
use crate::transaction::Transaction;
use futures_core::future::BoxFuture;
use futures_util::{future::ready, FutureExt, TryFutureExt};
use futures_util::{FutureExt, TryFutureExt};
use std::fmt::{self, Debug, Formatter};
use std::net::Shutdown;
use std::sync::Arc;
mod establish;
@@ -34,9 +33,26 @@ impl Connection for MssqlConnection {
type Options = MssqlConnectOptions;
fn close(self) -> BoxFuture<'static, Result<(), Error>> {
#[allow(unused_mut)]
fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
// NOTE: there does not seem to be a clean shutdown packet to send to MSSQL
ready(self.stream.shutdown(Shutdown::Both).map_err(Into::into)).boxed()
#[cfg(feature = "_rt-async-std")]
{
use std::future::ready;
use std::net::Shutdown;
ready(self.stream.shutdown(Shutdown::Both).map_err(Into::into)).boxed()
}
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
{
use sqlx_rt::AsyncWriteExt;
// FIXME: This is equivalent to Shutdown::Write, not Shutdown::Both like above
// https://docs.rs/tokio/1.0.1/tokio/io/trait.AsyncWriteExt.html#method.shutdown
async move { self.stream.shutdown().await.map_err(Into::into) }.boxed()
}
}
fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {

View File

@@ -1,4 +1,4 @@
use bytes::buf::ext::Chain;
use bytes::buf::Chain;
use bytes::Bytes;
use digest::{Digest, FixedOutput};
use generic_array::GenericArray;

View File

@@ -1,3 +1,4 @@
use bytes::buf::Buf;
use bytes::Bytes;
use crate::common::StatementCache;
@@ -8,7 +9,6 @@ use crate::mysql::protocol::connect::{
};
use crate::mysql::protocol::Capabilities;
use crate::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
use bytes::buf::BufExt as _;
impl MySqlConnection {
pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result<Self, Error> {

View File

@@ -50,7 +50,7 @@ impl Connection for MySqlConnection {
fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move {
self.stream.send_packet(Quit).await?;
self.stream.shutdown()?;
self.stream.shutdown().await?;
Ok(())
})

View File

@@ -1,4 +1,4 @@
use bytes::buf::ext::{BufExt as _, Chain};
use bytes::buf::Chain;
use bytes::{Buf, Bytes};
use crate::error::Error;
@@ -134,7 +134,7 @@ fn test_decode_handshake_mysql_8_0_18() {
));
assert_eq!(
&*p.auth_plugin_data.to_bytes(),
&*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
);
}
@@ -187,7 +187,7 @@ fn test_decode_handshake_mariadb_10_4_7() {
));
assert_eq!(
&*p.auth_plugin_data.to_bytes(),
&*p.auth_plugin_data.into_iter().collect::<Vec<_>>(),
&[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,]
);
}

View File

@@ -3,3 +3,15 @@ mod tls;
pub use socket::Socket;
pub use tls::MaybeTlsStream;
#[cfg(feature = "_rt-async-std")]
type PollReadBuf<'a> = [u8];
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
type PollReadBuf<'a> = sqlx_rt::ReadBuf<'a>;
#[cfg(feature = "_rt-async-std")]
type PollReadOut = usize;
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
type PollReadOut = ();

View File

@@ -1,7 +1,6 @@
#![allow(dead_code)]
use std::io;
use std::net::Shutdown;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
@@ -36,12 +35,29 @@ impl Socket {
))
}
pub fn shutdown(&self) -> io::Result<()> {
match self {
Socket::Tcp(s) => s.shutdown(Shutdown::Both),
pub async fn shutdown(&mut self) -> io::Result<()> {
#[cfg(feature = "_rt-async-std")]
{
use std::net::Shutdown;
#[cfg(unix)]
Socket::Unix(s) => s.shutdown(Shutdown::Both),
match self {
Socket::Tcp(s) => s.shutdown(Shutdown::Both),
#[cfg(unix)]
Socket::Unix(s) => s.shutdown(Shutdown::Both),
}
}
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
{
use sqlx_rt::AsyncWriteExt;
match self {
Socket::Tcp(s) => s.shutdown().await,
#[cfg(unix)]
Socket::Unix(s) => s.shutdown().await,
}
}
}
}
@@ -50,8 +66,8 @@ impl AsyncRead for Socket {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut super::PollReadBuf<'_>,
) -> Poll<io::Result<super::PollReadOut>> {
match &mut *self {
Socket::Tcp(s) => Pin::new(s).poll_read(cx, buf),
@@ -59,24 +75,6 @@ impl AsyncRead for Socket {
Socket::Unix(s) => Pin::new(s).poll_read(cx, buf),
}
}
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
fn poll_read_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: bytes::BufMut,
{
match &mut *self {
Socket::Tcp(s) => Pin::new(s).poll_read_buf(cx, buf),
#[cfg(unix)]
Socket::Unix(s) => Pin::new(s).poll_read_buf(cx, buf),
}
}
}
impl AsyncWrite for Socket {
@@ -121,22 +119,4 @@ impl AsyncWrite for Socket {
Socket::Unix(s) => Pin::new(s).poll_close(cx),
}
}
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
fn poll_write_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: bytes::Buf,
{
match &mut *self {
Socket::Tcp(s) => Pin::new(s).poll_write_buf(cx, buf),
#[cfg(unix)]
Socket::Unix(s) => Pin::new(s).poll_write_buf(cx, buf),
}
}
}

View File

@@ -114,8 +114,8 @@ where
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut super::PollReadBuf<'_>,
) -> Poll<io::Result<super::PollReadOut>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
@@ -123,24 +123,6 @@ where
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
}
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
fn poll_read_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: bytes::BufMut,
{
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read_buf(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read_buf(cx, buf),
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
}
}
impl<S> AsyncWrite for MaybeTlsStream<S>
@@ -188,24 +170,6 @@ where
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
}
#[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))]
fn poll_write_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: bytes::Buf,
{
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write_buf(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_buf(cx, buf),
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
}
}
}
impl<S> Deref for MaybeTlsStream<S>

View File

@@ -122,7 +122,7 @@ impl Connection for PgConnection {
Box::pin(async move {
self.stream.send(Terminate).await?;
self.stream.shutdown()?;
self.stream.shutdown().await?;
Ok(())
})

View File

@@ -76,7 +76,7 @@ impl Decode<'_, Postgres> for BitVec {
))?;
}
let mut bitvec = BitVec::from_bytes(bytes.bytes());
let mut bitvec = BitVec::from_bytes(&bytes);
// Chop off zeroes from the back. We get bits in bytes, so if
// our bitvec is not in full bytes, extra zeroes are added to