feat(mysql): implement the prepare phase for statements

This commit is contained in:
Ryan Leckey 2021-02-18 22:17:13 -08:00
parent 080fa46126
commit 60cf88c38f
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
6 changed files with 251 additions and 47 deletions

View File

@ -13,6 +13,9 @@ use crate::{MySql, MySqlConnectOptions};
#[macro_use]
mod flush;
#[macro_use]
mod prepare;
mod close;
mod connect;
mod executor;

View File

@ -3,7 +3,7 @@ use std::hint::unreachable_unchecked;
use sqlx_core::Result;
use crate::protocol::{QueryResponse, QueryStep, ResultPacket, Status};
use crate::protocol::{PrepareResponse, QueryResponse, QueryStep, ResultPacket, Status};
use crate::{MySqlConnection, MySqlDatabaseError};
pub(crate) struct CommandQueue(VecDeque<Command>);
@ -20,44 +20,6 @@ impl CommandQueue {
}
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum Command {
// expecting [ResultPacket]
Simple,
Query(QueryCommand),
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum QueryCommand {
// expecting [QueryResponse]
QueryResponse,
// expecting [QueryStep]
QueryStep,
// expecting {rem} more [ColumnDefinition] packets
ColumnDefinition { rem: u64 },
}
impl QueryCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self {
queue.0.push_back(Command::Query(Self::QueryResponse));
if let Some(Command::Query(cmd)) = queue.0.back_mut() {
cmd
} else {
// UNREACHABLE: just pushed a query command to the back of the vector, and we
// have &mut access, nobody else is pushing to it
#[allow(unsafe_code)]
unsafe {
unreachable_unchecked()
}
}
}
}
impl CommandQueue {
pub(crate) fn end(&mut self) {
self.0.pop_front();
@ -85,6 +47,75 @@ impl CommandQueue {
}
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum Command {
// expecting [ResultPacket]
Simple,
Query(QueryCommand),
Prepare(PrepareCommand),
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum QueryCommand {
// expecting [QueryResponse]
QueryResponse,
// expecting [QueryStep]
QueryStep,
// expecting {rem} more [ColumnDefinition] packets
ColumnDefinition { rem: u16 },
}
impl QueryCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self {
queue.0.push_back(Command::Query(Self::QueryResponse));
if let Some(Command::Query(cmd)) = queue.0.back_mut() {
cmd
} else {
// UNREACHABLE: just pushed a query command to the back of the vector, and we
// have &mut access, nobody else is pushing to it
#[allow(unsafe_code)]
unsafe {
unreachable_unchecked()
}
}
}
}
#[derive(Debug)]
pub(crate) enum PrepareCommand {
// expecting [ERR] or [COM_STMT_PREPARE_OK]
PrepareResponse,
// expecting {rem} more [ColumnDefinition] packets for each parameter
// stores {columns} as this state is before the [ColumnDefinition] state
ParameterDefinition { rem: u16, columns: u16 },
// expecting {rem} more [ColumnDefinition] packets for each parameter
ColumnDefinition { rem: u16 },
}
impl PrepareCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self {
queue.0.push_back(Command::Prepare(Self::PrepareResponse));
if let Some(Command::Prepare(cmd)) = queue.0.back_mut() {
cmd
} else {
// UNREACHABLE: just pushed a prepare command to the back of the vector, and we
// have &mut access, nobody else is pushing to it
#[allow(unsafe_code)]
unsafe {
unreachable_unchecked()
}
}
}
}
macro_rules! impl_flush {
($(@$blocking:ident)? $self:ident) => {{
let Self { ref mut commands, ref mut stream, capabilities, .. } = *$self;
@ -99,6 +130,55 @@ macro_rules! impl_flush {
commands.maybe_end(read_packet!($(@$blocking)? stream).deserialize_with(capabilities)?);
}
Command::Prepare(ref mut cmd) => {
loop {
match cmd {
PrepareCommand::PrepareResponse => {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
PrepareResponse::Ok(ok) => {
// STATE: expect the parameter definitions next
*cmd = PrepareCommand::ParameterDefinition { rem: ok.params, columns: ok.columns };
}
PrepareResponse::Err(error) => {
// without context, we should not bubble this err; log and continue forward
log::error!("{}", MySqlDatabaseError(error));
// STATE: end of command
break commands.end();
}
}
}
PrepareCommand::ParameterDefinition { rem, columns } => {
if *rem == 0 {
// no more parameters
// STATE: expect columns next
*cmd = PrepareCommand::ColumnDefinition { rem: *columns };
continue;
}
let _ = read_packet!($(@$blocking)? stream);
// STATE: now expecting the next parameter
*cmd = PrepareCommand::ParameterDefinition { rem: *rem - 1, columns: *columns };
}
PrepareCommand::ColumnDefinition { rem } => {
if *rem == 0 {
// no more columns; done
break commands.end();
}
let _ = read_packet!($(@$blocking)? stream);
// STATE: now expecting the next parameter
*cmd = PrepareCommand::ColumnDefinition { rem: *rem - 1 };
}
}
}
}
Command::Query(ref mut cmd) => {
loop {
match cmd {
@ -116,15 +196,17 @@ macro_rules! impl_flush {
// expecting a column definition
// remembers how many more column definitions we need
QueryCommand::ColumnDefinition { rem } => {
let _ = read_packet!($(@$blocking)? stream);
if *rem > 0 {
// STATE: now expecting the next column
*cmd = QueryCommand::ColumnDefinition { rem: *rem - 1 };
} else {
if *rem == 0 {
// no more parameters
// STATE: now expecting OK (END), ERR, or a row
*cmd = QueryCommand::QueryStep;
continue;
}
let _ = read_packet!($(@$blocking)? stream);
// STATE: now expecting the next column
*cmd = QueryCommand::ColumnDefinition { rem: *rem - 1 };
}
// expecting OK, ERR, or a Row

View File

@ -0,0 +1,95 @@
use sqlx_core::{Result, Runtime};
use crate::connection::flush::PrepareCommand;
use crate::protocol::{ColumnDefinition, Prepare, PrepareResponse};
use crate::{MySqlColumn, MySqlStatement, MySqlTypeInfo};
macro_rules! impl_prepare {
($(@$blocking:ident)? $self:ident, $sql:ident) => {{
let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self;
// send the server a query that to be prepared
stream.write_packet(&Prepare { sql: $sql })?;
// STATE: remember that we are now expecting a prepare response
let cmd = PrepareCommand::begin(commands);
let res = read_packet!($(@$blocking)? stream)
.deserialize_with::<PrepareResponse, _>(capabilities)?.into_result();
let ok = match res {
Ok(ok) => ok,
Err(error) => {
// STATE: prepare failed, command ended
commands.end();
return Err(error);
},
};
let mut stmt = MySqlStatement::new(ok.statement_id);
stmt.parameters.reserve(ok.params.into());
stmt.columns.reserve(ok.columns.into());
for index in (1..=ok.params).rev() {
// STATE: remember that we are expecting #rem more columns
*cmd = PrepareCommand::ParameterDefinition { rem: index, columns: ok.columns };
let def = read_packet!($(@$blocking)? stream).deserialize()?;
// extract the type only from the column definition
// most other fields are useless
stmt.parameters.push(MySqlTypeInfo::new(&def));
}
// TODO: handle EOF for old MySQL
for (ordinal, rem) in (1..=ok.columns).rev().enumerate() {
// STATE: remember that we are expecting #rem more columns
*cmd = PrepareCommand::ColumnDefinition { rem };
let def = read_packet!($(@$blocking)? stream).deserialize()?;
stmt.columns.push(MySqlColumn::new(ordinal, def));
}
// TODO: handle EOF for old MySQL
// STATE: the command is complete
commands.end();
Ok(stmt)
}};
}
// TODO: should be private
impl<Rt: Runtime> super::MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub async fn prepare_async(&mut self, sql: &str) -> Result<MySqlStatement>
where
Rt: sqlx_core::Async,
{
flush!(self);
impl_prepare!(self, sql)
}
#[cfg(feature = "blocking")]
pub fn prepare_blocking(&mut self, sql: &str) -> Result<MySqlStatement>
where
Rt: sqlx_core::blocking::Runtime,
{
flush!(@blocking self);
impl_prepare!(@blocking self, sql)
}
}
macro_rules! prepare {
(@blocking $self:ident, $sql:expr) => {
$self.prepare_blocking($sql)?
};
($self:ident, $sql:expr) => {
$self.prepare_async($sql).await?
};
}

View File

@ -30,11 +30,12 @@ mod options;
mod output;
mod protocol;
mod query_result;
mod raw_statement;
mod row;
mod type_id;
mod type_info;
pub mod types;
mod value;
mod raw_value;
#[cfg(test)]
mod mock;
@ -49,4 +50,4 @@ pub use query_result::MySqlQueryResult;
pub use row::MySqlRow;
pub use type_id::MySqlTypeId;
pub use type_info::MySqlTypeInfo;
pub use value::{MySqlRawValue, MySqlRawValueFormat};
pub use raw_value::{MySqlRawValue, MySqlRawValueFormat};

View File

@ -0,0 +1,23 @@
use crate::protocol::PrepareOk;
use crate::{MySqlColumn, MySqlTypeInfo};
#[derive(Debug)]
pub(crate) struct RawStatement {
id: u32,
pub(crate) columns: Vec<MySqlColumn>,
pub(crate) parameters: Vec<MySqlTypeInfo>,
}
impl RawStatement {
pub(crate) fn new(ok: PrepareOk) -> Self {
Self {
id: ok.statement_id,
columns: Vec::with_capacity(ok.columns.into()),
parameters: Vec::with_capacity(ok.parameters.into()),
}
}
pub(crate) fn id(&self) -> u32 {
self.id
}
}