diff --git a/sqlx-mysql/src/connection.rs b/sqlx-mysql/src/connection.rs index 97cd843f..84f09c02 100644 --- a/sqlx-mysql/src/connection.rs +++ b/sqlx-mysql/src/connection.rs @@ -1,16 +1,18 @@ +use std::collections::VecDeque; use std::fmt::{self, Debug, Formatter}; -use sqlx_core::io::BufStream; use sqlx_core::net::Stream as NetStream; use sqlx_core::{Close, Connect, Connection, Runtime}; use crate::protocol::Capabilities; +use crate::stream::MySqlStream; use crate::{MySql, MySqlConnectOptions}; mod close; +mod command; mod connect; +mod executor; mod ping; -mod stream; /// A single connection (also known as a session) to a MySQL database server. #[allow(clippy::module_name_repetitions)] @@ -18,16 +20,17 @@ pub struct MySqlConnection where Rt: Runtime, { - stream: BufStream>, + stream: MySqlStream, connection_id: u32, // the capability flags are used by the client and server to indicate which // features they support and want to use. capabilities: Capabilities, - // the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is - // reset to 0 when a new command begins in the Command Phase. - sequence_id: u8, + // queue of commands that are being processed + // this is what we expect to receive from the server + // in the case of a future or stream being dropped + commands: VecDeque, } impl MySqlConnection @@ -36,21 +39,22 @@ where { pub(crate) fn new(stream: NetStream) -> Self { Self { - stream: BufStream::with_capacity(stream, 4096, 1024), + stream: MySqlStream::new(stream), connection_id: 0, - sequence_id: 0, - capabilities: Capabilities::PROTOCOL_41 | Capabilities::LONG_PASSWORD + commands: VecDeque::with_capacity(2), + capabilities: Capabilities::PROTOCOL_41 + | Capabilities::LONG_PASSWORD | Capabilities::LONG_FLAG | Capabilities::IGNORE_SPACE | Capabilities::TRANSACTIONS | Capabilities::SECURE_CONNECTION - // | Capabilities::MULTI_STATEMENTS - // | Capabilities::MULTI_RESULTS - // | Capabilities::PS_MULTI_RESULTS + | Capabilities::MULTI_STATEMENTS + | Capabilities::MULTI_RESULTS + | Capabilities::PS_MULTI_RESULTS | Capabilities::PLUGIN_AUTH | Capabilities::PLUGIN_AUTH_LENENC_DATA - // | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS - // | Capabilities::SESSION_TRACK + | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS + | Capabilities::SESSION_TRACK | Capabilities::DEPRECATE_EOF, } } @@ -115,7 +119,7 @@ mod blocking { impl Connection for MySqlConnection { #[inline] fn ping(&mut self) -> sqlx_core::Result<()> { - self.ping() + self.ping_blocking() } } @@ -125,14 +129,14 @@ mod blocking { where Self: Sized, { - Self::connect(&url.parse::>()?) + Self::connect_blocking(&url.parse::>()?) } } impl Close for MySqlConnection { #[inline] fn close(self) -> sqlx_core::Result<()> { - self.close() + self.close_blocking() } } } diff --git a/sqlx-mysql/src/connection/close.rs b/sqlx-mysql/src/connection/close.rs index f4b877c1..b986754d 100644 --- a/sqlx-mysql/src/connection/close.rs +++ b/sqlx-mysql/src/connection/close.rs @@ -2,16 +2,13 @@ use sqlx_core::{io::Stream, Result, Runtime}; use crate::protocol::Quit; -impl super::MySqlConnection -where - Rt: Runtime, -{ +impl super::MySqlConnection { #[cfg(feature = "async")] pub(crate) async fn close_async(mut self) -> Result<()> where Rt: sqlx_core::Async, { - self.write_packet(&Quit)?; + self.stream.write_packet(&Quit)?; self.stream.flush_async().await?; self.stream.shutdown_async().await?; @@ -19,11 +16,11 @@ where } #[cfg(feature = "blocking")] - pub(crate) fn close(mut self) -> Result<()> + pub(crate) fn close_blocking(mut self) -> Result<()> where Rt: sqlx_core::blocking::Runtime, { - self.write_packet(&Quit)?; + self.stream.write_packet(&Quit)?; self.stream.flush()?; self.stream.shutdown()?; diff --git a/sqlx-mysql/src/connection/connect.rs b/sqlx-mysql/src/connection/connect.rs index a0f76366..d9aaef24 100644 --- a/sqlx-mysql/src/connection/connect.rs +++ b/sqlx-mysql/src/connection/connect.rs @@ -13,100 +13,115 @@ //! use sqlx_core::net::Stream as NetStream; use sqlx_core::Result; +use sqlx_core::Runtime; -use crate::protocol::{Auth, AuthResponse, Handshake, HandshakeResponse}; +use crate::protocol::{AuthResponse, Handshake, HandshakeResponse}; use crate::{MySqlConnectOptions, MySqlConnection}; -macro_rules! connect { - (@blocking @tcp $options:ident) => { +impl MySqlConnection { + fn recv_handshake( + &mut self, + options: &MySqlConnectOptions, + handshake: &Handshake, + ) -> Result<()> { + // & the declared server capabilities with our capabilities to find + // what rules the client should operate under + self.capabilities &= handshake.capabilities; + + // store the connection ID, mainly for debugging + self.connection_id = handshake.connection_id; + + // create the initial auth response + // this may just be a request for an RSA public key + let initial_auth_response = handshake + .auth_plugin + .invoke(&handshake.auth_plugin_data, options.get_password().unwrap_or_default()); + + // the contains an initial guess at the correct encoding of + // the password and some other metadata like "which database", "which user", etc. + self.stream.write_packet(&HandshakeResponse { + capabilities: self.capabilities, + auth_plugin_name: handshake.auth_plugin.name(), + auth_response: initial_auth_response, + charset: 45, // [utf8mb4] + database: options.get_database(), + max_packet_size: 1024, + username: options.get_username(), + })?; + + Ok(()) + } + + fn recv_auth_response( + &mut self, + options: &MySqlConnectOptions, + handshake: &mut Handshake, + response: AuthResponse, + ) -> Result { + match response { + AuthResponse::Ok(_) => { + // successful, simple authentication; good to go + return Ok(true); + } + + AuthResponse::MoreData(data) => { + if let Some(data) = handshake.auth_plugin.handle( + data, + &handshake.auth_plugin_data, + options.get_password().unwrap_or_default(), + )? { + // write the response from the plugin + self.stream.write_packet(&&*data)?; + } + } + + AuthResponse::Switch(sw) => { + // switch to the new plugin + handshake.auth_plugin = sw.plugin; + handshake.auth_plugin_data = sw.plugin_data; + + // generate an initial response from this plugin + let data = handshake.auth_plugin.invoke( + &handshake.auth_plugin_data, + options.get_password().unwrap_or_default(), + ); + + // write the response from the plugin + self.stream.write_packet(&&*data)?; + } + } + + Ok(false) + } +} + +macro_rules! impl_connect { + (@blocking @new $options:ident) => { NetStream::connect($options.address.as_ref())?; }; - (@tcp $options:ident) => { + (@new $options:ident) => { NetStream::connect_async($options.address.as_ref()).await?; }; - (@blocking @packet $self:ident) => { - $self.read_packet()?; - }; - - (@packet $self:ident) => { - $self.read_packet_async().await?; - }; - ($(@$blocking:ident)? $options:ident) => {{ // open a network stream to the database server - let stream = connect!($(@$blocking)? @tcp $options); + let stream = impl_connect!($(@$blocking)? @new $options); // construct a around the network stream // wraps the stream in a to buffer read and write let mut self_ = Self::new(stream); // immediately the server should emit a packet - let handshake: Handshake = connect!($(@$blocking)? @packet self_); - - // & the declared server capabilities with our capabilities to find - // what rules the client should operate under - self_.capabilities &= handshake.capabilities; - - // store the connection ID, mainly for debugging - self_.connection_id = handshake.connection_id; - - // extract the auth plugin and data from the handshake - // this can get overwritten by an auth switch - let mut auth_plugin = handshake.auth_plugin; - let mut auth_plugin_data = handshake.auth_plugin_data; - let password = $options.get_password().unwrap_or_default(); - - // create the initial auth response - // this may just be a request for an RSA public key - let initial_auth_response = auth_plugin.invoke(&auth_plugin_data, password); - - // the contains an initial guess at the correct encoding of - // the password and some other metadata like "which database", "which user", etc. - self_.write_packet(&HandshakeResponse { - auth_plugin_name: auth_plugin.name(), - auth_response: initial_auth_response, - charset: 45, // [utf8mb4] - database: $options.get_database(), - max_packet_size: 1024, - username: $options.get_username(), - })?; + // we need to handle that and reply with a + let mut handshake = read_packet!($(@$blocking)? self_.stream).deserialize()?; + self_.recv_handshake($options, &handshake)?; loop { - match connect!($(@$blocking)? @packet self_) { - Auth::Ok(_) => { - // successful, simple authentication; good to go - break; - } - - Auth::MoreData(data) => { - if let Some(data) = auth_plugin.handle(data, &auth_plugin_data, password)? { - // write the response from the plugin - self_.write_packet(&AuthResponse { data })?; - - // let's try again - continue; - } - - // all done, the plugin says we check out - break; - } - - Auth::Switch(sw) => { - // switch to the new plugin - auth_plugin = sw.plugin; - auth_plugin_data = sw.plugin_data; - - // generate an initial response from this plugin - let data = auth_plugin.invoke(&auth_plugin_data, password); - - // write the response from the plugin - self_.write_packet(&AuthResponse { data })?; - - // let's try again - continue; - } + let response = read_packet!($(@$blocking)? self_.stream).deserialize_with(self_.capabilities)?; + if self_.recv_auth_response($options, &mut handshake, response)? { + // complete, successful authentication + break; } } @@ -114,24 +129,21 @@ macro_rules! connect { }}; } -impl MySqlConnection -where - Rt: sqlx_core::Runtime, -{ +impl MySqlConnection { #[cfg(feature = "async")] pub(crate) async fn connect_async(options: &MySqlConnectOptions) -> Result where Rt: sqlx_core::Async, { - connect!(options) + impl_connect!(options) } #[cfg(feature = "blocking")] - pub(crate) fn connect(options: &MySqlConnectOptions) -> Result + pub(crate) fn connect_blocking(options: &MySqlConnectOptions) -> Result where Rt: sqlx_core::blocking::Runtime, { - connect!(@blocking options) + impl_connect!(@blocking options) } } diff --git a/sqlx-mysql/src/connection/ping.rs b/sqlx-mysql/src/connection/ping.rs index 4cef6821..a0da73c4 100644 --- a/sqlx-mysql/src/connection/ping.rs +++ b/sqlx-mysql/src/connection/ping.rs @@ -6,31 +6,37 @@ use crate::protocol::{OkPacket, Ping}; // send the COM_PING packet // should receive an OK -impl super::MySqlConnection -where - Rt: Runtime, -{ +macro_rules! impl_ping { + ($(@$blocking:ident)? $self:ident) => {{ + $self.stream.write_packet(&Ping)?; + + // STATE: remember that we are expecting an OK packet + $self.begin_simple_command(); + + let _ok: OkPacket = read_packet!($(@$blocking)? $self.stream) + .deserialize_with($self.capabilities)?; + + // STATE: received OK packet + $self.end_command(); + + Ok(()) + }}; +} + +impl super::MySqlConnection { #[cfg(feature = "async")] pub(crate) async fn ping_async(&mut self) -> Result<()> where Rt: sqlx_core::Async, { - self.write_packet(&Ping)?; - - let _ok: OkPacket = self.read_packet_async().await?; - - Ok(()) + impl_ping!(self) } #[cfg(feature = "blocking")] - pub(crate) fn ping(&mut self) -> Result<()> + pub(crate) fn ping_blocking(&mut self) -> Result<()> where Rt: sqlx_core::blocking::Runtime, { - self.write_packet(&Ping)?; - - let _ok: OkPacket = self.read_packet()?; - - Ok(()) + impl_ping!(@blocking self) } } diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs deleted file mode 100644 index ad60ab0b..00000000 --- a/sqlx-mysql/src/connection/stream.rs +++ /dev/null @@ -1,160 +0,0 @@ -//! Reads and writes packets to and from the MySQL database server. -//! -//! The logic for serializing data structures into the packets is found -//! mostly in `protocol/`. -//! -//! Packets in MySQL are prefixed by 4 bytes. -//! 3 for length (in LE) and a sequence id. -//! -//! Packets may only be as large as the communicated size in the initial -//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending -//! a larger payload is simply sending completely "full" packets, one after the -//! other, with an increasing sequence id. -//! -//! In other words, when we sent data, we: -//! -//! - Split the data into "packets" of size `2 ** 24 - 1` bytes. -//! -//! - Prepend each packet with a **packet header**, consisting of the length of that packet, -//! and the sequence number. -//! -//! https://dev.mysql.com/doc/internals/en/mysql-packet.html -//! -use std::fmt::Debug; - -use bytes::{Buf, BufMut}; -use sqlx_core::io::{Deserialize, Serialize}; -use sqlx_core::{Error, Result, Runtime}; - -use crate::protocol::{Capabilities, ErrPacket, MaybeCommand}; -use crate::{MySqlConnection, MySqlDatabaseError}; - -impl MySqlConnection -where - Rt: Runtime, -{ - pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()> - where - T: Serialize<'ser, Capabilities> + Debug + MaybeCommand, - { - log::trace!("write > {:?}", packet); - - // the sequence-id is incremented with each packet and may - // wrap around. it starts at 0 and is reset to 0 when a new command - // begins in the Command Phase - - self.sequence_id = if T::is_command() { 0 } else { self.sequence_id.wrapping_add(1) }; - - // optimize for <16M packet sizes, in the case of >= 16M we would - // swap out the write buffer for a fresh buffer and then split it into - // 16M chunks separated by packet headers - - let buf = self.stream.buffer(); - let pos = buf.len(); - - // leave room for the length of the packet header at the start - buf.reserve(4); - buf.extend_from_slice(&[0_u8; 3]); - buf.push(self.sequence_id); - - // serialize the passed packet structure directly into the write buffer - packet.serialize_with(buf, self.capabilities)?; - - let payload_len = buf.len() - pos - 4; - - // FIXME: handle split packets - assert!(payload_len < 0xFF_FF_FF); - - // write back the length of the packet - #[allow(clippy::cast_possible_truncation)] - (&mut buf[pos..]).put_uint_le(payload_len as u64, 3); - - Ok(()) - } - - fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result - where - T: Deserialize<'de, Capabilities> + Debug, - { - // FIXME: handle split packets - assert_ne!(len, 0xFF_FF_FF); - - // We store the sequence id here. To respond to a packet, it should use a - // sequence id of n+1. It only "resets" at the start of a new command. - self.sequence_id = self.stream.get(3, 1).get_u8(); - - // tell the stream that we are done with the 4-byte header - self.stream.consume(4); - - // and remove the remainder of the packet from the stream, the payload - let payload = self.stream.take(len); - - if payload[0] == 0xff { - // if the first byte of the payload is 0xFF and the payload is an ERR packet - let err = ErrPacket::deserialize_with(payload, self.capabilities)?; - log::trace!("read > {:?}", err); - - return Err(Error::connect(MySqlDatabaseError(err))); - } - - let packet = T::deserialize_with(payload, self.capabilities)?; - log::trace!("read > {:?}", packet); - - Ok(packet) - } -} - -macro_rules! read_packet { - ($(@$blocking:ident)? $self:ident) => {{ - // reads at least 4 bytes from the IO stream into the read buffer - read_packet!($(@$blocking)? @stream $self, 0, 4); - - // the first 3 bytes will be the payload length of the packet (in LE) - // ALLOW: the max this len will be is 16M - #[allow(clippy::cast_possible_truncation)] - let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize; - - // read bytes _after_ the 4 byte packet header - // note that we have not yet told the stream we are done with any of - // these bytes yet. if this next read invocation were to never return (eg., the - // outer future was dropped), then the next time read_packet_async was called - // it will re-read the parsed-above packet header. Note that we have NOT - // mutated `self` _yet_. This is important. - read_packet!($(@$blocking)? @stream $self, 4, payload_len); - - $self.recv_packet(payload_len) - }}; - - (@blocking @stream $self:ident, $offset:expr, $n:expr) => { - $self.stream.read($offset, $n)?; - }; - - (@stream $self:ident, $offset:expr, $n:expr) => { - $self.stream.read_async($offset, $n).await?; - }; -} - -impl MySqlConnection -where - Rt: Runtime, -{ - #[cfg(feature = "async")] - - pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result - where - T: Deserialize<'de, Capabilities> + Debug, - Rt: sqlx_core::Async, - { - read_packet!(self) - } - - #[cfg(feature = "blocking")] - - pub(super) fn read_packet<'de, T>(&'de mut self) -> Result - where - T: Deserialize<'de, Capabilities> + Debug, - Rt: sqlx_core::blocking::Runtime, - { - read_packet!(@blocking self) - } -} diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index 7cf28cb5..90be62bb 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -18,6 +18,9 @@ #![warn(clippy::useless_let_if_seq)] #![allow(clippy::doc_markdown)] +#[macro_use] +mod stream; + mod connection; mod database; mod error; diff --git a/sqlx-mysql/src/options.rs b/sqlx-mysql/src/options.rs index 79d654d6..4cf010c4 100644 --- a/sqlx-mysql/src/options.rs +++ b/sqlx-mysql/src/options.rs @@ -97,7 +97,7 @@ mod blocking { where Self::Connection: Sized, { - >::connect(self) + >::connect_blocking(self) } } } diff --git a/sqlx-mysql/src/stream.rs b/sqlx-mysql/src/stream.rs new file mode 100644 index 00000000..5e64af95 --- /dev/null +++ b/sqlx-mysql/src/stream.rs @@ -0,0 +1,193 @@ +use std::fmt::Debug; +use std::ops::{Deref, DerefMut}; + +use bytes::{Buf, BufMut}; +use sqlx_core::io::{BufStream, Serialize}; +use sqlx_core::net::Stream as NetStream; +use sqlx_core::{Error, Result, Runtime}; + +use crate::protocol::{MaybeCommand, Packet}; +use crate::MySqlDatabaseError; + +/// Reads and writes packets to and from the MySQL database server. +/// +/// The logic for serializing data structures into the packets is found +/// mostly in `protocol/`. +/// +/// Packets in MySQL are prefixed by 4 bytes. +/// 3 for length (in LE) and a sequence id. +/// +/// Packets may only be as large as the communicated size in the initial +/// `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending +/// a larger payload is simply sending completely "full" packets, one after the +/// other, with an increasing sequence id. +/// +/// In other words, when we sent data, we: +/// +/// - Split the data into "packets" of size `2 ** 24 - 1` bytes. +/// +/// - Prepend each packet with a **packet header**, consisting of the length of that packet, +/// and the sequence number. +/// +/// +/// +#[allow(clippy::module_name_repetitions)] +pub(crate) struct MySqlStream { + stream: BufStream>, + + // the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is + // reset to 0 when a new command begins in the Command Phase. + sequence_id: u8, +} + +impl MySqlStream { + pub(crate) fn new(stream: NetStream) -> Self { + Self { stream: BufStream::with_capacity(stream, 4096, 1024), sequence_id: 0 } + } + + pub(crate) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()> + where + T: Serialize<'ser> + Debug + MaybeCommand, + { + log::trace!("write > {:?}", packet); + + // the sequence-id is incremented with each packet and may + // wrap around. it starts at 0 and is reset to 0 when a new command + // begins in the Command Phase + + self.sequence_id = if T::is_command() { 0 } else { self.sequence_id.wrapping_add(1) }; + + // optimize for <16M packet sizes, in the case of >= 16M we would + // swap out the write buffer for a fresh buffer and then split it into + // 16M chunks separated by packet headers + + let buf = self.stream.buffer(); + let pos = buf.len(); + + // leave room for the length of the packet header at the start + buf.reserve(4); + buf.extend_from_slice(&[0_u8; 3]); + buf.push(self.sequence_id); + + // serialize the passed packet structure directly into the write buffer + packet.serialize(buf)?; + + let payload_len = buf.len() - pos - 4; + + // FIXME: handle split packets + assert!(payload_len < 0xFF_FF_FF); + + // write back the length of the packet + #[allow(clippy::cast_possible_truncation)] + (&mut buf[pos..]).put_uint_le(payload_len as u64, 3); + + Ok(()) + } + + // read and consumes a packet from the stream _buffer_ + // assumes there is a packet on the stream + // is called by [read_packet_blocking] or [read_packet_async] + fn read_packet(&mut self, len: usize) -> Result { + // We store the sequence id here. To respond to a packet, it should use a + // sequence id of n+1. It only "resets" at the start of a new command. + self.sequence_id = self.stream.get(3, 1).get_u8(); + + // tell the stream that we are done with the 4-byte header + self.stream.consume(4); + + // and remove the remainder of the packet from the stream, the payload + let packet = Packet { bytes: self.stream.take(len) }; + + if packet.bytes.len() != len { + // BUG: something is very wrong somewhere if this branch is executed + // either in the SQLx MySQL driver or in the MySQL server + return Err(Error::connect(MySqlDatabaseError::malformed_packet(&format!( + "Received {} bytes for packet but expecting {} bytes", + packet.bytes.len(), + len + )))); + } + + if packet.bytes[0] == 0xff { + // if the first byte of the payload is 0xFF and the payload is an ERR packet + return Err(Error::connect(MySqlDatabaseError(packet.deserialize()?))); + } + + Ok(packet) + } +} + +macro_rules! impl_read_packet { + ($(@$blocking:ident)? $self:ident) => {{ + // reads at least 4 bytes from the IO stream into the read buffer + impl_read_packet!($(@$blocking)? @stream $self, 0, 4); + + // the first 3 bytes will be the payload length of the packet (in LE) + // ALLOW: the max this len will be is 16M + #[allow(clippy::cast_possible_truncation)] + let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize; + + // read bytes _after_ the 4 byte packet header + // note that we have not yet told the stream we are done with any of + // these bytes yet. if this next read invocation were to never return (eg., the + // outer future was dropped), then the next time read_packet was called + // it will re-read the parsed-above packet header. Note that we have NOT + // mutated `self` _yet_. This is important. + impl_read_packet!($(@$blocking)? @stream $self, 4, payload_len); + + // FIXME: handle split packets + assert_ne!(payload_len, 0xFF_FF_FF); + + $self.read_packet(payload_len) + }}; + + (@blocking @stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read($offset, $n)?; + }; + + (@stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read_async($offset, $n).await?; + }; +} + +impl MySqlStream { + #[cfg(feature = "async")] + pub(crate) async fn read_packet_async(&mut self) -> Result + where + Rt: sqlx_core::Async, + { + impl_read_packet!(self) + } + + #[cfg(feature = "blocking")] + pub(crate) fn read_packet_blocking(&mut self) -> Result + where + Rt: sqlx_core::blocking::Runtime, + { + impl_read_packet!(@blocking self) + } +} + +impl Deref for MySqlStream { + type Target = BufStream>; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl DerefMut for MySqlStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} + +macro_rules! read_packet { + (@blocking $stream:expr) => { + $stream.read_packet_blocking()? + }; + + ($stream:expr) => { + $stream.read_packet_async().await? + }; +}