wip(mysql): connection/establish, WriteExt, BufExt, and settle on AsyncRuntime

This commit is contained in:
Ryan Leckey
2021-01-01 12:28:48 -08:00
parent 44c175bb19
commit 55a8e7ba29
39 changed files with 4693 additions and 156 deletions

View File

@@ -26,16 +26,22 @@ blocking = []
# abstract async feature
# not meant to be used directly
# activates several crates used in all async runtimes
async = ["futures-util"]
async = ["futures-util", "futures-io"]
# async runtimes
async-std = ["async", "_async-std"]
actix = ["async", "actix-rt", "tokio_02"]
tokio = ["async", "_tokio"]
actix = ["async", "actix-rt", "tokio_02", "async-compat_02"]
tokio = ["async", "_tokio", "async-compat"]
[dependencies]
actix-rt = { version = "1.1.1", optional = true }
_async-std = { version = "1.8.0", optional = true, package = "async-std" }
futures-util = { version = "0.3.8", optional = true }
_tokio = { version = "1.0.1", optional = true, package = "tokio", features = ["net"] }
tokio_02 = { version = "0.2.24", optional = true, package = "tokio", features = ["net"] }
actix-rt = { version = "1.1", optional = true }
_async-std = { version = "1.8", optional = true, package = "async-std" }
futures-util = { version = "0.3", optional = true, features = ["io"] }
_tokio = { version = "1.0", optional = true, package = "tokio", features = ["net"] }
tokio_02 = { version = "0.2", optional = true, package = "tokio", features = ["net"] }
async-compat = { version = "*", git = "https://github.com/taiki-e/async-compat", branch = "tokio1", optional = true }
async-compat_02 = { version = "0.1", optional = true, package = "async-compat" }
futures-io = { version = "0.3", optional = true }
bytes = "1.0"
string = { version = "0.2.1", default-features = false }
memchr = "2.3"

View File

@@ -10,9 +10,8 @@ pub use options::ConnectOptions;
pub use runtime::{Blocking, Runtime};
pub mod prelude {
pub use crate::Database as _;
pub use super::ConnectOptions as _;
pub use super::Connection as _;
pub use super::Runtime as _;
pub use crate::Database as _;
}

View File

@@ -20,8 +20,7 @@ where
where
Self: Sized,
{
url.parse::<<Self as crate::Connection<Rt>>::Options>()?
.connect()
url.parse::<<Self as crate::Connection<Rt>>::Options>()?.connect()
}
/// Explicitly close this database connection.

View File

@@ -1,8 +1,8 @@
use crate::{ConnectOptions, Database, DefaultRuntime, Runtime};
#[cfg(feature = "async")]
use futures_util::future::BoxFuture;
use crate::{ConnectOptions, Database, DefaultRuntime, Runtime};
/// A unique connection (session) with a specific database.
///
/// With a client/server model, this is equivalent to a network connection
@@ -49,7 +49,8 @@ where
fn connect(url: &str) -> BoxFuture<'_, crate::Result<Self>>
where
Self: Sized,
Rt: crate::Async,
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
{
let options = url.parse::<Self::Options>();
Box::pin(async move { options?.connect().await })
@@ -64,7 +65,8 @@ where
#[cfg(feature = "async")]
fn close(self) -> BoxFuture<'static, crate::Result<()>>
where
Rt: crate::Async;
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin;
/// Checks if a connection to the database is still valid.
///
@@ -76,5 +78,6 @@ where
#[cfg(feature = "async")]
fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>
where
Rt: crate::Async;
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin;
}

View File

