wip(mysql): connect phase

This commit is contained in:
Ryan Leckey 2021-01-02 10:47:15 -08:00
parent 2195472e3e
commit 7750168b80
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
7 changed files with 158 additions and 50 deletions

50
Cargo.lock generated
View File

@ -445,6 +445,17 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "065374052e7df7ee4047b1160cca5e1467a12351a40b3da123c870ba0b8eda2a"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi",
"libc",
"winapi 0.3.9",
]
[[package]]
name = "autocfg"
version = "1.0.1"
@ -718,6 +729,19 @@ dependencies = [
"syn",
]
[[package]]
name = "env_logger"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26ecb66b4bdca6c1409b40fb255eefc2bd4f6d135dab3c3124f80ffa2a9661e"
dependencies = [
"atty",
"humantime",
"log",
"regex",
"termcolor",
]
[[package]]
name = "event-listener"
version = "2.5.1"
@ -996,6 +1020,12 @@ version = "1.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd179ae861f0c2e53da70d892f5f3029f9594be0c41dc5269cd371691b1dc2f9"
[[package]]
name = "humantime"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c1ad908cc71012b7bea4d0c53ba96a8cba9962f048fa68d143376143d863b7a"
[[package]]
name = "idna"
version = "0.2.0"
@ -1682,6 +1712,8 @@ dependencies = [
"actix-web",
"anyhow",
"async-std",
"env_logger",
"log",
"sqlx",
"tokio 1.0.1",
]
@ -1777,6 +1809,15 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "termcolor"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4"
dependencies = [
"winapi-util",
]
[[package]]
name = "thiserror"
version = "1.0.23"
@ -2189,6 +2230,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"

View File

@ -14,3 +14,5 @@ async-std = { version = "1.8.0", features = ["attributes"] }
#sqlx = { path = "../../sqlx", features = ["tokio", "mysql", "blocking", "async-std", "actix"] }
sqlx = { path = "../../sqlx", features = ["tokio", "mysql"] }
tokio = { version = "1.0.1", features = ["rt", "rt-multi-thread", "macros"] }
log = "0.4"
env_logger = "0.8.2"

View File

@ -1,35 +1,9 @@
// #[async_std::main]
// async fn main() -> anyhow::Result<()> {
// let _stream = AsyncStd::connect_tcp("localhost", 5432).await?;
//
// Ok(())
// }
use sqlx::mysql::MySqlConnectOptions;
use sqlx::mysql::MySqlConnection;
use sqlx::prelude::*;
// #[tokio::main]
// async fn main() -> anyhow::Result<()> {
// let mut conn = <MySqlConnection>::connect("mysql://").await?;
//
// Ok(())
// }
//
// #[async_std::main]
// async fn main() -> anyhow::Result<()> {
// let mut conn = <MySqlConnection>::builder()
// .host("loca%x91lhost")
// .port(20)
// .connect()
// .await?;
//
// Ok(())
// }
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut conn = <MySqlConnectOptions>::new().host("localhost").port(3306).connect().await?;
let _conn = <MySqlConnection>::connect("mysql://root:password@localhost:3307/main").await?;
Ok(())
}

View File

@ -49,6 +49,14 @@ impl<S> BufStream<S> {
pub fn consume(&mut self, n: usize) {
let _ = self.take(n);
}
pub fn reserve(&mut self, additional: usize) {
self.wbuf.reserve(additional);
}
pub fn write(&mut self, buf: &[u8]) {
self.wbuf.extend_from_slice(buf);
}
}
#[cfg(feature = "async")]
@ -56,12 +64,26 @@ impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn flush_async(&mut self) -> crate::Result<()> {
// write as much as we can each time and move the cursor as we write from the buffer
// if _this_ future drops, offset will have a record of how much of the wbuf has
// been written
while self.wbuf_offset < self.wbuf.len() {
self.wbuf_offset += self.stream.write(&self.wbuf[self.wbuf_offset..]).await?;
}
// fully written buffer, move cursor back to the beginning
self.wbuf_offset = 0;
self.wbuf.clear();
Ok(())
}
pub async fn read_async(&mut self, n: usize) -> crate::Result<()> {
// // before waiting to receive data
// // ensure that the write buffer is flushed
// if !self.wbuf.is_empty() {
// self.flush().await?;
// }
// before waiting to receive data; ensure that the write buffer is flushed
if !self.wbuf.is_empty() {
self.flush_async().await?;
}
// while our read buffer is too small to satisfy the requested amount
while self.rbuf.len() < n {

View File

@ -16,6 +16,9 @@ where
stream: BufStream<Rt::TcpStream>,
connection_id: u32,
capabilities: Capabilities,
// the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is
// reset to 0 when a new command begins in the Command Phase.
sequence_id: u8,
}
impl<Rt> MySqlConnection<Rt>
@ -26,18 +29,19 @@ where
Self {
stream: BufStream::with_capacity(stream, 4096, 1024),
connection_id: 0,
capabilities: Capabilities::LONG_PASSWORD
sequence_id: 0,
capabilities: Capabilities::PROTOCOL_41 | Capabilities::LONG_PASSWORD
| Capabilities::LONG_FLAG
| Capabilities::IGNORE_SPACE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
// | Capabilities::MULTI_STATEMENTS
// | Capabilities::MULTI_RESULTS
// | Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
// | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
// | Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF,
}
}

