diff --git a/sqlx-mysql/src/error.rs b/sqlx-mysql/src/error.rs index ef01dfb4..4a9fbd6b 100644 --- a/sqlx-mysql/src/error.rs +++ b/sqlx-mysql/src/error.rs @@ -14,6 +14,10 @@ impl MySqlDatabaseError { pub(crate) fn new(code: u16, message: &str) -> Self { Self(ErrPacket::new(code, message)) } + + pub(crate) fn malformed_packet(message: &str) -> Self { + Self::new(2027, &format!("Malformed packet: {}", message)) + } } impl DatabaseError for MySqlDatabaseError { diff --git a/sqlx-mysql/src/protocol.rs b/sqlx-mysql/src/protocol.rs index bd95b61b..cd6be759 100644 --- a/sqlx-mysql/src/protocol.rs +++ b/sqlx-mysql/src/protocol.rs @@ -1,25 +1,39 @@ -mod auth; mod auth_plugin; +mod auth_response; mod auth_switch; mod capabilities; +mod column_def; mod command; +mod eof; mod err; mod handshake; mod handshake_response; mod ok; mod ping; +mod query; +mod query_response; +mod query_step; +mod packet; mod quit; +mod row; mod status; -pub(crate) use auth::{Auth, AuthResponse}; +pub(crate) use packet::Packet; pub(crate) use auth_plugin::AuthPlugin; +pub(crate) use auth_response::AuthResponse; pub(crate) use auth_switch::AuthSwitch; pub(crate) use capabilities::Capabilities; +pub(crate) use column_def::ColumnDefinition; pub(crate) use command::{Command, MaybeCommand}; +pub(crate) use eof::EofPacket; pub(crate) use err::ErrPacket; pub(crate) use handshake::Handshake; pub(crate) use handshake_response::HandshakeResponse; pub(crate) use ok::OkPacket; pub(crate) use ping::Ping; +pub(crate) use query::Query; +pub(crate) use query_response::QueryResponse; +pub(crate) use query_step::QueryStep; pub(crate) use quit::Quit; +pub(crate) use row::Row; pub(crate) use status::Status; diff --git a/sqlx-mysql/src/protocol/auth.rs b/sqlx-mysql/src/protocol/auth.rs deleted file mode 100644 index 5fd33bf1..00000000 --- a/sqlx-mysql/src/protocol/auth.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::fmt::Debug; - -use bytes::Bytes; -use sqlx_core::io::{Deserialize, Serialize}; -use sqlx_core::{Error, Result}; - -use crate::protocol::{AuthSwitch, Capabilities, MaybeCommand, OkPacket}; -use crate::MySqlDatabaseError; - -#[derive(Debug)] -pub(crate) enum Auth { - Ok(OkPacket), - MoreData(Bytes), - Switch(AuthSwitch), -} - -impl Deserialize<'_, Capabilities> for Auth { - fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { - match buf[0] { - 0x00 => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok), - 0x01 => Ok(Self::MoreData(buf.slice(1..))), - 0xfe => AuthSwitch::deserialize_with(buf, capabilities).map(Self::Switch), - - tag => Err(Error::connect(MySqlDatabaseError::new( - 2027, - &format!( - "Malformed packet: Received 0x{:x} but expected one of: 0x0, 0x1, or 0xfe", - tag - ), - ))), - } - } -} - -#[derive(Debug)] -pub(crate) struct AuthResponse { - pub(crate) data: Vec, -} - -impl MaybeCommand for AuthResponse {} - -impl Serialize<'_, Capabilities> for AuthResponse { - fn serialize_with(&self, buf: &mut Vec, _context: Capabilities) -> Result<()> { - buf.extend_from_slice(&self.data); - - Ok(()) - } -} diff --git a/sqlx-mysql/src/protocol/auth_response.rs b/sqlx-mysql/src/protocol/auth_response.rs new file mode 100644 index 00000000..d368b138 --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_response.rs @@ -0,0 +1,34 @@ +use std::fmt::Debug; + +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Result}; + +use crate::protocol::{AuthSwitch, Capabilities, OkPacket}; +use crate::MySqlDatabaseError; + +#[derive(Debug)] +pub(crate) enum AuthResponse { + Ok(OkPacket), + MoreData(Bytes), + Switch(AuthSwitch), +} + +impl Deserialize<'_, Capabilities> for AuthResponse { + fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { + match buf.get(0) { + Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok), + Some(0x01) => Ok(Self::MoreData(buf.slice(1..))), + Some(0xfe) => AuthSwitch::deserialize(buf).map(Self::Switch), + + Some(tag) => Err(Error::connect(MySqlDatabaseError::malformed_packet(&format!( + "Received 0x{:x} but expected one of: 0x0 (OK), 0x1 (MORE DATA), or 0xfe (SWITCH) for auth response", + tag + )))), + + None => Err(Error::connect(MySqlDatabaseError::malformed_packet( + "Received no bytes for auth response", + ))), + } + } +} diff --git a/sqlx-mysql/src/protocol/auth_switch.rs b/sqlx-mysql/src/protocol/auth_switch.rs index 772b79c3..099d89d0 100644 --- a/sqlx-mysql/src/protocol/auth_switch.rs +++ b/sqlx-mysql/src/protocol/auth_switch.rs @@ -2,7 +2,6 @@ use bytes::{buf::Chain, Buf, Bytes}; use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; -use super::Capabilities; use crate::protocol::AuthPlugin; // https://dev.mysql.com/doc/internals/en/authentication-method-change.html @@ -14,8 +13,8 @@ pub(crate) struct AuthSwitch { pub(crate) plugin_data: Chain, } -impl Deserialize<'_, Capabilities> for AuthSwitch { - fn deserialize_with(mut buf: Bytes, _capabilities: Capabilities) -> Result { +impl Deserialize<'_> for AuthSwitch { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { let tag = buf.get_u8(); debug_assert_eq!(tag, 0xfe); diff --git a/sqlx-mysql/src/protocol/column_def.rs b/sqlx-mysql/src/protocol/column_def.rs new file mode 100644 index 00000000..648fb8c3 --- /dev/null +++ b/sqlx-mysql/src/protocol/column_def.rs @@ -0,0 +1,65 @@ +use bytes::{Buf, Bytes}; +use bytestring::ByteString; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +use crate::io::MySqlBufExt; + +/// Describes a column in the result set. +/// +/// +/// +#[derive(Debug)] +pub(crate) struct ColumnDefinition { + pub(crate) catalog: ByteString, + pub(crate) schema: ByteString, + pub(crate) table_alias: ByteString, + pub(crate) table: ByteString, + pub(crate) alias: ByteString, + pub(crate) name: ByteString, + pub(crate) charset: u16, + pub(crate) max_size: u32, + pub(crate) ty: u8, + pub(crate) flags: u16, + pub(crate) decimals: u8, +} + +impl Deserialize<'_> for ColumnDefinition { + #[allow(unsafe_code)] + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { + // UNSAFE: fields are known to be UTF-8 as we have connected with the + // UTF-8 connection charset + + let catalog = unsafe { buf.get_str_lenenc_unchecked() }; + let schema = unsafe { buf.get_str_lenenc_unchecked() }; + let table_alias = unsafe { buf.get_str_lenenc_unchecked() }; + let table = unsafe { buf.get_str_lenenc_unchecked() }; + let alias = unsafe { buf.get_str_lenenc_unchecked() }; + let name = unsafe { buf.get_str_lenenc_unchecked() }; + + let fixed_len_fields_len = buf.get_uint_lenenc(); + + // we are told that this is *always* 0x0c + debug_assert_eq!(fixed_len_fields_len, 0x0c); + + let charset = buf.get_u16_le(); + let max_size = buf.get_u32_le(); + let ty = buf.get_u8(); + let flags = buf.get_u16_le(); + let decimals = buf.get_u8(); + + Ok(Self { + catalog, + schema, + table_alias, + table, + alias, + name, + charset, + max_size, + ty, + flags, + decimals, + }) + } +} diff --git a/sqlx-mysql/src/protocol/command.rs b/sqlx-mysql/src/protocol/command.rs index 545abfbc..e4ad9298 100644 --- a/sqlx-mysql/src/protocol/command.rs +++ b/sqlx-mysql/src/protocol/command.rs @@ -13,6 +13,9 @@ pub(crate) trait MaybeCommand { } } +// raw bytes are not a command +impl MaybeCommand for &'_ [u8] {} + /// Marker trait to signal that this protocol type is a Command. pub(crate) trait Command: MaybeCommand {} diff --git a/sqlx-mysql/src/protocol/eof.rs b/sqlx-mysql/src/protocol/eof.rs new file mode 100644 index 00000000..63531f87 --- /dev/null +++ b/sqlx-mysql/src/protocol/eof.rs @@ -0,0 +1,31 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +use crate::protocol::{Capabilities, Status}; + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub(crate) struct EofPacket { + pub(crate) status: Status, + pub(crate) warnings: u16, +} + +impl Deserialize<'_, Capabilities> for EofPacket { + fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + let tag = buf.get_u8(); + debug_assert_eq!(tag, 0xfe); + + let status = + if capabilities.intersects(Capabilities::PROTOCOL_41 | Capabilities::TRANSACTIONS) { + Status::from_bits_truncate(buf.get_u16_le()) + } else { + Status::empty() + }; + + let warnings = + if capabilities.contains(Capabilities::PROTOCOL_41) { buf.get_u16_le() } else { 0 }; + + Ok(Self { status, warnings }) + } +} diff --git a/sqlx-mysql/src/protocol/err.rs b/sqlx-mysql/src/protocol/err.rs index 963c660b..bd7f5770 100644 --- a/sqlx-mysql/src/protocol/err.rs +++ b/sqlx-mysql/src/protocol/err.rs @@ -4,7 +4,6 @@ use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; use crate::io::MySqlBufExt; -use crate::protocol::Capabilities; // https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html @@ -27,14 +26,14 @@ impl ErrPacket { } } -impl Deserialize<'_, Capabilities> for ErrPacket { - fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result { +impl Deserialize<'_> for ErrPacket { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { let tag = buf.get_u8(); debug_assert!(tag == 0xff); let error_code = buf.get_u16_le(); - let sql_state = if capabilities.contains(Capabilities::PROTOCOL_41) && buf[0] == b'#' { + let sql_state = if buf[0] == b'#' { // if the next byte is '#' then we have the SQL STATE buf.advance(1); @@ -55,14 +54,13 @@ impl Deserialize<'_, Capabilities> for ErrPacket { #[cfg(test)] mod tests { - use super::{Capabilities, Deserialize, ErrPacket}; + use super::{Deserialize, ErrPacket}; #[test] fn test_err_connect_auth() { const DATA: &[u8] = b"\xff\xe3\x04Client does not support authentication protocol requested by server; consider upgrading MySQL client"; - let capabilities = Capabilities::PROTOCOL_41; - let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + let ok = ErrPacket::deserialize(DATA.into()).unwrap(); assert_eq!(ok.sql_state, None); assert_eq!(ok.error_code, 1251); @@ -76,8 +74,7 @@ mod tests { fn test_err_out_of_order() { const DATA: &[u8] = b"\xff\x84\x04Got packets out of order"; - let capabilities = Capabilities::PROTOCOL_41; - let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + let ok = ErrPacket::deserialize(DATA.into()).unwrap(); assert_eq!(ok.sql_state, None); assert_eq!(ok.error_code, 1156); @@ -88,8 +85,7 @@ mod tests { fn test_err_unknown_database() { const DATA: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; - let capabilities = Capabilities::PROTOCOL_41; - let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + let ok = ErrPacket::deserialize(DATA.into()).unwrap(); assert_eq!(ok.sql_state.as_deref(), Some("42000")); assert_eq!(ok.error_code, 1049); diff --git a/sqlx-mysql/src/protocol/handshake.rs b/sqlx-mysql/src/protocol/handshake.rs index 2ca377c3..f953db0a 100644 --- a/sqlx-mysql/src/protocol/handshake.rs +++ b/sqlx-mysql/src/protocol/handshake.rs @@ -31,8 +31,8 @@ pub(crate) struct Handshake { pub(crate) auth_plugin_data: Chain, } -impl Deserialize<'_, Capabilities> for Handshake { - fn deserialize_with(mut buf: Bytes, _: Capabilities) -> Result { +impl Deserialize<'_> for Handshake { + fn deserialize_with(mut buf: Bytes, _: ()) -> Result { let protocol_version = buf.get_u8(); // UNSAFE: server version is known to be ASCII @@ -135,13 +135,11 @@ mod tests { use super::{Capabilities, Handshake, Status}; - const EMPTY: Capabilities = Capabilities::empty(); - #[test] fn handshake_mysql_8_0_18() { const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; - let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_8_0_18.into(), EMPTY).unwrap(); + let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); assert_eq!(h.protocol_version, 10); @@ -186,7 +184,7 @@ mod tests { fn handshake_mariadb_10_4_7() { const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"hr&`3{55H\0mysql_native_password\0"; - let mut h = Handshake::deserialize_with(HANDSHAKE_MARIA_DB_10_5_8.into(), EMPTY).unwrap(); + let mut h = Handshake::deserialize(HANDSHAKE_MARIA_DB_10_5_8.into()).unwrap(); assert_eq!(h.protocol_version, 10); assert_eq!(&*h.server_version, "5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal"); @@ -274,7 +272,7 @@ mod tests { fn handshake_mysql_5_6_50() { const HANDSHAKE_MYSQL_5_6_50: &[u8] = b"\n5.6.50\0\x01\0\0\0-VLYZ:Pd\0\xff\xf7\x08\x02\0\x7f\x80\x15\0\0\0\0\0\0\0\0\0\0'2f+BL8nGV[G\0mysql_native_password\0"; - let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_6_50.into(), EMPTY).unwrap(); + let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_6_50.into()).unwrap(); assert_eq!(h.protocol_version, 10); @@ -318,7 +316,7 @@ mod tests { fn handshake_mysql_5_0_96() { const HANDSHAKE_MYSQL_5_0_96: &[u8] = b"\n5.0.96\0\x03\0\0\0bs=sNiGe\0,\xa2\x08\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0IzMP)yLLx;[9\0"; - let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_0_96.into(), EMPTY).unwrap(); + let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_0_96.into()).unwrap(); assert_eq!(h.protocol_version, 10); assert_eq!(&*h.server_version, "5.0.96"); @@ -350,7 +348,7 @@ mod tests { fn handshake_mysql_5_1_73() { const HANDSHAKE_MYSQL_5_1_73: &[u8] = b"\n5.1.73\0\x01\0\0\0 { + pub(crate) capabilities: Capabilities, pub(crate) database: Option<&'a str>, pub(crate) max_packet_size: u32, pub(crate) charset: u8, @@ -19,12 +20,12 @@ pub(crate) struct HandshakeResponse<'a> { impl MaybeCommand for HandshakeResponse<'_> {} -impl Serialize<'_, Capabilities> for HandshakeResponse<'_> { - fn serialize_with(&self, buf: &mut Vec, capabilities: Capabilities) -> Result<()> { +impl Serialize<'_> for HandshakeResponse<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { // the truncation is the intent // capability bits over 32 are MariaDB only (and we don't currently support them) #[allow(clippy::cast_possible_truncation)] - buf.extend_from_slice(&(capabilities.bits() as u32).to_le_bytes()); + buf.extend_from_slice(&(self.capabilities.bits() as u32).to_le_bytes()); buf.extend_from_slice(&self.max_packet_size.to_le_bytes()); buf.push(self.charset); @@ -35,9 +36,9 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> { let auth_response = self.auth_response.as_slice(); - if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + if self.capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { buf.write_bytes_lenenc(auth_response); - } else if capabilities.contains(Capabilities::SECURE_CONNECTION) { + } else if self.capabilities.contains(Capabilities::SECURE_CONNECTION) { debug_assert!(auth_response.len() <= u8::max_value().into()); buf.reserve(1 + auth_response.len()); @@ -53,11 +54,11 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> { buf.push(b'\0'); } - if capabilities.contains(Capabilities::CONNECT_WITH_DB) { + if self.capabilities.contains(Capabilities::CONNECT_WITH_DB) { buf.write_maybe_str_nul(self.database); } - if capabilities.contains(Capabilities::PLUGIN_AUTH) { + if self.capabilities.contains(Capabilities::PLUGIN_AUTH) { buf.write_str_nul(self.auth_plugin_name); } diff --git a/sqlx-mysql/src/protocol/packet.rs b/sqlx-mysql/src/protocol/packet.rs new file mode 100644 index 00000000..79911a72 --- /dev/null +++ b/sqlx-mysql/src/protocol/packet.rs @@ -0,0 +1,32 @@ +use std::fmt::Debug; + +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +#[derive(Debug)] +pub(crate) struct Packet { + pub(crate) bytes: Bytes, +} + +impl Packet { + #[inline] + pub(crate) fn deserialize<'de, T>(self) -> Result + where + T: Deserialize<'de> + Debug, + { + self.deserialize_with(()) + } + + #[inline] + pub(crate) fn deserialize_with<'de, T, Cx: 'de>(self, context: Cx) -> Result + where + T: Deserialize<'de, Cx> + Debug, + { + let packet = T::deserialize_with(self.bytes, context)?; + + log::trace!("read > {:?}", packet); + + Ok(packet) + } +} diff --git a/sqlx-mysql/src/protocol/ping.rs b/sqlx-mysql/src/protocol/ping.rs index 6edf33f3..ddc45a92 100644 --- a/sqlx-mysql/src/protocol/ping.rs +++ b/sqlx-mysql/src/protocol/ping.rs @@ -1,7 +1,7 @@ use sqlx_core::io::Serialize; use sqlx_core::Result; -use crate::protocol::{Capabilities, Command}; +use crate::protocol::Command; /// Check if the server is alive. /// @@ -11,8 +11,8 @@ use crate::protocol::{Capabilities, Command}; #[derive(Debug)] pub(crate) struct Ping; -impl Serialize<'_, Capabilities> for Ping { - fn serialize_with(&self, buf: &mut Vec, _: Capabilities) -> Result<()> { +impl Serialize<'_> for Ping { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(0x0e); Ok(()) @@ -26,12 +26,11 @@ mod tests { use sqlx_core::io::Serialize; use super::Ping; - use crate::protocol::Capabilities; #[test] fn should_serialize() -> anyhow::Result<()> { let mut buf = Vec::new(); - Ping.serialize_with(&mut buf, Capabilities::empty())?; + Ping.serialize(&mut buf)?; assert_eq!(&buf, &[0x0e]); diff --git a/sqlx-mysql/src/protocol/query.rs b/sqlx-mysql/src/protocol/query.rs new file mode 100644 index 00000000..dcffb2af --- /dev/null +++ b/sqlx-mysql/src/protocol/query.rs @@ -0,0 +1,25 @@ +use sqlx_core::io::Serialize; +use sqlx_core::Result; + +use super::Command; + +/// Send the server a text-based query that is executed immediately. +/// +/// https://dev.mysql.com/doc/internals/en/com-query.html +/// https://mariadb.com/kb/en/com_query/ +/// +#[derive(Debug)] +pub(crate) struct Query<'q> { + pub(crate) sql: &'q str, +} + +impl Serialize<'_> for Query<'_> { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { + buf.push(0x03); + buf.extend_from_slice(self.sql.as_bytes()); + + Ok(()) + } +} + +impl Command for Query<'_> {} diff --git a/sqlx-mysql/src/protocol/query_response.rs b/sqlx-mysql/src/protocol/query_response.rs new file mode 100644 index 00000000..9cfcaebe --- /dev/null +++ b/sqlx-mysql/src/protocol/query_response.rs @@ -0,0 +1,47 @@ +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Result}; + +use super::{Capabilities, OkPacket}; +use crate::io::MySqlBufExt; +use crate::MySqlDatabaseError; + +/// The query-response packet is a meta-packet that starts with one of: +/// +/// - OK packet +/// - ERR packet +/// - LOCAL INFILE request (unimplemented) +/// - Result Set +/// +/// A result set is *also* a meta-packet that starts with a length-encoded +/// integer for the number of columns. That is all we return from this +/// deserialization and expect the executor to follow up with reading +/// more from the stream. +/// +/// +/// +#[derive(Debug)] +pub(crate) enum QueryResponse { + Ok(OkPacket), + ResultSet { columns: u64 }, +} + +impl Deserialize<'_, Capabilities> for QueryResponse { + fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + // .get does not consume the byte + match buf.get(0) { + Some(0x00) => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok), + + // ERR packets are handled on a higher-level (in `recv_packet`), we will + // never receive them here + + // If its non-0, then its the number of columns and the start + // of a result set + Some(_) => Ok(Self::ResultSet { columns: buf.get_uint_lenenc() }), + + None => Err(Error::connect(MySqlDatabaseError::malformed_packet( + "Received no bytes for COM_QUERY response", + ))), + } + } +} diff --git a/sqlx-mysql/src/protocol/query_step.rs b/sqlx-mysql/src/protocol/query_step.rs new file mode 100644 index 00000000..7faf57b6 --- /dev/null +++ b/sqlx-mysql/src/protocol/query_step.rs @@ -0,0 +1,40 @@ +use bytes::Bytes; +use sqlx_core::io::Deserialize; +use sqlx_core::{Error, Result}; + +use super::{Capabilities, ColumnDefinition, OkPacket, Row}; +use crate::MySqlDatabaseError; + +/// +/// +#[derive(Debug)] +pub(crate) enum QueryStep { + Row(Row), + End(OkPacket), +} + +impl Deserialize<'_, (Capabilities, &'_ [ColumnDefinition])> for QueryStep { + fn deserialize_with( + buf: Bytes, + (capabilities, columns): (Capabilities, &'_ [ColumnDefinition]), + ) -> Result { + // .get does not consume the byte + match buf.get(0) { + // To safely confirm that a packet with a 0xFE header is an OK packet (OK_Packet) or an + // EOF packet (EOF_Packet), you must also check that the packet length is less than 0xFFFFFF + Some(0xfe) if buf.len() < 0xFF_FF_FF => { + OkPacket::deserialize_with(buf, capabilities).map(Self::End) + } + + // ERR packets are handled on a higher-level (in `recv_packet`), we will + // never receive them here + + // If its non-0, then its a Row + Some(_) => Row::deserialize_with(buf, columns).map(Self::Row), + + None => Err(Error::connect(MySqlDatabaseError::malformed_packet( + "Received no bytes for the next step in a result set", + ))), + } + } +} diff --git a/sqlx-mysql/src/protocol/quit.rs b/sqlx-mysql/src/protocol/quit.rs index d8263bcd..adfac8fc 100644 --- a/sqlx-mysql/src/protocol/quit.rs +++ b/sqlx-mysql/src/protocol/quit.rs @@ -1,7 +1,7 @@ use sqlx_core::io::Serialize; use sqlx_core::Result; -use crate::protocol::{Capabilities, Command}; +use crate::protocol::Command; /// Tells the server that the client wants to close the connection. /// @@ -10,8 +10,8 @@ use crate::protocol::{Capabilities, Command}; #[derive(Debug)] pub(crate) struct Quit; -impl Serialize<'_, Capabilities> for Quit { - fn serialize_with(&self, buf: &mut Vec, _: Capabilities) -> Result<()> { +impl Serialize<'_> for Quit { + fn serialize_with(&self, buf: &mut Vec, _: ()) -> Result<()> { buf.push(0x01); Ok(()) @@ -25,12 +25,11 @@ mod tests { use sqlx_core::io::Serialize; use super::Quit; - use crate::protocol::Capabilities; #[test] fn should_serialize() -> anyhow::Result<()> { let mut buf = Vec::new(); - Quit.serialize_with(&mut buf, Capabilities::empty())?; + Quit.serialize(&mut buf)?; assert_eq!(&buf, &[0x01]); diff --git a/sqlx-mysql/src/protocol/row.rs b/sqlx-mysql/src/protocol/row.rs new file mode 100644 index 00000000..429315a3 --- /dev/null +++ b/sqlx-mysql/src/protocol/row.rs @@ -0,0 +1,28 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +use crate::io::MySqlBufExt; +use crate::protocol::ColumnDefinition; + +#[derive(Debug)] +pub(crate) struct Row { + pub(crate) values: Vec>, +} + +impl<'de> Deserialize<'de, &'de [ColumnDefinition]> for Row { + fn deserialize_with(mut buf: Bytes, columns: &'de [ColumnDefinition]) -> Result { + let mut values = Vec::with_capacity(columns.len()); + + for _ in columns { + values.push(if buf.get(0).copied() == Some(0xfb) { + buf.advance(1); + None + } else { + Some(buf.get_bytes_lenenc()) + }); + } + + Ok(Self { values }) + } +}