@@ -7,10 +7,7 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
Configuration {
message: Cow<'static, str>,
source: Option<Box<dyn StdError + Send + Sync>>,
},
Configuration { message: Cow<'static, str>, source: Option<Box<dyn StdError + Send + Sync>> },
Network(std::io::Error),
}
@@ -21,18 +18,12 @@ impl Error {
message: impl Into<Cow<'static, str>>,
source: impl Into<Box<dyn StdError + Send + Sync>>,
) -> Self {
Self::Configuration {
message: message.into(),
source: Some(source.into()),
}
Self::Configuration { message: message.into(), source: Some(source.into()) }
}
#[doc(hidden)]
pub fn configuration_msg(message: impl Into<Cow<'static, str>>) -> Self {
Self::Configuration {
message: message.into(),
source: None,
}
Self::Configuration { message: message.into(), source: None }
}
}
@@ -41,15 +32,11 @@ impl Display for Error {
match self {
Self::Network(source) => write!(f, "network: {}", source),
Self::Configuration {
message,
source: None,
} => write!(f, "configuration: {}", message),
Self::Configuration { message, source: None } => {
write!(f, "configuration: {}", message)
}
Self::Configuration {
message,
source: Some(source),
} => {
Self::Configuration { message, source: Some(source) } => {
write!(f, "configuration: {}: {}", message, source)
}
}
@@ -59,10 +46,7 @@ impl Display for Error {
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::Configuration {
source: Some(source),
..
} => Some(&**source),
Self::Configuration { source: Some(source), .. } => Some(&**source),
Self::Network(source) => Some(source),
@@ -76,3 +60,9 @@ impl From<std::io::Error> for Error {
Error::Network(error)
}
}
impl From<std::io::ErrorKind> for Error {
fn from(error: std::io::ErrorKind) -> Self {
Error::Network(error.into())
}
}

11
sqlx-core/src/io.rs Normal file
View File

@@ -0,0 +1,11 @@
mod buf;
mod write;
mod buf_stream;
mod deserialize;
mod serialize;
pub use buf::BufExt;
pub use write::WriteExt;
pub use buf_stream::BufStream;
pub use deserialize::Deserialize;
pub use serialize::Serialize;

72
sqlx-core/src/io/buf.rs Normal file
View File

@@ -0,0 +1,72 @@
use std::io;
use bytes::{Bytes, Buf};
use memchr::memchr;
use string::String;
// UNSAFE: _unchecked string methods
// intended for use when the protocol is *known* to always produce
// valid UTF-8 data
pub trait BufExt: Buf {
#[allow(unsafe_code)]
unsafe fn get_str_unchecked(&mut self, n: usize) -> String<Bytes>;
#[allow(unsafe_code)]
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<String<Bytes>>;
}
impl BufExt for Bytes {
#[allow(unsafe_code)]
unsafe fn get_str_unchecked(&mut self, n: usize) -> String<Bytes> {
String::from_utf8_unchecked(self.split_to(n))
}
#[allow(unsafe_code)]
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<String<Bytes>> {
let nul = memchr(b'\0', self).ok_or(io::ErrorKind::InvalidData)?;
Ok(String::from_utf8_unchecked(self.split_to(nul + 1).slice(..nul)))
}
}
#[cfg(test)]
mod tests {
use std::io;
use bytes::{Bytes, Buf};
use super::BufExt;
#[test]
fn test_get_str() {
let mut buf = Bytes::from_static(b"Hello World\0");
#[allow(unsafe_code)]
let s = unsafe { buf.get_str_unchecked(5) };
buf.advance(1);
#[allow(unsafe_code)]
let s2 = unsafe { buf.get_str_unchecked(5) };
assert_eq!(&s, "Hello");
assert_eq!(&s2, "World");
}
#[test]
fn test_get_str_nul() -> io::Result<()> {
let mut buf = Bytes::from_static(b"Hello\0 World\0");
#[allow(unsafe_code)]
let s = unsafe { buf.get_str_nul_unchecked()? };
buf.advance(1);
#[allow(unsafe_code)]
let s2 = unsafe { buf.get_str_nul_unchecked()? };
assert_eq!(&s, "Hello");
assert_eq!(&s2, "World");
Ok(())
}
}

View File

@@ -0,0 +1,92 @@
#[cfg(feature = "blocking")]
use std::io::{Read, Write};
use std::slice::from_raw_parts_mut;
use bytes::{BufMut, Bytes, BytesMut};
#[cfg(feature = "async")]
use futures_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "async")]
use futures_util::{AsyncReadExt, AsyncWriteExt};
/// Wraps a stream and buffers input and output to and from it.
///
/// It can be excessively inefficient to work directly with a `Read` or `Write`. For example,
/// every call to `read` or `write` on `TcpStream` results in a system call (leading to
/// a network interaction). `BufStream` keeps a read and write buffer with infrequent calls
/// to `read` and `write` on the underlying stream.
///
pub struct BufStream<S> {
stream: S,
// (r)ead buffer
rbuf: BytesMut,
// (w)rite buffer
wbuf: Vec<u8>,
// offset into [wbuf] that a previous write operation has written into
wbuf_offset: usize,
}
impl<S> BufStream<S> {
pub fn with_capacity(stream: S, read: usize, write: usize) -> Self {
Self {
stream,
rbuf: BytesMut::with_capacity(read),
wbuf: Vec::with_capacity(write),
wbuf_offset: 0,
}
}
pub fn get(&self, offset: usize, n: usize) -> &[u8] {
&(self.rbuf.as_ref())[offset..(offset + n)]
}
pub fn take(&mut self, n: usize) -> Bytes {
self.rbuf.split_to(n).freeze()
}
pub fn consume(&mut self, n: usize) {
let _ = self.take(n);
}
}
#[cfg(feature = "async")]
impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn read_async(&mut self, n: usize) -> crate::Result<()> {
// // before waiting to receive data
// // ensure that the write buffer is flushed
// if !self.wbuf.is_empty() {
// self.flush().await?;
// }
// while our read buffer is too small to satisfy the requested amount
while self.rbuf.len() < n {
// ensure that there is room in the read buffer
self.rbuf.reserve(n.max(128));
#[allow(unsafe_code)]
unsafe {
// prepare a chunk of uninitialized memory to write to
// this is UB if the Read impl of the stream reads from the write buffer
let b = self.rbuf.chunk_mut();
let b = from_raw_parts_mut(b.as_mut_ptr(), b.len());
// read as much as we can and return when the stream or our buffer is exhausted
let n = self.stream.read(b).await?;
// [!] read more than the length of our buffer
debug_assert!(n <= b.len());
// update the len of the read buffer to let the safe world that its okay
// to look at these bytes now
self.rbuf.advance_mut(n);
}
}
Ok(())
}
}

