refactor(mysql): raise MySqlClientError for general client-side errors

This commit is contained in:
Ryan Leckey 2021-04-15 21:54:05 -07:00
parent 0267fe0482
commit 6166ec7b8f
15 changed files with 175 additions and 142 deletions

View File

@ -1,69 +1,5 @@
use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
mod client;
mod database;
use sqlx_core::DatabaseError;
use crate::protocol::ErrPacket;
/// An error returned from the MySQL database.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct MySqlDatabaseError(pub(crate) ErrPacket);
impl MySqlDatabaseError {
/// Returns a human-readable error message.
pub fn message(&self) -> &str {
&*self.0.error_message
}
/// Returns the error code.
///
/// All possible error codes should be documented in
/// the [Server Error Message Reference]. Each code refers to a
/// unique error messasge.
///
/// [Server Error Message Reference]: https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html
///
pub const fn code(&self) -> u16 {
self.0.error_code
}
/// Return the [SQLSTATE] error code.
///
/// The error code consists of 5 characters with `"00000"`
/// meaning "no error". [SQLSTATE] values are defined by the SQL standard
/// and should be consistent across databases.
///
/// [SQLSTATE]: https://en.wikipedia.org/wiki/SQLSTATE
///
pub fn sql_state(&self) -> &str {
self.0.sql_state.as_deref().unwrap_or_default()
}
}
impl MySqlDatabaseError {
pub(crate) fn new(code: u16, message: &str) -> Self {
Self(ErrPacket::new(code, message))
}
pub(crate) fn malformed_packet(message: &str) -> Self {
Self::new(2027, &format!("Malformed packet: {}", message))
}
}
impl DatabaseError for MySqlDatabaseError {
fn message(&self) -> &str {
&self.0.error_message
}
}
impl Display for MySqlDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.0.sql_state {
Some(state) => write!(f, "{} ({}): {}", self.0.error_code, state, self.message()),
None => write!(f, "{}: {}", self.0.error_code, self.message()),
}
}
}
impl StdError for MySqlDatabaseError {}
pub use client::MySqlClientError;
pub use database::MySqlDatabaseError;

View File

@ -0,0 +1,51 @@
use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
use crate::protocol::AuthPlugin;
use sqlx_core::{ClientError, Error};
#[derive(Debug)]
#[non_exhaustive]
pub enum MySqlClientError {
UnknownAuthPlugin(String),
AuthPlugin { plugin: &'static str, source: Box<dyn StdError + 'static + Send + Sync> },
EmptyPacket { context: &'static str },
UnexpectedPacketSize { expected: usize, actual: usize },
}
impl MySqlClientError {
pub(crate) fn auth_plugin(
plugin: &impl AuthPlugin,
source: impl Into<Box<dyn StdError + Send + Sync>>,
) -> Self {
Self::AuthPlugin { plugin: plugin.name(), source: source.into() }
}
}
impl Display for MySqlClientError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::UnknownAuthPlugin(name) => write!(f, "unknown authentication plugin: {}", name),
Self::AuthPlugin { plugin, source } => {
write!(f, "authentication plugin '{}' reported error: {}", plugin, source)
}
Self::EmptyPacket { context } => write!(f, "received no bytes for {}", context),
Self::UnexpectedPacketSize { actual, expected } => {
write!(f, "received {} bytes for packet but expecting {} bytes", actual, expected)
}
}
}
}
impl StdError for MySqlClientError {}
impl ClientError for MySqlClientError {}
impl From<MySqlClientError> for Error {
fn from(err: MySqlClientError) -> Self {
Self::client(err)
}
}

View File

