diff --git a/sqlx-mysql/src/connection.rs b/sqlx-mysql/src/connection.rs index 43edf6a1..dccf86cc 100644 --- a/sqlx-mysql/src/connection.rs +++ b/sqlx-mysql/src/connection.rs @@ -13,6 +13,9 @@ use crate::{MySql, MySqlConnectOptions}; #[macro_use] mod flush; +#[macro_use] +mod prepare; + mod close; mod connect; mod executor; diff --git a/sqlx-mysql/src/connection/flush.rs b/sqlx-mysql/src/connection/flush.rs index 9d7ed39e..74aae206 100644 --- a/sqlx-mysql/src/connection/flush.rs +++ b/sqlx-mysql/src/connection/flush.rs @@ -3,7 +3,7 @@ use std::hint::unreachable_unchecked; use sqlx_core::Result; -use crate::protocol::{QueryResponse, QueryStep, ResultPacket, Status}; +use crate::protocol::{PrepareResponse, QueryResponse, QueryStep, ResultPacket, Status}; use crate::{MySqlConnection, MySqlDatabaseError}; pub(crate) struct CommandQueue(VecDeque); @@ -20,44 +20,6 @@ impl CommandQueue { } } -#[derive(Debug)] -#[repr(u8)] -pub(crate) enum Command { - // expecting [ResultPacket] - Simple, - Query(QueryCommand), -} - -#[derive(Debug)] -#[repr(u8)] -pub(crate) enum QueryCommand { - // expecting [QueryResponse] - QueryResponse, - - // expecting [QueryStep] - QueryStep, - - // expecting {rem} more [ColumnDefinition] packets - ColumnDefinition { rem: u64 }, -} - -impl QueryCommand { - pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self { - queue.0.push_back(Command::Query(Self::QueryResponse)); - - if let Some(Command::Query(cmd)) = queue.0.back_mut() { - cmd - } else { - // UNREACHABLE: just pushed a query command to the back of the vector, and we - // have &mut access, nobody else is pushing to it - #[allow(unsafe_code)] - unsafe { - unreachable_unchecked() - } - } - } -} - impl CommandQueue { pub(crate) fn end(&mut self) { self.0.pop_front(); @@ -85,6 +47,75 @@ impl CommandQueue { } } +#[derive(Debug)] +#[repr(u8)] +pub(crate) enum Command { + // expecting [ResultPacket] + Simple, + Query(QueryCommand), + Prepare(PrepareCommand), +} + +#[derive(Debug)] +#[repr(u8)] +pub(crate) enum QueryCommand { + // expecting [QueryResponse] + QueryResponse, + + // expecting [QueryStep] + QueryStep, + + // expecting {rem} more [ColumnDefinition] packets + ColumnDefinition { rem: u16 }, +} + +impl QueryCommand { + pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self { + queue.0.push_back(Command::Query(Self::QueryResponse)); + + if let Some(Command::Query(cmd)) = queue.0.back_mut() { + cmd + } else { + // UNREACHABLE: just pushed a query command to the back of the vector, and we + // have &mut access, nobody else is pushing to it + #[allow(unsafe_code)] + unsafe { + unreachable_unchecked() + } + } + } +} + +#[derive(Debug)] +pub(crate) enum PrepareCommand { + // expecting [ERR] or [COM_STMT_PREPARE_OK] + PrepareResponse, + + // expecting {rem} more [ColumnDefinition] packets for each parameter + // stores {columns} as this state is before the [ColumnDefinition] state + ParameterDefinition { rem: u16, columns: u16 }, + + // expecting {rem} more [ColumnDefinition] packets for each parameter + ColumnDefinition { rem: u16 }, +} + +impl PrepareCommand { + pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self { + queue.0.push_back(Command::Prepare(Self::PrepareResponse)); + + if let Some(Command::Prepare(cmd)) = queue.0.back_mut() { + cmd + } else { + // UNREACHABLE: just pushed a prepare command to the back of the vector, and we + // have &mut access, nobody else is pushing to it + #[allow(unsafe_code)] + unsafe { + unreachable_unchecked() + } + } + } +} + macro_rules! impl_flush { ($(@$blocking:ident)? $self:ident) => {{ let Self { ref mut commands, ref mut stream, capabilities, .. } = *$self; @@ -99,6 +130,55 @@ macro_rules! impl_flush { commands.maybe_end(read_packet!($(@$blocking)? stream).deserialize_with(capabilities)?); } + Command::Prepare(ref mut cmd) => { + loop { + match cmd { + PrepareCommand::PrepareResponse => { + match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? { + PrepareResponse::Ok(ok) => { + // STATE: expect the parameter definitions next + *cmd = PrepareCommand::ParameterDefinition { rem: ok.params, columns: ok.columns }; + } + + PrepareResponse::Err(error) => { + // without context, we should not bubble this err; log and continue forward + log::error!("{}", MySqlDatabaseError(error)); + + // STATE: end of command + break commands.end(); + } + } + } + + PrepareCommand::ParameterDefinition { rem, columns } => { + if *rem == 0 { + // no more parameters + // STATE: expect columns next + *cmd = PrepareCommand::ColumnDefinition { rem: *columns }; + continue; + } + + let _ = read_packet!($(@$blocking)? stream); + + // STATE: now expecting the next parameter + *cmd = PrepareCommand::ParameterDefinition { rem: *rem - 1, columns: *columns }; + } + + PrepareCommand::ColumnDefinition { rem } => { + if *rem == 0 { + // no more columns; done + break commands.end(); + } + + let _ = read_packet!($(@$blocking)? stream); + + // STATE: now expecting the next parameter + *cmd = PrepareCommand::ColumnDefinition { rem: *rem - 1 }; + } + } + } + } + Command::Query(ref mut cmd) => { loop { match cmd { @@ -116,15 +196,17 @@ macro_rules! impl_flush { // expecting a column definition // remembers how many more column definitions we need QueryCommand::ColumnDefinition { rem } => { - let _ = read_packet!($(@$blocking)? stream); - - if *rem > 0 { - // STATE: now expecting the next column - *cmd = QueryCommand::ColumnDefinition { rem: *rem - 1 }; - } else { + if *rem == 0 { + // no more parameters // STATE: now expecting OK (END), ERR, or a row *cmd = QueryCommand::QueryStep; + continue; } + + let _ = read_packet!($(@$blocking)? stream); + + // STATE: now expecting the next column + *cmd = QueryCommand::ColumnDefinition { rem: *rem - 1 }; } // expecting OK, ERR, or a Row diff --git a/sqlx-mysql/src/connection/prepare.rs b/sqlx-mysql/src/connection/prepare.rs new file mode 100644 index 00000000..ad95a702 --- /dev/null +++ b/sqlx-mysql/src/connection/prepare.rs @@ -0,0 +1,95 @@ +use sqlx_core::{Result, Runtime}; + +use crate::connection::flush::PrepareCommand; +use crate::protocol::{ColumnDefinition, Prepare, PrepareResponse}; +use crate::{MySqlColumn, MySqlStatement, MySqlTypeInfo}; + +macro_rules! impl_prepare { + ($(@$blocking:ident)? $self:ident, $sql:ident) => {{ + let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self; + + // send the server a query that to be prepared + stream.write_packet(&Prepare { sql: $sql })?; + + // STATE: remember that we are now expecting a prepare response + let cmd = PrepareCommand::begin(commands); + + let res = read_packet!($(@$blocking)? stream) + .deserialize_with::(capabilities)?.into_result(); + + let ok = match res { + Ok(ok) => ok, + Err(error) => { + // STATE: prepare failed, command ended + commands.end(); + + return Err(error); + }, + }; + + let mut stmt = MySqlStatement::new(ok.statement_id); + + stmt.parameters.reserve(ok.params.into()); + stmt.columns.reserve(ok.columns.into()); + + for index in (1..=ok.params).rev() { + // STATE: remember that we are expecting #rem more columns + *cmd = PrepareCommand::ParameterDefinition { rem: index, columns: ok.columns }; + + let def = read_packet!($(@$blocking)? stream).deserialize()?; + + // extract the type only from the column definition + // most other fields are useless + stmt.parameters.push(MySqlTypeInfo::new(&def)); + } + + // TODO: handle EOF for old MySQL + + for (ordinal, rem) in (1..=ok.columns).rev().enumerate() { + // STATE: remember that we are expecting #rem more columns + *cmd = PrepareCommand::ColumnDefinition { rem }; + + let def = read_packet!($(@$blocking)? stream).deserialize()?; + + stmt.columns.push(MySqlColumn::new(ordinal, def)); + } + + // TODO: handle EOF for old MySQL + + // STATE: the command is complete + commands.end(); + + Ok(stmt) + }}; +} + +// TODO: should be private +impl super::MySqlConnection { + #[cfg(feature = "async")] + pub async fn prepare_async(&mut self, sql: &str) -> Result + where + Rt: sqlx_core::Async, + { + flush!(self); + impl_prepare!(self, sql) + } + + #[cfg(feature = "blocking")] + pub fn prepare_blocking(&mut self, sql: &str) -> Result + where + Rt: sqlx_core::blocking::Runtime, + { + flush!(@blocking self); + impl_prepare!(@blocking self, sql) + } +} + +macro_rules! prepare { + (@blocking $self:ident, $sql:expr) => { + $self.prepare_blocking($sql)? + }; + + ($self:ident, $sql:expr) => { + $self.prepare_async($sql).await? + }; +} diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index c844c4ce..dc56c016 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -30,11 +30,12 @@ mod options; mod output; mod protocol; mod query_result; +mod raw_statement; mod row; mod type_id; mod type_info; pub mod types; -mod value; +mod raw_value; #[cfg(test)] mod mock; @@ -49,4 +50,4 @@ pub use query_result::MySqlQueryResult; pub use row::MySqlRow; pub use type_id::MySqlTypeId; pub use type_info::MySqlTypeInfo; -pub use value::{MySqlRawValue, MySqlRawValueFormat}; +pub use raw_value::{MySqlRawValue, MySqlRawValueFormat}; diff --git a/sqlx-mysql/src/raw_statement.rs b/sqlx-mysql/src/raw_statement.rs new file mode 100644 index 00000000..8e511edb --- /dev/null +++ b/sqlx-mysql/src/raw_statement.rs @@ -0,0 +1,23 @@ +use crate::protocol::PrepareOk; +use crate::{MySqlColumn, MySqlTypeInfo}; + +#[derive(Debug)] +pub(crate) struct RawStatement { + id: u32, + pub(crate) columns: Vec, + pub(crate) parameters: Vec, +} + +impl RawStatement { + pub(crate) fn new(ok: PrepareOk) -> Self { + Self { + id: ok.statement_id, + columns: Vec::with_capacity(ok.columns.into()), + parameters: Vec::with_capacity(ok.parameters.into()), + } + } + + pub(crate) fn id(&self) -> u32 { + self.id + } +} diff --git a/sqlx-mysql/src/value.rs b/sqlx-mysql/src/raw_value.rs similarity index 100% rename from sqlx-mysql/src/value.rs rename to sqlx-mysql/src/raw_value.rs