From 6166ec7b8febec91b483e33549b644a342af0374 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Thu, 15 Apr 2021 21:54:05 -0700 Subject: [PATCH] refactor(mysql): raise MySqlClientError for general client-side errors --- sqlx-mysql/src/error.rs | 72 ++----------------- sqlx-mysql/src/error/client.rs | 51 +++++++++++++ sqlx-mysql/src/error/database.rs | 59 +++++++++++++++ sqlx-mysql/src/lib.rs | 2 +- sqlx-mysql/src/protocol/auth_plugin.rs | 26 +------ .../src/protocol/auth_plugin/caching_sha2.rs | 19 +++-- sqlx-mysql/src/protocol/auth_plugin/dialog.rs | 12 ++-- sqlx-mysql/src/protocol/auth_plugin/rsa.rs | 19 +++-- sqlx-mysql/src/protocol/auth_plugin/sha256.rs | 16 +++-- sqlx-mysql/src/protocol/auth_response.rs | 7 +- sqlx-mysql/src/protocol/query_response.rs | 7 +- sqlx-mysql/src/protocol/query_step.rs | 10 +-- sqlx-mysql/src/stream.rs | 11 ++- sqlx-postgres/src/error/client.rs | 4 +- sqlx-postgres/src/protocol/backend/auth.rs | 2 +- 15 files changed, 175 insertions(+), 142 deletions(-) create mode 100644 sqlx-mysql/src/error/client.rs create mode 100644 sqlx-mysql/src/error/database.rs diff --git a/sqlx-mysql/src/error.rs b/sqlx-mysql/src/error.rs index e995191d..d0dcae06 100644 --- a/sqlx-mysql/src/error.rs +++ b/sqlx-mysql/src/error.rs @@ -1,69 +1,5 @@ -use std::error::Error as StdError; -use std::fmt::{self, Display, Formatter}; +mod client; +mod database; -use sqlx_core::DatabaseError; - -use crate::protocol::ErrPacket; - -/// An error returned from the MySQL database. -#[allow(clippy::module_name_repetitions)] -#[derive(Debug)] -pub struct MySqlDatabaseError(pub(crate) ErrPacket); - -impl MySqlDatabaseError { - /// Returns a human-readable error message. - pub fn message(&self) -> &str { - &*self.0.error_message - } - - /// Returns the error code. - /// - /// All possible error codes should be documented in - /// the [Server Error Message Reference]. Each code refers to a - /// unique error messasge. - /// - /// [Server Error Message Reference]: https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html - /// - pub const fn code(&self) -> u16 { - self.0.error_code - } - - /// Return the [SQLSTATE] error code. - /// - /// The error code consists of 5 characters with `"00000"` - /// meaning "no error". [SQLSTATE] values are defined by the SQL standard - /// and should be consistent across databases. - /// - /// [SQLSTATE]: https://en.wikipedia.org/wiki/SQLSTATE - /// - pub fn sql_state(&self) -> &str { - self.0.sql_state.as_deref().unwrap_or_default() - } -} - -impl MySqlDatabaseError { - pub(crate) fn new(code: u16, message: &str) -> Self { - Self(ErrPacket::new(code, message)) - } - - pub(crate) fn malformed_packet(message: &str) -> Self { - Self::new(2027, &format!("Malformed packet: {}", message)) - } -} - -impl DatabaseError for MySqlDatabaseError { - fn message(&self) -> &str { - &self.0.error_message - } -} - -impl Display for MySqlDatabaseError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match &self.0.sql_state { - Some(state) => write!(f, "{} ({}): {}", self.0.error_code, state, self.message()), - None => write!(f, "{}: {}", self.0.error_code, self.message()), - } - } -} - -impl StdError for MySqlDatabaseError {} +pub use client::MySqlClientError; +pub use database::MySqlDatabaseError; diff --git a/sqlx-mysql/src/error/client.rs b/sqlx-mysql/src/error/client.rs new file mode 100644 index 00000000..076c4325 --- /dev/null +++ b/sqlx-mysql/src/error/client.rs @@ -0,0 +1,51 @@ +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; + +use crate::protocol::AuthPlugin; +use sqlx_core::{ClientError, Error}; + +#[derive(Debug)] +#[non_exhaustive] +pub enum MySqlClientError { + UnknownAuthPlugin(String), + AuthPlugin { plugin: &'static str, source: Box }, + EmptyPacket { context: &'static str }, + UnexpectedPacketSize { expected: usize, actual: usize }, +} + +impl MySqlClientError { + pub(crate) fn auth_plugin( + plugin: &impl AuthPlugin, + source: impl Into>, + ) -> Self { + Self::AuthPlugin { plugin: plugin.name(), source: source.into() } + } +} + +impl Display for MySqlClientError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::UnknownAuthPlugin(name) => write!(f, "unknown authentication plugin: {}", name), + + Self::AuthPlugin { plugin, source } => { + write!(f, "authentication plugin '{}' reported error: {}", plugin, source) + } + + Self::EmptyPacket { context } => write!(f, "received no bytes for {}", context), + + Self::UnexpectedPacketSize { actual, expected } => { + write!(f, "received {} bytes for packet but expecting {} bytes", actual, expected) + } + } + } +} + +impl StdError for MySqlClientError {} + +impl ClientError for MySqlClientError {} + +impl From for Error { + fn from(err: MySqlClientError) -> Self { + Self::client(err) + } +} diff --git a/sqlx-mysql/src/error/database.rs b/sqlx-mysql/src/error/database.rs new file mode 100644 index 00000000..f867dea0 --- /dev/null +++ b/sqlx-mysql/src/error/database.rs @@ -0,0 +1,59 @@ +use std::error::Error as StdError; +use std::fmt::{self, Display, Formatter}; + +use sqlx_core::DatabaseError; + +use crate::protocol::ErrPacket; + +/// An error returned from the MySQL database. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug)] +pub struct MySqlDatabaseError(pub(crate) ErrPacket); + +impl MySqlDatabaseError { + /// Returns a human-readable error message. + pub fn message(&self) -> &str { + &*self.0.error_message + } + + /// Returns the error code. + /// + /// All possible error codes should be documented in + /// the [Server Error Message Reference]. Each code refers to a + /// unique error messasge. + /// + /// [Server Error Message Reference]: https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html + /// + pub const fn code(&self) -> u16 { + self.0.error_code + } + + /// Return the [SQLSTATE] error code. + /// + /// The error code consists of 5 characters with `"00000"` + /// meaning "no error". [SQLSTATE] values are defined by the SQL standard + /// and should be consistent across databases. + /// + /// [SQLSTATE]: https://en.wikipedia.org/wiki/SQLSTATE + /// + pub fn sql_state(&self) -> &str { + self.0.sql_state.as_deref().unwrap_or_default() + } +} + +impl DatabaseError for MySqlDatabaseError { + fn message(&self) -> &str { + &self.0.error_message + } +} + +impl Display for MySqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match &self.0.sql_state { + Some(state) => write!(f, "{} ({}): {}", self.0.error_code, state, self.message()), + None => write!(f, "{}: {}", self.0.error_code, self.message()), + } + } +} + +impl StdError for MySqlDatabaseError {} diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index f56551a3..79a53db6 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -48,7 +48,7 @@ mod mock; pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; -pub use error::MySqlDatabaseError; +pub use error::{MySqlClientError, MySqlDatabaseError}; pub use options::MySqlConnectOptions; pub use output::MySqlOutput; pub use query_result::MySqlQueryResult; diff --git a/sqlx-mysql/src/protocol/auth_plugin.rs b/sqlx-mysql/src/protocol/auth_plugin.rs index 5c8bfd16..585fc065 100644 --- a/sqlx-mysql/src/protocol/auth_plugin.rs +++ b/sqlx-mysql/src/protocol/auth_plugin.rs @@ -1,11 +1,10 @@ -use std::error::Error as StdError; use std::fmt::Debug; use bytes::buf::Chain; use bytes::Bytes; -use sqlx_core::{Error, Result}; +use sqlx_core::Result; -use crate::MySqlDatabaseError; +use crate::MySqlClientError; mod caching_sha2; mod dialog; @@ -44,11 +43,7 @@ impl dyn AuthPlugin { _ if s == NativeAuthPlugin.name() => Ok(Box::new(NativeAuthPlugin)), _ if s == DialogAuthPlugin.name() => Ok(Box::new(DialogAuthPlugin)), - _ => Err(MySqlDatabaseError::new( - 2059, - &format!("Authentication plugin '{}' cannot be loaded", s), - ) - .into()), + _ => Err(MySqlClientError::UnknownAuthPlugin(s.to_owned()).into()), } } } @@ -62,18 +57,3 @@ fn xor_eq(x: &mut [u8], y: &[u8]) { x[i] ^= y[i % y_len]; } } - -fn err_msg(plugin: &'static str, message: &str) -> Error { - MySqlDatabaseError::new( - 2061, - &format!("Authentication plugin '{}' reported error: {}", plugin, message), - ) - .into() -} - -fn err(plugin: &'static str, error: &E) -> Error -where - E: StdError, -{ - err_msg(plugin, &error.to_string()) -} diff --git a/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs b/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs index 6a754384..420c4247 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs @@ -1,3 +1,7 @@ +use super::rsa::encrypt as rsa_encrypt; +use super::xor_eq; +use crate::protocol::AuthPlugin; +use crate::MySqlClientError; use bytes::buf::Chain; use bytes::Bytes; use sha2::{Digest, Sha256}; @@ -12,7 +16,7 @@ use sqlx_core::Result; #[derive(Debug)] pub(crate) struct CachingSha2AuthPlugin; -impl super::AuthPlugin for CachingSha2AuthPlugin { +impl AuthPlugin for CachingSha2AuthPlugin { fn name(&self) -> &'static str { "caching_sha2_password" } @@ -44,7 +48,7 @@ impl super::AuthPlugin for CachingSha2AuthPlugin { // SHA256( nonce + SHA256( SHA256( password ) ) ) let nonce_pw_sha1_sha1 = hasher.finalize(); - super::xor_eq(&mut pw_sha2, &nonce_pw_sha1_sha1); + xor_eq(&mut pw_sha2, &nonce_pw_sha1_sha1); pw_sha2.to_vec() } @@ -60,10 +64,11 @@ impl super::AuthPlugin for CachingSha2AuthPlugin { const AUTH_CONTINUE: u8 = 0x4; if command != 0x01 { - return Err(super::err_msg( - self.name(), - &format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command), - )); + return Err(MySqlClientError::auth_plugin( + self, + format!("received 0x{:x} but expected 0x1 (MORE DATA)", command), + ) + .into()); } match data[0] { @@ -78,7 +83,7 @@ impl super::AuthPlugin for CachingSha2AuthPlugin { _ => { let rsa_pub_key = data; - let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?; + let encrypted = rsa_encrypt(self, &rsa_pub_key, password, nonce)?; Ok(Some(encrypted)) } diff --git a/sqlx-mysql/src/protocol/auth_plugin/dialog.rs b/sqlx-mysql/src/protocol/auth_plugin/dialog.rs index 92adbd4e..3a6cbcee 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/dialog.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/dialog.rs @@ -1,8 +1,7 @@ -use std::borrow::Cow; - +use crate::MySqlClientError; use bytes::buf::Chain; use bytes::Bytes; -use sqlx_core::{Error, Result}; +use sqlx_core::Result; /// Dialog authentication implementation /// @@ -27,9 +26,10 @@ impl super::AuthPlugin for DialogAuthPlugin { _nonce: &Chain, _password: &str, ) -> Result>> { - Err(super::err_msg( - self.name(), + Err(MySqlClientError::auth_plugin( + self, "interactive dialog authentication is currently not supported", - )) + ) + .into()) } } diff --git a/sqlx-mysql/src/protocol/auth_plugin/rsa.rs b/sqlx-mysql/src/protocol/auth_plugin/rsa.rs index d6b21bc6..329a0c98 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/rsa.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/rsa.rs @@ -1,12 +1,15 @@ use std::str::from_utf8; +use crate::protocol::auth_plugin::xor_eq; +use crate::protocol::AuthPlugin; +use crate::MySqlClientError; use bytes::buf::Chain; use bytes::Bytes; use rsa::{PaddingScheme, PublicKey, RSAPublicKey}; use sqlx_core::Result; pub(crate) fn encrypt( - plugin: &'static str, + plugin: &impl AuthPlugin, key: &[u8], password: &str, nonce: &Chain, @@ -20,18 +23,20 @@ pub(crate) fn encrypt( nonce.extend_from_slice(&*a); nonce.extend_from_slice(&*b); - super::xor_eq(&mut pass, &*nonce); + xor_eq(&mut pass, &*nonce); // client sends an RSA encrypted password let public = parse_rsa_pub_key(plugin, key)?; let padding = PaddingScheme::new_oaep::(); - public.encrypt(&mut rng(), padding, &pass[..]).map_err(|err| super::err(plugin, &err)) + public + .encrypt(&mut rng(), padding, &pass[..]) + .map_err(|err| MySqlClientError::auth_plugin(plugin, err).into()) } // https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1 -fn parse_rsa_pub_key(plugin: &'static str, key: &[u8]) -> Result { - let key = from_utf8(key).map_err(|err| super::err(plugin, &err))?; +fn parse_rsa_pub_key(plugin: &impl AuthPlugin, key: &[u8]) -> Result { + let key = from_utf8(key).map_err(|err| MySqlClientError::auth_plugin(plugin, err))?; // Takes advantage of the knowledge that we know // we are receiving a PKCS#8 RSA Public Key at all @@ -43,9 +48,9 @@ fn parse_rsa_pub_key(plugin: &'static str, key: &[u8]) -> Result { data }); - let der = base64::decode(&encoded).map_err(|err| super::err(plugin, &err))?; + let der = base64::decode(&encoded).map_err(|err| MySqlClientError::auth_plugin(plugin, err))?; - RSAPublicKey::from_pkcs8(&der).map_err(|err| super::err(plugin, &err)) + RSAPublicKey::from_pkcs8(&der).map_err(|err| MySqlClientError::auth_plugin(plugin, err).into()) } fn to_asciz(s: &str) -> Vec { diff --git a/sqlx-mysql/src/protocol/auth_plugin/sha256.rs b/sqlx-mysql/src/protocol/auth_plugin/sha256.rs index cc92d1f7..418a4fd8 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/sha256.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/sha256.rs @@ -1,3 +1,6 @@ +use super::rsa::encrypt as rsa_encrypt; +use crate::protocol::AuthPlugin; +use crate::MySqlClientError; use bytes::buf::Chain; use bytes::Bytes; use sqlx_core::Result; @@ -13,7 +16,7 @@ use sqlx_core::Result; #[derive(Debug)] pub(crate) struct Sha256AuthPlugin; -impl super::AuthPlugin for Sha256AuthPlugin { +impl AuthPlugin for Sha256AuthPlugin { fn name(&self) -> &'static str { "sha256_password" } @@ -36,14 +39,15 @@ impl super::AuthPlugin for Sha256AuthPlugin { password: &str, ) -> Result>> { if command != 0x01 { - return Err(super::err_msg( - self.name(), - &format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command), - )); + return Err(MySqlClientError::auth_plugin( + self, + format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command), + ) + .into()); } let rsa_pub_key = data; - let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?; + let encrypted = rsa_encrypt(self, &rsa_pub_key, password, nonce)?; Ok(Some(encrypted)) } diff --git a/sqlx-mysql/src/protocol/auth_response.rs b/sqlx-mysql/src/protocol/auth_response.rs index 8a9fed62..e109de2f 100644 --- a/sqlx-mysql/src/protocol/auth_response.rs +++ b/sqlx-mysql/src/protocol/auth_response.rs @@ -5,7 +5,7 @@ use sqlx_core::io::Deserialize; use sqlx_core::Result; use crate::protocol::{AuthSwitch, Capabilities, ResultPacket}; -use crate::MySqlDatabaseError; +use crate::{MySqlClientError, MySqlDatabaseError}; #[derive(Debug)] pub(crate) enum AuthResponse { @@ -28,10 +28,7 @@ impl Deserialize<'_, Capabilities> for AuthResponse { // send a command to the active auth plugin Some(command) => Ok(Self::Command(*command, buf.slice(1..))), - None => { - Err(MySqlDatabaseError::malformed_packet("Received no bytes for auth response") - .into()) - } + None => Err(MySqlClientError::EmptyPacket { context: "auth response" }.into()), } } } diff --git a/sqlx-mysql/src/protocol/query_response.rs b/sqlx-mysql/src/protocol/query_response.rs index 3558a835..caa90cb0 100644 --- a/sqlx-mysql/src/protocol/query_response.rs +++ b/sqlx-mysql/src/protocol/query_response.rs @@ -4,7 +4,7 @@ use sqlx_core::Result; use super::{Capabilities, ResultPacket}; use crate::io::MySqlBufExt; -use crate::MySqlDatabaseError; +use crate::{MySqlClientError, MySqlDatabaseError}; /// The query-response packet is a meta-packet that starts with one of: /// @@ -46,10 +46,7 @@ impl Deserialize<'_, Capabilities> for QueryResponse { Ok(Self::ResultSet { columns: columns as u16 }) } - None => Err(MySqlDatabaseError::malformed_packet( - "Received no bytes for COM_QUERY response", - ) - .into()), + None => Err(MySqlClientError::EmptyPacket { context: "COM_QUERY response" }.into()), } } } diff --git a/sqlx-mysql/src/protocol/query_step.rs b/sqlx-mysql/src/protocol/query_step.rs index 47e0d0ff..98acd8ae 100644 --- a/sqlx-mysql/src/protocol/query_step.rs +++ b/sqlx-mysql/src/protocol/query_step.rs @@ -4,7 +4,7 @@ use sqlx_core::Result; use super::{Capabilities, ResultPacket}; use crate::protocol::Packet; -use crate::MySqlDatabaseError; +use crate::MySqlClientError; /// /// @@ -30,10 +30,10 @@ impl Deserialize<'_, Capabilities> for QueryStep { // If its non-0, then its a Row Some(_) => Ok(Self::Row(Packet { bytes: buf })), - None => Err(MySqlDatabaseError::malformed_packet( - "Received no bytes for the next step in a result set", - ) - .into()), + None => { + Err(MySqlClientError::EmptyPacket { context: "the next step in a result set" } + .into()) + } } } } diff --git a/sqlx-mysql/src/stream.rs b/sqlx-mysql/src/stream.rs index 340f03d8..e7829e86 100644 --- a/sqlx-mysql/src/stream.rs +++ b/sqlx-mysql/src/stream.rs @@ -7,7 +7,7 @@ use sqlx_core::net::Stream as NetStream; use sqlx_core::{Result, Runtime}; use crate::protocol::{MaybeCommand, Packet, Quit}; -use crate::MySqlDatabaseError; +use crate::{MySqlClientError, MySqlDatabaseError}; /// Reads and writes packets to and from the MySQL database server. /// @@ -101,11 +101,10 @@ impl MySqlStream { if packet.bytes.len() != len { // BUG: something is very wrong somewhere if this branch is executed // either in the SQLx MySQL driver or in the MySQL server - return Err(MySqlDatabaseError::malformed_packet(&format!( - "Received {} bytes for packet but expecting {} bytes", - packet.bytes.len(), - len - )) + return Err(MySqlClientError::UnexpectedPacketSize { + expected: len, + actual: packet.bytes.len(), + } .into()); } diff --git a/sqlx-postgres/src/error/client.rs b/sqlx-postgres/src/error/client.rs index 86ec838e..c2934228 100644 --- a/sqlx-postgres/src/error/client.rs +++ b/sqlx-postgres/src/error/client.rs @@ -10,7 +10,7 @@ pub enum PgClientError { // attempting to interpret data from postgres as UTF-8, when it should // be UTF-8, but for some reason (data corruption?) it is not NotUtf8(Utf8Error), - UnknownAuthenticationMethod(u32), + UnknownAuthMethod(u32), UnknownMessageType(u8), UnknownTransactionStatus(u8), UnknownValueFormat(i16), @@ -22,7 +22,7 @@ impl Display for PgClientError { match self { Self::NotUtf8(source) => write!(f, "unexpected invalid utf-8: {}", source), - Self::UnknownAuthenticationMethod(method) => { + Self::UnknownAuthMethod(method) => { write!(f, "unknown authentication method: {}", method) } diff --git a/sqlx-postgres/src/protocol/backend/auth.rs b/sqlx-postgres/src/protocol/backend/auth.rs index d9899e1a..a2a6fb26 100644 --- a/sqlx-postgres/src/protocol/backend/auth.rs +++ b/sqlx-postgres/src/protocol/backend/auth.rs @@ -50,7 +50,7 @@ impl Deserialize<'_> for Authentication { 11 => AuthenticationSaslContinue::deserialize(buf).map(Self::SaslContinue), 12 => AuthenticationSaslFinal::deserialize(buf).map(Self::SaslFinal), - ty => Err(PgClientError::UnknownAuthenticationMethod(ty).into()), + ty => Err(PgClientError::UnknownAuthMethod(ty).into()), } } }