@ -0,0 +1,59 @@
use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
use sqlx_core::DatabaseError;
use crate::protocol::ErrPacket;
/// An error returned from the MySQL database.
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct MySqlDatabaseError(pub(crate) ErrPacket);
impl MySqlDatabaseError {
/// Returns a human-readable error message.
pub fn message(&self) -> &str {
&*self.0.error_message
}
/// Returns the error code.
///
/// All possible error codes should be documented in
/// the [Server Error Message Reference]. Each code refers to a
/// unique error messasge.
///
/// [Server Error Message Reference]: https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html
///
pub const fn code(&self) -> u16 {
self.0.error_code
}
/// Return the [SQLSTATE] error code.
///
/// The error code consists of 5 characters with `"00000"`
/// meaning "no error". [SQLSTATE] values are defined by the SQL standard
/// and should be consistent across databases.
///
/// [SQLSTATE]: https://en.wikipedia.org/wiki/SQLSTATE
///
pub fn sql_state(&self) -> &str {
self.0.sql_state.as_deref().unwrap_or_default()
}
}
impl DatabaseError for MySqlDatabaseError {
fn message(&self) -> &str {
&self.0.error_message
}
}
impl Display for MySqlDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.0.sql_state {
Some(state) => write!(f, "{} ({}): {}", self.0.error_code, state, self.message()),
None => write!(f, "{}: {}", self.0.error_code, self.message()),
}
}
}
impl StdError for MySqlDatabaseError {}

View File

@ -48,7 +48,7 @@ mod mock;
pub use column::MySqlColumn;
pub use connection::MySqlConnection;
pub use database::MySql;
pub use error::MySqlDatabaseError;
pub use error::{MySqlClientError, MySqlDatabaseError};
pub use options::MySqlConnectOptions;
pub use output::MySqlOutput;
pub use query_result::MySqlQueryResult;

View File

