wip(mysql): connection/establish, WriteExt, BufExt, and settle on AsyncRuntime

This commit is contained in:
Ryan Leckey 2021-01-01 12:28:48 -08:00
parent 44c175bb19
commit 55a8e7ba29
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
39 changed files with 4693 additions and 156 deletions

2
.rustfmt.toml Normal file
View File

@ -0,0 +1,2 @@
use_small_heuristics = "Max"
group_imports = "StdExternalCrate"

1187
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,4 +4,5 @@ members = [
"sqlx-core",
"sqlx-mysql",
"sqlx",
"examples/quickstart"
]

2188
examples/quickstart/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,16 @@
[package]
name = "sqlx-example-quickstart"
version = "0.0.0"
license = "MIT OR Apache-2.0"
edition = "2018"
authors = [
"LaunchBadge <contact@launchbadge.com>"
]
[dependencies]
actix-web = "3.3.2"
anyhow = "1.0.36"
async-std = { version = "1.8.0", features = ["attributes"] }
#sqlx = { path = "../../sqlx", features = ["tokio", "mysql", "blocking", "async-std", "actix"] }
sqlx = { path = "../../sqlx", features = ["tokio", "mysql"] }
tokio = { version = "1.0.1", features = ["rt", "rt-multi-thread", "macros"] }

View File

@ -0,0 +1,35 @@
// #[async_std::main]
// async fn main() -> anyhow::Result<()> {
// let _stream = AsyncStd::connect_tcp("localhost", 5432).await?;
//
// Ok(())
// }
use sqlx::mysql::MySqlConnectOptions;
use sqlx::prelude::*;
// #[tokio::main]
// async fn main() -> anyhow::Result<()> {
// let mut conn = <MySqlConnection>::connect("mysql://").await?;
//
// Ok(())
// }
//
// #[async_std::main]
// async fn main() -> anyhow::Result<()> {
// let mut conn = <MySqlConnection>::builder()
// .host("loca%x91lhost")
// .port(20)
// .connect()
// .await?;
//
// Ok(())
// }
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut conn = <MySqlConnectOptions>::new().host("localhost").port(3306).connect().await?;
Ok(())
}

View File

@ -26,16 +26,22 @@ blocking = []
# abstract async feature
# not meant to be used directly
# activates several crates used in all async runtimes
async = ["futures-util"]
async = ["futures-util", "futures-io"]
# async runtimes
async-std = ["async", "_async-std"]
actix = ["async", "actix-rt", "tokio_02"]
tokio = ["async", "_tokio"]
actix = ["async", "actix-rt", "tokio_02", "async-compat_02"]
tokio = ["async", "_tokio", "async-compat"]
[dependencies]
actix-rt = { version = "1.1.1", optional = true }
_async-std = { version = "1.8.0", optional = true, package = "async-std" }
futures-util = { version = "0.3.8", optional = true }
_tokio = { version = "1.0.1", optional = true, package = "tokio", features = ["net"] }
tokio_02 = { version = "0.2.24", optional = true, package = "tokio", features = ["net"] }
actix-rt = { version = "1.1", optional = true }
_async-std = { version = "1.8", optional = true, package = "async-std" }
futures-util = { version = "0.3", optional = true, features = ["io"] }
_tokio = { version = "1.0", optional = true, package = "tokio", features = ["net"] }
tokio_02 = { version = "0.2", optional = true, package = "tokio", features = ["net"] }
async-compat = { version = "*", git = "https://github.com/taiki-e/async-compat", branch = "tokio1", optional = true }
async-compat_02 = { version = "0.1", optional = true, package = "async-compat" }
futures-io = { version = "0.3", optional = true }
bytes = "1.0"
string = { version = "0.2.1", default-features = false }
memchr = "2.3"

View File

@ -10,9 +10,8 @@ pub use options::ConnectOptions;
pub use runtime::{Blocking, Runtime};
pub mod prelude {
pub use crate::Database as _;
pub use super::ConnectOptions as _;
pub use super::Connection as _;
pub use super::Runtime as _;
pub use crate::Database as _;
}

View File

@ -20,8 +20,7 @@ where
where
Self: Sized,
{
url.parse::<<Self as crate::Connection<Rt>>::Options>()?
.connect()
url.parse::<<Self as crate::Connection<Rt>>::Options>()?.connect()
}
/// Explicitly close this database connection.

View File

@ -1,8 +1,8 @@
use crate::{ConnectOptions, Database, DefaultRuntime, Runtime};
#[cfg(feature = "async")]
use futures_util::future::BoxFuture;
use crate::{ConnectOptions, Database, DefaultRuntime, Runtime};
/// A unique connection (session) with a specific database.
///
/// With a client/server model, this is equivalent to a network connection
@ -49,7 +49,8 @@ where
fn connect(url: &str) -> BoxFuture<'_, crate::Result<Self>>
where
Self: Sized,
Rt: crate::Async,
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
{
let options = url.parse::<Self::Options>();
Box::pin(async move { options?.connect().await })
@ -64,7 +65,8 @@ where
#[cfg(feature = "async")]
fn close(self) -> BoxFuture<'static, crate::Result<()>>
where
Rt: crate::Async;
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin;
/// Checks if a connection to the database is still valid.
///
@ -76,5 +78,6 @@ where
#[cfg(feature = "async")]
fn ping(&mut self) -> BoxFuture<'_, crate::Result<()>>
where
Rt: crate::Async;
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin;
}

View File

