diff --git a/sqlx-mysql/src/connection/connect.rs b/sqlx-mysql/src/connection/connect.rs index 6411a29f..53e14949 100644 --- a/sqlx-mysql/src/connection/connect.rs +++ b/sqlx-mysql/src/connection/connect.rs @@ -70,8 +70,9 @@ impl MySqlConnection { return Ok(true); } - AuthResponse::MoreData(data) => { + AuthResponse::Command(command, data) => { if let Some(data) = handshake.auth_plugin.handle( + command, data, &handshake.auth_plugin_data, options.get_password().unwrap_or_default(), @@ -168,6 +169,8 @@ mod tests { const SRV_PUBLIC_KEY: &[u8] = b"\x01-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwnXi3nr9TmN+NF49A3Y7\nUBnAVhApNJy2cmuf/y6vFM9eHFu5T80Ij1qYc6c79oAGA8nNNCFQL+0j5De88cln\nKrlzq/Ab3U+j5SqgNwk//F6Y3iyjV4L7feSDqjpcheFzkjEslbm/yoRwQ78AAU6s\nqA0hcFuh66mcvnotDrvZAGQ8U2EbbZa6oiR3wrgbzifSKq767g65zIrCpoyxzKMH\nAETSDIaMKpFio4dRATKT5ASQtPoIyxSBmjRtc22sqlhEeiejEMsJzd6Bliuait+A\nkTXL6G1Tbam26Dok/L88CnTAWAkLwTA3bjPcS8Zl9gTsJvoiMuwW1UPEVV/aJ11Z\n/wIDAQAB\n-----END PUBLIC KEY-----\n"; const SRV_AUTH_OK: &[u8] = b"\0\0\0\x02\0\0\0"; + const SRV_AUTH_ERR: &[u8] = + b"\xff\x15\x04#28000Access denied for user 'root'@'172.17.0.1' (using password: YES)"; const SRV_AUTH_MORE_CONTINUE: &[u8] = b"\x01\x04"; const SRV_AUTH_MORE_OK: &[u8] = b"\x01\x03"; const SRV_SWITCH_CACHING_SHA2_AUTH: &[u8] = @@ -397,6 +400,30 @@ mod tests { }) } + #[test] + fn should_fail_connect_err() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?; + mock.write_packet_async(2, SRV_AUTH_ERR).await?; + + let err = MySqlConnectOptions::new() + .port(mock.port()) + .username("root") + .connect::, _>() + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "1045 (28000): Access denied for user \'root\'@\'172.17.0.1\' (using password: YES)" + ); + + Ok(()) + }) + } + #[test] fn should_not_connect_old_auth() -> anyhow::Result<()> { block_on(async { diff --git a/sqlx-mysql/src/protocol/auth_plugin.rs b/sqlx-mysql/src/protocol/auth_plugin.rs index 498712f9..5c8bfd16 100644 --- a/sqlx-mysql/src/protocol/auth_plugin.rs +++ b/sqlx-mysql/src/protocol/auth_plugin.rs @@ -29,6 +29,7 @@ pub(crate) trait AuthPlugin: 'static + Debug + Send + Sync { // if the plugin returns Some(_) that is sent back to MySQL fn handle( &self, + command: u8, data: Bytes, nonce: &Chain, password: &str, diff --git a/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs b/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs index 4932e2d6..6a754384 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs @@ -51,6 +51,7 @@ impl super::AuthPlugin for CachingSha2AuthPlugin { fn handle( &self, + command: u8, data: Bytes, nonce: &Chain, password: &str, @@ -58,6 +59,13 @@ impl super::AuthPlugin for CachingSha2AuthPlugin { const AUTH_SUCCESS: u8 = 0x3; const AUTH_CONTINUE: u8 = 0x4; + if command != 0x01 { + return Err(super::err_msg( + self.name(), + &format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command), + )); + } + match data[0] { // good to go, return nothing AUTH_SUCCESS => Ok(None), diff --git a/sqlx-mysql/src/protocol/auth_plugin/dialog.rs b/sqlx-mysql/src/protocol/auth_plugin/dialog.rs index f38a160c..d1edd375 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/dialog.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/dialog.rs @@ -1,7 +1,8 @@ +use std::borrow::Cow; + use bytes::buf::Chain; use bytes::Bytes; use sqlx_core::{Error, Result}; -use std::borrow::Cow; /// Dialog authentication implementation /// @@ -21,6 +22,7 @@ impl super::AuthPlugin for DialogAuthPlugin { fn handle( &self, + _command: u8, _data: Bytes, _nonce: &Chain, _password: &str, diff --git a/sqlx-mysql/src/protocol/auth_plugin/native.rs b/sqlx-mysql/src/protocol/auth_plugin/native.rs index 83c215fa..34c9a9e0 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/native.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/native.rs @@ -49,6 +49,7 @@ impl super::AuthPlugin for NativeAuthPlugin { fn handle( &self, + _command: u8, _data: Bytes, _nonce: &Chain, _password: &str, diff --git a/sqlx-mysql/src/protocol/auth_plugin/sha256.rs b/sqlx-mysql/src/protocol/auth_plugin/sha256.rs index 4696fcc3..cc92d1f7 100644 --- a/sqlx-mysql/src/protocol/auth_plugin/sha256.rs +++ b/sqlx-mysql/src/protocol/auth_plugin/sha256.rs @@ -30,10 +30,18 @@ impl super::AuthPlugin for Sha256AuthPlugin { fn handle( &self, + command: u8, data: Bytes, nonce: &Chain, password: &str, ) -> Result>> { + if command != 0x01 { + return Err(super::err_msg( + self.name(), + &format!("Received 0x{:x} but expected 0x1 (MORE DATA)", command), + )); + } + let rsa_pub_key = data; let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?; diff --git a/sqlx-mysql/src/protocol/auth_response.rs b/sqlx-mysql/src/protocol/auth_response.rs index c89021c5..8a9fed62 100644 --- a/sqlx-mysql/src/protocol/auth_response.rs +++ b/sqlx-mysql/src/protocol/auth_response.rs @@ -10,25 +10,28 @@ use crate::MySqlDatabaseError; #[derive(Debug)] pub(crate) enum AuthResponse { End(ResultPacket), - MoreData(Bytes), + Command(u8, Bytes), Switch(AuthSwitch), } impl Deserialize<'_, Capabilities> for AuthResponse { fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { match buf.get(0) { - Some(0x00) => ResultPacket::deserialize_with(buf, capabilities).map(Self::End), - Some(0x01) => Ok(Self::MoreData(buf.slice(1..))), + // OK or ERR -> end the auth cycle + Some(0x00) | Some(0xff) => { + ResultPacket::deserialize_with(buf, capabilities).map(Self::End) + } + + // switch to another auth plugin Some(0xfe) => AuthSwitch::deserialize(buf).map(Self::Switch), - Some(tag) => Err(MySqlDatabaseError::malformed_packet(&format!( - "Received 0x{:x} but expected one of: 0x0 (OK), 0x1 (MORE DATA), or 0xfe (SWITCH) for auth response", - tag - )).into()), + // send a command to the active auth plugin + Some(command) => Ok(Self::Command(*command, buf.slice(1..))), - None => Err(MySqlDatabaseError::malformed_packet( - "Received no bytes for auth response", - ).into()), + None => { + Err(MySqlDatabaseError::malformed_packet("Received no bytes for auth response") + .into()) + } } } }