View File

@@ -0,0 +1,13 @@
use bytes::Bytes;
pub trait Deserialize<'de, Cx = ()>: Sized {
#[inline]
fn deserialize(buf: Bytes) -> crate::Result<Self>
where
Self: Deserialize<'de, ()>,
{
Self::deserialize_with(buf, ())
}
fn deserialize_with(buf: Bytes, context: Cx) -> crate::Result<Self>;
}

View File

@@ -0,0 +1,11 @@
pub trait Serialize<'ser, Cx = ()>: Sized {
#[inline]
fn serialize(&self, buf: &mut Vec<u8>) -> crate::Result<()>
where
Self: Serialize<'ser, ()>,
{
self.serialize_with(buf, ())
}
fn serialize_with(&self, buf: &mut Vec<u8>, context: Cx) -> crate::Result<()>;
}

34
sqlx-core/src/io/write.rs Normal file
View File

@@ -0,0 +1,34 @@
pub trait WriteExt {
fn write_str_nul(&mut self, s: &str);
fn write_maybe_str_nul(&mut self, s: Option<&str>);
}
impl WriteExt for Vec<u8> {
fn write_str_nul(&mut self, s: &str) {
self.reserve(s.len() + 1);
self.extend_from_slice(s.as_bytes());
self.push(0);
}
fn write_maybe_str_nul(&mut self, s: Option<&str>) {
if let Some(s) = s {
self.reserve(s.len() + 1);
self.extend_from_slice(s.as_bytes());
}
self.push(0);
}
}
#[cfg(test)]
mod tests {
use super::WriteExt;
#[test]
fn write_str() {
let mut buf = Vec::new();
buf.write_str_nul("this is a random dice roll");
assert_eq!(&buf, b"this is a random dice roll\0");
}
}

View File

@@ -33,6 +33,9 @@ mod error;
mod options;
mod runtime;
#[doc(hidden)]
pub mod io;
#[cfg(feature = "blocking")]
pub mod blocking;
@@ -40,19 +43,15 @@ pub use connection::Connection;
pub use database::{Database, HasOutput};
pub use error::{Error, Result};
pub use options::ConnectOptions;
pub use runtime::Runtime;
#[cfg(feature = "async-std")]
pub use runtime::AsyncStd;
#[cfg(feature = "tokio")]
pub use runtime::Tokio;
#[cfg(feature = "actix")]
pub use runtime::Actix;
#[cfg(feature = "async")]
pub use runtime::Async;
pub use runtime::AsyncRuntime;
#[cfg(feature = "async-std")]
pub use runtime::AsyncStd;
pub use runtime::Runtime;
#[cfg(feature = "tokio")]
pub use runtime::Tokio;
// pick a default runtime
// this is so existing applications in SQLx pre 0.6 work and to
@@ -78,11 +77,10 @@ pub type DefaultRuntime = blocking::Blocking;
pub type DefaultRuntime = ();
pub mod prelude {
#[cfg(all(not(feature = "async"), feature = "blocking"))]
pub use super::blocking::prelude::*;
pub use super::ConnectOptions as _;
pub use super::Connection as _;
pub use super::Database as _;
pub use super::Runtime as _;
#[cfg(all(not(feature = "async"), feature = "blocking"))]
pub use super::blocking::prelude::*;
}

View File

@@ -17,5 +17,6 @@ where
fn connect(&self) -> futures_util::future::BoxFuture<'_, crate::Result<Self::Connection>>
where
Self::Connection: Sized,
Rt: crate::Async;
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin;
}

View File