@ -7,10 +7,7 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
Configuration {
message: Cow<'static, str>,
source: Option<Box<dyn StdError + Send + Sync>>,
},
Configuration { message: Cow<'static, str>, source: Option<Box<dyn StdError + Send + Sync>> },
Network(std::io::Error),
}
@ -21,18 +18,12 @@ impl Error {
message: impl Into<Cow<'static, str>>,
source: impl Into<Box<dyn StdError + Send + Sync>>,
) -> Self {
Self::Configuration {
message: message.into(),
source: Some(source.into()),
}
Self::Configuration { message: message.into(), source: Some(source.into()) }
}
#[doc(hidden)]
pub fn configuration_msg(message: impl Into<Cow<'static, str>>) -> Self {
Self::Configuration {
message: message.into(),
source: None,
}
Self::Configuration { message: message.into(), source: None }
}
}
@ -41,15 +32,11 @@ impl Display for Error {
match self {
Self::Network(source) => write!(f, "network: {}", source),
Self::Configuration {
message,
source: None,
} => write!(f, "configuration: {}", message),
Self::Configuration { message, source: None } => {
write!(f, "configuration: {}", message)
}
Self::Configuration {
message,
source: Some(source),
} => {
Self::Configuration { message, source: Some(source) } => {
write!(f, "configuration: {}: {}", message, source)
}
}
@ -59,10 +46,7 @@ impl Display for Error {
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Self::Configuration {
source: Some(source),
..
} => Some(&**source),
Self::Configuration { source: Some(source), .. } => Some(&**source),
Self::Network(source) => Some(source),
@ -76,3 +60,9 @@ impl From<std::io::Error> for Error {
Error::Network(error)
}
}
impl From<std::io::ErrorKind> for Error {
fn from(error: std::io::ErrorKind) -> Self {
Error::Network(error.into())
}
}

11
sqlx-core/src/io.rs Normal file
View File

@ -0,0 +1,11 @@
mod buf;
mod write;
mod buf_stream;
mod deserialize;
mod serialize;
pub use buf::BufExt;
pub use write::WriteExt;
pub use buf_stream::BufStream;
pub use deserialize::Deserialize;
pub use serialize::Serialize;

72
sqlx-core/src/io/buf.rs Normal file
View File

@ -0,0 +1,72 @@
use std::io;
use bytes::{Bytes, Buf};
use memchr::memchr;
use string::String;
// UNSAFE: _unchecked string methods
// intended for use when the protocol is *known* to always produce
// valid UTF-8 data
pub trait BufExt: Buf {
#[allow(unsafe_code)]
unsafe fn get_str_unchecked(&mut self, n: usize) -> String<Bytes>;
#[allow(unsafe_code)]
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<String<Bytes>>;
}
impl BufExt for Bytes {
#[allow(unsafe_code)]
unsafe fn get_str_unchecked(&mut self, n: usize) -> String<Bytes> {
String::from_utf8_unchecked(self.split_to(n))
}
#[allow(unsafe_code)]
unsafe fn get_str_nul_unchecked(&mut self) -> io::Result<String<Bytes>> {
let nul = memchr(b'\0', self).ok_or(io::ErrorKind::InvalidData)?;
Ok(String::from_utf8_unchecked(self.split_to(nul + 1).slice(..nul)))
}
}
#[cfg(test)]
mod tests {
use std::io;
use bytes::{Bytes, Buf};
use super::BufExt;
#[test]
fn test_get_str() {
let mut buf = Bytes::from_static(b"Hello World\0");
#[allow(unsafe_code)]
let s = unsafe { buf.get_str_unchecked(5) };
buf.advance(1);
#[allow(unsafe_code)]
let s2 = unsafe { buf.get_str_unchecked(5) };
assert_eq!(&s, "Hello");
assert_eq!(&s2, "World");
}
#[test]
fn test_get_str_nul() -> io::Result<()> {
let mut buf = Bytes::from_static(b"Hello\0 World\0");
#[allow(unsafe_code)]
let s = unsafe { buf.get_str_nul_unchecked()? };
buf.advance(1);
#[allow(unsafe_code)]
let s2 = unsafe { buf.get_str_nul_unchecked()? };
assert_eq!(&s, "Hello");
assert_eq!(&s2, "World");
Ok(())
}
}

View File

@ -0,0 +1,92 @@
#[cfg(feature = "blocking")]
use std::io::{Read, Write};
use std::slice::from_raw_parts_mut;
use bytes::{BufMut, Bytes, BytesMut};
#[cfg(feature = "async")]
use futures_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "async")]
use futures_util::{AsyncReadExt, AsyncWriteExt};
/// Wraps a stream and buffers input and output to and from it.
///
/// It can be excessively inefficient to work directly with a `Read` or `Write`. For example,
/// every call to `read` or `write` on `TcpStream` results in a system call (leading to
/// a network interaction). `BufStream` keeps a read and write buffer with infrequent calls
/// to `read` and `write` on the underlying stream.
///
pub struct BufStream<S> {
stream: S,
// (r)ead buffer
rbuf: BytesMut,
// (w)rite buffer
wbuf: Vec<u8>,
// offset into [wbuf] that a previous write operation has written into
wbuf_offset: usize,
}
impl<S> BufStream<S> {
pub fn with_capacity(stream: S, read: usize, write: usize) -> Self {
Self {
stream,
rbuf: BytesMut::with_capacity(read),
wbuf: Vec::with_capacity(write),
wbuf_offset: 0,
}
}
pub fn get(&self, offset: usize, n: usize) -> &[u8] {
&(self.rbuf.as_ref())[offset..(offset + n)]
}
pub fn take(&mut self, n: usize) -> Bytes {
self.rbuf.split_to(n).freeze()
}
pub fn consume(&mut self, n: usize) {
let _ = self.take(n);
}
}
#[cfg(feature = "async")]
impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn read_async(&mut self, n: usize) -> crate::Result<()> {
// // before waiting to receive data
// // ensure that the write buffer is flushed
// if !self.wbuf.is_empty() {
// self.flush().await?;
// }
// while our read buffer is too small to satisfy the requested amount
while self.rbuf.len() < n {
// ensure that there is room in the read buffer
self.rbuf.reserve(n.max(128));
#[allow(unsafe_code)]
unsafe {
// prepare a chunk of uninitialized memory to write to
// this is UB if the Read impl of the stream reads from the write buffer
let b = self.rbuf.chunk_mut();
let b = from_raw_parts_mut(b.as_mut_ptr(), b.len());
// read as much as we can and return when the stream or our buffer is exhausted
let n = self.stream.read(b).await?;
// [!] read more than the length of our buffer
debug_assert!(n <= b.len());
// update the len of the read buffer to let the safe world that its okay
// to look at these bytes now
self.rbuf.advance_mut(n);
}
}
Ok(())
}
}

View File

