diff --git a/Cargo.lock b/Cargo.lock index 0d2fc084..067d7515 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1131,6 +1131,7 @@ dependencies = [ "bitflags", "bytes", "bytestring", + "conquer-once", "either", "futures-executor", "futures-io", diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index a675c569..87837209 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -48,3 +48,4 @@ rand = "0.7" sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } futures-executor = "0.3.8" anyhow = "1.0.37" +conquer-once = "0.3.2" diff --git a/sqlx-mysql/src/connection/connect.rs b/sqlx-mysql/src/connection/connect.rs index d9aaef24..9d1f8b3d 100644 --- a/sqlx-mysql/src/connection/connect.rs +++ b/sqlx-mysql/src/connection/connect.rs @@ -15,7 +15,7 @@ use sqlx_core::net::Stream as NetStream; use sqlx_core::Result; use sqlx_core::Runtime; -use crate::protocol::{AuthResponse, Handshake, HandshakeResponse}; +use crate::protocol::{AuthResponse, Handshake, Capabilities, HandshakeResponse}; use crate::{MySqlConnectOptions, MySqlConnection}; impl MySqlConnection { @@ -24,6 +24,10 @@ impl MySqlConnection { options: &MySqlConnectOptions, handshake: &Handshake, ) -> Result<()> { + // IF the options specify a database, try to use the CONNECT_WITH_DB capability + // this lets us skip a round-trip after connect + self.capabilities |= Capabilities::CONNECT_WITH_DB; + // & the declared server capabilities with our capabilities to find // what rules the client should operate under self.capabilities &= handshake.capabilities; diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index 90be62bb..5ab709c1 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -26,11 +26,13 @@ mod database; mod error; mod io; mod options; +mod query_result; mod protocol; #[cfg(test)] mod mock; +pub use query_result::MySqlQueryResult; pub use connection::MySqlConnection; pub use database::MySql; pub use error::MySqlDatabaseError; diff --git a/sqlx-mysql/src/protocol.rs b/sqlx-mysql/src/protocol.rs index cd6be759..944b40af 100644 --- a/sqlx-mysql/src/protocol.rs +++ b/sqlx-mysql/src/protocol.rs @@ -12,12 +12,14 @@ mod ok; mod ping; mod query; mod query_response; +mod info; mod query_step; mod packet; mod quit; mod row; mod status; +pub(crate) use info::Info; pub(crate) use packet::Packet; pub(crate) use auth_plugin::AuthPlugin; pub(crate) use auth_response::AuthResponse; diff --git a/sqlx-mysql/src/protocol/info.rs b/sqlx-mysql/src/protocol/info.rs new file mode 100644 index 00000000..c83a75a0 --- /dev/null +++ b/sqlx-mysql/src/protocol/info.rs @@ -0,0 +1,68 @@ +// https://dev.mysql.com/doc/c-api/8.0/en/mysql-info.html +// https://mariadb.com/kb/en/mysql_info/ + +#[derive(Debug)] +pub(crate) struct Info { + pub(crate) records: u64, + pub(crate) duplicates: u64, + pub(crate) matched: u64, +} + +impl Info { + pub(crate) fn parse(info: &str) -> Self { + let mut records = 0; + let mut duplicates = 0; + let mut matched = 0; + + let mut failed = false; + + for item in info.split(" ") { + let mut item = item.split(": "); + + if let Some((key, value)) = item.next().zip(item.next()) { + let value: u64 = if let Ok(value) = value.parse() { + value + } else { + // remember failed, invalid value + failed = true; + 0 + }; + + match key { + "Records" => records = value, + "Duplicates" => duplicates = value, + "Rows matched" => matched = value, + + // unknown key + _ => failed = true, + } + } + } + + if failed { + log::warn!("failed to parse status information from OK packet: {:?}", info); + } + + Self { records, duplicates, matched } + } +} + +#[cfg(test)] +mod tests { + use super::Info; + + #[test] + fn parse_insert() { + let info = Info::parse("Records: 10 Duplicates: 5 Warnings: 0"); + + assert_eq!(info.records, 10); + assert_eq!(info.duplicates, 5); + } + + #[test] + fn parse_update() { + let info = Info::parse("Rows matched: 40 Changed: 5 Warnings: 0"); + + assert_eq!(info.matched, 40); + } +} diff --git a/sqlx-mysql/src/protocol/ok.rs b/sqlx-mysql/src/protocol/ok.rs index 13b348e0..41449cb1 100644 --- a/sqlx-mysql/src/protocol/ok.rs +++ b/sqlx-mysql/src/protocol/ok.rs @@ -1,11 +1,13 @@ use bytes::{Buf, Bytes}; +use bytestring::ByteString; use sqlx_core::io::Deserialize; use sqlx_core::Result; use crate::io::MySqlBufExt; -use crate::protocol::{Capabilities, Status}; +use crate::protocol::{Capabilities, Info, Status}; // https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html +// https://mariadb.com/kb/en/ok_packet/ /// An OK packet is sent from the server to the client to signal successful completion of a command. /// As of MySQL 5.7.5, OK packes are also used to indicate EOF, and EOF packets are deprecated. @@ -16,6 +18,9 @@ pub(crate) struct OkPacket { pub(crate) last_insert_id: u64, pub(crate) status: Status, pub(crate) warnings: u16, + + // human readable status information + pub(crate) info: Info, } impl Deserialize<'_, Capabilities> for OkPacket { @@ -36,7 +41,22 @@ impl Deserialize<'_, Capabilities> for OkPacket { let warnings = if capabilities.contains(Capabilities::PROTOCOL_41) { buf.get_u16_le() } else { 0 }; - Ok(Self { affected_rows, last_insert_id, status, warnings }) + let info = if buf.is_empty() { + // no info, end of buffer + ByteString::default() + } else { + // human readable status information + #[allow(unsafe_code)] + if capabilities.contains(Capabilities::SESSION_TRACK) { + // if [CLIENT_SESSION_TRACK] the info comes down as string + unsafe { buf.get_str_lenenc_unchecked() } + } else { + // otherwise the ASCII info is sent as string + unsafe { buf.get_str_eof_unchecked() } + } + }; + + Ok(Self { affected_rows, last_insert_id, status, warnings, info: Info::parse(&info) }) } } diff --git a/sqlx-mysql/src/query_result.rs b/sqlx-mysql/src/query_result.rs new file mode 100644 index 00000000..782a9f32 --- /dev/null +++ b/sqlx-mysql/src/query_result.rs @@ -0,0 +1,214 @@ +use std::fmt::{self, Debug, Formatter}; + +use crate::protocol::OkPacket; + +/// Represents the execution result of an operation on the database server. +/// +/// Returned from [`execute()`][sqlx_core::Executor::execute]. +/// +#[allow(clippy::module_name_repetitions)] +pub struct MySqlQueryResult(OkPacket); + +impl MySqlQueryResult { + /// Returns the number of rows changed, deleted, or inserted by the statement + /// if it was an `UPDATE`, `DELETE` or `INSERT`. For `SELECT` statements, returns + /// the number of rows returned. + /// + /// For more information, see the corresponding method in the official C API: + /// + /// + #[doc(alias = "affected_rows")] + #[must_use] + pub const fn rows_affected(&self) -> u64 { + self.0.affected_rows + } + + /// Return the number of rows matched by the `UPDATE` statement. + /// + /// This is in contrast to [`rows_affected()`] which will return the number + /// of rows actually changed by the `UPDATE statement. + /// + /// Returns `0` for all other statements. + /// + #[must_use] + pub const fn rows_matched(&self) -> u64 { + self.0.info.matched + } + + /// Returns the number of rows processed by the multi-row `INSERT` + /// or `ALTER TABLE` statement. + /// + /// For multi-row `INSERT`, this is not necessarily the number of rows actually + /// inserted because [`duplicates()`] can be non-zero. + /// + /// For `ALTER TABLE`, this is the number of rows that were copied while + /// making alterations. + /// + /// Returns `0` for all other statements. + /// + #[must_use] + pub const fn records(&self) -> u64 { + self.0.info.records + } + + /// Returns the number of rows that could not be inserted by a multi-row `INSERT` + /// statement because they would duplicate some existing unique index value. + /// + /// Returns `0` for all other statements. + /// + #[must_use] + pub const fn duplicates(&self) -> u64 { + self.0.info.duplicates + } + + /// Returns the integer generated for an `AUTO_INCREMENT` column by the + /// `INSERT` statement. + /// + /// When inserting multiple rows, returns the id of the _first_ row in + /// set of inserted rows. + /// + /// For more information, see the corresponding method in the official C API: + /// + /// + #[doc(alias = "last_insert_id")] + #[must_use] + pub const fn inserted_id(&self) -> Option { + // NOTE: a valid ID is never zero + if self.0.last_insert_id == 0 { None } else { Some(self.0.last_insert_id) } + } + + /// Returns the number of errors, warnings, and notes generated during + /// execution of the statement. + /// + /// To read the warning messages, execute + /// the [`SHOW WARNINGS`](https://dev.mysql.com/doc/refman/8.0/en/show-warnings.html) + /// statement on the same connection (and before executing any other statements). + /// + /// As an example, the statement `SELECT 1/0` will execute successfully and return `NULL` but + /// indicate 1 warning. + /// + #[doc(alias = "warnings_count")] + #[must_use] + pub const fn warnings(&self) -> u16 { + self.0.warnings + } +} + +impl Debug for MySqlQueryResult { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MySqlQueryResult") + .field("inserted_id", &self.inserted_id()) + .field("rows_affected", &self.rows_affected()) + .field("rows_matched", &self.rows_matched()) + .field("records", &self.records()) + .field("duplicates", &self.duplicates()) + .field("warnings", &self.warnings()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use conquer_once::Lazy; + use sqlx_core::io::Deserialize; + + use super::MySqlQueryResult; + use crate::protocol::{Capabilities, OkPacket}; + + static CAPABILITIES: Lazy = Lazy::new(|| { + Capabilities::PROTOCOL_41 | Capabilities::SESSION_TRACK | Capabilities::TRANSACTIONS + }); + + #[test] + fn insert_1() -> anyhow::Result<()> { + let packet = Bytes::from(&b"\0\x01\x01\x02\0\0\0"[..]); + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 1); + assert_eq!(res.inserted_id(), Some(1)); + + Ok(()) + } + + #[test] + fn insert_5() -> anyhow::Result<()> { + let packet = Bytes::from(&b"\0\x05\x02\x02\0\0\0&Records: 5 Duplicates: 0 Warnings: 0"[..]); + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 5); + assert_eq!(res.inserted_id(), Some(2)); + assert_eq!(res.records(), 5); + assert_eq!(res.duplicates(), 0); + + Ok(()) + } + + #[test] + fn insert_5_or_update_3() -> anyhow::Result<()> { + let packet = Bytes::from(&b"\0\x08\x07\x02\0\0\0&Records: 5 Duplicates: 3 Warnings: 0"[..]); + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 8); + assert_eq!(res.inserted_id(), Some(7)); + assert_eq!(res.records(), 5); + assert_eq!(res.duplicates(), 3); + + Ok(()) + } + + #[test] + fn update_7_change_3() -> anyhow::Result<()> { + let packet = Bytes::from(&b"\0\x03\0\"\0\0\0(Rows matched: 7 Changed: 3 Warnings: 0"[..]); + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 3); + assert_eq!(res.inserted_id(), None); + assert_eq!(res.rows_matched(), 7); + + Ok(()) + } + + #[test] + fn update_1_change_1() -> anyhow::Result<()> { + let packet = + Bytes::from(&b"\0\x01\0\x02\0\0\0(Rows matched: 1 Changed: 1 Warnings: 0"[..]); + + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 1); + assert_eq!(res.inserted_id(), None); + assert_eq!(res.rows_matched(), 1); + + Ok(()) + } + + #[test] + fn delete_1() -> anyhow::Result<()> { + let packet = Bytes::from(&b"\0\x01\0\x02\0\0\0"[..]); + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 1); + assert_eq!(res.inserted_id(), None); + + Ok(()) + } + + #[test] + fn delete_6() -> anyhow::Result<()> { + let packet = Bytes::from(&b"\0\x06\0\"\0\0\0"[..]); + let ok = OkPacket::deserialize_with(packet, *CAPABILITIES)?; + let res = MySqlQueryResult(ok); + + assert_eq!(res.rows_affected(), 6); + assert_eq!(res.inserted_id(), None); + + Ok(()) + } +}