From 7db850da71a9637a8395ea25fc668dbb1d399730 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Thu, 18 Feb 2021 23:35:36 -0800 Subject: [PATCH] feat(mysql): thread E: Execute through the executor + handle normal errors separate from unexpected errors, an unexpected error causes the connection to close (in which case, if this was behind a pool, the pool would not allow this connection to be acquired again) --- sqlx-core/src/error.rs | 4 + sqlx-mysql/src/connection.rs | 21 +-- sqlx-mysql/src/connection/close.rs | 29 ---- sqlx-mysql/src/connection/command.rs | 151 ++++++++++++++++++ sqlx-mysql/src/connection/executor.rs | 28 +++- sqlx-mysql/src/connection/executor/columns.rs | 23 +-- sqlx-mysql/src/connection/executor/execute.rs | 51 +++--- .../src/connection/executor/fetch_all.rs | 55 ++++--- .../src/connection/executor/fetch_optional.rs | 60 ++++--- .../{prepare.rs => executor/raw_prepare.rs} | 42 ++--- .../src/connection/executor/raw_query.rs | 67 ++++++++ sqlx-mysql/src/connection/flush.rs | 137 ++++------------ sqlx-mysql/src/connection/ping.rs | 9 +- sqlx-mysql/src/raw_statement.rs | 4 +- sqlx-mysql/src/stream.rs | 40 ++++- 15 files changed, 450 insertions(+), 271 deletions(-) delete mode 100644 sqlx-mysql/src/connection/close.rs create mode 100644 sqlx-mysql/src/connection/command.rs rename sqlx-mysql/src/connection/{prepare.rs => executor/raw_prepare.rs} (66%) create mode 100644 sqlx-mysql/src/connection/executor/raw_query.rs diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index b445c8c5..3f5c6664 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -30,6 +30,8 @@ pub enum Error { /// RowNotFound, + Closed, + Decode(DecodeError), Encode(EncodeError), @@ -82,6 +84,8 @@ impl Display for Error { f.write_str("no row returned by a query required to return at least one row") } + Self::Closed => f.write_str("connection or pool was closed"), + Self::Decode(error) => { write!(f, "{}", error) } diff --git a/sqlx-mysql/src/connection.rs b/sqlx-mysql/src/connection.rs index dccf86cc..42d96fb7 100644 --- a/sqlx-mysql/src/connection.rs +++ b/sqlx-mysql/src/connection.rs @@ -5,7 +5,7 @@ use futures_util::future::{BoxFuture, FutureExt}; use sqlx_core::net::Stream as NetStream; use sqlx_core::{Close, Connect, Connection, Runtime}; -use crate::connection::flush::CommandQueue; +use crate::connection::command::CommandQueue; use crate::protocol::Capabilities; use crate::stream::MySqlStream; use crate::{MySql, MySqlConnectOptions}; @@ -13,10 +13,7 @@ use crate::{MySql, MySqlConnectOptions}; #[macro_use] mod flush; -#[macro_use] -mod prepare; - -mod close; +mod command; mod connect; mod executor; mod ping; @@ -29,6 +26,7 @@ where { stream: MySqlStream, connection_id: u32, + closed: bool, // the capability flags are used by the client and server to indicate which // features they support and want to use. @@ -48,6 +46,7 @@ where Self { stream: MySqlStream::new(stream), connection_id: 0, + closed: false, commands: CommandQueue::new(), capabilities: Capabilities::PROTOCOL_41 | Capabilities::LONG_PASSWORD @@ -106,11 +105,15 @@ impl Connect for MySqlConnection { impl Close for MySqlConnection { #[cfg(feature = "async")] - fn close(self) -> BoxFuture<'static, sqlx_core::Result<()>> + fn close(mut self) -> BoxFuture<'static, sqlx_core::Result<()>> where Rt: sqlx_core::Async, { - Box::pin(self.close_async()) + Box::pin(async move { + self.stream.close_async().await?; + + Ok(()) + }) } } @@ -139,8 +142,8 @@ mod blocking { impl Close for MySqlConnection { #[inline] - fn close(self) -> sqlx_core::Result<()> { - self.close_blocking() + fn close(mut self) -> sqlx_core::Result<()> { + self.stream.close_blocking() } } } diff --git a/sqlx-mysql/src/connection/close.rs b/sqlx-mysql/src/connection/close.rs deleted file mode 100644 index b986754d..00000000 --- a/sqlx-mysql/src/connection/close.rs +++ /dev/null @@ -1,29 +0,0 @@ -use sqlx_core::{io::Stream, Result, Runtime}; - -use crate::protocol::Quit; - -impl super::MySqlConnection { - #[cfg(feature = "async")] - pub(crate) async fn close_async(mut self) -> Result<()> - where - Rt: sqlx_core::Async, - { - self.stream.write_packet(&Quit)?; - self.stream.flush_async().await?; - self.stream.shutdown_async().await?; - - Ok(()) - } - - #[cfg(feature = "blocking")] - pub(crate) fn close_blocking(mut self) -> Result<()> - where - Rt: sqlx_core::blocking::Runtime, - { - self.stream.write_packet(&Quit)?; - self.stream.flush()?; - self.stream.shutdown()?; - - Ok(()) - } -} diff --git a/sqlx-mysql/src/connection/command.rs b/sqlx-mysql/src/connection/command.rs new file mode 100644 index 00000000..4868ab04 --- /dev/null +++ b/sqlx-mysql/src/connection/command.rs @@ -0,0 +1,151 @@ +use std::collections::VecDeque; +use std::hint::unreachable_unchecked; +use std::marker::PhantomData; +use std::mem; +use std::ops::{Deref, DerefMut}; + +use sqlx_core::Result; + +use crate::protocol::{PrepareResponse, QueryResponse, QueryStep, ResultPacket, Status}; +use crate::{MySqlConnection, MySqlDatabaseError}; + +pub(crate) struct CommandQueue(pub(super) VecDeque); + +impl CommandQueue { + pub(crate) fn new() -> Self { + Self(VecDeque::with_capacity(2)) + } + + // begin a simple command + // in which we are expecting OK or ERR (a result) + pub(crate) fn begin(&mut self) { + self.0.push_back(Command::Simple); + } +} + +impl CommandQueue { + pub(crate) fn end(&mut self) { + self.0.pop_front(); + } +} + +#[derive(Debug)] +#[repr(u8)] +pub(crate) enum Command { + Simple, + Close, + Query(QueryCommand), + Prepare(PrepareCommand), +} + +pub(crate) struct CommandGuard<'cmd, C> { + queue: &'cmd mut CommandQueue, + command: PhantomData<&'cmd mut C>, + index: usize, + ended: bool, +} + +impl<'cmd, C> CommandGuard<'cmd, C> { + fn begin(queue: &'cmd mut CommandQueue, command: Command) -> Self { + let index = queue.0.len(); + queue.0.push_back(command); + + Self { queue, index, ended: false, command: PhantomData } + } + + // called on successful command completion + pub(crate) fn end(&mut self) { + self.ended = true; + } + + // on an error result, the command needs to end *normally* and pass + // through the error to bubble + pub(crate) fn end_if_error(&mut self, res: Result) -> Result { + match res { + Ok(ok) => Ok(ok), + Err(error) => { + self.end(); + Err(error) + } + } + } +} + +impl Drop for CommandGuard<'_, C> { + fn drop(&mut self) { + self.queue.end(); + + if !self.ended { + // if the command was not "completed" by success or a known + // failure, we are in a **weird** state, queue up a close if + // someone tries to re-use this connection + self.queue.0.push_front(Command::Close); + } + } +} + +#[derive(Debug)] +#[repr(u8)] +pub(crate) enum QueryCommand { + // expecting [QueryResponse] + QueryResponse, + + // expecting [QueryStep] + QueryStep, + + // expecting {rem} more [ColumnDefinition] packets + ColumnDefinition { rem: u16 }, +} + +impl QueryCommand { + pub(crate) fn begin(queue: &mut CommandQueue) -> CommandGuard<'_, Self> { + CommandGuard::begin(queue, Command::Query(Self::QueryResponse)) + } +} + +impl Deref for CommandGuard<'_, QueryCommand> { + type Target = QueryCommand; + + fn deref(&self) -> &Self::Target { + if let Command::Query(cmd) = &self.queue.0[self.index] { cmd } else { unreachable!() } + } +} + +impl DerefMut for CommandGuard<'_, QueryCommand> { + fn deref_mut(&mut self) -> &mut Self::Target { + if let Command::Query(cmd) = &mut self.queue.0[self.index] { cmd } else { unreachable!() } + } +} + +#[derive(Debug)] +pub(crate) enum PrepareCommand { + // expecting [ERR] or [COM_STMT_PREPARE_OK] + PrepareResponse, + + // expecting {rem} more [ColumnDefinition] packets for each parameter + // stores {columns} as this state is before the [ColumnDefinition] state + ParameterDefinition { rem: u16, columns: u16 }, + + // expecting {rem} more [ColumnDefinition] packets for each parameter + ColumnDefinition { rem: u16 }, +} + +impl PrepareCommand { + pub(crate) fn begin(queue: &mut CommandQueue) -> CommandGuard<'_, Self> { + CommandGuard::begin(queue, Command::Prepare(Self::PrepareResponse)) + } +} + +impl Deref for CommandGuard<'_, PrepareCommand> { + type Target = PrepareCommand; + + fn deref(&self) -> &Self::Target { + if let Command::Prepare(cmd) = &self.queue.0[self.index] { cmd } else { unreachable!() } + } +} + +impl DerefMut for CommandGuard<'_, PrepareCommand> { + fn deref_mut(&mut self) -> &mut Self::Target { + if let Command::Prepare(cmd) = &mut self.queue.0[self.index] { cmd } else { unreachable!() } + } +} diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 4bc9f31a..db72ba7d 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -1,12 +1,18 @@ #[cfg(feature = "async")] use futures_util::{future::BoxFuture, FutureExt}; -use sqlx_core::{Executor, Result, Runtime}; +use sqlx_core::{Arguments, Execute, Executor, Result, Runtime}; use crate::{MySql, MySqlConnection, MySqlQueryResult, MySqlRow}; #[macro_use] mod columns; +#[macro_use] +mod raw_prepare; + +#[macro_use] +mod raw_query; + mod execute; mod fetch_all; mod fetch_optional; @@ -16,38 +22,44 @@ impl Executor for MySqlConnection { #[cfg(feature = "async")] #[inline] - fn execute<'x, 'e, 'q>(&'e mut self, sql: &'q str) -> BoxFuture<'x, Result> + fn execute<'x, 'e, 'q, 'a, E>(&'e mut self, query: E) -> BoxFuture<'x, Result> where Rt: sqlx_core::Async, + E: 'x + Execute<'q, 'a, MySql>, 'e: 'x, 'q: 'x, + 'a: 'x, { - self.execute_async(sql).boxed() + self.execute_async(query).boxed() } #[cfg(feature = "async")] #[inline] - fn fetch_all<'x, 'e, 'q>(&'e mut self, sql: &'q str) -> BoxFuture<'x, Result>> + fn fetch_all<'x, 'e, 'q, 'a, E>(&'e mut self, query: E) -> BoxFuture<'x, Result>> where Rt: sqlx_core::Async, + E: 'x + Execute<'q, 'a, MySql>, 'e: 'x, 'q: 'x, + 'a: 'x, { - self.fetch_all_async(sql).boxed() + self.fetch_all_async(query).boxed() } #[cfg(feature = "async")] #[inline] - fn fetch_optional<'x, 'e, 'q>( + fn fetch_optional<'x, 'e, 'q, 'a, E>( &'e mut self, - sql: &'q str, + query: E, ) -> BoxFuture<'x, Result>> where Rt: sqlx_core::Async, + E: 'x + Execute<'q, 'a, MySql>, 'e: 'x, 'q: 'x, + 'a: 'x, { - self.fetch_optional_async(sql).boxed() + self.fetch_optional_async(query).boxed() } } diff --git a/sqlx-mysql/src/connection/executor/columns.rs b/sqlx-mysql/src/connection/executor/columns.rs index 1f01a398..e9c26902 100644 --- a/sqlx-mysql/src/connection/executor/columns.rs +++ b/sqlx-mysql/src/connection/executor/columns.rs @@ -1,22 +1,23 @@ use sqlx_core::{Result, Runtime}; -use crate::connection::flush::QueryCommand; +use crate::connection::command::QueryCommand; use crate::protocol::ColumnDefinition; use crate::stream::MySqlStream; +use crate::MySqlColumn; macro_rules! impl_recv_columns { ($(@$blocking:ident)? $store:expr, $num_columns:ident, $stream:ident, $cmd:ident) => {{ #[allow(clippy::cast_possible_truncation)] let mut columns = if $store { - Vec::::with_capacity($num_columns as usize) + Vec::::with_capacity($num_columns as usize) } else { // we are going to drop column definitions, do not allocate Vec::new() }; - for index in (1..=$num_columns).rev() { + for (ordinal, rem) in (1..=$num_columns).rev().enumerate() { // STATE: remember that we are expecting #rem more columns - *$cmd = QueryCommand::ColumnDefinition { rem: index }; + *$cmd = QueryCommand::ColumnDefinition { rem }; // read in definition and only deserialize if we are saving // the column definitions @@ -24,7 +25,7 @@ macro_rules! impl_recv_columns { let packet = read_packet!($(@$blocking)? $stream); if $store { - columns.push(packet.deserialize()?); + columns.push(MySqlColumn::new(ordinal, packet.deserialize()?)); } } @@ -40,9 +41,9 @@ impl MySqlStream { pub(super) async fn recv_columns_async( &mut self, store: bool, - columns: u64, + columns: u16, cmd: &mut QueryCommand, - ) -> Result> + ) -> Result> where Rt: sqlx_core::Async, { @@ -53,9 +54,9 @@ impl MySqlStream { pub(crate) fn recv_columns_blocking( &mut self, store: bool, - columns: u64, + columns: u16, cmd: &mut QueryCommand, - ) -> Result> + ) -> Result> where Rt: sqlx_core::blocking::Runtime, { @@ -65,10 +66,10 @@ impl MySqlStream { macro_rules! recv_columns { (@blocking $store:expr, $columns:ident, $stream:ident, $cmd:ident) => { - $stream.recv_columns_blocking($store, $columns, $cmd)? + $stream.recv_columns_blocking($store, $columns, &mut *$cmd)? }; ($store:expr, $columns:ident, $stream:ident, $cmd:ident) => { - $stream.recv_columns_async($store, $columns, $cmd).await? + $stream.recv_columns_async($store, $columns, &mut *$cmd).await? }; } diff --git a/sqlx-mysql/src/connection/executor/execute.rs b/sqlx-mysql/src/connection/executor/execute.rs index 5eb37e51..8023d0f0 100644 --- a/sqlx-mysql/src/connection/executor/execute.rs +++ b/sqlx-mysql/src/connection/executor/execute.rs @@ -1,19 +1,17 @@ -use sqlx_core::Result; +use sqlx_core::{Execute, Result, Runtime}; -use crate::connection::flush::QueryCommand; +use crate::connection::command::QueryCommand; use crate::protocol::{Query, QueryResponse, QueryStep, Status}; -use crate::{MySqlConnection, MySqlQueryResult}; +use crate::{MySql, MySqlConnection, MySqlQueryResult}; macro_rules! impl_execute { - ($(@$blocking:ident)? $self:ident, $sql:ident) => {{ + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + raw_query!($(@$blocking)? $self, $query); + let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self; - // send the server a text-based query that will be executed immediately - // replies with ERR, OK, or a result set - stream.write_packet(&Query { sql: $sql })?; - // STATE: remember that we are now expecting a query response - let cmd = QueryCommand::begin(commands); + let mut cmd = QueryCommand::begin(commands); // default an empty query result // execute collects all discovered query results and SUMs @@ -22,9 +20,9 @@ macro_rules! impl_execute { #[allow(clippy::while_let_loop, unused_labels)] 'results: loop { - let ok = 'result: loop { + let res = 'result: loop { match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { - QueryResponse::End(res) => break 'result res.into_result()?, + QueryResponse::End(res) => break 'result res.into_result(), QueryResponse::ResultSet { columns } => { // acknowledge but discard any columns as execute returns no rows recv_columns!($(@$blocking)? /* store = */ false, columns, stream, cmd); @@ -34,13 +32,16 @@ macro_rules! impl_execute { // execute ignores any rows returned // but we do increment affected rows QueryStep::Row(_row) => result.0.affected_rows += 1, - QueryStep::End(res) => break 'result res.into_result()?, + QueryStep::End(res) => break 'result res.into_result(), } } } } }; + // STATE: command is complete on error + let ok = cmd.end_if_error(res)?; + // fold this into the total result for the SQL result.extend(Some(ok.into())); @@ -54,24 +55,30 @@ macro_rules! impl_execute { } // STATE: the current command is complete - commands.end(); + cmd.end(); Ok(result) }}; } -#[cfg(feature = "async")] -impl MySqlConnection { - pub(super) async fn execute_async(&mut self, sql: &str) -> Result { +impl MySqlConnection { + #[cfg(feature = "async")] + pub(super) async fn execute_async<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, MySql>, + { flush!(self); - impl_execute!(self, sql) + impl_execute!(self, query) } -} -#[cfg(feature = "blocking")] -impl MySqlConnection { - pub(super) fn execute_blocking(&mut self, sql: &str) -> Result { + #[cfg(feature = "blocking")] + pub(super) fn execute_blocking<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, MySql>, + { flush!(@blocking self); - impl_execute!(@blocking self, sql) + impl_execute!(@blocking self, query) } } diff --git a/sqlx-mysql/src/connection/executor/fetch_all.rs b/sqlx-mysql/src/connection/executor/fetch_all.rs index 8a9bf3e5..801d571c 100644 --- a/sqlx-mysql/src/connection/executor/fetch_all.rs +++ b/sqlx-mysql/src/connection/executor/fetch_all.rs @@ -1,28 +1,26 @@ -use sqlx_core::Result; +use sqlx_core::{Arguments, Execute, Result, Runtime}; -use crate::connection::flush::QueryCommand; -use crate::protocol::{Query, QueryResponse, QueryStep, Status}; -use crate::{MySqlConnection, MySqlRow}; +use crate::connection::command::QueryCommand; +use crate::protocol::{self, Query, QueryResponse, QueryStep, Status}; +use crate::{MySql, MySqlConnection, MySqlRawValueFormat, MySqlRow}; macro_rules! impl_fetch_all { - ($(@$blocking:ident)? $self:ident, $sql:ident) => {{ + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + let format = raw_query!($(@$blocking)? $self, $query); + let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self; - // send the server a text-based query that will be executed immediately - // replies with ERR, OK, or a result set - stream.write_packet(&Query { sql: $sql })?; - // STATE: remember that we are now expecting a query response - let cmd = QueryCommand::begin(commands); + let mut cmd = QueryCommand::begin(commands); // default an empty row set let mut rows = Vec::with_capacity(10); #[allow(clippy::while_let_loop, unused_labels)] 'results: loop { - let ok = 'result: loop { + let res = 'result: loop { match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { - QueryResponse::End(res) => break 'result res.into_result()?, + QueryResponse::End(res) => break 'result res.into_result(), QueryResponse::ResultSet { columns } => { let columns = recv_columns!($(@$blocking)? /* store = */ true, columns, stream, cmd); @@ -30,14 +28,17 @@ macro_rules! impl_fetch_all { match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { // execute ignores any rows returned // but we do increment affected rows - QueryStep::End(res) => break 'result res.into_result()?, - QueryStep::Row(row) => rows.push(MySqlRow::new(row.deserialize_with(&columns[..])?)), + QueryStep::End(res) => break 'result res.into_result(), + QueryStep::Row(row) => rows.push(MySqlRow::new(row.deserialize_with((format, &columns[..]))?, &columns)), } } } } }; + // STATE: command is complete on error + let ok = cmd.end_if_error(res)?; + if !ok.status.contains(Status::MORE_RESULTS_EXISTS) { // no more results, time to finally call it quits break; @@ -48,24 +49,30 @@ macro_rules! impl_fetch_all { } // STATE: the current command is complete - commands.end(); + cmd.end(); Ok(rows) }}; } -#[cfg(feature = "async")] -impl MySqlConnection { - pub(super) async fn fetch_all_async(&mut self, sql: &str) -> Result> { +impl MySqlConnection { + #[cfg(feature = "async")] + pub(super) async fn fetch_all_async<'q, 'a, E>(&mut self, query: E) -> Result> + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, MySql>, + { flush!(self); - impl_fetch_all!(self, sql) + impl_fetch_all!(self, query) } -} -#[cfg(feature = "blocking")] -impl MySqlConnection { - pub(super) fn fetch_all_blocking(&mut self, sql: &str) -> Result> { + #[cfg(feature = "blocking")] + pub(super) fn fetch_all_blocking<'q, 'a, E>(&mut self, query: E) -> Result> + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, MySql>, + { flush!(@blocking self); - impl_fetch_all!(@blocking self, sql) + impl_fetch_all!(@blocking self, query) } } diff --git a/sqlx-mysql/src/connection/executor/fetch_optional.rs b/sqlx-mysql/src/connection/executor/fetch_optional.rs index 15dbbb44..b8a95d5f 100644 --- a/sqlx-mysql/src/connection/executor/fetch_optional.rs +++ b/sqlx-mysql/src/connection/executor/fetch_optional.rs @@ -1,39 +1,36 @@ -use sqlx_core::Result; +use sqlx_core::{Execute, Result, Runtime}; -use crate::connection::flush::QueryCommand; +use crate::connection::command::QueryCommand; use crate::protocol::{Query, QueryResponse, QueryStep, Status}; -use crate::{MySqlConnection, MySqlRow}; +use crate::{MySql, MySqlConnection, MySqlRawValueFormat, MySqlRow}; macro_rules! impl_fetch_optional { - ($(@$blocking:ident)? $self:ident, $sql:ident) => {{ + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + let format = raw_query!($(@$blocking)? $self, $query); + let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self; - // send the server a text-based query that will be executed immediately - // replies with ERR, OK, or a result set - stream.write_packet(&Query { sql: $sql })?; - // STATE: remember that we are now expecting a query response - let cmd = QueryCommand::begin(commands); + let mut cmd = QueryCommand::begin(commands); // default we did not find a row let mut first_row = None; #[allow(clippy::while_let_loop, unused_labels)] 'results: loop { - let ok = 'result: loop { + let res = 'result: loop { match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { - QueryResponse::End(res) => break 'result res.into_result()?, + QueryResponse::End(res) => break 'result res.into_result(), QueryResponse::ResultSet { columns } => { let columns = recv_columns!($(@$blocking)? /* store = */ true, columns, stream, cmd); - log::debug!("columns = {:?}", columns); 'rows: loop { match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { // execute ignores any rows returned // but we do increment affected rows - QueryStep::End(res) => break 'result res.into_result()?, + QueryStep::End(res) => break 'result res.into_result(), QueryStep::Row(row) => { - first_row = Some(MySqlRow::new(row.deserialize_with(&columns[..])?)); + first_row = Some(MySqlRow::new(row.deserialize_with((format, &columns[..]))?, &columns)); // get out as soon as possible after finding our one row break 'results; @@ -44,9 +41,12 @@ macro_rules! impl_fetch_optional { } }; + // STATE: command is complete on error + let ok = cmd.end_if_error(res)?; + if !ok.status.contains(Status::MORE_RESULTS_EXISTS) { // STATE: the current command is complete - commands.end(); + cmd.end(); // no more results, time to finally call it quits and give up break; @@ -60,18 +60,30 @@ macro_rules! impl_fetch_optional { }}; } -#[cfg(feature = "async")] -impl MySqlConnection { - pub(super) async fn fetch_optional_async(&mut self, sql: &str) -> Result> { +impl MySqlConnection { + #[cfg(feature = "async")] + pub(super) async fn fetch_optional_async<'q, 'a, E>( + &mut self, + query: E, + ) -> Result> + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, MySql>, + { flush!(self); - impl_fetch_optional!(self, sql) + impl_fetch_optional!(self, query) } -} -#[cfg(feature = "blocking")] -impl MySqlConnection { - pub(super) fn fetch_optional_blocking(&mut self, sql: &str) -> Result> { + #[cfg(feature = "blocking")] + pub(super) fn fetch_optional_blocking<'q, 'a, E>( + &mut self, + query: E, + ) -> Result> + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, MySql>, + { flush!(@blocking self); - impl_fetch_optional!(@blocking self, sql) + impl_fetch_optional!(@blocking self, query) } } diff --git a/sqlx-mysql/src/connection/prepare.rs b/sqlx-mysql/src/connection/executor/raw_prepare.rs similarity index 66% rename from sqlx-mysql/src/connection/prepare.rs rename to sqlx-mysql/src/connection/executor/raw_prepare.rs index ad95a702..d2949686 100644 --- a/sqlx-mysql/src/connection/prepare.rs +++ b/sqlx-mysql/src/connection/executor/raw_prepare.rs @@ -1,10 +1,11 @@ use sqlx_core::{Result, Runtime}; -use crate::connection::flush::PrepareCommand; +use crate::connection::command::PrepareCommand; use crate::protocol::{ColumnDefinition, Prepare, PrepareResponse}; -use crate::{MySqlColumn, MySqlStatement, MySqlTypeInfo}; +use crate::raw_statement::RawStatement; +use crate::{MySqlColumn, MySqlTypeInfo}; -macro_rules! impl_prepare { +macro_rules! impl_raw_prepare { ($(@$blocking:ident)? $self:ident, $sql:ident) => {{ let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self; @@ -12,25 +13,15 @@ macro_rules! impl_prepare { stream.write_packet(&Prepare { sql: $sql })?; // STATE: remember that we are now expecting a prepare response - let cmd = PrepareCommand::begin(commands); + let mut cmd = PrepareCommand::begin(commands); let res = read_packet!($(@$blocking)? stream) .deserialize_with::(capabilities)?.into_result(); - let ok = match res { - Ok(ok) => ok, - Err(error) => { - // STATE: prepare failed, command ended - commands.end(); + // STATE: command is complete on error + let ok = cmd.end_if_error(res)?; - return Err(error); - }, - }; - - let mut stmt = MySqlStatement::new(ok.statement_id); - - stmt.parameters.reserve(ok.params.into()); - stmt.columns.reserve(ok.columns.into()); + let mut stmt = RawStatement::new(&ok); for index in (1..=ok.params).rev() { // STATE: remember that we are expecting #rem more columns @@ -57,39 +48,38 @@ macro_rules! impl_prepare { // TODO: handle EOF for old MySQL // STATE: the command is complete - commands.end(); + cmd.end(); Ok(stmt) }}; } -// TODO: should be private impl super::MySqlConnection { #[cfg(feature = "async")] - pub async fn prepare_async(&mut self, sql: &str) -> Result + pub(super) async fn raw_prepare_async(&mut self, sql: &str) -> Result where Rt: sqlx_core::Async, { flush!(self); - impl_prepare!(self, sql) + impl_raw_prepare!(self, sql) } #[cfg(feature = "blocking")] - pub fn prepare_blocking(&mut self, sql: &str) -> Result + pub(super) fn raw_prepare_blocking(&mut self, sql: &str) -> Result where Rt: sqlx_core::blocking::Runtime, { flush!(@blocking self); - impl_prepare!(@blocking self, sql) + impl_raw_prepare!(@blocking self, sql) } } -macro_rules! prepare { +macro_rules! raw_prepare { (@blocking $self:ident, $sql:expr) => { - $self.prepare_blocking($sql)? + $self.raw_prepare_blocking($sql)? }; ($self:ident, $sql:expr) => { - $self.prepare_async($sql).await? + $self.raw_prepare_async($sql).await? }; } diff --git a/sqlx-mysql/src/connection/executor/raw_query.rs b/sqlx-mysql/src/connection/executor/raw_query.rs new file mode 100644 index 00000000..7efbf1d0 --- /dev/null +++ b/sqlx-mysql/src/connection/executor/raw_query.rs @@ -0,0 +1,67 @@ +use sqlx_core::{Arguments, Execute, Result, Runtime}; + +use crate::protocol::{self, Query, QueryResponse, QueryStep, Status}; +use crate::{MySql, MySqlConnection, MySqlRawValueFormat, MySqlRow}; + +macro_rules! impl_raw_query { + ($(@$blocking:ident)? $self:ident, $query:ident) => {{ + let format = if let Some(arguments) = $query.arguments() { + // prepare the statement for execution + let statement = raw_prepare!($(@$blocking:ident)? $self, $query.sql()); + + // execute the prepared statement + $self.stream.write_packet(&protocol::Execute { + statement: statement.id(), + parameters: &statement.parameters, + arguments: &arguments, + })?; + + // prepared queries always use the BINARY format + MySqlRawValueFormat::Binary + } else { + // directly execute the query as an unprepared, simple query + $self.stream.write_packet(&Query { sql: $query.sql() })?; + + // unprepared queries use the TEXT format + // this is a significant waste of bandwidth for large result sets + MySqlRawValueFormat::Text + }; + + Ok(format) + }}; +} + +impl MySqlConnection { + #[cfg(feature = "async")] + pub(super) async fn raw_query_async<'q, 'a, E>( + &mut self, + query: E, + ) -> Result + where + Rt: sqlx_core::Async, + E: Execute<'q, 'a, MySql>, + { + flush!(self); + impl_raw_query!(self, query) + } + + #[cfg(feature = "blocking")] + pub(super) fn raw_query_blocking<'q, 'a, E>(&mut self, query: E) -> Result + where + Rt: sqlx_core::blocking::Runtime, + E: Execute<'q, 'a, MySql>, + { + flush!(@blocking self); + impl_raw_query!(@blocking self, query) + } +} + +macro_rules! raw_query { + (@blocking $self:ident, $sql:expr) => { + $self.raw_query_blocking($sql)? + }; + + ($self:ident, $sql:expr) => { + $self.raw_query_async($sql).await? + }; +} diff --git a/sqlx-mysql/src/connection/flush.rs b/sqlx-mysql/src/connection/flush.rs index 74aae206..96990422 100644 --- a/sqlx-mysql/src/connection/flush.rs +++ b/sqlx-mysql/src/connection/flush.rs @@ -1,133 +1,52 @@ use std::collections::VecDeque; use std::hint::unreachable_unchecked; -use sqlx_core::Result; +use sqlx_core::{Error, Result}; +use crate::connection::command::{Command, CommandQueue, PrepareCommand, QueryCommand}; use crate::protocol::{PrepareResponse, QueryResponse, QueryStep, ResultPacket, Status}; use crate::{MySqlConnection, MySqlDatabaseError}; -pub(crate) struct CommandQueue(VecDeque); - -impl CommandQueue { - pub(crate) fn new() -> Self { - Self(VecDeque::with_capacity(2)) - } - - // begin a simple command - // in which we are expecting OK or ERR (a result) - pub(crate) fn begin(&mut self) { - self.0.push_back(Command::Simple); - } -} - -impl CommandQueue { - pub(crate) fn end(&mut self) { - self.0.pop_front(); - } - - fn maybe_end(&mut self, res: ResultPacket) { - match res { - ResultPacket::Ok(ok) => { - if ok.status.contains(Status::MORE_RESULTS_EXISTS) { - // an attached query response is next - // we are still expecting one - return; - } - } - - ResultPacket::Err(error) => { - // without context, we should not bubble this err - // log and continue forward - log::error!("{}", MySqlDatabaseError(error)); +fn maybe_end_with(queue: &mut CommandQueue, res: ResultPacket) { + match res { + ResultPacket::Ok(ok) => { + if ok.status.contains(Status::MORE_RESULTS_EXISTS) { + // an attached query response is next + // we are still expecting one + return; } } - // STATE: end of query - self.0.pop_front(); - } -} - -#[derive(Debug)] -#[repr(u8)] -pub(crate) enum Command { - // expecting [ResultPacket] - Simple, - Query(QueryCommand), - Prepare(PrepareCommand), -} - -#[derive(Debug)] -#[repr(u8)] -pub(crate) enum QueryCommand { - // expecting [QueryResponse] - QueryResponse, - - // expecting [QueryStep] - QueryStep, - - // expecting {rem} more [ColumnDefinition] packets - ColumnDefinition { rem: u16 }, -} - -impl QueryCommand { - pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self { - queue.0.push_back(Command::Query(Self::QueryResponse)); - - if let Some(Command::Query(cmd)) = queue.0.back_mut() { - cmd - } else { - // UNREACHABLE: just pushed a query command to the back of the vector, and we - // have &mut access, nobody else is pushing to it - #[allow(unsafe_code)] - unsafe { - unreachable_unchecked() - } + ResultPacket::Err(error) => { + // without context, we should not bubble this err + // log and continue forward + log::error!("{}", MySqlDatabaseError(error)); } } -} -#[derive(Debug)] -pub(crate) enum PrepareCommand { - // expecting [ERR] or [COM_STMT_PREPARE_OK] - PrepareResponse, - - // expecting {rem} more [ColumnDefinition] packets for each parameter - // stores {columns} as this state is before the [ColumnDefinition] state - ParameterDefinition { rem: u16, columns: u16 }, - - // expecting {rem} more [ColumnDefinition] packets for each parameter - ColumnDefinition { rem: u16 }, -} - -impl PrepareCommand { - pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self { - queue.0.push_back(Command::Prepare(Self::PrepareResponse)); - - if let Some(Command::Prepare(cmd)) = queue.0.back_mut() { - cmd - } else { - // UNREACHABLE: just pushed a prepare command to the back of the vector, and we - // have &mut access, nobody else is pushing to it - #[allow(unsafe_code)] - unsafe { - unreachable_unchecked() - } - } - } + // STATE: end of query + queue.0.pop_front(); } macro_rules! impl_flush { ($(@$blocking:ident)? $self:ident) => {{ - let Self { ref mut commands, ref mut stream, capabilities, .. } = *$self; - - log::debug!("flush!"); + let Self { ref mut commands, ref mut stream, ref mut closed, capabilities, .. } = *$self; while let Some(command) = commands.0.get_mut(0) { match command { + Command::Close => { + if !*closed { + close!($(@$blocking)? stream); + *closed = true; + } + + return Err(Error::Closed); + } + Command::Simple => { // simple commands where we expect an OK or ERR // ex. COM_PING, COM_QUERY, COM_STMT_RESET, COM_SET_OPTION - commands.maybe_end(read_packet!($(@$blocking)? stream).deserialize_with(capabilities)?); + maybe_end_with(commands, read_packet!($(@$blocking)? stream).deserialize_with(capabilities)?); } Command::Prepare(ref mut cmd) => { @@ -185,7 +104,7 @@ macro_rules! impl_flush { // expecting OK, ERR, or a result set QueryCommand::QueryResponse => { match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { - QueryResponse::End(end) => break commands.maybe_end(end), + QueryResponse::End(end) => break maybe_end_with(commands, end), QueryResponse::ResultSet { columns } => { // STATE: expect the column definitions for each column *cmd = QueryCommand::ColumnDefinition { rem: columns }; @@ -214,7 +133,7 @@ macro_rules! impl_flush { // either the query result set has ended or we receive // and immediately drop a row match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { - QueryStep::End(end) => break commands.maybe_end(end), + QueryStep::End(end) => break maybe_end_with(commands, end), QueryStep::Row(_) => {} } } diff --git a/sqlx-mysql/src/connection/ping.rs b/sqlx-mysql/src/connection/ping.rs index 426d986b..7c29d3ad 100644 --- a/sqlx-mysql/src/connection/ping.rs +++ b/sqlx-mysql/src/connection/ping.rs @@ -13,14 +13,13 @@ macro_rules! impl_ping { // STATE: remember that we are expecting an OK packet $self.commands.begin(); - let _ok = read_packet!($(@$blocking)? $self.stream) - .deserialize_with::($self.capabilities)? - .into_result()?; + let res = read_packet!($(@$blocking)? $self.stream) + .deserialize_with::($self.capabilities)?; - // STATE: received OK packet + // STATE: received result packet $self.commands.end(); - Ok(()) + res.into_result().map(|_| ()) }}; } diff --git a/sqlx-mysql/src/raw_statement.rs b/sqlx-mysql/src/raw_statement.rs index 8e511edb..246bbdd4 100644 --- a/sqlx-mysql/src/raw_statement.rs +++ b/sqlx-mysql/src/raw_statement.rs @@ -9,11 +9,11 @@ pub(crate) struct RawStatement { } impl RawStatement { - pub(crate) fn new(ok: PrepareOk) -> Self { + pub(crate) fn new(ok: &PrepareOk) -> Self { Self { id: ok.statement_id, columns: Vec::with_capacity(ok.columns.into()), - parameters: Vec::with_capacity(ok.parameters.into()), + parameters: Vec::with_capacity(ok.params.into()), } } diff --git a/sqlx-mysql/src/stream.rs b/sqlx-mysql/src/stream.rs index 1855f723..e42b75c8 100644 --- a/sqlx-mysql/src/stream.rs +++ b/sqlx-mysql/src/stream.rs @@ -2,11 +2,11 @@ use std::fmt::Debug; use std::ops::{Deref, DerefMut}; use bytes::{Buf, BufMut}; -use sqlx_core::io::{BufStream, Serialize}; +use sqlx_core::io::{BufStream, Serialize, Stream}; use sqlx_core::net::Stream as NetStream; use sqlx_core::{Error, Result, Runtime}; -use crate::protocol::{MaybeCommand, Packet}; +use crate::protocol::{MaybeCommand, Packet, Quit}; use crate::MySqlDatabaseError; /// Reads and writes packets to and from the MySQL database server. @@ -186,3 +186,39 @@ macro_rules! read_packet { $stream.read_packet_async().await? }; } + +impl MySqlStream { + #[cfg(feature = "async")] + pub(crate) async fn close_async(&mut self) -> Result<()> + where + Rt: sqlx_core::Async, + { + self.write_packet(&Quit)?; + self.flush_async().await?; + self.shutdown_async().await?; + + Ok(()) + } + + #[cfg(feature = "blocking")] + pub(crate) fn close_blocking(&mut self) -> Result<()> + where + Rt: sqlx_core::blocking::Runtime, + { + self.write_packet(&Quit)?; + self.flush()?; + self.shutdown()?; + + Ok(()) + } +} + +macro_rules! close { + (@blocking $self:ident) => { + $self.close_blocking()? + }; + + ($self:ident) => { + $self.close_async().await? + }; +}