@ -0,0 +1,13 @@
use bytes::Bytes;
pub trait Deserialize<'de, Cx = ()>: Sized {
#[inline]
fn deserialize(buf: Bytes) -> crate::Result<Self>
where
Self: Deserialize<'de, ()>,
{
Self::deserialize_with(buf, ())
}
fn deserialize_with(buf: Bytes, context: Cx) -> crate::Result<Self>;
}

View File

@ -0,0 +1,11 @@
pub trait Serialize<'ser, Cx = ()>: Sized {
#[inline]
fn serialize(&self, buf: &mut Vec<u8>) -> crate::Result<()>
where
Self: Serialize<'ser, ()>,
{
self.serialize_with(buf, ())
}
fn serialize_with(&self, buf: &mut Vec<u8>, context: Cx) -> crate::Result<()>;
}

34
sqlx-core/src/io/write.rs Normal file
View File

@ -0,0 +1,34 @@
pub trait WriteExt {
fn write_str_nul(&mut self, s: &str);
fn write_maybe_str_nul(&mut self, s: Option<&str>);
}
impl WriteExt for Vec<u8> {
fn write_str_nul(&mut self, s: &str) {
self.reserve(s.len() + 1);
self.extend_from_slice(s.as_bytes());
self.push(0);
}
fn write_maybe_str_nul(&mut self, s: Option<&str>) {
if let Some(s) = s {
self.reserve(s.len() + 1);
self.extend_from_slice(s.as_bytes());
}
self.push(0);
}
}
#[cfg(test)]
mod tests {
use super::WriteExt;
#[test]
fn write_str() {
let mut buf = Vec::new();
buf.write_str_nul("this is a random dice roll");
assert_eq!(&buf, b"this is a random dice roll\0");
}
}

View File

@ -33,6 +33,9 @@ mod error;
mod options;
mod runtime;
#[doc(hidden)]
pub mod io;
#[cfg(feature = "blocking")]
pub mod blocking;
@ -40,19 +43,15 @@ pub use connection::Connection;
pub use database::{Database, HasOutput};
pub use error::{Error, Result};
pub use options::ConnectOptions;
pub use runtime::Runtime;
#[cfg(feature = "async-std")]
pub use runtime::AsyncStd;
#[cfg(feature = "tokio")]
pub use runtime::Tokio;
#[cfg(feature = "actix")]
pub use runtime::Actix;
#[cfg(feature = "async")]
pub use runtime::Async;
pub use runtime::AsyncRuntime;
#[cfg(feature = "async-std")]
pub use runtime::AsyncStd;
pub use runtime::Runtime;
#[cfg(feature = "tokio")]
pub use runtime::Tokio;
// pick a default runtime
// this is so existing applications in SQLx pre 0.6 work and to
@ -78,11 +77,10 @@ pub type DefaultRuntime = blocking::Blocking;
pub type DefaultRuntime = ();
pub mod prelude {
#[cfg(all(not(feature = "async"), feature = "blocking"))]
pub use super::blocking::prelude::*;
pub use super::ConnectOptions as _;
pub use super::Connection as _;
pub use super::Database as _;
pub use super::Runtime as _;
#[cfg(all(not(feature = "async"), feature = "blocking"))]
pub use super::blocking::prelude::*;
}

View File

@ -17,5 +17,6 @@ where
fn connect(&self) -> futures_util::future::BoxFuture<'_, crate::Result<Self::Connection>>
where
Self::Connection: Sized,
Rt: crate::Async;
Rt: crate::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin;
}

View File

@ -7,41 +7,35 @@ mod actix;
#[cfg(feature = "tokio")]
mod tokio;
#[cfg(feature = "async-std")]
pub use self::async_std::AsyncStd;
#[cfg(feature = "tokio")]
pub use self::tokio::Tokio;
#[cfg(feature = "actix")]
pub use self::actix::Actix;
#[cfg(feature = "async-std")]
pub use self::async_std::AsyncStd;
#[cfg(feature = "tokio")]
pub use self::tokio::Tokio;
/// Describes a set of types and functions used to open and manage
/// resources within SQLx.
pub trait Runtime: 'static + Send + Sync {
type TcpStream: Send;
}
#[cfg(feature = "async")]
pub trait AsyncRuntime: Runtime
where
Self::TcpStream: futures_io::AsyncRead,
{
/// Opens a TCP connection to a remote host at the specified port.
#[cfg(feature = "async")]
#[allow(unused_variables)]
fn connect_tcp(
host: &str,
port: u16,
) -> futures_util::future::BoxFuture<'_, std::io::Result<Self::TcpStream>>
where
Self: Async,
{
// re-implemented for async runtimes
// for sync runtimes, this cannot be implemented but the compiler
// with guarantee it won't be called
// see: https://github.com/rust-lang/rust/issues/48214
unimplemented!()
}
) -> futures_util::future::BoxFuture<'_, std::io::Result<Self::TcpStream>>;
}
/// Marker trait that identifies a `Runtime` as supporting asynchronous I/O.
#[cfg(feature = "async")]
pub trait Async: Runtime {}
pub trait AsyncRead {
fn read(&mut self, buf: &mut [u8]) -> futures_util::future::BoxFuture<'_, u64>;
}
// when the async feature is not specified, this is an empty trait
// we implement `()` for it to allow the lib to still compile

View File

@ -1,9 +1,10 @@
use std::io;
use actix_rt::net::TcpStream;
use async_compat_02::Compat;
use futures_util::{future::BoxFuture, FutureExt};
use crate::{Async, Runtime};
use crate::{AsyncRuntime, Runtime};
/// Actix SQLx runtime. Uses [`actix-rt`][actix_rt] to provide [`Runtime`].
///
@ -15,12 +16,15 @@ use crate::{Async, Runtime};
#[derive(Debug)]
pub struct Actix;
impl Async for Actix {}
impl Runtime for Actix {
type TcpStream = TcpStream;
type TcpStream = Compat<TcpStream>;
}
impl AsyncRuntime for Actix
where
Self::TcpStream: futures_io::AsyncRead,
{
fn connect_tcp(host: &str, port: u16) -> BoxFuture<'_, io::Result<Self::TcpStream>> {
TcpStream::connect((host, port)).boxed()
TcpStream::connect((host, port)).map_ok(Compat::new).boxed()
}
}

View File