View File

@ -1,10 +1,10 @@
use bytes::Buf;
use bytes::{buf::Chain, Buf, Bytes};
use futures_io::{AsyncRead, AsyncWrite};
use sqlx_core::io::{BufStream, Deserialize};
use sqlx_core::{AsyncRuntime, Result, Runtime};
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{AsyncRuntime, Error, Result, Runtime};
use crate::protocol::Handshake;
use crate::{MySqlConnectOptions, MySqlConnection};
use crate::protocol::{Capabilities, ErrPacket, Handshake, HandshakeResponse, OkPacket};
use crate::{MySqlConnectOptions, MySqlConnection, MySqlDatabaseError};
// https://dev.mysql.com/doc/internals/en/connection-phase.html
@ -16,25 +16,73 @@ use crate::{MySqlConnectOptions, MySqlConnection};
// the server may immediately send an ERR packet and finish the handshake
// or send a [InitialHandshake]
fn make_auth_response(
auth_plugin_name: Option<&str>,
username: &str,
password: Option<&str>,
nonce: &Chain<Bytes, Bytes>,
) -> Vec<u8> {
vec![]
}
fn make_handshake_response<Rt: Runtime>(options: &MySqlConnectOptions<Rt>) -> HandshakeResponse<'_> {
HandshakeResponse {
auth_plugin_name: None,
auth_response: None,
charset: 45, // [utf8mb4]
database: options.get_database(),
max_packet_size: 1024,
username: options.get_username(),
}
}
impl<Rt> MySqlConnection<Rt>
where
Rt: AsyncRuntime,
<Rt as Runtime>::TcpStream: Unpin + AsyncWrite + AsyncRead,
{
fn recv_handshake(&mut self, handshake: &Handshake) {
self.capabilities &= handshake.capabilities;
self.connection_id = handshake.connection_id;
}
pub(crate) async fn establish_async(options: &MySqlConnectOptions<Rt>) -> Result<Self> {
let stream = Rt::connect_tcp(options.get_host(), options.get_port()).await?;
let mut self_ = Self::new(stream);
// FIXME: Handle potential ERR packet here
let handshake = self_.read_packet_async::<Handshake>().await?;
println!("{:#?}", handshake);
let handshake = self_.read_packet_async().await?;
self_.recv_handshake(&handshake);
self_.write_packet(make_handshake_response(options))?;
self_.stream.flush_async().await?;
let _ok: OkPacket = self_.read_packet_async().await?;
Ok(self_)
}
fn write_packet<'ser, T>(&'ser mut self, packet: T) -> Result<()>
where
T: Serialize<'ser, Capabilities>,
{
let mut wbuf = Vec::<u8>::with_capacity(1024);
packet.serialize_with(&mut wbuf, self.capabilities)?;
self.sequence_id = self.sequence_id.wrapping_add(1);
self.stream.reserve(wbuf.len() + 4);
self.stream.write(&(wbuf.len() as u32).to_le_bytes()[..3]);
self.stream.write(&[self.sequence_id]);
self.stream.write(&wbuf);
Ok(())
}
async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de>,
T: Deserialize<'de, Capabilities>,
{
// https://dev.mysql.com/doc/internals/en/mysql-packet.html
self.stream.read_async(4).await?;
@ -44,13 +92,19 @@ where
// FIXME: handle split packets
assert_ne!(payload_len, 0xFF_FF_FF);
let _seq_no = self.stream.get(3, 1).get_i8();
self.sequence_id = self.stream.get(3, 1).get_u8();
self.stream.read_async(4 + payload_len).await?;
self.stream.consume(4);
let payload = self.stream.take(payload_len);
T::deserialize(payload)
if payload[0] == 0xff {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
let err = ErrPacket::deserialize_with(payload, self.capabilities)?;
return Err(Error::Connect(Box::new(MySqlDatabaseError(err))));
}
T::deserialize_with(payload, self.capabilities)
}
}

View File

@ -24,6 +24,7 @@ mod database;
mod io;
mod options;
mod protocol;
mod error;
#[cfg(feature = "blocking")]
mod blocking;
@ -31,3 +32,4 @@ mod blocking;
pub use connection::MySqlConnection;
pub use database::MySql;
pub use options::MySqlConnectOptions;
pub use error::MySqlDatabaseError;