From 4a1087db6c69f699ccd8381c7c1bd7d78148cc82 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 10 Feb 2021 17:19:08 -0800 Subject: [PATCH] refactor(mysql): create ResultPacket (OK or ERR), make Row parsing lazy, and move ERR handling to the Packet type --- sqlx-mysql/src/protocol.rs | 2 ++ sqlx-mysql/src/protocol/auth_response.rs | 6 ++--- sqlx-mysql/src/protocol/packet.rs | 14 +++++++++- sqlx-mysql/src/protocol/query_response.rs | 6 ++--- sqlx-mysql/src/protocol/query_step.rs | 18 ++++++------- sqlx-mysql/src/protocol/result.rs | 32 +++++++++++++++++++++++ sqlx-mysql/src/stream.rs | 5 ---- 7 files changed, 61 insertions(+), 22 deletions(-) create mode 100644 sqlx-mysql/src/protocol/result.rs diff --git a/sqlx-mysql/src/protocol.rs b/sqlx-mysql/src/protocol.rs index 944b40af..fd3c9c44 100644 --- a/sqlx-mysql/src/protocol.rs +++ b/sqlx-mysql/src/protocol.rs @@ -13,6 +13,7 @@ mod ping; mod query; mod query_response; mod info; +mod result; mod query_step; mod packet; mod quit; @@ -29,6 +30,7 @@ pub(crate) use column_def::ColumnDefinition; pub(crate) use command::{Command, MaybeCommand}; pub(crate) use eof::EofPacket; pub(crate) use err::ErrPacket; +pub(crate) use result::ResultPacket; pub(crate) use handshake::Handshake; pub(crate) use handshake_response::HandshakeResponse; pub(crate) use ok::OkPacket; diff --git a/sqlx-mysql/src/protocol/auth_response.rs b/sqlx-mysql/src/protocol/auth_response.rs index d368b138..4d17def2 100644 --- a/sqlx-mysql/src/protocol/auth_response.rs +++ b/sqlx-mysql/src/protocol/auth_response.rs @@ -4,12 +4,12 @@ use bytes::Bytes; use sqlx_core::io::Deserialize; use sqlx_core::{Error, Result}; -use crate::protocol::{AuthSwitch, Capabilities, OkPacket}; +use crate::protocol::{AuthSwitch, Capabilities, ResultPacket}; use crate::MySqlDatabaseError; #[derive(Debug)] pub(crate) enum AuthResponse { - Ok(OkPacket), + End(ResultPacket), MoreData(Bytes), Switch(AuthSwitch), } @@ -17,7 +17,7 @@ pub(crate) enum AuthResponse { impl Deserialize<'_, Capabilities> for AuthResponse { fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { match buf.get(0) { - Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok), + Some(0x00) => ResultPacket::deserialize_with(buf, capabilities).map(Self::End), Some(0x01) => Ok(Self::MoreData(buf.slice(1..))), Some(0xfe) => AuthSwitch::deserialize(buf).map(Self::Switch), diff --git a/sqlx-mysql/src/protocol/packet.rs b/sqlx-mysql/src/protocol/packet.rs index 79911a72..170d0e48 100644 --- a/sqlx-mysql/src/protocol/packet.rs +++ b/sqlx-mysql/src/protocol/packet.rs @@ -2,7 +2,9 @@ use std::fmt::Debug; use bytes::Bytes; use sqlx_core::io::Deserialize; -use sqlx_core::Result; +use sqlx_core::{Error, Result}; + +use crate::MySqlDatabaseError; #[derive(Debug)] pub(crate) struct Packet { @@ -10,6 +12,11 @@ pub(crate) struct Packet { } impl Packet { + pub(crate) fn is_error(&self) -> bool { + // if the first byte of the payload is 0xFF and the payload is an ERR packet + !self.bytes.is_empty() && self.bytes[0] == 0xff + } + #[inline] pub(crate) fn deserialize<'de, T>(self) -> Result where @@ -23,6 +30,11 @@ impl Packet { where T: Deserialize<'de, Cx> + Debug, { + if self.is_error() { + // if the first byte of the payload is 0xFF and the payload is an ERR packet + return Err(Error::connect(MySqlDatabaseError(self.deserialize()?))); + } + let packet = T::deserialize_with(self.bytes, context)?; log::trace!("read > {:?}", packet); diff --git a/sqlx-mysql/src/protocol/query_response.rs b/sqlx-mysql/src/protocol/query_response.rs index 9cfcaebe..c7e86952 100644 --- a/sqlx-mysql/src/protocol/query_response.rs +++ b/sqlx-mysql/src/protocol/query_response.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use sqlx_core::io::Deserialize; use sqlx_core::{Error, Result}; -use super::{Capabilities, OkPacket}; +use super::{Capabilities, ResultPacket}; use crate::io::MySqlBufExt; use crate::MySqlDatabaseError; @@ -22,7 +22,7 @@ use crate::MySqlDatabaseError; /// #[derive(Debug)] pub(crate) enum QueryResponse { - Ok(OkPacket), + End(ResultPacket), ResultSet { columns: u64 }, } @@ -30,7 +30,7 @@ impl Deserialize<'_, Capabilities> for QueryResponse { fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result { // .get does not consume the byte match buf.get(0) { - Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok), + Some(0x00) => ResultPacket::deserialize_with(buf, capabilities).map(Self::End), // ERR packets are handled on a higher-level (in `recv_packet`), we will // never receive them here diff --git a/sqlx-mysql/src/protocol/query_step.rs b/sqlx-mysql/src/protocol/query_step.rs index 7faf57b6..cc9575b0 100644 --- a/sqlx-mysql/src/protocol/query_step.rs +++ b/sqlx-mysql/src/protocol/query_step.rs @@ -2,35 +2,33 @@ use bytes::Bytes; use sqlx_core::io::Deserialize; use sqlx_core::{Error, Result}; -use super::{Capabilities, ColumnDefinition, OkPacket, Row}; +use super::{Capabilities, ResultPacket}; +use crate::protocol::Packet; use crate::MySqlDatabaseError; /// /// #[derive(Debug)] pub(crate) enum QueryStep { - Row(Row), - End(OkPacket), + Row(Packet), + End(ResultPacket), } -impl Deserialize<'_, (Capabilities, &'_ [ColumnDefinition])> for QueryStep { - fn deserialize_with( - buf: Bytes, - (capabilities, columns): (Capabilities, &'_ [ColumnDefinition]), - ) -> Result { +impl Deserialize<'_, Capabilities> for QueryStep { + fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { // .get does not consume the byte match buf.get(0) { // To safely confirm that a packet with a 0xFE header is an OK packet (OK_Packet) or an // EOF packet (EOF_Packet), you must also check that the packet length is less than 0xFFFFFF Some(0xfe) if buf.len() < 0xFF_FF_FF => { - OkPacket::deserialize_with(buf, capabilities).map(Self::End) + ResultPacket::deserialize_with(buf, capabilities).map(Self::End) } // ERR packets are handled on a higher-level (in `recv_packet`), we will // never receive them here // If its non-0, then its a Row - Some(_) => Row::deserialize_with(buf, columns).map(Self::Row), + Some(_) => Ok(Self::Row(Packet { bytes: buf })), None => Err(Error::connect(MySqlDatabaseError::malformed_packet( "Received no bytes for the next step in a result set", diff --git a/sqlx-mysql/src/protocol/result.rs b/sqlx-mysql/src/protocol/result.rs new file mode 100644 index 00000000..b6ffdb38 --- /dev/null +++ b/sqlx-mysql/src/protocol/result.rs @@ -0,0 +1,32 @@ +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Result}; + +use super::{Capabilities, ErrPacket, OkPacket}; +use crate::MySqlDatabaseError; + +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub(crate) enum ResultPacket { + Ok(OkPacket), + Err(ErrPacket), +} + +impl ResultPacket { + pub(crate) fn into_result(self) -> Result { + match self { + Self::Ok(ok) => Ok(ok), + Self::Err(err) => Err(Error::connect(MySqlDatabaseError(err))), + } + } +} + +impl Deserialize<'_, Capabilities> for ResultPacket { + fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { + Ok(if buf[0] == 0xff { + Self::Err(ErrPacket::deserialize(buf)?) + } else { + Self::Ok(OkPacket::deserialize_with(buf, capabilities)?) + }) + } +} diff --git a/sqlx-mysql/src/stream.rs b/sqlx-mysql/src/stream.rs index 5e64af95..1855f723 100644 --- a/sqlx-mysql/src/stream.rs +++ b/sqlx-mysql/src/stream.rs @@ -108,11 +108,6 @@ impl MySqlStream { )))); } - if packet.bytes[0] == 0xff { - // if the first byte of the payload is 0xFF and the payload is an ERR packet - return Err(Error::connect(MySqlDatabaseError(packet.deserialize()?))); - } - Ok(packet) } }