@ -3,18 +3,18 @@ use std::io;
use async_std::{net::TcpStream, task::block_on};
use futures_util::{future::BoxFuture, FutureExt};
use crate::{Async, Runtime};
use crate::{AsyncRuntime, Runtime};
/// [`async-std`](async_std) implementation of [`Runtime`].
#[cfg_attr(doc_cfg, doc(cfg(feature = "async-std")))]
#[derive(Debug)]
pub struct AsyncStd;
impl Async for AsyncStd {}
impl Runtime for AsyncStd {
type TcpStream = TcpStream;
}
impl AsyncRuntime for AsyncStd {
fn connect_tcp(host: &str, port: u16) -> BoxFuture<'_, io::Result<Self::TcpStream>> {
TcpStream::connect((host, port)).boxed()
}
@ -23,6 +23,6 @@ impl Runtime for AsyncStd {
#[cfg(feature = "blocking")]
impl crate::blocking::Runtime for AsyncStd {
fn connect_tcp(host: &str, port: u16) -> io::Result<Self::TcpStream> {
block_on(<AsyncStd as Runtime>::connect_tcp(host, port))
block_on(<AsyncStd as AsyncRuntime>::connect_tcp(host, port))
}
}

View File

@ -1,9 +1,10 @@
use std::io;
use futures_util::{future::BoxFuture, FutureExt};
use async_compat::Compat;
use futures_util::{future::BoxFuture, FutureExt, TryFutureExt};
use tokio::net::TcpStream;
use crate::{Async, Runtime};
use crate::{AsyncRuntime, Runtime};
/// Tokio SQLx runtime. Uses [`tokio`] to provide [`Runtime`].
///
@ -13,12 +14,12 @@ use crate::{Async, Runtime};
#[derive(Debug)]
pub struct Tokio;
impl Async for Tokio {}
impl Runtime for Tokio {
type TcpStream = TcpStream;
type TcpStream = Compat<TcpStream>;
}
impl AsyncRuntime for Tokio {
fn connect_tcp(host: &str, port: u16) -> BoxFuture<'_, io::Result<Self::TcpStream>> {
TcpStream::connect((host, port)).boxed()
TcpStream::connect((host, port)).map_ok(Compat::new).boxed()
}
}

View File

