From c8b9e047ac251c2bc7f43a0aee20f96908ef7b32 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sat, 2 Jan 2021 10:45:28 -0800 Subject: [PATCH] chore(mysql): add ERR and OK packets --- sqlx-mysql/src/protocol.rs | 6 +- sqlx-mysql/src/protocol/err.rs | 86 +++++++++++++++++++ sqlx-mysql/src/protocol/handshake.rs | 46 +++++----- sqlx-mysql/src/protocol/handshake_response.rs | 2 +- sqlx-mysql/src/protocol/ok.rs | 60 +++++++++++++ sqlx-mysql/src/protocol/status.rs | 2 +- 6 files changed, 176 insertions(+), 26 deletions(-) create mode 100644 sqlx-mysql/src/protocol/err.rs create mode 100644 sqlx-mysql/src/protocol/ok.rs diff --git a/sqlx-mysql/src/protocol.rs b/sqlx-mysql/src/protocol.rs index fb96dba0..68c07dc7 100644 --- a/sqlx-mysql/src/protocol.rs +++ b/sqlx-mysql/src/protocol.rs @@ -1,9 +1,13 @@ mod capabilities; mod handshake; mod handshake_response; +mod ok; mod status; +mod err; +pub(crate) use err::ErrPacket; +pub(crate) use ok::OkPacket; pub(crate) use capabilities::Capabilities; pub(crate) use handshake::Handshake; pub(crate) use handshake_response::HandshakeResponse; -pub(crate) use status::ServerStatus; +pub(crate) use status::Status; diff --git a/sqlx-mysql/src/protocol/err.rs b/sqlx-mysql/src/protocol/err.rs new file mode 100644 index 00000000..b9fe9f37 --- /dev/null +++ b/sqlx-mysql/src/protocol/err.rs @@ -0,0 +1,86 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::io::{BufExt, Deserialize}; +use sqlx_core::Result; +use string::String; + +use crate::io::MySqlBufExt; +use crate::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html +// https://mariadb.com/kb/en/err_packet/ + +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub(crate) struct ErrPacket { + pub(crate) error_code: u16, + pub(crate) sql_state: Option>, + pub(crate) error_message: String, +} + +impl Deserialize<'_, Capabilities> for ErrPacket { + fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + let tag = buf.get_u8(); + debug_assert!(tag == 0xff); + + let error_code = buf.get_u16_le(); + + let sql_state = if capabilities.contains(Capabilities::PROTOCOL_41) && buf[0] == b'#' { + // if the next byte is '#' then we have the SQL STATE + buf.advance(1); + + // UNSAFE: the SQL STATE is an ASCII error code + #[allow(unsafe_code)] + Some(unsafe { buf.get_str_unchecked(5) }) + } else { + None + }; + + // UNSAFE: the human-readable error message is UTF-8 + #[allow(unsafe_code)] + let error_message = unsafe { buf.get_str_eof_unchecked() }; + + Ok(Self { sql_state, error_code, error_message }) + } +} + +#[cfg(test)] +mod tests { + use super::{Capabilities, Deserialize, ErrPacket}; + + #[test] + fn test_err_connect_auth() { + const DATA: &[u8] = b"\xff\xe3\x04Client does not support authentication protocol requested by server; consider upgrading MySQL client"; + + let capabilities = Capabilities::PROTOCOL_41; + let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + + assert_eq!(ok.sql_state, None); + assert_eq!(ok.error_code, 1251); + assert_eq!(&ok.error_message, "Client does not support authentication protocol requested by server; consider upgrading MySQL client"); + } + + #[test] + fn test_err_out_of_order() { + const DATA: &[u8] = b"\xff\x84\x04Got packets out of order"; + + let capabilities = Capabilities::PROTOCOL_41; + let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + + assert_eq!(ok.sql_state, None); + assert_eq!(ok.error_code, 1156); + assert_eq!(&ok.error_message, "Got packets out of order"); + } + + #[test] + fn test_err_unknown_database() { + const DATA: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; + + let capabilities = Capabilities::PROTOCOL_41; + let ok = ErrPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + + assert_eq!(ok.sql_state.as_deref(), Some("42000")); + assert_eq!(ok.error_code, 1049); + assert_eq!(&ok.error_message, "Unknown database \'unknown\'"); + } +} diff --git a/sqlx-mysql/src/protocol/handshake.rs b/sqlx-mysql/src/protocol/handshake.rs index 68bbe0c7..f31db5e7 100644 --- a/sqlx-mysql/src/protocol/handshake.rs +++ b/sqlx-mysql/src/protocol/handshake.rs @@ -4,7 +4,7 @@ use memchr::memchr; use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; -use crate::protocol::{Capabilities, ServerStatus}; +use crate::protocol::{Capabilities, Status}; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeV10 // https://mariadb.com/kb/en/connection/#initial-handshake-packet @@ -20,7 +20,7 @@ pub(crate) struct Handshake { pub(crate) connection_id: u32, pub(crate) capabilities: Capabilities, - pub(crate) status: ServerStatus, + pub(crate) status: Status, // default server character set pub(crate) charset: Option, @@ -31,10 +31,8 @@ pub(crate) struct Handshake { pub(crate) auth_plugin_name: Option>, } -impl Deserialize<'_> for Handshake { - fn deserialize_with(mut buf: Bytes, _: ()) -> Result { - println!("{:?}", buf); - +impl Deserialize<'_, Capabilities> for Handshake { + fn deserialize_with(mut buf: Bytes, _: Capabilities) -> Result { let protocol_version = buf.get_u8(); // UNSAFE: server version is known to be ASCII @@ -56,9 +54,9 @@ impl Deserialize<'_> for Handshake { let charset = if buf.is_empty() { None } else { Some(buf.get_u8()) }; let status = if buf.is_empty() { - ServerStatus::empty() + Status::empty() } else { - ServerStatus::from_bits_truncate(buf.get_u16_le()) + Status::from_bits_truncate(buf.get_u16_le()) }; if !buf.is_empty() { @@ -128,13 +126,15 @@ mod tests { use bytes::Buf; use sqlx_core::io::Deserialize; - use super::{Capabilities, Handshake, ServerStatus}; + use super::{Capabilities, Handshake, Status}; + + const EMPTY: Capabilities = Capabilities::empty(); #[test] fn handshake_mysql_8_0_18() { const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; - let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); + let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_8_0_18.into(), EMPTY).unwrap(); assert_eq!(h.protocol_version, 10); @@ -166,7 +166,7 @@ mod tests { ); assert_eq!(h.charset, Some(255)); - assert_eq!(h.status, ServerStatus::AUTOCOMMIT); + assert_eq!(h.status, Status::AUTOCOMMIT); assert_eq!(h.auth_plugin_name.as_deref(), Some("caching_sha2_password")); assert_eq!( @@ -179,7 +179,7 @@ mod tests { fn handshake_mariadb_10_4_7() { const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"hr&`3{55H\0mysql_native_password\0"; - let mut h = Handshake::deserialize(HANDSHAKE_MARIA_DB_10_5_8.into()).unwrap(); + let mut h = Handshake::deserialize_with(HANDSHAKE_MARIA_DB_10_5_8.into(), EMPTY).unwrap(); assert_eq!(h.protocol_version, 10); assert_eq!(&*h.server_version, "5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal"); @@ -257,7 +257,7 @@ mod tests { ); assert_eq!(h.charset, Some(45)); - assert_eq!(h.status, ServerStatus::AUTOCOMMIT); + assert_eq!(h.status, Status::AUTOCOMMIT); assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password")); assert_eq!( @@ -273,7 +273,7 @@ mod tests { fn handshake_mysql_5_6_50() { const HANDSHAKE_MYSQL_5_6_50: &[u8] = b"\n5.6.50\0\x01\0\0\0-VLYZ:Pd\0\xff\xf7\x08\x02\0\x7f\x80\x15\0\0\0\0\0\0\0\0\0\0'2f+BL8nGV[G\0mysql_native_password\0"; - let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_6_50.into()).unwrap(); + let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_6_50.into(), EMPTY).unwrap(); assert_eq!(h.protocol_version, 10); @@ -304,7 +304,7 @@ mod tests { ); assert_eq!(h.charset, Some(8)); - assert_eq!(h.status, ServerStatus::AUTOCOMMIT); + assert_eq!(h.status, Status::AUTOCOMMIT); assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password")); assert_eq!( @@ -317,7 +317,7 @@ mod tests { fn handshake_mysql_5_0_96() { const HANDSHAKE_MYSQL_5_0_96: &[u8] = b"\n5.0.96\0\x03\0\0\0bs=sNiGe\0,\xa2\x08\x02\0\0\0\0\0\0\0\0\0\0\0\0\0\0IzMP)yLLx;[9\0"; - let mut h = Handshake::deserialize(HANDSHAKE_MYSQL_5_0_96.into()).unwrap(); + let mut h = Handshake::deserialize_with(HANDSHAKE_MYSQL_5_0_96.into(), EMPTY).unwrap(); assert_eq!(h.protocol_version, 10); assert_eq!(&*h.server_version, "5.0.96"); @@ -333,7 +333,7 @@ mod tests { ); assert_eq!(h.charset, Some(8)); - assert_eq!(h.status, ServerStatus::AUTOCOMMIT); + assert_eq!(h.status, Status::AUTOCOMMIT); assert_eq!(h.auth_plugin_name, None); assert_eq!( @@ -349,7 +349,7 @@ mod tests { fn handshake_mysql_5_1_73() { const HANDSHAKE_MYSQL_5_1_73: &[u8] = b"\n5.1.73\0\x01\0\0\0 for HandshakeResponse<'_> { fn serialize_with(&self, buf: &mut Vec, capabilities: Capabilities) -> Result<()> { buf.extend_from_slice(&(capabilities.bits() as u32).to_le_bytes()); buf.extend_from_slice(&self.max_packet_size.to_le_bytes()); - buf.extend_from_slice(&self.charset.to_le_bytes()); + buf.push(self.charset); // reserved (all 0) buf.extend_from_slice(&[0_u8; 23]); diff --git a/sqlx-mysql/src/protocol/ok.rs b/sqlx-mysql/src/protocol/ok.rs new file mode 100644 index 00000000..61ad029b --- /dev/null +++ b/sqlx-mysql/src/protocol/ok.rs @@ -0,0 +1,60 @@ +use bytes::{Buf, Bytes}; +use sqlx_core::io::Deserialize; +use sqlx_core::Result; + +use crate::io::MySqlBufExt; +use crate::protocol::{Capabilities, Status}; + +// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html + +/// An OK packet is sent from the server to the client to signal successful completion of a command. +/// As of MySQL 5.7.5, OK packes are also used to indicate EOF, and EOF packets are deprecated. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub(crate) struct OkPacket { + pub(crate) affected_rows: u64, + pub(crate) last_insert_id: u64, + pub(crate) status: Status, + pub(crate) warnings: u16, +} + +impl Deserialize<'_, Capabilities> for OkPacket { + fn deserialize_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + let tag = buf.get_u8(); + debug_assert!(tag == 0x00 || tag == 0xfe); + + let affected_rows = buf.get_uint_lenenc(); + let last_insert_id = buf.get_uint_lenenc(); + + let status = + if capabilities.intersects(Capabilities::PROTOCOL_41 | Capabilities::TRANSACTIONS) { + Status::from_bits_truncate(buf.get_u16_le()) + } else { + Status::empty() + }; + + let warnings = + if capabilities.contains(Capabilities::PROTOCOL_41) { buf.get_u16_le() } else { 0 }; + + Ok(Self { affected_rows, last_insert_id, status, warnings }) + } +} + +#[cfg(test)] +mod tests { + use super::{OkPacket, Capabilities, Deserialize, Status}; + + #[test] + fn test_empty_ok_packet() { + const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; + + let capabilities = Capabilities::PROTOCOL_41 | Capabilities::TRANSACTIONS; + + let ok = OkPacket::deserialize_with(DATA.into(), capabilities).unwrap(); + + assert_eq!(ok.affected_rows, 0); + assert_eq!(ok.last_insert_id, 0); + assert_eq!(ok.warnings, 0); + assert_eq!(ok.status, Status::AUTOCOMMIT | Status::SESSION_STATE_CHANGED); + } +} diff --git a/sqlx-mysql/src/protocol/status.rs b/sqlx-mysql/src/protocol/status.rs index f78534cf..9c1cb9a2 100644 --- a/sqlx-mysql/src/protocol/status.rs +++ b/sqlx-mysql/src/protocol/status.rs @@ -2,7 +2,7 @@ // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a1d854e841086925be1883e4d7b4e8cad // https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status bitflags::bitflags! { - pub struct ServerStatus: u16 { + pub struct Status: u16 { // Is raised when a multi-statement transaction has been started, either explicitly, // by means of BEGIN or COMMIT AND CHAIN, or implicitly, by the first // transactional statement, when autocommit=off.