@@ -7,41 +7,35 @@ mod actix;
#[cfg(feature = "tokio")]
mod tokio;
#[cfg(feature = "async-std")]
pub use self::async_std::AsyncStd;
#[cfg(feature = "tokio")]
pub use self::tokio::Tokio;
#[cfg(feature = "actix")]
pub use self::actix::Actix;
#[cfg(feature = "async-std")]
pub use self::async_std::AsyncStd;
#[cfg(feature = "tokio")]
pub use self::tokio::Tokio;
/// Describes a set of types and functions used to open and manage
/// resources within SQLx.
pub trait Runtime: 'static + Send + Sync {
type TcpStream: Send;
}
#[cfg(feature = "async")]
pub trait AsyncRuntime: Runtime
where
Self::TcpStream: futures_io::AsyncRead,
{
/// Opens a TCP connection to a remote host at the specified port.
#[cfg(feature = "async")]
#[allow(unused_variables)]
fn connect_tcp(
host: &str,
port: u16,
) -> futures_util::future::BoxFuture<'_, std::io::Result<Self::TcpStream>>
where
Self: Async,
{
// re-implemented for async runtimes
// for sync runtimes, this cannot be implemented but the compiler
// with guarantee it won't be called
// see: https://github.com/rust-lang/rust/issues/48214
unimplemented!()
}
) -> futures_util::future::BoxFuture<'_, std::io::Result<Self::TcpStream>>;
}
/// Marker trait that identifies a `Runtime` as supporting asynchronous I/O.
#[cfg(feature = "async")]
pub trait Async: Runtime {}
pub trait AsyncRead {
fn read(&mut self, buf: &mut [u8]) -> futures_util::future::BoxFuture<'_, u64>;
}
// when the async feature is not specified, this is an empty trait
// we implement `()` for it to allow the lib to still compile

View File

@@ -1,9 +1,10 @@
use std::io;
use actix_rt::net::TcpStream;
use async_compat_02::Compat;
use futures_util::{future::BoxFuture, FutureExt};
use crate::{Async, Runtime};
use crate::{AsyncRuntime, Runtime};
/// Actix SQLx runtime. Uses [`actix-rt`][actix_rt] to provide [`Runtime`].
///
@@ -15,12 +16,15 @@ use crate::{Async, Runtime};
#[derive(Debug)]
pub struct Actix;
impl Async for Actix {}
impl Runtime for Actix {
type TcpStream = TcpStream;
type TcpStream = Compat<TcpStream>;
}
impl AsyncRuntime for Actix
where
Self::TcpStream: futures_io::AsyncRead,
{
fn connect_tcp(host: &str, port: u16) -> BoxFuture<'_, io::Result<Self::TcpStream>> {
TcpStream::connect((host, port)).boxed()
TcpStream::connect((host, port)).map_ok(Compat::new).boxed()
}
}

View File

@@ -3,18 +3,18 @@ use std::io;
use async_std::{net::TcpStream, task::block_on};
use futures_util::{future::BoxFuture, FutureExt};
use crate::{Async, Runtime};
use crate::{AsyncRuntime, Runtime};
/// [`async-std`](async_std) implementation of [`Runtime`].
#[cfg_attr(doc_cfg, doc(cfg(feature = "async-std")))]
#[derive(Debug)]
pub struct AsyncStd;
impl Async for AsyncStd {}
impl Runtime for AsyncStd {
type TcpStream = TcpStream;
}
impl AsyncRuntime for AsyncStd {
fn connect_tcp(host: &str, port: u16) -> BoxFuture<'_, io::Result<Self::TcpStream>> {
TcpStream::connect((host, port)).boxed()
}
@@ -23,6 +23,6 @@ impl Runtime for AsyncStd {
#[cfg(feature = "blocking")]
impl crate::blocking::Runtime for AsyncStd {
fn connect_tcp(host: &str, port: u16) -> io::Result<Self::TcpStream> {
block_on(<AsyncStd as Runtime>::connect_tcp(host, port))
block_on(<AsyncStd as AsyncRuntime>::connect_tcp(host, port))
}
}

View File

@@ -1,9 +1,10 @@
use std::io;
use futures_util::{future::BoxFuture, FutureExt};
use async_compat::Compat;
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use tokio::net::TcpStream;
use crate::{Async, Runtime};
use crate::{AsyncRuntime, Runtime};
/// Tokio SQLx runtime. Uses [`tokio`] to provide [`Runtime`].
///
@@ -13,12 +14,12 @@ use crate::{Async, Runtime};
#[derive(Debug)]
pub struct Tokio;
impl Async for Tokio {}
impl Runtime for Tokio {
type TcpStream = TcpStream;
type TcpStream = Compat<TcpStream>;
}
impl AsyncRuntime for Tokio {
fn connect_tcp(host: &str, port: u16) -> BoxFuture<'_, io::Result<Self::TcpStream>> {
TcpStream::connect((host, port)).boxed()
TcpStream::connect((host, port)).map_ok(Compat::new).boxed()
}
}