@ -24,7 +24,7 @@ blocking = ["sqlx-core/blocking"]
# async runtime
# not meant to be used directly
async = ["futures-util", "sqlx-core/async"]
async = ["futures-util", "sqlx-core/async", "futures-io"]
[dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" }
@ -32,3 +32,8 @@ futures-util = { version = "0.3.8", optional = true }
either = "1.6.1"
url = "2.2.0"
percent-encoding = "2.1.0"
futures-io = { version = "0.3", optional = true }
bytes = "1.0"
memchr = "2.3"
bitflags = "1.2"
string = { version = "0.2.1", default-features = false }

View File

@ -9,8 +9,9 @@ where
Self::Connection: sqlx_core::Connection<Rt, Options = Self> + Connection<Rt>,
{
fn connect(&self) -> Result<MySqlConnection<Rt>> {
let stream = <Rt as Runtime>::connect_tcp(self.get_host(), self.get_port())?;
Ok(MySqlConnection { stream })
// let stream = <Rt as Runtime>::connect_tcp(self.get_host(), self.get_port())?;
//
// Ok(MySqlConnection { stream })
todo!()
}
}

View File

@ -1,14 +1,46 @@
use std::fmt::{self, Debug, Formatter};
use sqlx_core::io::BufStream;
use sqlx_core::{Connection, DefaultRuntime, Runtime};
use crate::protocol::Capabilities;
use crate::{MySql, MySqlConnectOptions};
#[cfg(feature = "async")]
pub(crate) mod establish;
pub struct MySqlConnection<Rt = DefaultRuntime>
where
Rt: Runtime,
{
pub(crate) stream: Rt::TcpStream,
stream: BufStream<Rt::TcpStream>,
connection_id: u32,
capabilities: Capabilities,
}
impl<Rt> MySqlConnection<Rt>
where
Rt: Runtime,
{
pub(crate) fn new(stream: Rt::TcpStream) -> Self {
Self {
stream: BufStream::with_capacity(stream, 4096, 1024),
connection_id: 0,
capabilities: Capabilities::LONG_PASSWORD
| Capabilities::LONG_FLAG
| Capabilities::IGNORE_SPACE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF,
}
}
}
impl<Rt> Debug for MySqlConnection<Rt>
@ -31,7 +63,8 @@ where
#[cfg(feature = "async")]
fn close(self) -> futures_util::future::BoxFuture<'static, sqlx_core::Result<()>>
where
Rt: sqlx_core::Async,
Rt: sqlx_core::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
{
unimplemented!()
}
@ -39,7 +72,8 @@ where
#[cfg(feature = "async")]
fn ping(&mut self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<()>>
where
Rt: sqlx_core::Async,
Rt: sqlx_core::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
{
unimplemented!()
}

View File

@ -0,0 +1,56 @@
use bytes::Buf;
use futures_io::{AsyncRead, AsyncWrite};
use sqlx_core::io::{BufStream, Deserialize};
use sqlx_core::{AsyncRuntime, Result, Runtime};
use crate::protocol::Handshake;
use crate::{MySqlConnectOptions, MySqlConnection};
// https://dev.mysql.com/doc/internals/en/connection-phase.html
// the connection phase (establish) performs these tasks:
// - exchange the capabilities of client and server
// - setup SSL communication channel if requested
// - authenticate the client against the server
// the server may immediately send an ERR packet and finish the handshake
// or send a [InitialHandshake]
impl<Rt> MySqlConnection<Rt>
where
Rt: AsyncRuntime,
<Rt as Runtime>::TcpStream: Unpin + AsyncWrite + AsyncRead,
{
pub(crate) async fn establish_async(options: &MySqlConnectOptions<Rt>) -> Result<Self> {
let stream = Rt::connect_tcp(options.get_host(), options.get_port()).await?;
let mut self_ = Self::new(stream);
// FIXME: Handle potential ERR packet here
let handshake = self_.read_packet_async::<Handshake>().await?;
println!("{:#?}", handshake);
Ok(self_)
}
async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de>,
{
// https://dev.mysql.com/doc/internals/en/mysql-packet.html
self.stream.read_async(4).await?;
let payload_len: usize = self.stream.get(0, 3).get_int_le(3) as usize;
// FIXME: handle split packets
assert_ne!(payload_len, 0xFF_FF_FF);
let _seq_no = self.stream.get(3, 1).get_i8();
self.stream.read_async(4 + payload_len).await?;
self.stream.consume(4);
let payload = self.stream.take(payload_len);
T::deserialize(payload)
}
}

5
sqlx-mysql/src/io.rs Normal file
View File

@ -0,0 +1,5 @@
mod write;
mod buf;
pub(crate) use write::MySqlWriteExt;
pub(crate) use buf::MySqlBufExt;

49
sqlx-mysql/src/io/buf.rs Normal file
View File

@ -0,0 +1,49 @@
use bytes::{Bytes, Buf};
use string::String;
use sqlx_core::io::BufExt;
// UNSAFE: _unchecked string methods
// intended for use when the protocol is *known* to always produce
// valid UTF-8 data
pub(crate) trait MySqlBufExt: BufExt {
fn get_uint_lenenc(&mut self) -> u64;
#[allow(unsafe_code)]
unsafe fn get_str_lenenc_unchecked(&mut self) -> String<Bytes>;
fn get_bytes_lenenc(&mut self) -> Bytes;
}
impl MySqlBufExt for Bytes {
fn get_uint_lenenc(&mut self) -> u64 {
// https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
match self.get_u8() {
// NOTE: 0xFB represents NULL in TextResultRow
0xfb => unreachable!("unexpected 0xFB (NULL) in `get_uint_lenenc`"),
0xfc => u64::from(self.get_u16_le()),
0xfd => self.get_uint_le(3),
0xfe => self.get_u64_le(),
// NOTE: 0xFF may be the first byte of an ERR packet
0xff => unreachable!("unexpected 0xFF (undefined) in `get_uint_lenenc`"),
value => u64::from(value)
}
}
#[allow(unsafe_code)]
unsafe fn get_str_lenenc_unchecked(&mut self) -> String<Bytes> {
let len = self.get_uint_lenenc() as usize;
self.get_str_unchecked(len)
}
fn get_bytes_lenenc(&mut self) -> Bytes {
let len = self.get_uint_lenenc() as usize;
self.split_to(len)
}
}

137
sqlx-mysql/src/io/write.rs Normal file
View File

@ -0,0 +1,137 @@
pub(crate) trait MySqlWriteExt: sqlx_core::io::WriteExt {
fn write_uint_lenenc(&mut self, value: u64);
fn write_str_lenenc(&mut self, value: &str);
fn write_bytes_lenenc(&mut self, value: &[u8]);
}
impl MySqlWriteExt for Vec<u8> {
fn write_uint_lenenc(&mut self, value: u64) {
// https://dev.mysql.com/doc/internals/en/integer.html
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
if value < 251 {
// if the value is < 251, it is stored as a 1-byte integer
self.push(value as u8);
} else if value < 0x1_00_00 {
// if the value is ≥ 251 and < (2 ** 16), it is stored as fc + 2-byte integer
self.reserve(3);
self.push(0xfc);
self.extend_from_slice(&(value as u16).to_le_bytes());
} else if value < 0x1_00_00_00 {
// if the value is ≥ (2 ** 16) and < (2 ** 24), it is stored as fd + 3-byte integer
self.reserve(4);
self.push(0xfd);
self.extend_from_slice(&(value as u32).to_le_bytes()[..3]);
} else {
// if the value is ≥ (2 ** 24) and < (2 ** 64) it is stored as fe + 8-byte integer
self.reserve(9);
self.push(0xfe);
self.extend_from_slice(&value.to_le_bytes());
}
}
#[inline]
fn write_str_lenenc(&mut self, value: &str) {
self.write_bytes_lenenc(value.as_bytes());
}
fn write_bytes_lenenc(&mut self, value: &[u8]) {
self.write_uint_lenenc(value.len() as u64);
self.extend_from_slice(value);
}
}
#[cfg(test)]
mod tests {
use super::MySqlWriteExt;
#[test]
fn write_int_lenenc_u8() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFA as u64);
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn write_int_lenenc_u16() {
let mut buf = Vec::new();
buf.write_uint_lenenc(std::u16::MAX as u64);
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn write_int_lenenc_u24() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFF_FF_FF as u64);
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn write_int_lenenc_u64() {
let mut buf = Vec::new();
buf.write_uint_lenenc(std::u64::MAX);
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn write_int_lenenc_fb() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFB as u64);
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn write_int_lenenc_fc() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFC as u64);
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn write_int_lenenc_fd() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFD as u64);
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn write_int_lenenc_fe() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFE as u64);
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
#[test]
fn write_int_lenenc_ff() {
let mut buf = Vec::new();
buf.write_uint_lenenc(0xFF as u64);
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn write_string_lenenc() {
let mut buf = Vec::new();
buf.write_str_lenenc("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn write_byte_lenenc() {
let mut buf = Vec::new();
buf.write_bytes_lenenc(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
}

View File

@ -21,7 +21,9 @@
mod connection;
mod database;
mod io;
mod options;
mod protocol;
#[cfg(feature = "blocking")]
mod blocking;

View File

@ -83,20 +83,12 @@ where
{
/// Returns the hostname of the database server.
pub fn get_host(&self) -> &str {
self.address
.as_ref()
.left()
.map(|(host, _)| &**host)
.unwrap_or(default::HOST)
self.address.as_ref().left().map(|(host, _)| &**host).unwrap_or(default::HOST)
}
/// Returns the TCP port number of the database server.
pub fn get_port(&self) -> u16 {
self.address
.as_ref()
.left()
.map(|(_, port)| *port)
.unwrap_or(default::PORT)
self.address.as_ref().left().map(|(_, port)| *port).unwrap_or(default::PORT)
}
/// Returns the path to the Unix domain socket, if one is configured.
@ -140,12 +132,9 @@ where
fn connect(&self) -> futures_util::future::BoxFuture<'_, sqlx_core::Result<Self::Connection>>
where
Self::Connection: Sized,
Rt: sqlx_core::Async,
Rt: sqlx_core::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
{
futures_util::FutureExt::boxed(async move {
let stream = Rt::connect_tcp(self.get_host(), self.get_port()).await?;
Ok(MySqlConnection { stream })
})
futures_util::FutureExt::boxed(MySqlConnection::establish_async(self))
}
}

View File

@ -13,9 +13,7 @@ where
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url: Url = s
.parse()
.map_err(|error| Error::configuration("database url", error))?;
let url: Url = s.parse().map_err(|error| Error::configuration("database url", error))?;
if !matches!(url.scheme(), "mysql") {
return Err(Error::configuration_msg(format!(
@ -36,17 +34,11 @@ where
let username = url.username();
if !username.is_empty() {
options.username(percent_decode_str_utf8(
username,
"username in database url",
)?);
options.username(percent_decode_str_utf8(username, "username in database url")?);
}
if let Some(password) = url.password() {
options.password(percent_decode_str_utf8(
password,
"password in database url",
)?);
options.password(percent_decode_str_utf8(password, "password in database url")?);
}
let mut path = url.path();
@ -113,11 +105,12 @@ fn percent_decode_str_utf8(value: &str, context: &str) -> Result<String, Error>
#[cfg(test)]
mod tests {
use super::MySqlConnectOptions;
use std::path::Path;
use super::MySqlConnectOptions;
#[test]
fn it_should_parse() {
fn parse() {
let url = "mysql://user:password@hostname:5432/database?timezone=system&charset=utf8";
let options: MySqlConnectOptions = url.parse().unwrap();
@ -131,7 +124,7 @@ mod tests {
}
#[test]
fn it_should_parse_with_defaults() {
fn parse_with_defaults() {
let url = "mysql://";
let options: MySqlConnectOptions = url.parse().unwrap();
@ -145,21 +138,18 @@ mod tests {
}
#[test]
fn it_should_parse_socket_from_query() {
fn parse_socket_from_query() {
let url = "mysql://user:password@localhost/database?socket=/var/run/mysqld/mysqld.sock";
let options: MySqlConnectOptions = url.parse().unwrap();
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_database(), Some("database"));
assert_eq!(
options.get_socket(),
Some(Path::new("/var/run/mysqld/mysqld.sock"))
);
assert_eq!(options.get_socket(), Some(Path::new("/var/run/mysqld/mysqld.sock")));
}
#[test]
fn it_should_parse_socket_from_host() {
fn parse_socket_from_host() {
// socket path in host requires URL encoding but does work
let url = "mysql://user:password@%2Fvar%2Frun%2Fmysqld%2Fmysqld.sock/database";
let options: MySqlConnectOptions = url.parse().unwrap();
@ -167,21 +157,18 @@ mod tests {
assert_eq!(options.get_username(), Some("user"));
assert_eq!(options.get_password(), Some("password"));
assert_eq!(options.get_database(), Some("database"));
assert_eq!(
options.get_socket(),
Some(Path::new("/var/run/mysqld/mysqld.sock"))
);
assert_eq!(options.get_socket(), Some(Path::new("/var/run/mysqld/mysqld.sock")));
}
#[test]
#[should_panic]
fn it_should_fail_to_parse_non_mysql() {
fn fail_to_parse_non_mysql() {
let url = "postgres://user:password@hostname:5432/database?timezone=system&charset=utf8";
let _: MySqlConnectOptions = url.parse().unwrap();
}
#[test]
fn it_should_parse_username_with_at_sign() {
fn parse_username_with_at_sign() {
let url = "mysql://user@hostname:password@hostname:5432/database";
let options: MySqlConnectOptions = url.parse().unwrap();
@ -189,7 +176,7 @@ mod tests {
}
#[test]
fn it_should_parse_password_with_non_ascii_chars() {
fn parse_password_with_non_ascii_chars() {
let url = "mysql://username:p@ssw0rd@hostname:5432/database";
let options: MySqlConnectOptions = url.parse().unwrap();

View File

@ -0,0 +1,9 @@
mod capabilities;
mod handshake;
mod handshake_response;
mod status;
pub(crate) use capabilities::Capabilities;
pub(crate) use handshake::Handshake;
pub(crate) use handshake_response::HandshakeResponse;
pub(crate) use status::ServerStatus;

View File

@ -0,0 +1,80 @@
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__capabilities__flags.html
// https://mariadb.com/kb/en/library/connection/#capabilities
bitflags::bitflags! {
pub struct Capabilities: u64 {
// use the improved version of "old password auth"
// assumed to be set since 4.1
const LONG_PASSWORD = 0x00000001;
// send found (read: matched) rows instead of affected rows in the EOF packet
const FOUND_ROWS = 0x00000002;
// longer flags for column metadata
// not used if PROTOCOL_41 is used (long flags are always received)
const LONG_FLAG = 0x00000004;
// database (schema) name can be specified on connect in Handshake Response Packet
const CONNECT_WITH_DB = 0x00000008;
// do not permit `database.table.column`
const NO_SCHEMA = 0x00000010;
// compression protocol supported
// todo: expose in MySqlConnectOptions
const COMPRESS = 0x00000020;
// legacy flag to enable special ODBC handling
// no handling since MySQL v3.22
const ODBC = 0x00000040;
// enable LOAD DATA LOCAL
const LOCAL_FILES = 0x00000080;
// SQL parser can ignore spaces before '('
const IGNORE_SPACE = 0x00000100;
// uses the 4.1+ protocol
const PROTOCOL_41 = 0x00000200;
// this is an interactive client
// wait_timeout versus wait_interactive_timeout.
const INTERACTIVE = 0x00000400;
// use SSL encryption for this session
const SSL = 0x00000800;
// EOF packets will contain transaction status flags
const TRANSACTIONS = 0x00002000;
// support native 4.1+ authentication
const SECURE_CONNECTION = 0x00008000;
// can handle multiple statements in COM_QUERY and COM_STMT_PREPARE
const MULTI_STATEMENTS = 0x00010000;
// can send multiple result sets for COM_QUERY
const MULTI_RESULTS = 0x00020000;
// can send multiple result sets for COM_STMT_EXECUTE
const PS_MULTI_RESULTS = 0x00040000;
// supports authentication plugins
const PLUGIN_AUTH = 0x00080000;
// permits connection attributes
const CONNECT_ATTRS = 0x00100000;
// enable authentication response packet to be larger than 255 bytes.
const PLUGIN_AUTH_LENENC_DATA = 0x00200000;
// can handle connection for a user account with expired passwords
const CAN_HANDLE_EXPIRED_PASSWORDS = 0x00400000;
// capable of handling server state change information in an OK packet
const SESSION_TRACK = 0x00800000;
// client no longer needs EOF_Packet and will use OK_Packet instead.
const DEPRECATE_EOF = 0x01000000;
}
}

View File

@ -0,0 +1,429 @@
use bytes::buf::Chain;
use bytes::{Buf, Bytes};
use memchr::memchr;
use sqlx_core::io::{BufExt, Deserialize};
use sqlx_core::Result;
use crate::protocol::{Capabilities, ServerStatus};
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeV10
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
#[derive(Debug)]
pub(crate) struct Handshake {
// (0x0a) protocol version
pub(crate) protocol_version: u8,
// human-readable server version
pub(crate) server_version: string::String<Bytes>,
pub(crate) connection_id: u32,
pub(crate) capabilities: Capabilities,
pub(crate) status: ServerStatus,
// default server character set
pub(crate) charset: Option<u8>,
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
// name of the auth_method that the auth_plugin_data belongs to
pub(crate) auth_plugin_name: Option<string::String<Bytes>>,
}
impl Deserialize<'_> for Handshake {
fn deserialize_with(mut buf: Bytes, _: ()) -> Result<Self> {
println!("{:?}", buf);
let protocol_version = buf.get_u8();
// UNSAFE: server version is known to be ASCII
#[allow(unsafe_code)]
let server_version = unsafe { buf.get_str_nul_unchecked()? };
let connection_id = buf.get_u32_le();
// first 8 bytes of the auth-plugin data
let auth_plugin_data_1 = buf.split_to(8);
buf.advance(1); // filler [00]
let mut capabilities = Capabilities::from_bits_truncate(buf.get_u16_le().into());
// from this point on, all additional packet fields are **optional**
// the packet payload can end at any time
let charset = if buf.is_empty() { None } else { Some(buf.get_u8()) };
let status = if buf.is_empty() {
ServerStatus::empty()
} else {
ServerStatus::from_bits_truncate(buf.get_u16_le())
};
if !buf.is_empty() {
// upper 2 bytes of the capabilities flags
capabilities |= Capabilities::from_bits_truncate(u64::from(buf.get_u16_le()) << 16);
}
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.get_u8()
} else {
// a single 0 byte, if present
if !buf.is_empty() {
buf.advance(1);
}
0
};
if buf.len() >= 10 {
// reserved (10, 0 bytes)
buf.advance(10);
}
let mut auth_plugin_data_2 = Bytes::new();
let mut auth_plugin_name = None;
if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let len = (if auth_plugin_data_len > 8 { auth_plugin_data_len - 8 } else { 0 }).max(13);
auth_plugin_data_2 = buf.split_to(len as usize);
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// due to Bug#59453 the auth-plugin-name is missing the terminating NUL-char
// in versions prior to 5.5.10 and 5.6.2
// ref: https://bugs.mysql.com/bug.php?id=59453
// read to NUL or read to the end if we can't find a NUL
let auth_plugin_name_end =
memchr(b'\0', &buf).map(|end| end - 1).unwrap_or(buf.len());
// UNSAFE: auth plugin names are known to be ASCII
#[allow(unsafe_code)]
let auth_plugin_name_ =
unsafe { Some(buf.get_str_unchecked(auth_plugin_name_end)) };
auth_plugin_name = auth_plugin_name_;
}
}
Ok(Self {
protocol_version,
server_version,
connection_id,
charset,
capabilities,
status,
auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
auth_plugin_name,
})
}
}
#[cfg(test)]
mod tests {
use bytes::Buf;
use sqlx_core::io::Deserialize;
use super::{Capabilities, Handshake, ServerStatus};
#[test]
fn handshake_mysql_8_0_18() {
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(
h.capabilities,
Capabilities::LONG_PASSWORD
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::SSL
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF,
);
assert_eq!(h.charset, Some(255));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("caching_sha2_password"));
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32, 0]
);
}
#[test]
fn handshake_mariadb_10_4_7() {
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
let mut h = Handshake::deserialize(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic");
assert_eq!(
h.capabilities,
Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
);
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53,
110, 0
]
);
}
#[test]
fn handshake_mariadb_10_5_8() {
const HANDSHAKE_MARIA_DB_10_5_8: &[u8] = b"\n5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal\0\x07\0\0\0'PB949cf\0\xfe\xf7-\x02\0\xff\x81\x15\0\0\0\0\0\0\x0f\0\0\0UY>hr&`3{55H\0mysql_native_password\0";
let mut h = Handshake::deserialize(HANDSHAKE_MARIA_DB_10_5_8.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal");
assert_eq!(
h.capabilities,
Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
);
assert_eq!(h.charset, Some(45));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
39, 80, 66, 57, 52, 57, 99, 102, 85, 89, 62, 104, 114, 38, 96, 51, 123, 53, 53, 72,
0
]
);
}
#[test]
fn handshake_mysql_5_6_50() {
const HANDSHAKE_MYSQL_5_6_50: &[u8] = b"\n5.6.50\0\x01\0\0\0-VLYZ:Pd\0\xff\xf7\x08\x02\0\x7f\x80\x15\0\0\0\0\0\0\0\0\0\0'2f+BL8nGV[G\0mysql_native_password\0";
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_6_50.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.6.50");
assert_eq!(
h.capabilities,
Capabilities::LONG_PASSWORD
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
);
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[45, 86, 76, 89, 90, 58, 80, 100, 39, 50, 102, 43, 66, 76, 56, 110, 71, 86, 91, 71, 0]
);
}
#[test]
fn handshake_mysql_5_0_96() {
const HANDSHAKE_MYSQL_5_0_96: &[u8] = b"\n5.0.96\0\x03\0\0\0bs=sNiGe\0,\xa2\x08\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0IzMP)yLLx;[9\0";
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_0_96.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.0.96");
assert_eq!(
h.capabilities,
Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::COMPRESS
| Capabilities::PROTOCOL_41
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
);
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name, None);
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
98, 115, 61, 115, 78, 105, 71, 101, 73, 122, 77, 80, 41, 121, 76, 76, 120, 59, 91,
57, 0
]
);
}
#[test]
fn handshake_mysql_5_1_73() {
const HANDSHAKE_MYSQL_5_1_73: &[u8] = b"\n5.1.73\0\x01\0\0\0<fllZ\\Bs\0\xff\xf7\x08\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0<qEC_87JO/9q\0";
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_1_73.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.1.73");
assert_eq!(
h.capabilities,
Capabilities::LONG_PASSWORD
| Capabilities::LONG_FLAG
| Capabilities::FOUND_ROWS
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::INTERACTIVE
| Capabilities::PROTOCOL_41
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
);
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name, None);
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
60, 102, 108, 108, 90, 92, 66, 115, 60, 113, 69, 67, 95, 56, 55, 74, 79, 47, 57,
113, 0
]
);
}
#[test]
fn handshake_mysql_5_5_14() {
const HANDSHAKE_MYSQL_5_5_14: &[u8] = b"\n5.5.14\0\x01\0\0\0`o-/CEp'\0\xff\xf7\x08\x02\0\x0f\x80\x15\0\0\0\0\0\0\0\0\0\0kf@J5j6nJfAP\0mysql_native_password\0";
let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_5_14.into()).unwrap();
assert_eq!(h.protocol_version, 10);
assert_eq!(&*h.server_version, "5.5.14");
assert_eq!(
h.capabilities,
Capabilities::LONG_PASSWORD
| Capabilities::LONG_FLAG
| Capabilities::FOUND_ROWS
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::IGNORE_SPACE
| Capabilities::INTERACTIVE
| Capabilities::PROTOCOL_41
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
);
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, ServerStatus::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
96, 111, 45, 47, 67, 69, 112, 39, 107, 102, 64, 74, 53, 106, 54, 110, 74, 102, 65,
80, 0
]
);
}
}

