feat(mysql): thread E: Execute through the executor

+ handle normal errors separate from unexpected errors, an unexpected
   error causes the connection to close (in which case, if this was
   behind a pool, the pool would not allow this connection to be
   acquired again)
This commit is contained in:
Ryan Leckey 2021-02-18 23:35:36 -08:00
parent 60cf88c38f
commit 7db850da71
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
15 changed files with 450 additions and 271 deletions

View File

@ -30,6 +30,8 @@ pub enum Error {
///
RowNotFound,
Closed,
Decode(DecodeError),
Encode(EncodeError),
@ -82,6 +84,8 @@ impl Display for Error {
f.write_str("no row returned by a query required to return at least one row")
}
Self::Closed => f.write_str("connection or pool was closed"),
Self::Decode(error) => {
write!(f, "{}", error)
}

View File

@ -5,7 +5,7 @@ use futures_util::future::{BoxFuture, FutureExt};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Close, Connect, Connection, Runtime};
use crate::connection::flush::CommandQueue;
use crate::connection::command::CommandQueue;
use crate::protocol::Capabilities;
use crate::stream::MySqlStream;
use crate::{MySql, MySqlConnectOptions};
@ -13,10 +13,7 @@ use crate::{MySql, MySqlConnectOptions};
#[macro_use]
mod flush;
#[macro_use]
mod prepare;
mod close;
mod command;
mod connect;
mod executor;
mod ping;
@ -29,6 +26,7 @@ where
{
stream: MySqlStream<Rt>,
connection_id: u32,
closed: bool,
// the capability flags are used by the client and server to indicate which
// features they support and want to use.
@ -48,6 +46,7 @@ where
Self {
stream: MySqlStream::new(stream),
connection_id: 0,
closed: false,
commands: CommandQueue::new(),
capabilities: Capabilities::PROTOCOL_41
| Capabilities::LONG_PASSWORD
@ -106,11 +105,15 @@ impl<Rt: Runtime> Connect<Rt> for MySqlConnection<Rt> {
impl<Rt: Runtime> Close<Rt> for MySqlConnection<Rt> {
#[cfg(feature = "async")]
fn close(self) -> BoxFuture<'static, sqlx_core::Result<()>>
fn close(mut self) -> BoxFuture<'static, sqlx_core::Result<()>>
where
Rt: sqlx_core::Async,
{
Box::pin(self.close_async())
Box::pin(async move {
self.stream.close_async().await?;
Ok(())
})
}
}
@ -139,8 +142,8 @@ mod blocking {
impl<Rt: Runtime> Close<Rt> for MySqlConnection<Rt> {
#[inline]
fn close(self) -> sqlx_core::Result<()> {
self.close_blocking()
fn close(mut self) -> sqlx_core::Result<()> {
self.stream.close_blocking()
}
}
}

View File

@ -1,29 +0,0 @@
use sqlx_core::{io::Stream, Result, Runtime};
use crate::protocol::Quit;
impl<Rt: Runtime> super::MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn close_async(mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
self.stream.write_packet(&Quit)?;
self.stream.flush_async().await?;
self.stream.shutdown_async().await?;
Ok(())
}
#[cfg(feature = "blocking")]
pub(crate) fn close_blocking(mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
self.stream.write_packet(&Quit)?;
self.stream.flush()?;
self.stream.shutdown()?;
Ok(())
}
}

View File

@ -0,0 +1,151 @@
use std::collections::VecDeque;
use std::hint::unreachable_unchecked;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut};
use sqlx_core::Result;
use crate::protocol::{PrepareResponse, QueryResponse, QueryStep, ResultPacket, Status};
use crate::{MySqlConnection, MySqlDatabaseError};
pub(crate) struct CommandQueue(pub(super) VecDeque<Command>);
impl CommandQueue {
pub(crate) fn new() -> Self {
Self(VecDeque::with_capacity(2))
}
// begin a simple command
// in which we are expecting OK or ERR (a result)
pub(crate) fn begin(&mut self) {
self.0.push_back(Command::Simple);
}
}
impl CommandQueue {
pub(crate) fn end(&mut self) {
self.0.pop_front();
}
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum Command {
Simple,
Close,
Query(QueryCommand),
Prepare(PrepareCommand),
}
pub(crate) struct CommandGuard<'cmd, C> {
queue: &'cmd mut CommandQueue,
command: PhantomData<&'cmd mut C>,
index: usize,
ended: bool,
}
impl<'cmd, C> CommandGuard<'cmd, C> {
fn begin(queue: &'cmd mut CommandQueue, command: Command) -> Self {
let index = queue.0.len();
queue.0.push_back(command);
Self { queue, index, ended: false, command: PhantomData }
}
// called on successful command completion
pub(crate) fn end(&mut self) {
self.ended = true;
}
// on an error result, the command needs to end *normally* and pass
// through the error to bubble
pub(crate) fn end_if_error<T>(&mut self, res: Result<T>) -> Result<T> {
match res {
Ok(ok) => Ok(ok),
Err(error) => {
self.end();
Err(error)
}
}
}
}
impl<C> Drop for CommandGuard<'_, C> {
fn drop(&mut self) {
self.queue.end();
if !self.ended {
// if the command was not "completed" by success or a known
// failure, we are in a **weird** state, queue up a close if
// someone tries to re-use this connection
self.queue.0.push_front(Command::Close);
}
}
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum QueryCommand {
// expecting [QueryResponse]
QueryResponse,
// expecting [QueryStep]
QueryStep,
// expecting {rem} more [ColumnDefinition] packets
ColumnDefinition { rem: u16 },
}
impl QueryCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> CommandGuard<'_, Self> {
CommandGuard::begin(queue, Command::Query(Self::QueryResponse))
}
}
impl Deref for CommandGuard<'_, QueryCommand> {
type Target = QueryCommand;
fn deref(&self) -> &Self::Target {
if let Command::Query(cmd) = &self.queue.0[self.index] { cmd } else { unreachable!() }
}
}
impl DerefMut for CommandGuard<'_, QueryCommand> {
fn deref_mut(&mut self) -> &mut Self::Target {
if let Command::Query(cmd) = &mut self.queue.0[self.index] { cmd } else { unreachable!() }
}
}
#[derive(Debug)]
pub(crate) enum PrepareCommand {
// expecting [ERR] or [COM_STMT_PREPARE_OK]
PrepareResponse,
// expecting {rem} more [ColumnDefinition] packets for each parameter
// stores {columns} as this state is before the [ColumnDefinition] state
ParameterDefinition { rem: u16, columns: u16 },
// expecting {rem} more [ColumnDefinition] packets for each parameter
ColumnDefinition { rem: u16 },
}
impl PrepareCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> CommandGuard<'_, Self> {
CommandGuard::begin(queue, Command::Prepare(Self::PrepareResponse))
}
}
impl Deref for CommandGuard<'_, PrepareCommand> {
type Target = PrepareCommand;
fn deref(&self) -> &Self::Target {
if let Command::Prepare(cmd) = &self.queue.0[self.index] { cmd } else { unreachable!() }
}
}
impl DerefMut for CommandGuard<'_, PrepareCommand> {
fn deref_mut(&mut self) -> &mut Self::Target {
if let Command::Prepare(cmd) = &mut self.queue.0[self.index] { cmd } else { unreachable!() }
}
}