@ -1,11 +1,10 @@
use std::error::Error as StdError;
use std::fmt::Debug;
use bytes::buf::Chain;
use bytes::Bytes;
use sqlx_core::{Error, Result};
use sqlx_core::Result;
use crate::MySqlDatabaseError;
use crate::MySqlClientError;
mod caching_sha2;
mod dialog;
@ -44,11 +43,7 @@ impl dyn AuthPlugin {
_ if s == NativeAuthPlugin.name() => Ok(Box::new(NativeAuthPlugin)),
_ if s == DialogAuthPlugin.name() => Ok(Box::new(DialogAuthPlugin)),
_ => Err(MySqlDatabaseError::new(
2059,
&format!("Authentication plugin '{}' cannot be loaded", s),
)
.into()),
_ => Err(MySqlClientError::UnknownAuthPlugin(s.to_owned()).into()),
}
}
}
@ -62,18 +57,3 @@ fn xor_eq(x: &mut [u8], y: &[u8]) {
x[i] ^= y[i % y_len];
}
}
fn err_msg(plugin: &'static str, message: &str) -> Error {
MySqlDatabaseError::new(
2061,
&format!("Authentication plugin '{}' reported error: {}", plugin, message),
)
.into()
}
fn err<E>(plugin: &'static str, error: &E) -> Error
where
E: StdError,
{
err_msg(plugin, &error.to_string())
}

View File

@ -1,3 +1,7 @@
use super::rsa::encrypt as rsa_encrypt;
use super::xor_eq;
use crate::protocol::AuthPlugin;
use crate::MySqlClientError;
use bytes::buf::Chain;
use bytes::Bytes;
use sha2::{Digest, Sha256};
@ -12,7 +16,7 @@ use sqlx_core::Result;
#[derive(Debug)]
pub(crate) struct CachingSha2AuthPlugin;
impl super::AuthPlugin for CachingSha2AuthPlugin {
impl AuthPlugin for CachingSha2AuthPlugin {
fn name(&self) -> &'static str {
"caching_sha2_password"
}
@ -44,7 +48,7 @@ impl super::AuthPlugin for CachingSha2AuthPlugin {
// SHA256( nonce + SHA256( SHA256( password ) ) )
let nonce_pw_sha1_sha1 = hasher.finalize();
super::xor_eq(&mut pw_sha2, &nonce_pw_sha1_sha1);
xor_eq(&mut pw_sha2, &nonce_pw_sha1_sha1);
pw_sha2.to_vec()
}
@ -60,10 +64,11 @@ impl super::AuthPlugin for CachingSha2AuthPlugin {
const AUTH_CONTINUE: u8 = 0x4;
if command != 0x01 {
return Err(super::err_msg(
self.name(),
&format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command),
));
return Err(MySqlClientError::auth_plugin(
self,
format!("received 0x{:x} but expected 0x1 (MORE DATA)", command),
)
.into());
}
match data[0] {
@ -78,7 +83,7 @@ impl super::AuthPlugin for CachingSha2AuthPlugin {
_ => {
let rsa_pub_key = data;
let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?;
let encrypted = rsa_encrypt(self, &rsa_pub_key, password, nonce)?;
Ok(Some(encrypted))
}

View File

@ -1,8 +1,7 @@
use std::borrow::Cow;
use crate::MySqlClientError;
use bytes::buf::Chain;
use bytes::Bytes;
use sqlx_core::{Error, Result};
use sqlx_core::Result;
/// Dialog authentication implementation
///
@ -27,9 +26,10 @@ impl super::AuthPlugin for DialogAuthPlugin {
_nonce: &Chain<Bytes, Bytes>,
_password: &str,
) -> Result<Option<Vec<u8>>> {
Err(super::err_msg(
self.name(),
Err(MySqlClientError::auth_plugin(
self,
"interactive dialog authentication is currently not supported",
))
)
.into())
}
}

View File

@ -1,12 +1,15 @@
use std::str::from_utf8;
use crate::protocol::auth_plugin::xor_eq;
use crate::protocol::AuthPlugin;
use crate::MySqlClientError;
use bytes::buf::Chain;
use bytes::Bytes;
use rsa::{PaddingScheme, PublicKey, RSAPublicKey};
use sqlx_core::Result;
pub(crate) fn encrypt(
plugin: &'static str,
plugin: &impl AuthPlugin,
key: &[u8],
password: &str,
nonce: &Chain<Bytes, Bytes>,
@ -20,18 +23,20 @@ pub(crate) fn encrypt(
nonce.extend_from_slice(&*a);
nonce.extend_from_slice(&*b);
super::xor_eq(&mut pass, &*nonce);
xor_eq(&mut pass, &*nonce);
// client sends an RSA encrypted password
let public = parse_rsa_pub_key(plugin, key)?;
let padding = PaddingScheme::new_oaep::<sha1::Sha1>();
public.encrypt(&mut rng(), padding, &pass[..]).map_err(|err| super::err(plugin, &err))
public
.encrypt(&mut rng(), padding, &pass[..])
.map_err(|err| MySqlClientError::auth_plugin(plugin, err).into())
}
// https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1
fn parse_rsa_pub_key(plugin: &'static str, key: &[u8]) -> Result<RSAPublicKey> {
let key = from_utf8(key).map_err(|err| super::err(plugin, &err))?;
fn parse_rsa_pub_key(plugin: &impl AuthPlugin, key: &[u8]) -> Result<RSAPublicKey> {
let key = from_utf8(key).map_err(|err| MySqlClientError::auth_plugin(plugin, err))?;
// Takes advantage of the knowledge that we know
// we are receiving a PKCS#8 RSA Public Key at all
@ -43,9 +48,9 @@ fn parse_rsa_pub_key(plugin: &'static str, key: &[u8]) -> Result<RSAPublicKey> {
data
});
let der = base64::decode(&encoded).map_err(|err| super::err(plugin, &err))?;
let der = base64::decode(&encoded).map_err(|err| MySqlClientError::auth_plugin(plugin, err))?;
RSAPublicKey::from_pkcs8(&der).map_err(|err| super::err(plugin, &err))
RSAPublicKey::from_pkcs8(&der).map_err(|err| MySqlClientError::auth_plugin(plugin, err).into())
}
fn to_asciz(s: &str) -> Vec<u8> {

View File

@ -1,3 +1,6 @@
use super::rsa::encrypt as rsa_encrypt;
use crate::protocol::AuthPlugin;
use crate::MySqlClientError;
use bytes::buf::Chain;
use bytes::Bytes;
use sqlx_core::Result;
@ -13,7 +16,7 @@ use sqlx_core::Result;
#[derive(Debug)]
pub(crate) struct Sha256AuthPlugin;
impl super::AuthPlugin for Sha256AuthPlugin {
impl AuthPlugin for Sha256AuthPlugin {
fn name(&self) -> &'static str {
"sha256_password"
}
@ -36,14 +39,15 @@ impl super::AuthPlugin for Sha256AuthPlugin {
password: &str,
) -> Result<Option<Vec<u8>>> {
if command != 0x01 {
return Err(super::err_msg(
self.name(),
&format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command),
));
return Err(MySqlClientError::auth_plugin(
self,
format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command),
)
.into());
}
let rsa_pub_key = data;
let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?;
let encrypted = rsa_encrypt(self, &rsa_pub_key, password, nonce)?;
Ok(Some(encrypted))
}

View File

@ -5,7 +5,7 @@ use sqlx_core::io::Deserialize;
use sqlx_core::Result;
use crate::protocol::{AuthSwitch, Capabilities, ResultPacket};
use crate::MySqlDatabaseError;
use crate::{MySqlClientError, MySqlDatabaseError};
#[derive(Debug)]
pub(crate) enum AuthResponse {
@ -28,10 +28,7 @@ impl Deserialize<'_, Capabilities> for AuthResponse {
// send a command to the active auth plugin
Some(command) => Ok(Self::Command(*command, buf.slice(1..))),
None => {
Err(MySqlDatabaseError::malformed_packet("Received no bytes for auth response")
.into())
}
None => Err(MySqlClientError::EmptyPacket { context: "auth response" }.into()),
}
}
}

View File

@ -4,7 +4,7 @@ use sqlx_core::Result;
use super::{Capabilities, ResultPacket};
use crate::io::MySqlBufExt;
use crate::MySqlDatabaseError;
use crate::{MySqlClientError, MySqlDatabaseError};
/// The query-response packet is a meta-packet that starts with one of:
///
@ -46,10 +46,7 @@ impl Deserialize<'_, Capabilities> for QueryResponse {
Ok(Self::ResultSet { columns: columns as u16 })
}
None => Err(MySqlDatabaseError::malformed_packet(
"Received no bytes for COM_QUERY response",
)
.into()),
None => Err(MySqlClientError::EmptyPacket { context: "COM_QUERY response" }.into()),
}
}
}