View File

@ -0,0 +1,58 @@
use bytes::BufMut;
use sqlx_core::io::{Serialize, WriteExt};
use sqlx_core::Result;
use crate::protocol::Capabilities;
use crate::io::MySqlWriteExt;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
// https://mariadb.com/kb/en/connection/#client-handshake-response
#[derive(Debug)]
pub(crate) struct HandshakeResponse<'a> {
pub(crate) database: Option<&'a str>,
pub(crate) max_packet_size: u32,
pub(crate) charset: u8,
pub(crate) username: Option<&'a str>,
pub(crate) auth_plugin_name: Option<&'a str>,
pub(crate) auth_response: Option<&'a [u8]>,
}
impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
fn serialize_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) -> Result<()> {
buf.extend_from_slice(&(capabilities.bits() as u32).to_le_bytes());
buf.extend_from_slice(&self.max_packet_size.to_le_bytes());
buf.extend_from_slice(&self.charset.to_le_bytes());
// reserved (all 0)
buf.extend_from_slice(&[0_u8; 23]);
buf.write_maybe_str_nul(self.username);
let auth_response = self.auth_response.unwrap_or_default();
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
buf.write_bytes_lenenc(auth_response);
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
debug_assert!(auth_response.len() <= u8::max_value().into());
buf.reserve(1 + auth_response.len());
buf.push(auth_response.len() as u8);
buf.extend_from_slice(auth_response);
} else {
buf.reserve(1 + auth_response.len());
buf.extend_from_slice(auth_response);
buf.push(b'\0');
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
buf.write_maybe_str_nul(self.database);
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.write_maybe_str_nul(self.auth_plugin_name);
}
Ok(())
}
}

