From 7750168b80fff71c42eaa1c171df5360cfd400fe Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 2 Jan 2021 10:47:15 -0800 Subject: [PATCH] wip(mysql): connect phase --- Cargo.lock | 50 +++++++++++++++++ examples/quickstart/Cargo.toml | 2 + examples/quickstart/src/main.rs | 30 +--------- sqlx-core/src/io/buf_stream.rs | 32 +++++++++-- sqlx-mysql/src/connection.rs | 16 ++++-- sqlx-mysql/src/connection/establish.rs | 76 ++++++++++++++++++++++---- sqlx-mysql/src/lib.rs | 2 + 7 files changed, 158 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3a6a0e2b..212bf68e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -445,6 +445,17 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "065374052e7df7ee4047b1160cca5e1467a12351a40b3da123c870ba0b8eda2a" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi 0.3.9", +] + [[package]] name = "autocfg" version = "1.0.1" @@ -718,6 +729,19 @@ dependencies = [ "syn", ] +[[package]] +name = "env_logger" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26ecb66b4bdca6c1409b40fb255eefc2bd4f6d135dab3c3124f80ffa2a9661e" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + [[package]] name = "event-listener" version = "2.5.1" @@ -996,6 +1020,12 @@ version = "1.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd179ae861f0c2e53da70d892f5f3029f9594be0c41dc5269cd371691b1dc2f9" +[[package]] +name = "humantime" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c1ad908cc71012b7bea4d0c53ba96a8cba9962f048fa68d143376143d863b7a" + [[package]] name = "idna" version = "0.2.0" @@ -1682,6 +1712,8 @@ dependencies = [ "actix-web", "anyhow", "async-std", + "env_logger", + "log", "sqlx", "tokio 1.0.1", ] @@ -1777,6 +1809,15 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "termcolor" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dfed899f0eb03f32ee8c6a0aabdb8a7949659e3466561fc0adf54e26d88c5f4" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "1.0.23" @@ -2189,6 +2230,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/examples/quickstart/Cargo.toml b/examples/quickstart/Cargo.toml index 7642bb76..fec215bd 100644 --- a/examples/quickstart/Cargo.toml +++ b/examples/quickstart/Cargo.toml @@ -14,3 +14,5 @@ async-std = { version = "1.8.0", features = ["attributes"] } #sqlx = { path = "../../sqlx", features = ["tokio", "mysql", "blocking", "async-std", "actix"] } sqlx = { path = "../../sqlx", features = ["tokio", "mysql"] } tokio = { version = "1.0.1", features = ["rt", "rt-multi-thread", "macros"] } +log = "0.4" +env_logger = "0.8.2" diff --git a/examples/quickstart/src/main.rs b/examples/quickstart/src/main.rs index 1a9e40a7..fe4e3798 100644 --- a/examples/quickstart/src/main.rs +++ b/examples/quickstart/src/main.rs @@ -1,35 +1,9 @@ -// #[async_std::main] -// async fn main() -> anyhow::Result<()> { -// let _stream = AsyncStd::connect_tcp("localhost", 5432).await?; -// -// Ok(()) -// } - -use sqlx::mysql::MySqlConnectOptions; +use sqlx::mysql::MySqlConnection; use sqlx::prelude::*; -// #[tokio::main] -// async fn main() -> anyhow::Result<()> { -// let mut conn = ::connect("mysql://").await?; -// -// Ok(()) -// } -// - -// #[async_std::main] -// async fn main() -> anyhow::Result<()> { -// let mut conn = ::builder() -// .host("loca%x91lhost") -// .port(20) -// .connect() -// .await?; -// -// Ok(()) -// } - #[tokio::main] async fn main() -> anyhow::Result<()> { - let mut conn = ::new().host("localhost").port(3306).connect().await?; + let _conn = ::connect("mysql://root:password@localhost:3307/main").await?; Ok(()) } diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 50387f26..141932ea 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -49,6 +49,14 @@ impl BufStream { pub fn consume(&mut self, n: usize) { let _ = self.take(n); } + + pub fn reserve(&mut self, additional: usize) { + self.wbuf.reserve(additional); + } + + pub fn write(&mut self, buf: &[u8]) { + self.wbuf.extend_from_slice(buf); + } } #[cfg(feature = "async")] @@ -56,12 +64,26 @@ impl BufStream where S: AsyncRead + AsyncWrite + Unpin, { + pub async fn flush_async(&mut self) -> crate::Result<()> { + // write as much as we can each time and move the cursor as we write from the buffer + // if _this_ future drops, offset will have a record of how much of the wbuf has + // been written + while self.wbuf_offset < self.wbuf.len() { + self.wbuf_offset += self.stream.write(&self.wbuf[self.wbuf_offset..]).await?; + } + + // fully written buffer, move cursor back to the beginning + self.wbuf_offset = 0; + self.wbuf.clear(); + + Ok(()) + } + pub async fn read_async(&mut self, n: usize) -> crate::Result<()> { - // // before waiting to receive data - // // ensure that the write buffer is flushed - // if !self.wbuf.is_empty() { - // self.flush().await?; - // } + // before waiting to receive data; ensure that the write buffer is flushed + if !self.wbuf.is_empty() { + self.flush_async().await?; + } // while our read buffer is too small to satisfy the requested amount while self.rbuf.len() < n { diff --git a/sqlx-mysql/src/connection.rs b/sqlx-mysql/src/connection.rs index 72106ca7..37d58cf3 100644 --- a/sqlx-mysql/src/connection.rs +++ b/sqlx-mysql/src/connection.rs @@ -16,6 +16,9 @@ where stream: BufStream, connection_id: u32, capabilities: Capabilities, + // the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is + // reset to 0 when a new command begins in the Command Phase. + sequence_id: u8, } impl MySqlConnection @@ -26,18 +29,19 @@ where Self { stream: BufStream::with_capacity(stream, 4096, 1024), connection_id: 0, - capabilities: Capabilities::LONG_PASSWORD + sequence_id: 0, + capabilities: Capabilities::PROTOCOL_41 | Capabilities::LONG_PASSWORD | Capabilities::LONG_FLAG | Capabilities::IGNORE_SPACE | Capabilities::TRANSACTIONS | Capabilities::SECURE_CONNECTION - | Capabilities::MULTI_STATEMENTS - | Capabilities::MULTI_RESULTS - | Capabilities::PS_MULTI_RESULTS + // | Capabilities::MULTI_STATEMENTS + // | Capabilities::MULTI_RESULTS + // | Capabilities::PS_MULTI_RESULTS | Capabilities::PLUGIN_AUTH | Capabilities::PLUGIN_AUTH_LENENC_DATA - | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS - | Capabilities::SESSION_TRACK + // | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS + // | Capabilities::SESSION_TRACK | Capabilities::DEPRECATE_EOF, } } diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index e0a8b077..e7f18a7e 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -1,10 +1,10 @@ -use bytes::Buf; +use bytes::{buf::Chain, Buf, Bytes}; use futures_io::{AsyncRead, AsyncWrite}; -use sqlx_core::io::{BufStream, Deserialize}; -use sqlx_core::{AsyncRuntime, Result, Runtime}; +use sqlx_core::io::{Deserialize, Serialize}; +use sqlx_core::{AsyncRuntime, Error, Result, Runtime}; -use crate::protocol::Handshake; -use crate::{MySqlConnectOptions, MySqlConnection}; +use crate::protocol::{Capabilities, ErrPacket, Handshake, HandshakeResponse, OkPacket}; +use crate::{MySqlConnectOptions, MySqlConnection, MySqlDatabaseError}; // https://dev.mysql.com/doc/internals/en/connection-phase.html @@ -16,25 +16,73 @@ use crate::{MySqlConnectOptions, MySqlConnection}; // the server may immediately send an ERR packet and finish the handshake // or send a [InitialHandshake] +fn make_auth_response( + auth_plugin_name: Option<&str>, + username: &str, + password: Option<&str>, + nonce: &Chain, +) -> Vec { + vec![] +} + +fn make_handshake_response(options: &MySqlConnectOptions) -> HandshakeResponse<'_> { + HandshakeResponse { + auth_plugin_name: None, + auth_response: None, + charset: 45, // [utf8mb4] + database: options.get_database(), + max_packet_size: 1024, + username: options.get_username(), + } +} + impl MySqlConnection where Rt: AsyncRuntime, ::TcpStream: Unpin + AsyncWrite + AsyncRead, { + fn recv_handshake(&mut self, handshake: &Handshake) { + self.capabilities &= handshake.capabilities; + self.connection_id = handshake.connection_id; + } + pub(crate) async fn establish_async(options: &MySqlConnectOptions) -> Result { let stream = Rt::connect_tcp(options.get_host(), options.get_port()).await?; let mut self_ = Self::new(stream); - // FIXME: Handle potential ERR packet here - let handshake = self_.read_packet_async::().await?; - println!("{:#?}", handshake); + let handshake = self_.read_packet_async().await?; + self_.recv_handshake(&handshake); + + self_.write_packet(make_handshake_response(options))?; + + self_.stream.flush_async().await?; + + let _ok: OkPacket = self_.read_packet_async().await?; Ok(self_) } + fn write_packet<'ser, T>(&'ser mut self, packet: T) -> Result<()> + where + T: Serialize<'ser, Capabilities>, + { + let mut wbuf = Vec::::with_capacity(1024); + + packet.serialize_with(&mut wbuf, self.capabilities)?; + + self.sequence_id = self.sequence_id.wrapping_add(1); + + self.stream.reserve(wbuf.len() + 4); + self.stream.write(&(wbuf.len() as u32).to_le_bytes()[..3]); + self.stream.write(&[self.sequence_id]); + self.stream.write(&wbuf); + + Ok(()) + } + async fn read_packet_async<'de, T>(&'de mut self) -> Result where - T: Deserialize<'de>, + T: Deserialize<'de, Capabilities>, { // https://dev.mysql.com/doc/internals/en/mysql-packet.html self.stream.read_async(4).await?; @@ -44,13 +92,19 @@ where // FIXME: handle split packets assert_ne!(payload_len, 0xFF_FF_FF); - let _seq_no = self.stream.get(3, 1).get_i8(); + self.sequence_id = self.stream.get(3, 1).get_u8(); self.stream.read_async(4 + payload_len).await?; self.stream.consume(4); let payload = self.stream.take(payload_len); - T::deserialize(payload) + if payload[0] == 0xff { + // if the first byte of the payload is 0xFF and the payload is an ERR packet + let err = ErrPacket::deserialize_with(payload, self.capabilities)?; + return Err(Error::Connect(Box::new(MySqlDatabaseError(err)))); + } + + T::deserialize_with(payload, self.capabilities) } } diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index 41ec4ec0..f35bf755 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -24,6 +24,7 @@ mod database; mod io; mod options; mod protocol; +mod error; #[cfg(feature = "blocking")] mod blocking; @@ -31,3 +32,4 @@ mod blocking; pub use connection::MySqlConnection; pub use database::MySql; pub use options::MySqlConnectOptions; +pub use error::MySqlDatabaseError;