View File

@ -4,7 +4,7 @@ use sqlx_core::Result;
use super::{Capabilities, ResultPacket};
use crate::protocol::Packet;
use crate::MySqlDatabaseError;
use crate::MySqlClientError;
/// <https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset>
/// <https://mariadb.com/kb/en/result-set-packets/>
@ -30,10 +30,10 @@ impl Deserialize<'_, Capabilities> for QueryStep {
// If its non-0, then its a Row
Some(_) => Ok(Self::Row(Packet { bytes: buf })),
None => Err(MySqlDatabaseError::malformed_packet(
"Received no bytes for the next step in a result set",
)
.into()),
None => {
Err(MySqlClientError::EmptyPacket { context: "the next step in a result set" }
.into())
}
}
}
}

View File

@ -7,7 +7,7 @@ use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Result, Runtime};
use crate::protocol::{MaybeCommand, Packet, Quit};
use crate::MySqlDatabaseError;
use crate::{MySqlClientError, MySqlDatabaseError};
/// Reads and writes packets to and from the MySQL database server.
///
@ -101,11 +101,10 @@ impl<Rt: Runtime> MySqlStream<Rt> {
if packet.bytes.len() != len {
// BUG: something is very wrong somewhere if this branch is executed
// either in the SQLx MySQL driver or in the MySQL server
return Err(MySqlDatabaseError::malformed_packet(&format!(
"Received {} bytes for packet but expecting {} bytes",
packet.bytes.len(),
len
))
return Err(MySqlClientError::UnexpectedPacketSize {
expected: len,
actual: packet.bytes.len(),
}
.into());
}

View File

@ -10,7 +10,7 @@ pub enum PgClientError {
// attempting to interpret data from postgres as UTF-8, when it should
// be UTF-8, but for some reason (data corruption?) it is not
NotUtf8(Utf8Error),
UnknownAuthenticationMethod(u32),
UnknownAuthMethod(u32),
UnknownMessageType(u8),
UnknownTransactionStatus(u8),
UnknownValueFormat(i16),
@ -22,7 +22,7 @@ impl Display for PgClientError {
match self {
Self::NotUtf8(source) => write!(f, "unexpected invalid utf-8: {}", source),
Self::UnknownAuthenticationMethod(method) => {
Self::UnknownAuthMethod(method) => {
write!(f, "unknown authentication method: {}", method)
}

View File

@ -50,7 +50,7 @@ impl Deserialize<'_> for Authentication {
11 => AuthenticationSaslContinue::deserialize(buf).map(Self::SaslContinue),
12 => AuthenticationSaslFinal::deserialize(buf).map(Self::SaslFinal),
ty => Err(PgClientError::UnknownAuthenticationMethod(ty).into()),
ty => Err(PgClientError::UnknownAuthMethod(ty).into()),
}
}
}