refactor(mysql): create ResultPacket (OK or ERR), make Row parsing lazy, and move ERR handling to the Packet type

This commit is contained in:
Ryan Leckey 2021-02-10 17:19:08 -08:00
parent b837a3ca25
commit 4a1087db6c
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
7 changed files with 61 additions and 22 deletions

View File

@ -13,6 +13,7 @@ mod ping;
mod query;
mod query_response;
mod info;
mod result;
mod query_step;
mod packet;
mod quit;
@ -29,6 +30,7 @@ pub(crate) use column_def::ColumnDefinition;
pub(crate) use command::{Command, MaybeCommand};
pub(crate) use eof::EofPacket;
pub(crate) use err::ErrPacket;
pub(crate) use result::ResultPacket;
pub(crate) use handshake::Handshake;
pub(crate) use handshake_response::HandshakeResponse;
pub(crate) use ok::OkPacket;

View File

@ -4,12 +4,12 @@ use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use crate::protocol::{AuthSwitch, Capabilities, OkPacket};
use crate::protocol::{AuthSwitch, Capabilities, ResultPacket};
use crate::MySqlDatabaseError;
#[derive(Debug)]
pub(crate) enum AuthResponse {
Ok(OkPacket),
End(ResultPacket),
MoreData(Bytes),
Switch(AuthSwitch),
}
@ -17,7 +17,7 @@ pub(crate) enum AuthResponse {
impl Deserialize<'_, Capabilities> for AuthResponse {
fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result<Self> {
match buf.get(0) {
Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok),
Some(0x00) => ResultPacket::deserialize_with(buf, capabilities).map(Self::End),
Some(0x01) => Ok(Self::MoreData(buf.slice(1..))),
Some(0xfe) => AuthSwitch::deserialize(buf).map(Self::Switch),

View File

@ -2,7 +2,9 @@ use std::fmt::Debug;
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::Result;
use sqlx_core::{Error, Result};
use crate::MySqlDatabaseError;
#[derive(Debug)]
pub(crate) struct Packet {
@ -10,6 +12,11 @@ pub(crate) struct Packet {
}
impl Packet {
pub(crate) fn is_error(&self) -> bool {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
!self.bytes.is_empty() && self.bytes[0] == 0xff
}
#[inline]
pub(crate) fn deserialize<'de, T>(self) -> Result<T>
where
@ -23,6 +30,11 @@ impl Packet {
where
T: Deserialize<'de, Cx> + Debug,
{
if self.is_error() {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
return Err(Error::connect(MySqlDatabaseError(self.deserialize()?)));
}
let packet = T::deserialize_with(self.bytes, context)?;
log::trace!("read > {:?}", packet);

View File

@ -2,7 +2,7 @@ use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use super::{Capabilities, OkPacket};
use super::{Capabilities, ResultPacket};
use crate::io::MySqlBufExt;
use crate::MySqlDatabaseError;
@ -22,7 +22,7 @@ use crate::MySqlDatabaseError;
///
#[derive(Debug)]
pub(crate) enum QueryResponse {
Ok(OkPacket),
End(ResultPacket),
ResultSet { columns: u64 },
}
@ -30,7 +30,7 @@ impl Deserialize<'_, Capabilities> for QueryResponse {
fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self> {
// .get does not consume the byte
match buf.get(0) {
Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok),
Some(0x00) => ResultPacket::deserialize_with(buf, capabilities).map(Self::End),
// ERR packets are handled on a higher-level (in `recv_packet`), we will
// never receive them here

View File

@ -2,35 +2,33 @@ use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use super::{Capabilities, ColumnDefinition, OkPacket, Row};
use super::{Capabilities, ResultPacket};
use crate::protocol::Packet;
use crate::MySqlDatabaseError;
/// <https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset>
/// <https://mariadb.com/kb/en/result-set-packets/>
#[derive(Debug)]
pub(crate) enum QueryStep {
Row(Row),
End(OkPacket),
Row(Packet),
End(ResultPacket),
}
impl Deserialize<'_, (Capabilities, &'_ [ColumnDefinition])> for QueryStep {
fn deserialize_with(
buf: Bytes,
(capabilities, columns): (Capabilities, &'_ [ColumnDefinition]),
) -> Result<Self> {
impl Deserialize<'_, Capabilities> for QueryStep {
fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result<Self> {
// .get does not consume the byte
match buf.get(0) {
// To safely confirm that a packet with a 0xFE header is an OK packet (OK_Packet) or an
// EOF packet (EOF_Packet), you must also check that the packet length is less than 0xFFFFFF
Some(0xfe) if buf.len() < 0xFF_FF_FF => {
OkPacket::deserialize_with(buf, capabilities).map(Self::End)
ResultPacket::deserialize_with(buf, capabilities).map(Self::End)
}
// ERR packets are handled on a higher-level (in `recv_packet`), we will
// never receive them here
// If its non-0, then its a Row
Some(_) => Row::deserialize_with(buf, columns).map(Self::Row),
Some(_) => Ok(Self::Row(Packet { bytes: buf })),
None => Err(Error::connect(MySqlDatabaseError::malformed_packet(
"Received no bytes for the next step in a result set",

View File

@ -0,0 +1,32 @@
use bytes::Bytes;
use sqlx_core::io::Deserialize;
use sqlx_core::{Error, Result};
use super::{Capabilities, ErrPacket, OkPacket};
use crate::MySqlDatabaseError;
#[derive(Debug)]
#[allow(clippy::module_name_repetitions)]
pub(crate) enum ResultPacket {
Ok(OkPacket),
Err(ErrPacket),
}
impl ResultPacket {
pub(crate) fn into_result(self) -> Result<OkPacket> {
match self {
Self::Ok(ok) => Ok(ok),
Self::Err(err) => Err(Error::connect(MySqlDatabaseError(err))),
}
}
}
impl Deserialize<'_, Capabilities> for ResultPacket {
fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result<Self> {
Ok(if buf[0] == 0xff {
Self::Err(ErrPacket::deserialize(buf)?)
} else {
Self::Ok(OkPacket::deserialize_with(buf, capabilities)?)
})
}
}

View File

@ -108,11 +108,6 @@ impl<Rt: Runtime> MySqlStream<Rt> {
))));
}
if packet.bytes[0] == 0xff {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
return Err(Error::connect(MySqlDatabaseError(packet.deserialize()?)));
}
Ok(packet)
}
}