View File

@ -0,0 +1,50 @@
// https://dev.mysql.com/doc/internals/en/status-flags.html#packet-Protocol::StatusFlags
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad
// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status
bitflags::bitflags! {
pub struct ServerStatus: u16 {
// Is raised when a multi-statement transaction has been started, either explicitly,
// by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first
// transactional statement, when autocommit=off.
const IN_TRANS = 0x0001;
// Autocommit mode is set
const AUTOCOMMIT = 0x0002;
// Multi query - next query exists.
const MORE_RESULTS_EXISTS = 0x0008;
const NO_GOOD_INDEX_USED = 0x0010;
const NO_INDEX_USED = 0x0020;
// When using COM_STMT_FETCH, indicate that current cursor still has result
const CURSOR_EXISTS = 0x0040;
// When using COM_STMT_FETCH, indicate that current cursor has finished to send results
const LAST_ROW_SENT = 0x0080;
// Database has been dropped
const DB_DROPPED = 0x0100;
// Current escape mode is "no backslash escape"
const NO_BACKSLASH_ESCAPES = 0x0200;
// A DDL change did have an impact on an existing PREPARE (an automatic
// re-prepare has been executed)
const METADATA_CHANGED = 0x0400;
// Last statement took more than the time value specified
// in server variable long_query_time.
const QUERY_WAS_SLOW = 0x0800;
// This result-set contain stored procedure output parameter.
const PS_OUT_PARAMS = 0x1000;
// Current transaction is a read-only transaction.
const IN_TRANS_READONLY = 0x2000;
// This status flag, when on, implies that one of the state information has changed
// on the server because of the execution of the last statement.
const SESSION_STATE_CHANGED = 0x4000;
}
}

View File

@ -16,21 +16,16 @@
#![warn(clippy::useless_let_if_seq)]
#![allow(clippy::doc_markdown)]
#[cfg(feature = "blocking")]
pub use sqlx_core::blocking;
#[cfg(feature = "actix")]
pub use sqlx_core::Actix;
#[cfg(feature = "async-std")]
pub use sqlx_core::AsyncStd;
#[cfg(feature = "tokio")]
pub use sqlx_core::Tokio;
pub use sqlx_core::{
prelude, ConnectOptions, Connection, Database, DefaultRuntime, Error, Result, Runtime,
};
#[cfg(feature = "blocking")]
pub use sqlx_core::blocking;
#[cfg(feature = "async-std")]
pub use sqlx_core::AsyncStd;
#[cfg(feature = "tokio")]
pub use sqlx_core::Tokio;
#[cfg(feature = "actix")]
pub use sqlx_core::Actix;
#[cfg(feature = "mysql")]
pub use sqlx_mysql as mysql;