View File

@ -1,12 +1,18 @@
#[cfg(feature = "async")]
use futures_util::{future::BoxFuture, FutureExt};
use sqlx_core::{Executor, Result, Runtime};
use sqlx_core::{Arguments, Execute, Executor, Result, Runtime};
use crate::{MySql, MySqlConnection, MySqlQueryResult, MySqlRow};
#[macro_use]
mod columns;
#[macro_use]
mod raw_prepare;
#[macro_use]
mod raw_query;
mod execute;
mod fetch_all;
mod fetch_optional;
@ -16,38 +22,44 @@ impl<Rt: Runtime> Executor<Rt> for MySqlConnection<Rt> {
#[cfg(feature = "async")]
#[inline]
fn execute<'x, 'e, 'q>(&'e mut self, sql: &'q str) -> BoxFuture<'x, Result<MySqlQueryResult>>
fn execute<'x, 'e, 'q, 'a, E>(&'e mut self, query: E) -> BoxFuture<'x, Result<MySqlQueryResult>>
where
Rt: sqlx_core::Async,
E: 'x + Execute<'q, 'a, MySql>,
'e: 'x,
'q: 'x,
'a: 'x,
{
self.execute_async(sql).boxed()
self.execute_async(query).boxed()
}
#[cfg(feature = "async")]
#[inline]
fn fetch_all<'x, 'e, 'q>(&'e mut self, sql: &'q str) -> BoxFuture<'x, Result<Vec<MySqlRow>>>
fn fetch_all<'x, 'e, 'q, 'a, E>(&'e mut self, query: E) -> BoxFuture<'x, Result<Vec<MySqlRow>>>
where
Rt: sqlx_core::Async,
E: 'x + Execute<'q, 'a, MySql>,
'e: 'x,
'q: 'x,
'a: 'x,
{
self.fetch_all_async(sql).boxed()
self.fetch_all_async(query).boxed()
}
#[cfg(feature = "async")]
#[inline]
fn fetch_optional<'x, 'e, 'q>(
fn fetch_optional<'x, 'e, 'q, 'a, E>(
&'e mut self,
sql: &'q str,
query: E,
) -> BoxFuture<'x, Result<Option<MySqlRow>>>
where
Rt: sqlx_core::Async,
E: 'x + Execute<'q, 'a, MySql>,
'e: 'x,
'q: 'x,
'a: 'x,
{
self.fetch_optional_async(sql).boxed()
self.fetch_optional_async(query).boxed()
}
}

View File

@ -1,22 +1,23 @@
use sqlx_core::{Result, Runtime};
use crate::connection::flush::QueryCommand;
use crate::connection::command::QueryCommand;
use crate::protocol::ColumnDefinition;
use crate::stream::MySqlStream;
use crate::MySqlColumn;
macro_rules! impl_recv_columns {
($(@$blocking:ident)? $store:expr, $num_columns:ident, $stream:ident, $cmd:ident) => {{
#[allow(clippy::cast_possible_truncation)]
let mut columns = if $store {
Vec::<ColumnDefinition>::with_capacity($num_columns as usize)
Vec::<MySqlColumn>::with_capacity($num_columns as usize)
} else {
// we are going to drop column definitions, do not allocate
Vec::new()
};
for index in (1..=$num_columns).rev() {
for (ordinal, rem) in (1..=$num_columns).rev().enumerate() {
// STATE: remember that we are expecting #rem more columns
*$cmd = QueryCommand::ColumnDefinition { rem: index };
*$cmd = QueryCommand::ColumnDefinition { rem };
// read in definition and only deserialize if we are saving
// the column definitions
@ -24,7 +25,7 @@ macro_rules! impl_recv_columns {
let packet = read_packet!($(@$blocking)? $stream);
if $store {
columns.push(packet.deserialize()?);
columns.push(MySqlColumn::new(ordinal, packet.deserialize()?));
}
}
@ -40,9 +41,9 @@ impl<Rt: Runtime> MySqlStream<Rt> {
pub(super) async fn recv_columns_async(
&mut self,
store: bool,
columns: u64,
columns: u16,
cmd: &mut QueryCommand,
) -> Result<Vec<ColumnDefinition>>
) -> Result<Vec<MySqlColumn>>
where
Rt: sqlx_core::Async,
{
@ -53,9 +54,9 @@ impl<Rt: Runtime> MySqlStream<Rt> {
pub(crate) fn recv_columns_blocking(
&mut self,
store: bool,
columns: u64,
columns: u16,
cmd: &mut QueryCommand,
) -> Result<Vec<ColumnDefinition>>
) -> Result<Vec<MySqlColumn>>
where
Rt: sqlx_core::blocking::Runtime,
{
@ -65,10 +66,10 @@ impl<Rt: Runtime> MySqlStream<Rt> {
macro_rules! recv_columns {
(@blocking $store:expr, $columns:ident, $stream:ident, $cmd:ident) => {
$stream.recv_columns_blocking($store, $columns, $cmd)?
$stream.recv_columns_blocking($store, $columns, &mut *$cmd)?
};
($store:expr, $columns:ident, $stream:ident, $cmd:ident) => {
$stream.recv_columns_async($store, $columns, $cmd).await?
$stream.recv_columns_async($store, $columns, &mut *$cmd).await?
};
}

View File

@ -1,19 +1,17 @@
use sqlx_core::Result;
use sqlx_core::{Execute, Result, Runtime};
use crate::connection::flush::QueryCommand;
use crate::connection::command::QueryCommand;
use crate::protocol::{Query, QueryResponse, QueryStep, Status};
use crate::{MySqlConnection, MySqlQueryResult};
use crate::{MySql, MySqlConnection, MySqlQueryResult};
macro_rules! impl_execute {
($(@$blocking:ident)? $self:ident, $sql:ident) => {{
($(@$blocking:ident)? $self:ident, $query:ident) => {{
raw_query!($(@$blocking)? $self, $query);
let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self;
// send the server a text-based query that will be executed immediately
// replies with ERR, OK, or a result set
stream.write_packet(&Query { sql: $sql })?;
// STATE: remember that we are now expecting a query response
let cmd = QueryCommand::begin(commands);
let mut cmd = QueryCommand::begin(commands);
// default an empty query result
// execute collects all discovered query results and SUMs
@ -22,9 +20,9 @@ macro_rules! impl_execute {
#[allow(clippy::while_let_loop, unused_labels)]
'results: loop {
let ok = 'result: loop {
let res = 'result: loop {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
QueryResponse::End(res) => break 'result res.into_result()?,
QueryResponse::End(res) => break 'result res.into_result(),
QueryResponse::ResultSet { columns } => {
// acknowledge but discard any columns as execute returns no rows
recv_columns!($(@$blocking)? /* store = */ false, columns, stream, cmd);
@ -34,13 +32,16 @@ macro_rules! impl_execute {
// execute ignores any rows returned
// but we do increment affected rows
QueryStep::Row(_row) => result.0.affected_rows += 1,
QueryStep::End(res) => break 'result res.into_result()?,
QueryStep::End(res) => break 'result res.into_result(),
}
}
}
}
};
// STATE: command is complete on error
let ok = cmd.end_if_error(res)?;
// fold this into the total result for the SQL
result.extend(Some(ok.into()));
@ -54,24 +55,30 @@ macro_rules! impl_execute {
}
// STATE: the current command is complete
commands.end();
cmd.end();
Ok(result)
}};
}
#[cfg(feature = "async")]
impl<Rt: sqlx_core::Async> MySqlConnection<Rt> {
pub(super) async fn execute_async(&mut self, sql: &str) -> Result<MySqlQueryResult> {
impl<Rt: Runtime> MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(super) async fn execute_async<'q, 'a, E>(&mut self, query: E) -> Result<MySqlQueryResult>
where
Rt: sqlx_core::Async,
E: Execute<'q, 'a, MySql>,
{
flush!(self);
impl_execute!(self, sql)
impl_execute!(self, query)
}
}
#[cfg(feature = "blocking")]
impl<Rt: sqlx_core::blocking::Runtime> MySqlConnection<Rt> {
pub(super) fn execute_blocking(&mut self, sql: &str) -> Result<MySqlQueryResult> {
#[cfg(feature = "blocking")]
pub(super) fn execute_blocking<'q, 'a, E>(&mut self, query: E) -> Result<MySqlQueryResult>
where
Rt: sqlx_core::blocking::Runtime,
E: Execute<'q, 'a, MySql>,
{
flush!(@blocking self);
impl_execute!(@blocking self, sql)
impl_execute!(@blocking self, query)
}
}

View File

@ -1,28 +1,26 @@
use sqlx_core::Result;
use sqlx_core::{Arguments, Execute, Result, Runtime};
use crate::connection::flush::QueryCommand;
use crate::protocol::{Query, QueryResponse, QueryStep, Status};
use crate::{MySqlConnection, MySqlRow};
use crate::connection::command::QueryCommand;
use crate::protocol::{self, Query, QueryResponse, QueryStep, Status};
use crate::{MySql, MySqlConnection, MySqlRawValueFormat, MySqlRow};
macro_rules! impl_fetch_all {
($(@$blocking:ident)? $self:ident, $sql:ident) => {{
($(@$blocking:ident)? $self:ident, $query:ident) => {{
let format = raw_query!($(@$blocking)? $self, $query);
let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self;
// send the server a text-based query that will be executed immediately
// replies with ERR, OK, or a result set
stream.write_packet(&Query { sql: $sql })?;
// STATE: remember that we are now expecting a query response
let cmd = QueryCommand::begin(commands);
let mut cmd = QueryCommand::begin(commands);
// default an empty row set
let mut rows = Vec::with_capacity(10);
#[allow(clippy::while_let_loop, unused_labels)]
'results: loop {
let ok = 'result: loop {
let res = 'result: loop {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
QueryResponse::End(res) => break 'result res.into_result()?,
QueryResponse::End(res) => break 'result res.into_result(),
QueryResponse::ResultSet { columns } => {
let columns = recv_columns!($(@$blocking)? /* store = */ true, columns, stream, cmd);
@ -30,14 +28,17 @@ macro_rules! impl_fetch_all {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
// execute ignores any rows returned
// but we do increment affected rows
QueryStep::End(res) => break 'result res.into_result()?,
QueryStep::Row(row) => rows.push(MySqlRow::new(row.deserialize_with(&columns[..])?)),
QueryStep::End(res) => break 'result res.into_result(),
QueryStep::Row(row) => rows.push(MySqlRow::new(row.deserialize_with((format, &columns[..]))?, &columns)),
}
}
}
}
};
// STATE: command is complete on error
let ok = cmd.end_if_error(res)?;
if !ok.status.contains(Status::MORE_RESULTS_EXISTS) {
// no more results, time to finally call it quits
break;
@ -48,24 +49,30 @@ macro_rules! impl_fetch_all {
}
// STATE: the current command is complete
commands.end();
cmd.end();
Ok(rows)
}};
}
#[cfg(feature = "async")]
impl<Rt: sqlx_core::Async> MySqlConnection<Rt> {
pub(super) async fn fetch_all_async(&mut self, sql: &str) -> Result<Vec<MySqlRow>> {
impl<Rt: Runtime> MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(super) async fn fetch_all_async<'q, 'a, E>(&mut self, query: E) -> Result<Vec<MySqlRow>>
where
Rt: sqlx_core::Async,
E: Execute<'q, 'a, MySql>,
{
flush!(self);
impl_fetch_all!(self, sql)
impl_fetch_all!(self, query)
}
}
#[cfg(feature = "blocking")]
impl<Rt: sqlx_core::blocking::Runtime> MySqlConnection<Rt> {
pub(super) fn fetch_all_blocking(&mut self, sql: &str) -> Result<Vec<MySqlRow>> {
#[cfg(feature = "blocking")]
pub(super) fn fetch_all_blocking<'q, 'a, E>(&mut self, query: E) -> Result<Vec<MySqlRow>>
where
Rt: sqlx_core::blocking::Runtime,
E: Execute<'q, 'a, MySql>,
{
flush!(@blocking self);
impl_fetch_all!(@blocking self, sql)
impl_fetch_all!(@blocking self, query)
}
}

View File

@ -1,39 +1,36 @@
use sqlx_core::Result;
use sqlx_core::{Execute, Result, Runtime};
use crate::connection::flush::QueryCommand;
use crate::connection::command::QueryCommand;
use crate::protocol::{Query, QueryResponse, QueryStep, Status};
use crate::{MySqlConnection, MySqlRow};
use crate::{MySql, MySqlConnection, MySqlRawValueFormat, MySqlRow};
macro_rules! impl_fetch_optional {
($(@$blocking:ident)? $self:ident, $sql:ident) => {{
($(@$blocking:ident)? $self:ident, $query:ident) => {{
let format = raw_query!($(@$blocking)? $self, $query);
let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self;
// send the server a text-based query that will be executed immediately
// replies with ERR, OK, or a result set
stream.write_packet(&Query { sql: $sql })?;
// STATE: remember that we are now expecting a query response
let cmd = QueryCommand::begin(commands);
let mut cmd = QueryCommand::begin(commands);
// default we did not find a row
let mut first_row = None;
#[allow(clippy::while_let_loop, unused_labels)]
'results: loop {
let ok = 'result: loop {
let res = 'result: loop {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
QueryResponse::End(res) => break 'result res.into_result()?,
QueryResponse::End(res) => break 'result res.into_result(),
QueryResponse::ResultSet { columns } => {
let columns = recv_columns!($(@$blocking)? /* store = */ true, columns, stream, cmd);
log::debug!("columns = {:?}", columns);
'rows: loop {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
// execute ignores any rows returned
// but we do increment affected rows
QueryStep::End(res) => break 'result res.into_result()?,
QueryStep::End(res) => break 'result res.into_result(),
QueryStep::Row(row) => {
first_row = Some(MySqlRow::new(row.deserialize_with(&columns[..])?));
first_row = Some(MySqlRow::new(row.deserialize_with((format, &columns[..]))?, &columns));
// get out as soon as possible after finding our one row
break 'results;
@ -44,9 +41,12 @@ macro_rules! impl_fetch_optional {
}
};
// STATE: command is complete on error
let ok = cmd.end_if_error(res)?;
if !ok.status.contains(Status::MORE_RESULTS_EXISTS) {
// STATE: the current command is complete
commands.end();
cmd.end();
// no more results, time to finally call it quits and give up
break;
@ -60,18 +60,30 @@ macro_rules! impl_fetch_optional {
}};
}
#[cfg(feature = "async")]
impl<Rt: sqlx_core::Async> MySqlConnection<Rt> {
pub(super) async fn fetch_optional_async(&mut self, sql: &str) -> Result<Option<MySqlRow>> {
impl<Rt: Runtime> MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(super) async fn fetch_optional_async<'q, 'a, E>(
&mut self,
query: E,
) -> Result<Option<MySqlRow>>
where
Rt: sqlx_core::Async,
E: Execute<'q, 'a, MySql>,
{
flush!(self);
impl_fetch_optional!(self, sql)
impl_fetch_optional!(self, query)
}
}
#[cfg(feature = "blocking")]
impl<Rt: sqlx_core::blocking::Runtime> MySqlConnection<Rt> {
pub(super) fn fetch_optional_blocking(&mut self, sql: &str) -> Result<Option<MySqlRow>> {
#[cfg(feature = "blocking")]
pub(super) fn fetch_optional_blocking<'q, 'a, E>(
&mut self,
query: E,
) -> Result<Option<MySqlRow>>
where
Rt: sqlx_core::blocking::Runtime,
E: Execute<'q, 'a, MySql>,
{
flush!(@blocking self);
impl_fetch_optional!(@blocking self, sql)
impl_fetch_optional!(@blocking self, query)
}
}

View File

@ -1,10 +1,11 @@
use sqlx_core::{Result, Runtime};
use crate::connection::flush::PrepareCommand;
use crate::connection::command::PrepareCommand;
use crate::protocol::{ColumnDefinition, Prepare, PrepareResponse};
use crate::{MySqlColumn, MySqlStatement, MySqlTypeInfo};
use crate::raw_statement::RawStatement;
use crate::{MySqlColumn, MySqlTypeInfo};
macro_rules! impl_prepare {
macro_rules! impl_raw_prepare {
($(@$blocking:ident)? $self:ident, $sql:ident) => {{
let Self { ref mut stream, ref mut commands, capabilities, .. } = *$self;
@ -12,25 +13,15 @@ macro_rules! impl_prepare {
stream.write_packet(&Prepare { sql: $sql })?;
// STATE: remember that we are now expecting a prepare response
let cmd = PrepareCommand::begin(commands);
let mut cmd = PrepareCommand::begin(commands);
let res = read_packet!($(@$blocking)? stream)
.deserialize_with::<PrepareResponse, _>(capabilities)?.into_result();
let ok = match res {
Ok(ok) => ok,
Err(error) => {
// STATE: prepare failed, command ended
commands.end();
// STATE: command is complete on error
let ok = cmd.end_if_error(res)?;
return Err(error);
},
};
let mut stmt = MySqlStatement::new(ok.statement_id);
stmt.parameters.reserve(ok.params.into());
stmt.columns.reserve(ok.columns.into());
let mut stmt = RawStatement::new(&ok);
for index in (1..=ok.params).rev() {
// STATE: remember that we are expecting #rem more columns
@ -57,39 +48,38 @@ macro_rules! impl_prepare {
// TODO: handle EOF for old MySQL
// STATE: the command is complete
commands.end();
cmd.end();
Ok(stmt)
}};
}
// TODO: should be private
impl<Rt: Runtime> super::MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub async fn prepare_async(&mut self, sql: &str) -> Result<MySqlStatement>
pub(super) async fn raw_prepare_async(&mut self, sql: &str) -> Result<RawStatement>
where
Rt: sqlx_core::Async,
{
flush!(self);
impl_prepare!(self, sql)
impl_raw_prepare!(self, sql)
}
#[cfg(feature = "blocking")]
pub fn prepare_blocking(&mut self, sql: &str) -> Result<MySqlStatement>
pub(super) fn raw_prepare_blocking(&mut self, sql: &str) -> Result<RawStatement>
where
Rt: sqlx_core::blocking::Runtime,
{
flush!(@blocking self);
impl_prepare!(@blocking self, sql)
impl_raw_prepare!(@blocking self, sql)
}
}
macro_rules! prepare {
macro_rules! raw_prepare {
(@blocking $self:ident, $sql:expr) => {
$self.prepare_blocking($sql)?
$self.raw_prepare_blocking($sql)?
};
($self:ident, $sql:expr) => {
$self.prepare_async($sql).await?
$self.raw_prepare_async($sql).await?
};
}

View File

@ -0,0 +1,67 @@
use sqlx_core::{Arguments, Execute, Result, Runtime};
use crate::protocol::{self, Query, QueryResponse, QueryStep, Status};
use crate::{MySql, MySqlConnection, MySqlRawValueFormat, MySqlRow};
macro_rules! impl_raw_query {
($(@$blocking:ident)? $self:ident, $query:ident) => {{
let format = if let Some(arguments) = $query.arguments() {
// prepare the statement for execution
let statement = raw_prepare!($(@$blocking:ident)? $self, $query.sql());
// execute the prepared statement
$self.stream.write_packet(&protocol::Execute {
statement: statement.id(),
parameters: &statement.parameters,
arguments: &arguments,
})?;
// prepared queries always use the BINARY format
MySqlRawValueFormat::Binary
} else {
// directly execute the query as an unprepared, simple query
$self.stream.write_packet(&Query { sql: $query.sql() })?;
// unprepared queries use the TEXT format
// this is a significant waste of bandwidth for large result sets
MySqlRawValueFormat::Text
};
Ok(format)
}};
}
impl<Rt: Runtime> MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(super) async fn raw_query_async<'q, 'a, E>(
&mut self,
query: E,
) -> Result<MySqlRawValueFormat>
where
Rt: sqlx_core::Async,
E: Execute<'q, 'a, MySql>,
{
flush!(self);
impl_raw_query!(self, query)
}
#[cfg(feature = "blocking")]
pub(super) fn raw_query_blocking<'q, 'a, E>(&mut self, query: E) -> Result<MySqlRawValueFormat>
where
Rt: sqlx_core::blocking::Runtime,
E: Execute<'q, 'a, MySql>,
{
flush!(@blocking self);
impl_raw_query!(@blocking self, query)
}
}
macro_rules! raw_query {
(@blocking $self:ident, $sql:expr) => {
$self.raw_query_blocking($sql)?
};
($self:ident, $sql:expr) => {
$self.raw_query_async($sql).await?
};
}

View File

@ -1,133 +1,52 @@
use std::collections::VecDeque;
use std::hint::unreachable_unchecked;
use sqlx_core::Result;
use sqlx_core::{Error, Result};
use crate::connection::command::{Command, CommandQueue, PrepareCommand, QueryCommand};
use crate::protocol::{PrepareResponse, QueryResponse, QueryStep, ResultPacket, Status};
use crate::{MySqlConnection, MySqlDatabaseError};
pub(crate) struct CommandQueue(VecDeque<Command>);
impl CommandQueue {
pub(crate) fn new() -> Self {
Self(VecDeque::with_capacity(2))
}
// begin a simple command
// in which we are expecting OK or ERR (a result)
pub(crate) fn begin(&mut self) {
self.0.push_back(Command::Simple);
}
}
impl CommandQueue {
pub(crate) fn end(&mut self) {
self.0.pop_front();
}
fn maybe_end(&mut self, res: ResultPacket) {
match res {
ResultPacket::Ok(ok) => {
if ok.status.contains(Status::MORE_RESULTS_EXISTS) {
// an attached query response is next
// we are still expecting one
return;
}
}
ResultPacket::Err(error) => {
// without context, we should not bubble this err
// log and continue forward
log::error!("{}", MySqlDatabaseError(error));
fn maybe_end_with(queue: &mut CommandQueue, res: ResultPacket) {
match res {
ResultPacket::Ok(ok) => {
if ok.status.contains(Status::MORE_RESULTS_EXISTS) {
// an attached query response is next
// we are still expecting one
return;
}
}
// STATE: end of query
self.0.pop_front();
}
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum Command {
// expecting [ResultPacket]
Simple,
Query(QueryCommand),
Prepare(PrepareCommand),
}
#[derive(Debug)]
#[repr(u8)]
pub(crate) enum QueryCommand {
// expecting [QueryResponse]
QueryResponse,
// expecting [QueryStep]
QueryStep,
// expecting {rem} more [ColumnDefinition] packets
ColumnDefinition { rem: u16 },
}
impl QueryCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self {
queue.0.push_back(Command::Query(Self::QueryResponse));
if let Some(Command::Query(cmd)) = queue.0.back_mut() {
cmd
} else {
// UNREACHABLE: just pushed a query command to the back of the vector, and we
// have &mut access, nobody else is pushing to it
#[allow(unsafe_code)]
unsafe {
unreachable_unchecked()
}
ResultPacket::Err(error) => {
// without context, we should not bubble this err
// log and continue forward
log::error!("{}", MySqlDatabaseError(error));
}
}
}
#[derive(Debug)]
pub(crate) enum PrepareCommand {
// expecting [ERR] or [COM_STMT_PREPARE_OK]
PrepareResponse,
// expecting {rem} more [ColumnDefinition] packets for each parameter
// stores {columns} as this state is before the [ColumnDefinition] state
ParameterDefinition { rem: u16, columns: u16 },
// expecting {rem} more [ColumnDefinition] packets for each parameter
ColumnDefinition { rem: u16 },
}
impl PrepareCommand {
pub(crate) fn begin(queue: &mut CommandQueue) -> &mut Self {
queue.0.push_back(Command::Prepare(Self::PrepareResponse));
if let Some(Command::Prepare(cmd)) = queue.0.back_mut() {
cmd
} else {
// UNREACHABLE: just pushed a prepare command to the back of the vector, and we
// have &mut access, nobody else is pushing to it
#[allow(unsafe_code)]
unsafe {
unreachable_unchecked()
}
}
}
// STATE: end of query
queue.0.pop_front();
}
macro_rules! impl_flush {
($(@$blocking:ident)? $self:ident) => {{
let Self { ref mut commands, ref mut stream, capabilities, .. } = *$self;
log::debug!("flush!");
let Self { ref mut commands, ref mut stream, ref mut closed, capabilities, .. } = *$self;
while let Some(command) = commands.0.get_mut(0) {
match command {
Command::Close => {
if !*closed {
close!($(@$blocking)? stream);
*closed = true;
}
return Err(Error::Closed);
}
Command::Simple => {
// simple commands where we expect an OK or ERR
// ex. COM_PING, COM_QUERY, COM_STMT_RESET, COM_SET_OPTION
commands.maybe_end(read_packet!($(@$blocking)? stream).deserialize_with(capabilities)?);
maybe_end_with(commands, read_packet!($(@$blocking)? stream).deserialize_with(capabilities)?);
}
Command::Prepare(ref mut cmd) => {
@ -185,7 +104,7 @@ macro_rules! impl_flush {
// expecting OK, ERR, or a result set
QueryCommand::QueryResponse => {
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
QueryResponse::End(end) => break commands.maybe_end(end),
QueryResponse::End(end) => break maybe_end_with(commands, end),
QueryResponse::ResultSet { columns } => {
// STATE: expect the column definitions for each column
*cmd = QueryCommand::ColumnDefinition { rem: columns };
@ -214,7 +133,7 @@ macro_rules! impl_flush {
// either the query result set has ended or we receive
// and immediately drop a row
match read_packet!($(@$blocking)? stream).deserialize_with(capabilities)? {
QueryStep::End(end) => break commands.maybe_end(end),
QueryStep::End(end) => break maybe_end_with(commands, end),
QueryStep::Row(_) => {}
}
}

View File

@ -13,14 +13,13 @@ macro_rules! impl_ping {
// STATE: remember that we are expecting an OK packet
$self.commands.begin();
let _ok = read_packet!($(@$blocking)? $self.stream)
.deserialize_with::<ResultPacket, _>($self.capabilities)?
.into_result()?;
let res = read_packet!($(@$blocking)? $self.stream)
.deserialize_with::<ResultPacket, _>($self.capabilities)?;
// STATE: received OK packet
// STATE: received result packet
$self.commands.end();
Ok(())
res.into_result().map(|_| ())
}};
}

View File

@ -9,11 +9,11 @@ pub(crate) struct RawStatement {
}
impl RawStatement {
pub(crate) fn new(ok: PrepareOk) -> Self {
pub(crate) fn new(ok: &PrepareOk) -> Self {
Self {
id: ok.statement_id,
columns: Vec::with_capacity(ok.columns.into()),
parameters: Vec::with_capacity(ok.parameters.into()),
parameters: Vec::with_capacity(ok.params.into()),
}
}

View File

@ -2,11 +2,11 @@ use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use bytes::{Buf, BufMut};
use sqlx_core::io::{BufStream, Serialize};
use sqlx_core::io::{BufStream, Serialize, Stream};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Error, Result, Runtime};
use crate::protocol::{MaybeCommand, Packet};
use crate::protocol::{MaybeCommand, Packet, Quit};
use crate::MySqlDatabaseError;
/// Reads and writes packets to and from the MySQL database server.
@ -186,3 +186,39 @@ macro_rules! read_packet {
$stream.read_packet_async().await?
};
}
impl<Rt: Runtime> MySqlStream<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn close_async(&mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
self.write_packet(&Quit)?;
self.flush_async().await?;
self.shutdown_async().await?;
Ok(())
}
#[cfg(feature = "blocking")]
pub(crate) fn close_blocking(&mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
self.write_packet(&Quit)?;
self.flush()?;
self.shutdown()?;
Ok(())
}
}
macro_rules! close {
(@blocking $self:ident) => {
$self.close_blocking()?
};
($self:ident) => {
$self.close_async().await?
};
}