From 255755793592ca98ac18fb72b973fb917985d4a0 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Fri, 8 Jan 2021 15:28:26 -0800 Subject: [PATCH] feat(mysql): impl full connect phase for MySQL with support for: - mysql_native_password - caching_sha2_password - sha256_password - non-default auth plugin (new) --- sqlx-mysql/Cargo.toml | 12 +- sqlx-mysql/src/auth.rs | 13 - sqlx-mysql/src/auth/native.rs | 35 -- sqlx-mysql/src/blocking.rs | 2 - sqlx-mysql/src/blocking/connection.rs | 17 - sqlx-mysql/src/blocking/options.rs | 17 - sqlx-mysql/src/connection.rs | 28 +- sqlx-mysql/src/connection/connect.rs | 406 ++++++++++++++++++ sqlx-mysql/src/connection/establish.rs | 139 ------ sqlx-mysql/src/connection/stream.rs | 154 +++++++ sqlx-mysql/src/error.rs | 6 + sqlx-mysql/src/io/buf.rs | 10 +- sqlx-mysql/src/lib.rs | 10 +- sqlx-mysql/src/mock.rs | 67 +++ sqlx-mysql/src/options.rs | 18 +- sqlx-mysql/src/options/parse.rs | 2 +- sqlx-mysql/src/protocol.rs | 12 +- sqlx-mysql/src/protocol/auth.rs | 46 ++ sqlx-mysql/src/protocol/auth_plugin.rs | 76 ++++ .../src/protocol/auth_plugin/caching_sha2.rs | 79 ++++ sqlx-mysql/src/protocol/auth_plugin/native.rs | 60 +++ sqlx-mysql/src/protocol/auth_plugin/rsa.rs | 70 +++ sqlx-mysql/src/protocol/auth_plugin/sha256.rs | 42 ++ sqlx-mysql/src/protocol/auth_switch.rs | 39 ++ sqlx-mysql/src/protocol/err.rs | 17 +- sqlx-mysql/src/protocol/handshake.rs | 62 +-- sqlx-mysql/src/protocol/handshake_response.rs | 9 +- sqlx-mysql/src/protocol/ok.rs | 2 +- 28 files changed, 1160 insertions(+), 290 deletions(-) delete mode 100644 sqlx-mysql/src/auth.rs delete mode 100644 sqlx-mysql/src/auth/native.rs delete mode 100644 sqlx-mysql/src/blocking.rs delete mode 100644 sqlx-mysql/src/blocking/connection.rs delete mode 100644 sqlx-mysql/src/blocking/options.rs create mode 100644 sqlx-mysql/src/connection/connect.rs delete mode 100644 sqlx-mysql/src/connection/establish.rs create mode 100644 sqlx-mysql/src/connection/stream.rs create mode 100644 sqlx-mysql/src/mock.rs create mode 100644 sqlx-mysql/src/protocol/auth.rs create mode 100644 sqlx-mysql/src/protocol/auth_plugin.rs create mode 100644 sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs create mode 100644 sqlx-mysql/src/protocol/auth_plugin/native.rs create mode 100644 sqlx-mysql/src/protocol/auth_plugin/rsa.rs create mode 100644 sqlx-mysql/src/protocol/auth_plugin/sha256.rs create mode 100644 sqlx-mysql/src/protocol/auth_switch.rs diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index 28c94e3b..a675c569 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -30,11 +30,21 @@ async = ["futures-util", "sqlx-core/async", "futures-io"] sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" } futures-util = { version = "0.3.8", optional = true } either = "1.6.1" +log = "0.4.11" +bytestring = "1.0.0" url = "2.2.0" percent-encoding = "2.1.0" futures-io = { version = "0.3", optional = true } bytes = "1.0" memchr = "2.3" bitflags = "1.2" -string = { version = "0.2.1", default-features = false } sha-1 = "0.9.2" +sha2 = "0.9.2" +rsa = "0.3.0" +base64 = "0.13.0" +rand = "0.7" + +[dev-dependencies] +sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] } +futures-executor = "0.3.8" +anyhow = "1.0.37" diff --git a/sqlx-mysql/src/auth.rs b/sqlx-mysql/src/auth.rs deleted file mode 100644 index 7ddbbc6d..00000000 --- a/sqlx-mysql/src/auth.rs +++ /dev/null @@ -1,13 +0,0 @@ -pub(crate) mod native; -// mod caching_sha2; -// mod sha256; - -// XOR(x, y) -// If len(y) < len(x), wrap around inside y -fn xor_eq(x: &mut [u8], y: &[u8]) { - let y_len = y.len(); - - for i in 0..x.len() { - x[i] ^= y[i % y_len]; - } -} diff --git a/sqlx-mysql/src/auth/native.rs b/sqlx-mysql/src/auth/native.rs deleted file mode 100644 index 220d8aae..00000000 --- a/sqlx-mysql/src/auth/native.rs +++ /dev/null @@ -1,35 +0,0 @@ -use bytes::{buf::Chain, Bytes}; -use sha1::{Digest, Sha1}; - -use super::xor_eq; - -// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin -// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html - -pub(crate) fn scramble(nonce: &Chain, password: &str) -> Vec { - // SHA1( password ) ^ SHA1( nonce + SHA1( SHA1( password ) ) ) - - let mut hasher = Sha1::new(); - - hasher.update(password); - - // SHA1( password ) - let mut pw_sha1 = hasher.finalize_reset(); - - hasher.update(&pw_sha1); - - // SHA1( SHA1( password ) ) - let pw_sha1_sha1 = hasher.finalize_reset(); - - // NOTE: use the first 20 bytes of the nonce, we MAY have gotten a nul terminator - hasher.update(nonce.first_ref()); - hasher.update(&nonce.last_ref()[..20 - nonce.first_ref().len()]); - hasher.update(&pw_sha1_sha1); - - // SHA1( seed + SHA1( SHA1( password ) ) ) - let nonce_pw_sha1_sha1 = hasher.finalize(); - - xor_eq(&mut pw_sha1, &nonce_pw_sha1_sha1); - - pw_sha1.to_vec() -} diff --git a/sqlx-mysql/src/blocking.rs b/sqlx-mysql/src/blocking.rs deleted file mode 100644 index 217280ea..00000000 --- a/sqlx-mysql/src/blocking.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod connection; -mod options; diff --git a/sqlx-mysql/src/blocking/connection.rs b/sqlx-mysql/src/blocking/connection.rs deleted file mode 100644 index 199f9343..00000000 --- a/sqlx-mysql/src/blocking/connection.rs +++ /dev/null @@ -1,17 +0,0 @@ -use sqlx_core::blocking::{Connection, Runtime}; -use sqlx_core::Result; - -use crate::MySqlConnection; - -impl Connection for MySqlConnection -where - Rt: Runtime, -{ - fn close(self) -> Result<()> { - unimplemented!() - } - - fn ping(&mut self) -> Result<()> { - unimplemented!() - } -} diff --git a/sqlx-mysql/src/blocking/options.rs b/sqlx-mysql/src/blocking/options.rs deleted file mode 100644 index 3f8f4177..00000000 --- a/sqlx-mysql/src/blocking/options.rs +++ /dev/null @@ -1,17 +0,0 @@ -use sqlx_core::blocking::{ConnectOptions, Connection, Runtime}; -use sqlx_core::Result; - -use crate::{MySqlConnectOptions, MySqlConnection}; - -impl ConnectOptions for MySqlConnectOptions -where - Rt: Runtime, - Self::Connection: sqlx_core::Connection + Connection, -{ - fn connect(&self) -> Result> { - // let stream = ::connect_tcp(self.get_host(), self.get_port())?; - // - // Ok(MySqlConnection { stream }) - todo!() - } -} diff --git a/sqlx-mysql/src/connection.rs b/sqlx-mysql/src/connection.rs index 37d58cf3..5380c022 100644 --- a/sqlx-mysql/src/connection.rs +++ b/sqlx-mysql/src/connection.rs @@ -6,16 +6,24 @@ use sqlx_core::{Connection, DefaultRuntime, Runtime}; use crate::protocol::Capabilities; use crate::{MySql, MySqlConnectOptions}; -#[cfg(feature = "async")] -pub(crate) mod establish; +#[cfg(any(feature = "async", feature = "blocking"))] +mod connect; +#[cfg(any(feature = "async", feature = "blocking"))] +mod stream; + +#[allow(clippy::module_name_repetitions)] pub struct MySqlConnection where Rt: Runtime, { stream: BufStream, connection_id: u32, + + // the capability flags are used by the client and server to indicate which + // features they support and want to use. capabilities: Capabilities, + // the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is // reset to 0 when a new command begins in the Command Phase. sequence_id: u8, @@ -25,6 +33,7 @@ impl MySqlConnection where Rt: Runtime, { + #[cfg(any(feature = "async", feature = "blocking"))] pub(crate) fn new(stream: Rt::TcpStream) -> Self { Self { stream: BufStream::with_capacity(stream, 4096, 1024), @@ -82,3 +91,18 @@ where unimplemented!() } } + +#[cfg(feature = "blocking")] +impl sqlx_core::blocking::Connection for MySqlConnection +where + Rt: sqlx_core::blocking::Runtime, + ::TcpStream: std::io::Read + std::io::Write, +{ + fn close(self) -> sqlx_core::Result<()> { + unimplemented!() + } + + fn ping(&mut self) -> sqlx_core::Result<()> { + unimplemented!() + } +} diff --git a/sqlx-mysql/src/connection/connect.rs b/sqlx-mysql/src/connection/connect.rs new file mode 100644 index 00000000..34c135b7 --- /dev/null +++ b/sqlx-mysql/src/connection/connect.rs @@ -0,0 +1,406 @@ +//! Implements the connection phase. +//! +//! The connection phase (establish) performs these tasks: +//! +//! - exchange the capabilities of client and server +//! - setup SSL communication channel if requested +//! - authenticate the client against the server +//! +//! The server may immediately send an ERR packet and finish the handshake +//! or send a `Handshake`. +//! +//! https://dev.mysql.com/doc/internals/en/connection-phase.html +//! +use sqlx_core::{Result, Runtime}; + +use crate::protocol::{Auth, AuthResponse, Handshake, HandshakeResponse}; +use crate::{MySqlConnectOptions, MySqlConnection}; + +macro_rules! connect { + (@blocking @tcp $options:ident) => { + Rt::connect_tcp($options.get_host(), $options.get_port())?; + }; + + (@tcp $options:ident) => { + Rt::connect_tcp($options.get_host(), $options.get_port()).await?; + }; + + (@blocking @packet $self:ident) => { + $self.read_packet()?; + }; + + (@packet $self:ident) => { + $self.read_packet_async().await?; + }; + + ($(@$blocking:ident)? $options:ident) => {{ + // open a network stream to the database server + let stream = connect!($(@$blocking)? @tcp $options); + + // construct a around the network stream + // wraps the stream in a to buffer read and write + let mut self_ = Self::new(stream); + + // immediately the server should emit a packet + let handshake: Handshake = connect!($(@$blocking)? @packet self_); + + // & the declared server capabilities with our capabilities to find + // what rules the client should operate under + self_.capabilities &= handshake.capabilities; + + // store the connection ID, mainly for debugging + self_.connection_id = handshake.connection_id; + + // extract the auth plugin and data from the handshake + // this can get overwritten by an auth switch + let mut auth_plugin = handshake.auth_plugin; + let mut auth_plugin_data = handshake.auth_plugin_data; + let password = $options.get_password().unwrap_or_default(); + + // create the initial auth response + // this may just be a request for an RSA public key + let initial_auth_response = auth_plugin.invoke(&auth_plugin_data, password); + + // the contains an initial guess at the correct encoding of + // the password and some other metadata like "which database", "which user", etc. + self_.write_packet(&HandshakeResponse { + auth_plugin_name: auth_plugin.name(), + auth_response: initial_auth_response, + charset: 45, // [utf8mb4] + database: $options.get_database(), + max_packet_size: 1024, + username: $options.get_username(), + })?; + + loop { + match connect!($(@$blocking)? @packet self_) { + Auth::Ok(_) => { + // successful, simple authentication; good to go + break; + } + + Auth::MoreData(data) => { + if let Some(data) = auth_plugin.handle(data, &auth_plugin_data, password)? { + // write the response from the plugin + self_.write_packet(&AuthResponse { data })?; + + // let's try again + continue; + } + + // all done, the plugin says we check out + break; + } + + Auth::Switch(sw) => { + // switch to the new plugin + auth_plugin = sw.plugin; + auth_plugin_data = sw.plugin_data; + + // generate an initial response from this plugin + let data = auth_plugin.invoke(&auth_plugin_data, password); + + // write the response from the plugin + self_.write_packet(&AuthResponse { data })?; + + // let's try again + continue; + } + } + } + + Ok(self_) + }}; +} + +#[cfg(feature = "async")] +impl MySqlConnection +where + Rt: sqlx_core::AsyncRuntime, + ::TcpStream: Unpin + futures_io::AsyncWrite + futures_io::AsyncRead, +{ + pub(crate) async fn connect_async(options: &MySqlConnectOptions) -> Result { + connect!(options) + } +} + +#[cfg(feature = "blocking")] +impl MySqlConnection +where + Rt: sqlx_core::blocking::Runtime, + ::TcpStream: std::io::Write + std::io::Read, +{ + pub(crate) fn connect(options: &MySqlConnectOptions) -> Result { + connect!(@blocking options) + } +} + +#[cfg(all(test, feature = "async"))] +mod tests { + use futures_executor::block_on; + use sqlx_core::{ConnectOptions, Mock}; + + use crate::mock::MySqlMockStreamExt; + use crate::MySqlConnectOptions; + + const SRV_HANDSHAKE_DEFAULT_OLD_AUTH: &[u8] = b"\n5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal\0)\0\0\04bo+$r4H\0\xfe\xf7-\x02\0\xff\x81\x15\0\0\0\0\0\0\x0f\0\0\0O5X>j}Ur]Y)^\0mysql_old_password\0"; + const SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH: &[u8] = b"\n5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal\0)\0\0\04bo+$r4H\0\xfe\xf7-\x02\0\xff\x81\x15\0\0\0\0\0\0\x0f\0\0\0O5X>j}Ur]Y)^\0mysql_native_password\0"; + const SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH: &[u8] = b"\n8.0.22\0\x08\0\0\0TIbl}%U#\0\xff\xff\xff\x02\0\xff\xc7\x15\0\0\0\0\0\0\0\0\0\0\x06\x12\x0e`5\x1b\x12\x0b\x13\x06_\x19\0caching_sha2_password\0"; + const SRV_HANDSHAKE_DEFAULT_SHA256_AUTH: &[u8] = b"\n8.0.22\0\x0e\0\0\0\x1b\x02O\x04hL8D\0\xff\xff\xff\x02\0\xff\xc7\x15\0\0\0\0\0\0\0\0\0\0^*Nh\x19\x1f*)-\x0c\x07v\0sha256_password\0"; + + const SRV_PUBLIC_KEY: &[u8] = b"\x01-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwnXi3nr9TmN+NF49A3Y7\nUBnAVhApNJy2cmuf/y6vFM9eHFu5T80Ij1qYc6c79oAGA8nNNCFQL+0j5De88cln\nKrlzq/Ab3U+j5SqgNwk//F6Y3iyjV4L7feSDqjpcheFzkjEslbm/yoRwQ78AAU6s\nqA0hcFuh66mcvnotDrvZAGQ8U2EbbZa6oiR3wrgbzifSKq767g65zIrCpoyxzKMH\nAETSDIaMKpFio4dRATKT5ASQtPoIyxSBmjRtc22sqlhEeiejEMsJzd6Bliuait+A\nkTXL6G1Tbam26Dok/L88CnTAWAkLwTA3bjPcS8Zl9gTsJvoiMuwW1UPEVV/aJ11Z\n/wIDAQAB\n-----END PUBLIC KEY-----\n"; + const SRV_AUTH_OK: &[u8] = b"\0\0\0\x02\0\0\0"; + const SRV_AUTH_MORE_CONTINUE: &[u8] = b"\x01\x04"; + const SRV_AUTH_MORE_OK: &[u8] = b"\x01\x03"; + const SRV_SWITCH_CACHING_SHA2_AUTH: &[u8] = + b"\xfecaching_sha2_password\0\x12}Wz?0-M9sO*S\x03\nP\x1c]pe\0"; + const SRV_SWITCH_NATIVE_AUTH: &[u8] = + b"\xfemysql_native_password\0\r.89j]CpA3Ov~\x1de\\/\x15,\r\0"; + + const RES_HANDSHAKE_NATIVE_AUTH: &[u8] = b"P\0\0\x01\x04\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0\x14P\xaf\xf1\x12,\xe9\xad\xea\x7f\xa0\n\xcd\xa2\xb5<\x17\xa5\xc9J\xd0mysql_native_password\0"; + const RES_HANDSHAKE_EMPTY_AUTH: &[u8] = b"<\0\0\x01\x04\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0\0mysql_native_password\0"; + const RES_HANDSHAKE_CACHING_SHA2_AUTH: &[u8] = b"\\\0\0\x01\x05\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0 \x9d\x85T\x15\xfe\xa9u\x13\x02&\x9dlG\x17\x98\x1b`\x8a\x96\xfcI\x19\x17\xe0(I8\xba\xd7\xfax\xa9caching_sha2_password\0"; + const RES_HANDSHAKE_SHA256_AUTH: &[u8] = b"7\0\0\x01\x05\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0\x01\x01sha256_password\0"; + + const RES_ASK_RSA_KEY: &[u8] = b"\x01\0\0\x03\x02"; + const RES_ASK_RSA_KEY_2: &[u8] = b"\x01\0\0\x05\x02"; + const RES_RSA_PASSWORD_SHA256: &[u8] = b"\0\x01\0\x03\xc1*\xf5=\xc3\x86\x95U$=\x9c \x946_Rg\xdc\x9d\xa0M\xf2@\xba\xf7\x8f\rE\xdbrI\xac\x05\xfb\xd1\xaa\r0 '\xf2\xec\xb3Xu\x98\x82\xf2\x8d)\x80\xe7\xdcG\\\xde\x87\x0e\x07\x87f\xach\xbb\x0b\xdf\xe0\xd9\xd1N\x9f_\x17xT\xec\xd5\xff\xd3\xa35\x11PO\xca\xf2\x13?=n\xe7\xd5\xbb\xa0\xd0\xca\xc5\x80\xb0\0\xc0\xe9F\x90f\xa0a\xd1\xdb\xe4(\xed2\xd7@\xb8u\x859U\xd6\xa2\xc3\xa2\xbe\x9a\xeeSy\x92\x95\r\xd3\x14\x90\x80\xb1o#\xa6\x7f\x16\x7f\t-'\xf35\xa02zY\xaeP^e\xf9O\xed\x9d\xb5\x8b\x9d\x0cayA\xff\"-\x80\x8c<\xc4\x11e\xdf\x9c\xe2\x9b)\x8f\xb0\xe9\xe1\xbcj\xf9\xa0U\xe6\x95\x9b\x01 \xba\x7f\"\\\x0cF9\\'\xf2\xfcMD\x1a\xd8\xe3\x11\xdfN\xc4\xd3\x9e\xee\x8d\r\xda\x94\xc4\xafR\xf3\x1e8b\x8d$\x84Nj\x18~\xa7\xf1\x8bb&\x90\xc0\xad\xb1O\xec\xfa\x98h\xf0{.\x07R\n"; + const RES_RSA_PASSWORD_CACHING_SHA2: &[u8] = b"\0\x01\0\x05#7\x8f\xd6\x8dCi9*\xee\x87\xb3\xb1,@\xdf\x94\xa8g\xbf\xed5\xf3\x1e\x9c\xfe\xda\xe8-6\x9c\x1eO\xb6\x80\x81]h\x0b\xd8\x10xx\xeb\x8b\xe9\x8a\x93\xd7\x83\xf7\x9a\xe1\xb94\xfd\xb0\x81\xeb\x0f\xecU:\xf4\x82\x11\xd3\xee\x8e+\x9e_rm\xb4\xbdM\xa0\x90\xff\xc3\x03V*\xa6|\x16\xdd\xea\xd2\x92\xef\xf5E\xb1t\n\xb7\xd9\x8bU\xbd\x94\xb8\x80|S+z\x1bO\x1e\xdf&\xf7(\xf0~\x97\x8b\xee1\xa4\xbb\x9f6\xc4\x88\xbf\x14$\xb2\xc0\xea\x9f\xdd\xfc\x99\xc8\xfe\x178\xf3X\x90\x01\xcc\xa8\x86\x9d\xe9\x98\xbf\xc2\xdc\xe8\xff\x96\xbd^\xf6\r \xb5\xe8\x0euo\xb5(\x80\xffW7\xf0\xdd\xcc\xaa\x9fYl\xef\xb7y\xf7A\xf4\xcf\x1f\xfc\rS\x7f\x13\xa9b\xadd\x1c\xcf\xf5\x98\x0ei\xc3\x0f\x9c\x8eqeTu\x8b\x17\xe7\xd47\xc5\xe9j=\xfc\x82\x04\x96}V.U?\x85\x14J\xe2\xd3.+:\xc5\xe0'm\x9a3\x85\x1e\xf7\xad\xf9J\xcf\xfc\xa7\xc2\x04@"; + const RES_SHA_SCRAMBLE: &[u8] = b" \0\0\x03\xffjg\x06p\x1d\xeawto\xf3\xf6\xa0\x9f7\xa9Z\xb3\xa5\xf9\x0b\x80\x14j8WTb\xf1{f\xf5"; + const RES_NATIVE_SCRAMBLE: &[u8] = + b"\x14\0\0\x031.Z\x95JON\x81\x9ak\xc7\xba\xe6{L\x0f\xe8\x03N\xef"; + const RES_SWITCH_RSA_PASSWORD_CACHING_SHA2: &[u8] = b"\0\x01\0\x077fS:\x9d3\xec\xe47\xbe\xda\xd8a\x14\x7f\xa8\xa82\x15\xb3\xb8\xa4D\x8f\x8e,,\xc4\x7f\x9ck\x9cI2&\xc2a\xd4\xef\r\x04\xc2\xd1\x89\xb05\xab\xe2YL\xd2hz\xf6y\xb7\xcb\x08\x9a\x1d\xc0A\x7f\x97\xba*\x1e,c\xbcP\xab\xa2\xee\xfa\xcd^=\x1flj\x96\x8fGx\x8e\x9b\xfd\xea\xd05w\xcc\xf2\xfc\xf8\xb4Pm;\xc4\x94}A~=R\xbcr\xbb?\xd1]\r\xb1\xd9{\xf6\x1b%\x14iAe\x04a\x91\x144q\x1e\x92H\xcb\xe7z,+1!6#\x92\x8c\x12o\x8eyb\xe7g\xd2[\x11W\xfeJ\xe3.\x88C\x1a$\xa5\xfa\xfd\xe1\x1e\x0c4\xc5\xbf7\x94\xca$\x0c\xa6\xbc\x07d\x04\x0f\xe4\xfc\xbeZ\x1c7\xce\x0c^8@d; \xf9\xfe\x1dU\x15\x9e\x9f[b\xe6Z\xda\xa9\x17\xcf\xd9\xa8\x0b\x10\xf5\xe3\xa1\xc0\xe2Z\x8b\x9fq\xe9\xe8\x97f\x1bY\xec\xbc\x8b\x89\x9a\xeb\xffU\xe2\xfa#%\xa5d\xfa\xeb\x15\"\x8a\xf4R\x85\xdf\xe3\xcd"; + + #[test] + fn should_connect_default_native_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?; + mock.write_packet_async(2, SRV_AUTH_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_all_async().await?; + + assert_eq!(&buf, RES_HANDSHAKE_NATIVE_AUTH); + + Ok(()) + }) + } + + #[test] + fn should_connect_default_sha256_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_SHA256_AUTH).await?; + mock.write_packet_async(2, SRV_PUBLIC_KEY).await?; + mock.write_packet_async(4, SRV_AUTH_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_exact_async(RES_HANDSHAKE_SHA256_AUTH.len()).await?; + assert_eq!(&buf, RES_HANDSHAKE_SHA256_AUTH); + + let buf = mock.read_all_async().await?; + assert_eq!(&buf, RES_RSA_PASSWORD_SHA256); + + Ok(()) + }) + } + + #[test] + fn should_connect_default_caching_sha2_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH).await?; + mock.write_packet_async(2, SRV_AUTH_MORE_CONTINUE).await?; + mock.write_packet_async(4, SRV_PUBLIC_KEY).await?; + mock.write_packet_async(6, SRV_AUTH_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_exact_async(RES_HANDSHAKE_CACHING_SHA2_AUTH.len()).await?; + assert_eq!(&buf, RES_HANDSHAKE_CACHING_SHA2_AUTH); + + let buf = mock.read_exact_async(RES_ASK_RSA_KEY.len()).await?; + assert_eq!(&buf, RES_ASK_RSA_KEY); + + let buf = mock.read_all_async().await?; + assert_eq!(&buf, RES_RSA_PASSWORD_CACHING_SHA2); + + Ok(()) + }) + } + + #[test] + fn should_reconnect_default_caching_sha2_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH).await?; + mock.write_packet_async(2, SRV_AUTH_MORE_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_all_async().await?; + assert_eq!(&buf, RES_HANDSHAKE_CACHING_SHA2_AUTH); + + Ok(()) + }) + } + + #[test] + fn should_connect_switch_native_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH).await?; + mock.write_packet_async(2, SRV_SWITCH_NATIVE_AUTH).await?; + mock.write_packet_async(4, SRV_AUTH_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_exact_async(RES_HANDSHAKE_CACHING_SHA2_AUTH.len()).await?; + assert_eq!(&buf, RES_HANDSHAKE_CACHING_SHA2_AUTH); + + let buf = mock.read_all_async().await?; + assert_eq!(&buf, RES_NATIVE_SCRAMBLE); + + Ok(()) + }) + } + + #[test] + fn should_connect_switch_caching_sha2_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?; + mock.write_packet_async(2, SRV_SWITCH_CACHING_SHA2_AUTH).await?; + mock.write_packet_async(4, SRV_AUTH_MORE_CONTINUE).await?; + mock.write_packet_async(6, SRV_PUBLIC_KEY).await?; + mock.write_packet_async(8, SRV_AUTH_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_exact_async(RES_HANDSHAKE_NATIVE_AUTH.len()).await?; + assert_eq!(&buf, RES_HANDSHAKE_NATIVE_AUTH); + + let buf = mock.read_exact_async(RES_SHA_SCRAMBLE.len()).await?; + assert_eq!(&buf, RES_SHA_SCRAMBLE); + + let buf = mock.read_exact_async(RES_ASK_RSA_KEY_2.len()).await?; + assert_eq!(&buf, RES_ASK_RSA_KEY_2); + + let buf = mock.read_all_async().await?; + assert_eq!(&buf, RES_SWITCH_RSA_PASSWORD_CACHING_SHA2); + + Ok(()) + }) + } + + #[test] + fn should_reconnect_switch_caching_sha2_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?; + mock.write_packet_async(2, SRV_SWITCH_CACHING_SHA2_AUTH).await?; + mock.write_packet_async(4, SRV_AUTH_MORE_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await?; + + let buf = mock.read_exact_async(RES_HANDSHAKE_NATIVE_AUTH.len()).await?; + assert_eq!(&buf, RES_HANDSHAKE_NATIVE_AUTH); + + let buf = mock.read_all_async().await?; + assert_eq!(&buf, RES_SHA_SCRAMBLE); + + Ok(()) + }) + } + + #[test] + fn should_connect_empty_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?; + mock.write_packet_async(2, SRV_AUTH_OK).await?; + + let _conn = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .connect() + .await?; + + let buf = mock.read_all_async().await?; + + assert_eq!(&buf, RES_HANDSHAKE_EMPTY_AUTH); + + Ok(()) + }) + } + + #[test] + fn should_not_connect_old_auth() -> anyhow::Result<()> { + block_on(async { + let mut mock = Mock::stream(); + + mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_OLD_AUTH).await?; + + let err = MySqlConnectOptions::::new() + .port(mock.port()) + .username("root") + .password("password") + .connect() + .await + .unwrap_err(); + + assert_eq!( + err.to_string(), + "2059 (HY000): Authentication plugin 'mysql_old_password' cannot be loaded" + ); + + Ok(()) + }) + } +} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs deleted file mode 100644 index 2adf4986..00000000 --- a/sqlx-mysql/src/connection/establish.rs +++ /dev/null @@ -1,139 +0,0 @@ -use bytes::{buf::Chain, Buf, Bytes}; -use futures_io::{AsyncRead, AsyncWrite}; -use sqlx_core::io::{Deserialize, Serialize}; -use sqlx_core::{AsyncRuntime, Error, Result, Runtime}; - -use crate::protocol::{Capabilities, ErrPacket, Handshake, HandshakeResponse, OkPacket}; -use crate::{auth, MySqlConnectOptions, MySqlConnection, MySqlDatabaseError}; - -// https://dev.mysql.com/doc/internals/en/connection-phase.html - -// the connection phase (establish) performs these tasks: -// - exchange the capabilities of client and server -// - setup SSL communication channel if requested -// - authenticate the client against the server - -// the server may immediately send an ERR packet and finish the handshake -// or send a [InitialHandshake] - -fn make_auth_response( - auth_plugin_name: Option<&str>, - username: Option<&str>, - password: Option<&str>, - nonce: &Chain, -) -> Result>> { - match (auth_plugin_name, password) { - // NOTE: for no authentication plugin, we assume mysql_native_password - // this means we have no support for mysql_old_password (pre mysql 4) - // if you need this, please open an issue - (Some("mysql_native_password"), Some(password)) | (None, Some(password)) => { - Ok(Some(auth::native::scramble(nonce, password))) - } - - (_, None) => Ok(None), - - // an unsupported plugin error looks like this in the official client: - // ERROR 2059 (HY000): Authentication plugin 'caching_sha2_password' cannot be loaded: /usr/local/mysql/lib/plugin/caching_sha2_password.so: cannot open shared object file: No such file or directory - - // and renders like this in SQLx: - // Error: 2059 (HY000): Authentication plugin 'caching_sha2_password' cannot be loaded - (Some(plugin), _) => Err(Error::Connect(Box::new(MySqlDatabaseError(ErrPacket::new( - 2059, - &format!("Authentication plugin '{}' cannot be loaded", plugin), - ))))), - } -} - -fn make_handshake_response<'a, Rt: Runtime>( - handshake: &'a Handshake, - options: &'a MySqlConnectOptions, -) -> Result> { - let auth_response = make_auth_response( - handshake.auth_plugin_name.as_deref(), - options.get_username(), - options.get_password(), - &handshake.auth_plugin_data, - )?; - - Ok(HandshakeResponse { - auth_plugin_name: handshake.auth_plugin_name.as_deref(), - auth_response, - charset: 45, // [utf8mb4] - database: options.get_database(), - max_packet_size: 1024, - username: options.get_username(), - }) -} - -impl MySqlConnection -where - Rt: AsyncRuntime, - ::TcpStream: Unpin + AsyncWrite + AsyncRead, -{ - fn recv_handshake(&mut self, handshake: &Handshake) { - self.capabilities &= handshake.capabilities; - self.connection_id = handshake.connection_id; - } - - pub(crate) async fn establish_async(options: &MySqlConnectOptions) -> Result { - let stream = Rt::connect_tcp(options.get_host(), options.get_port()).await?; - let mut self_ = Self::new(stream); - - let handshake = self_.read_packet_async().await?; - self_.recv_handshake(&handshake); - - self_.write_packet(make_handshake_response(&handshake, options)?)?; - - self_.stream.flush_async().await?; - - let _ok: OkPacket = self_.read_packet_async().await?; - - Ok(self_) - } - - fn write_packet<'ser, T>(&'ser mut self, packet: T) -> Result<()> - where - T: Serialize<'ser, Capabilities>, - { - let mut wbuf = Vec::::with_capacity(1024); - - packet.serialize_with(&mut wbuf, self.capabilities)?; - - self.sequence_id = self.sequence_id.wrapping_add(1); - - self.stream.reserve(wbuf.len() + 4); - self.stream.write(&(wbuf.len() as u32).to_le_bytes()[..3]); - self.stream.write(&[self.sequence_id]); - self.stream.write(&wbuf); - - Ok(()) - } - - async fn read_packet_async<'de, T>(&'de mut self) -> Result - where - T: Deserialize<'de, Capabilities>, - { - // https://dev.mysql.com/doc/internals/en/mysql-packet.html - self.stream.read_async(4).await?; - - let payload_len: usize = self.stream.get(0, 3).get_uint_le(3) as usize; - - // FIXME: handle split packets - assert_ne!(payload_len, 0xFF_FF_FF); - - self.sequence_id = self.stream.get(3, 1).get_u8(); - - self.stream.read_async(4 + payload_len).await?; - - self.stream.consume(4); - let payload = self.stream.take(payload_len); - - if payload[0] == 0xff { - // if the first byte of the payload is 0xFF and the payload is an ERR packet - let err = ErrPacket::deserialize_with(payload, self.capabilities)?; - return Err(Error::Connect(Box::new(MySqlDatabaseError(err)))); - } - - T::deserialize_with(payload, self.capabilities) - } -} diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs new file mode 100644 index 00000000..aefb506a --- /dev/null +++ b/sqlx-mysql/src/connection/stream.rs @@ -0,0 +1,154 @@ +//! Reads and writes packets to and from the MySQL database server. +//! +//! The logic for serializing data structures into the packets is found +//! mostly in `protocol/`. +//! +//! Packets in MySQL are prefixed by 4 bytes. +//! 3 for length (in LE) and a sequence id. +//! +//! Packets may only be as large as the communicated size in the initial +//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending +//! a larger payload is simply sending completely "full" packets, one after the +//! other, with an increasing sequence id. +//! +//! In other words, when we sent data, we: +//! +//! - Split the data into "packets" of size `2 ** 24 - 1` bytes. +//! +//! - Prepend each packet with a **packet header**, consisting of the length of that packet, +//! and the sequence number. +//! +//! https://dev.mysql.com/doc/internals/en/mysql-packet.html +//! +use bytes::{Buf, BufMut}; +use sqlx_core::io::{Deserialize, Serialize}; +use sqlx_core::{Error, Result, Runtime}; + +use crate::protocol::{Capabilities, ErrPacket}; +use crate::{MySqlConnection, MySqlDatabaseError}; + +impl MySqlConnection +where + Rt: Runtime, +{ + pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()> + where + T: Serialize<'ser, Capabilities>, + { + // the sequence-id is incremented with each packet and may + // wrap around. it starts at 0 and is reset to 0 when a new command + // begins in the Command Phase + + self.sequence_id = self.sequence_id.wrapping_add(1); + + // optimize for <16M packet sizes, in the case of >= 16M we would + // swap out the write buffer for a fresh buffer and then split it into + // 16M chunks separated by packet headers + + let buf = self.stream.buffer(); + let pos = buf.len(); + + // leave room for the length of the packet header at the start + buf.reserve(4); + buf.extend_from_slice(&[0_u8; 3]); + buf.push(self.sequence_id); + + // serialize the passed packet structure directly into the write buffer + packet.serialize_with(buf, self.capabilities)?; + + let payload_len = buf.len() - pos - 4; + + // FIXME: handle split packets + assert!(payload_len < 0xFF_FF_FF); + + // write back the length of the packet + #[allow(clippy::cast_possible_truncation)] + (&mut buf[pos..]).put_uint_le(payload_len as u64, 3); + + Ok(()) + } + + fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result + where + T: Deserialize<'de, Capabilities>, + { + // FIXME: handle split packets + assert_ne!(len, 0xFF_FF_FF); + + // We store the sequence id here. To respond to a packet, it should use a + // sequence id of n+1. It only "resets" at the start of a new command. + self.sequence_id = self.stream.get(3, 1).get_u8(); + + // tell the stream that we are done with the 4-byte header + self.stream.consume(4); + + // and remove the remainder of the packet from the stream, the payload + let payload = self.stream.take(len); + + if payload[0] == 0xff { + // if the first byte of the payload is 0xFF and the payload is an ERR packet + let err = ErrPacket::deserialize_with(payload, self.capabilities)?; + return Err(Error::connect(MySqlDatabaseError(err))); + } + + T::deserialize_with(payload, self.capabilities) + } +} + +macro_rules! read_packet { + ($(@$blocking:ident)? $self:ident) => {{ + // reads at least 4 bytes from the IO stream into the read buffer + read_packet!($(@$blocking)? @stream $self, 0, 4); + + // the first 3 bytes will be the payload length of the packet (in LE) + // ALLOW: the max this len will be is 16M + #[allow(clippy::cast_possible_truncation)] + let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize; + + // read bytes _after_ the 4 byte packet header + // note that we have not yet told the stream we are done with any of + // these bytes yet. if this next read invocation were to never return (eg., the + // outer future was dropped), then the next time read_packet_async was called + // it will re-read the parsed-above packet header. Note that we have NOT + // mutated `self` _yet_. This is important. + read_packet!($(@$blocking)? @stream $self, 4, payload_len); + + $self.recv_packet(payload_len) + }}; + + (@blocking @stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read($offset, $n)?; + }; + + (@stream $self:ident, $offset:expr, $n:expr) => { + $self.stream.read_async($offset, $n).await?; + }; +} + +#[cfg(feature = "async")] +impl MySqlConnection +where + Rt: sqlx_core::AsyncRuntime, + ::TcpStream: Unpin + futures_io::AsyncWrite + futures_io::AsyncRead, +{ + pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result + where + T: Deserialize<'de, Capabilities>, + { + read_packet!(self) + } +} + +#[cfg(feature = "blocking")] +impl MySqlConnection +where + Rt: Runtime, + ::TcpStream: std::io::Write + std::io::Read, +{ + pub(super) fn read_packet<'de, T>(&'de mut self) -> Result + where + T: Deserialize<'de, Capabilities>, + { + read_packet!(@blocking self) + } +} diff --git a/sqlx-mysql/src/error.rs b/sqlx-mysql/src/error.rs index a8de4e14..ef01dfb4 100644 --- a/sqlx-mysql/src/error.rs +++ b/sqlx-mysql/src/error.rs @@ -10,6 +10,12 @@ use crate::protocol::ErrPacket; #[derive(Debug)] pub struct MySqlDatabaseError(pub(crate) ErrPacket); +impl MySqlDatabaseError { + pub(crate) fn new(code: u16, message: &str) -> Self { + Self(ErrPacket::new(code, message)) + } +} + impl DatabaseError for MySqlDatabaseError { fn message(&self) -> &str { &self.0.error_message diff --git a/sqlx-mysql/src/io/buf.rs b/sqlx-mysql/src/io/buf.rs index 62143748..41b7f7b1 100644 --- a/sqlx-mysql/src/io/buf.rs +++ b/sqlx-mysql/src/io/buf.rs @@ -1,6 +1,6 @@ use bytes::{Buf, Bytes}; +use bytestring::ByteString; use sqlx_core::io::BufExt; -use string::String; // UNSAFE: _unchecked string methods // intended for use when the protocol is *known* to always produce @@ -10,10 +10,10 @@ pub(crate) trait MySqlBufExt: BufExt { fn get_uint_lenenc(&mut self) -> u64; #[allow(unsafe_code)] - unsafe fn get_str_lenenc_unchecked(&mut self) -> String; + unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString; #[allow(unsafe_code)] - unsafe fn get_str_eof_unchecked(&mut self) -> String; + unsafe fn get_str_eof_unchecked(&mut self) -> ByteString; fn get_bytes_lenenc(&mut self) -> Bytes; } @@ -38,14 +38,14 @@ impl MySqlBufExt for Bytes { } #[allow(unsafe_code)] - unsafe fn get_str_lenenc_unchecked(&mut self) -> String { + unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString { let len = self.get_uint_lenenc() as usize; self.get_str_unchecked(len) } #[allow(unsafe_code)] - unsafe fn get_str_eof_unchecked(&mut self) -> String { + unsafe fn get_str_eof_unchecked(&mut self) -> ByteString { self.get_str_unchecked(self.len()) } diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index a0f0fcfd..402ff353 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -3,6 +3,7 @@ //! [MySQL]: https://www.mysql.com/ //! #![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(not(any(feature = "async", feature = "blocking")), allow(unused))] #![deny(unsafe_code)] #![warn(rust_2018_idioms)] #![warn(future_incompatible)] @@ -21,16 +22,15 @@ mod connection; mod database; +mod error; mod io; mod options; mod protocol; -mod error; -mod auth; -#[cfg(feature = "blocking")] -mod blocking; +#[cfg(test)] +mod mock; pub use connection::MySqlConnection; pub use database::MySql; -pub use options::MySqlConnectOptions; pub use error::MySqlDatabaseError; +pub use options::MySqlConnectOptions; diff --git a/sqlx-mysql/src/mock.rs b/sqlx-mysql/src/mock.rs new file mode 100644 index 00000000..31ec05a5 --- /dev/null +++ b/sqlx-mysql/src/mock.rs @@ -0,0 +1,67 @@ +use std::io; + +use sqlx_core::mock::MockStream; + +pub(crate) trait MySqlMockStreamExt { + #[cfg(feature = "async")] + fn write_packet_async<'x>( + &'x mut self, + seq: u8, + packet: &'x [u8], + ) -> futures_util::future::BoxFuture<'x, io::Result<()>>; + + #[cfg(feature = "async")] + fn read_exact_async( + &mut self, + n: usize, + ) -> futures_util::future::BoxFuture<'_, io::Result>>; + + #[cfg(feature = "async")] + fn read_all_async(&mut self) -> futures_util::future::BoxFuture<'_, io::Result>>; +} + +impl MySqlMockStreamExt for MockStream { + #[cfg(feature = "async")] + fn write_packet_async<'x>( + &'x mut self, + seq: u8, + packet: &'x [u8], + ) -> futures_util::future::BoxFuture<'x, io::Result<()>> { + use futures_util::AsyncWriteExt; + + Box::pin(async move { + self.write_all(&packet.len().to_le_bytes()[..3]).await?; + self.write_all(&[seq]).await?; + self.write_all(packet).await + }) + } + + #[cfg(feature = "async")] + fn read_exact_async( + &mut self, + n: usize, + ) -> futures_util::future::BoxFuture<'_, io::Result>> { + use futures_util::AsyncReadExt; + + Box::pin(async move { + let mut buf = vec![0; n]; + let read = self.read(&mut buf).await?; + buf.truncate(read); + + Ok(buf) + }) + } + + #[cfg(feature = "async")] + fn read_all_async(&mut self) -> futures_util::future::BoxFuture<'_, io::Result>> { + use futures_util::AsyncReadExt; + + Box::pin(async move { + let mut buf = vec![0; 1024]; + let read = self.read(&mut buf).await?; + buf.truncate(read); + + Ok(buf) + }) + } +} diff --git a/sqlx-mysql/src/options.rs b/sqlx-mysql/src/options.rs index d08afc17..fe9c8bdb 100644 --- a/sqlx-mysql/src/options.rs +++ b/sqlx-mysql/src/options.rs @@ -11,6 +11,8 @@ mod builder; mod default; mod parse; +// TODO: RSA Public Key (to avoid the key exchange for caching_sha2 and sha256 plugins) + /// Options which can be used to configure how a MySQL connection is opened. /// /// A value of `MySqlConnectOptions` can be parsed from a connection URL, @@ -135,6 +137,20 @@ where Rt: sqlx_core::AsyncRuntime, ::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin, { - futures_util::FutureExt::boxed(MySqlConnection::establish_async(self)) + Box::pin(MySqlConnection::connect_async(self)) + } +} + +#[cfg(feature = "blocking")] +impl sqlx_core::blocking::ConnectOptions for MySqlConnectOptions +where + Rt: sqlx_core::blocking::Runtime, + ::TcpStream: std::io::Read + std::io::Write, +{ + fn connect(&self) -> sqlx_core::Result + where + Self::Connection: Sized, + { + >::connect(self) } } diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index 4a504abe..2c200fd1 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,5 +1,5 @@ -use std::str::FromStr; use std::borrow::Cow; +use std::str::FromStr; use percent_encoding::percent_decode_str; use sqlx_core::{Error, Runtime}; diff --git a/sqlx-mysql/src/protocol.rs b/sqlx-mysql/src/protocol.rs index 68c07dc7..ac500d5d 100644 --- a/sqlx-mysql/src/protocol.rs +++ b/sqlx-mysql/src/protocol.rs @@ -1,13 +1,19 @@ +mod auth; +mod auth_plugin; +mod auth_switch; mod capabilities; +mod err; mod handshake; mod handshake_response; mod ok; mod status; -mod err; -pub(crate) use err::ErrPacket; -pub(crate) use ok::OkPacket; +pub(crate) use auth::{Auth, AuthResponse}; +pub(crate) use auth_plugin::AuthPlugin; +pub(crate) use auth_switch::AuthSwitch; pub(crate) use capabilities::Capabilities; +pub(crate) use err::ErrPacket; pub(crate) use handshake::Handshake; pub(crate) use handshake_response::HandshakeResponse; +pub(crate) use ok::OkPacket; pub(crate) use status::Status; diff --git a/sqlx-mysql/src/protocol/auth.rs b/sqlx-mysql/src/protocol/auth.rs new file mode 100644 index 00000000..b2dfde21 --- /dev/null +++ b/sqlx-mysql/src/protocol/auth.rs @@ -0,0 +1,46 @@ +use std::fmt::Debug; + +use bytes::Bytes; +use sqlx_core::io::{Deserialize, Serialize}; +use sqlx_core::{Error, Result}; + +use crate::protocol::{AuthSwitch, Capabilities, OkPacket}; +use crate::MySqlDatabaseError; + +#[derive(Debug)] +pub(crate) enum Auth { + Ok(OkPacket), + MoreData(Bytes), + Switch(AuthSwitch), +} + +impl Deserialize<'_, Capabilities> for Auth { + fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result { + match buf[0] { + 0x00 => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok), + 0x01 => Ok(Self::MoreData(buf.slice(1..))), + 0xfe => AuthSwitch::deserialize_with(buf, capabilities).map(Self::Switch), + + tag => Err(Error::connect(MySqlDatabaseError::new( + 2027, + &format!( + "Malformed packet: Received 0x{:x} but expected one of: 0x0, 0x1, or 0xfe", + tag + ), + ))), + } + } +} + +#[derive(Debug)] +pub(crate) struct AuthResponse { + pub(crate) data: Vec, +} + +impl Serialize<'_, Capabilities> for AuthResponse { + fn serialize_with(&self, buf: &mut Vec, _context: Capabilities) -> Result<()> { + buf.extend_from_slice(&self.data); + + Ok(()) + } +} diff --git a/sqlx-mysql/src/protocol/auth_plugin.rs b/sqlx-mysql/src/protocol/auth_plugin.rs new file mode 100644 index 00000000..59eb05ac --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_plugin.rs @@ -0,0 +1,76 @@ +use std::error::Error as StdError; +use std::fmt::Debug; +use std::str::FromStr; + +use bytes::buf::Chain; +use bytes::Bytes; +use sqlx_core::{Error, Result}; + +use crate::MySqlDatabaseError; + +mod caching_sha2; +mod native; +mod rsa; +mod sha256; + +pub(crate) use self::caching_sha2::CachingSha2AuthPlugin; +pub(crate) use self::native::NativeAuthPlugin; +pub(crate) use self::sha256::Sha256AuthPlugin; + +pub(crate) trait AuthPlugin: 'static + Debug + Send + Sync { + fn name(&self) -> &'static str; + + // Invoke the auth plugin and return the auth response + fn invoke(&self, nonce: &Chain, password: &str) -> Vec; + + // Handle "more data" from the MySQL server + // which tells the plugin some plugin-specific information + // if the plugin returns Some(_) that is sent back to MySQL + fn handle( + &self, + data: Bytes, + nonce: &Chain, + password: &str, + ) -> Result>>; +} + +impl FromStr for Box { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + _ if s == CachingSha2AuthPlugin.name() => Ok(Box::new(CachingSha2AuthPlugin)), + _ if s == Sha256AuthPlugin.name() => Ok(Box::new(Sha256AuthPlugin)), + _ if s == NativeAuthPlugin.name() => Ok(Box::new(NativeAuthPlugin)), + + _ => Err(Error::connect(MySqlDatabaseError::new( + 2059, + &format!("Authentication plugin '{}' cannot be loaded", s), + ))), + } + } +} + +// XOR(x, y) +// If len(y) < len(x), wrap around inside y +fn xor_eq(x: &mut [u8], y: &[u8]) { + let y_len = y.len(); + + for i in 0..x.len() { + x[i] ^= y[i % y_len]; + } +} + +fn err_msg(plugin: &'static str, message: &str) -> Error { + Error::connect(MySqlDatabaseError::new( + 2061, + &format!("Authentication plugin '{}' reported error: {}", plugin, message), + )) +} + +fn err(plugin: &'static str, error: E) -> Error +where + E: StdError, +{ + err_msg(plugin, &error.to_string()) +} diff --git a/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs b/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs new file mode 100644 index 00000000..d1ab6175 --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_plugin/caching_sha2.rs @@ -0,0 +1,79 @@ +use bytes::buf::Chain; +use bytes::Bytes; +use sha2::{Digest, Sha256}; +use sqlx_core::Result; + +/// Implements SHA-256 authentication but uses caching on the server-side for better performance. +/// After the first authentication, a fast path is used that doesn't involve the RSA key exchange. +/// +/// https://dev.mysql.com/doc/refman/8.0/en/caching-sha2-pluggable-authentication.html +/// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ +/// +#[derive(Debug)] +pub(crate) struct CachingSha2AuthPlugin; + +impl super::AuthPlugin for CachingSha2AuthPlugin { + fn name(&self) -> &'static str { + "caching_sha2_password" + } + + fn invoke(&self, nonce: &Chain, password: &str) -> Vec { + if password.is_empty() { + // empty password => no scramble + return vec![]; + } + + // SHA256( password ) ^ SHA256( nonce + SHA256( SHA256( password ) ) ) + + let mut hasher = Sha256::new(); + + hasher.update(password); + + // SHA256( password ) + let mut pw_sha2 = hasher.finalize_reset(); + + hasher.update(&pw_sha2); + + // SHA256( SHA256( password ) ) + let pw_sha2_sha2 = hasher.finalize_reset(); + + hasher.update(pw_sha2_sha2); + hasher.update(nonce.first_ref()); + hasher.update(nonce.last_ref()); + + // SHA256( nonce + SHA256( SHA256( password ) ) ) + let nonce_pw_sha1_sha1 = hasher.finalize(); + + super::xor_eq(&mut pw_sha2, &nonce_pw_sha1_sha1); + + pw_sha2.to_vec() + } + + fn handle( + &self, + data: Bytes, + nonce: &Chain, + password: &str, + ) -> Result>> { + const AUTH_SUCCESS: u8 = 0x3; + const AUTH_CONTINUE: u8 = 0x4; + + match data[0] { + // good to go, return nothing + AUTH_SUCCESS => Ok(None), + + AUTH_CONTINUE => { + // ruh roh, we need to ask for the RSA public key, so we can + // encrypt our password directly and send it + return Ok(Some(vec![0x2_u8])); + } + + _ => { + let rsa_pub_key = data; + let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?; + + Ok(Some(encrypted)) + } + } + } +} diff --git a/sqlx-mysql/src/protocol/auth_plugin/native.rs b/sqlx-mysql/src/protocol/auth_plugin/native.rs new file mode 100644 index 00000000..83c215fa --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_plugin/native.rs @@ -0,0 +1,60 @@ +use bytes::{buf::Chain, Bytes}; +use sha1::{Digest, Sha1}; +use sqlx_core::Result; + +use super::xor_eq; + +// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin +// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + +#[derive(Debug)] +pub(crate) struct NativeAuthPlugin; + +impl super::AuthPlugin for NativeAuthPlugin { + fn name(&self) -> &'static str { + "mysql_native_password" + } + + fn invoke(&self, nonce: &Chain, password: &str) -> Vec { + if password.is_empty() { + // no password => empty scramble + return vec![]; + } + + // SHA1( password ) ^ SHA1( nonce + SHA1( SHA1( password ) ) ) + + let mut hasher = Sha1::new(); + + hasher.update(password); + + // SHA1( password ) + let mut pw_sha1 = hasher.finalize_reset(); + + hasher.update(&pw_sha1); + + // SHA1( SHA1( password ) ) + let pw_sha1_sha1 = hasher.finalize_reset(); + + hasher.update(nonce.first_ref()); + hasher.update(&nonce.last_ref()); + hasher.update(&pw_sha1_sha1); + + // SHA1( seed + SHA1( SHA1( password ) ) ) + let nonce_pw_sha1_sha1 = hasher.finalize(); + + xor_eq(&mut pw_sha1, &nonce_pw_sha1_sha1); + + pw_sha1.to_vec() + } + + fn handle( + &self, + _data: Bytes, + _nonce: &Chain, + _password: &str, + ) -> Result>> { + // MySQL should not be returning any additional data for + // the native mysql auth plugin + unreachable!() + } +} diff --git a/sqlx-mysql/src/protocol/auth_plugin/rsa.rs b/sqlx-mysql/src/protocol/auth_plugin/rsa.rs new file mode 100644 index 00000000..2f086229 --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_plugin/rsa.rs @@ -0,0 +1,70 @@ +use std::str::from_utf8; + +use bytes::buf::Chain; +use bytes::Bytes; +use rsa::{PaddingScheme, PublicKey, RSAPublicKey}; +use sqlx_core::Result; + +pub(crate) fn encrypt( + plugin: &'static str, + key: &[u8], + password: &str, + nonce: &Chain, +) -> Result> { + // xor the password with the given nonce + let mut pass = to_asciz(password); + + let (a, b) = (nonce.first_ref(), nonce.last_ref()); + let mut nonce = Vec::with_capacity(a.len() + b.len()); + + nonce.extend_from_slice(&*a); + nonce.extend_from_slice(&*b); + + super::xor_eq(&mut pass, &*nonce); + + // client sends an RSA encrypted password + let pkey = parse_rsa_pub_key(plugin, key)?; + let padding = PaddingScheme::new_oaep::(); + + pkey.encrypt(&mut rng(), padding, &pass[..]).map_err(|err| super::err(plugin, err)) +} + +// https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1 +fn parse_rsa_pub_key(plugin: &'static str, key: &[u8]) -> Result { + let key = from_utf8(key).map_err(|err| super::err(plugin, err))?; + + // Takes advantage of the knowledge that we know + // we are receiving a PKCS#8 RSA Public Key at all + // times from MySQL + + let encoded = + key.lines().filter(|line| !line.starts_with("-")).fold(String::new(), |mut data, line| { + data.push_str(&line); + data + }); + + let der = base64::decode(&encoded).map_err(|err| super::err(plugin, err))?; + + RSAPublicKey::from_pkcs8(&der).map_err(|err| super::err(plugin, err)) +} + +fn to_asciz(s: &str) -> Vec { + let mut z = String::with_capacity(s.len() + 1); + z.push_str(s); + z.push('\0'); + + z.into_bytes() +} + +// use a stable stream of numbers for encryption +// during tests to assert the result of [encrypt] + +#[cfg(not(test))] +fn rng() -> rand::rngs::ThreadRng { + rand::thread_rng() +} + +#[cfg(test)] +fn rng() -> rand::rngs::mock::StepRng { + rand::rngs::mock::StepRng::new(0, 1) +} diff --git a/sqlx-mysql/src/protocol/auth_plugin/sha256.rs b/sqlx-mysql/src/protocol/auth_plugin/sha256.rs new file mode 100644 index 00000000..4696fcc3 --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_plugin/sha256.rs @@ -0,0 +1,42 @@ +use bytes::buf::Chain; +use bytes::Bytes; +use sqlx_core::Result; + +/// Implements SHA-256 authentication. +/// +/// Each time we connect we have to do an RSA key exchange. +/// This slows down auth quite a bit. +/// +/// https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html +/// https://mariadb.com/kb/en/sha256_password-plugin/ +/// +#[derive(Debug)] +pub(crate) struct Sha256AuthPlugin; + +impl super::AuthPlugin for Sha256AuthPlugin { + fn name(&self) -> &'static str { + "sha256_password" + } + + fn invoke(&self, _nonce: &Chain, password: &str) -> Vec { + if password.is_empty() { + // no password => do not ask for RSA key + return vec![]; + } + + // ask for the RSA key + vec![0x01] + } + + fn handle( + &self, + data: Bytes, + nonce: &Chain, + password: &str, + ) -> Result>> { + let rsa_pub_key = data; + let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?; + + Ok(Some(encrypted)) + } +} diff --git a/sqlx-mysql/src/protocol/auth_switch.rs b/sqlx-mysql/src/protocol/auth_switch.rs new file mode 100644 index 00000000..1fd84ca4 --- /dev/null +++ b/sqlx-mysql/src/protocol/auth_switch.rs @@ -0,0 +1,39 @@ +use std::str::FromStr; + +use bytes::{buf::Chain, Buf, Bytes}; +use sqlx_core::io::{BufExt, Deserialize}; +use sqlx_core::Result; + +use super::Capabilities; +use crate::protocol::AuthPlugin; + +// https://dev.mysql.com/doc/internals/en/authentication-method-change.html +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest + +#[derive(Debug)] +pub(crate) struct AuthSwitch { + pub(crate) plugin: Box, + pub(crate) plugin_data: Chain, +} + +impl Deserialize<'_, Capabilities> for AuthSwitch { + fn deserialize_with(mut buf: Bytes, _capabilities: Capabilities) -> Result { + let tag = buf.get_u8(); + debug_assert_eq!(tag, 0xfe); + + // SAFE: auth plugins are ASCII only + #[allow(unsafe_code)] + let name = unsafe { buf.get_str_nul_unchecked()? }; + + if buf.ends_with(&[0]) { + // if this terminates in a NUL; drop the NUL + buf.truncate(buf.len() - 1); + } + + let plugin_data = buf.chain(Bytes::new()); + + let plugin = >::from_str(&*name)?; + + Ok(Self { plugin, plugin_data }) + } +} diff --git a/sqlx-mysql/src/protocol/err.rs b/sqlx-mysql/src/protocol/err.rs index bbf70e40..963c660b 100644 --- a/sqlx-mysql/src/protocol/err.rs +++ b/sqlx-mysql/src/protocol/err.rs @@ -1,7 +1,7 @@ use bytes::{Buf, Bytes}; +use bytestring::ByteString; use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; -use string::String; use crate::io::MySqlBufExt; use crate::protocol::Capabilities; @@ -14,21 +14,14 @@ use crate::protocol::Capabilities; #[derive(Debug)] pub(crate) struct ErrPacket { pub(crate) error_code: u16, - pub(crate) sql_state: Option>, - pub(crate) error_message: String, + pub(crate) sql_state: Option, + pub(crate) error_message: ByteString, } impl ErrPacket { pub(crate) fn new(code: u16, message: &str) -> Self { - let message_bytes = Bytes::copy_from_slice(message.as_bytes()); - let state_bytes = Bytes::from_static(b"HY000"); - - // UNSAFE: the UTF-8 string is converted to bytes right above. The string crate has a - // safe method for creation from Rust str but it pulls in an old version of Bytes - #[allow(unsafe_code)] - let (message, state) = unsafe { - (String::from_utf8_unchecked(message_bytes), String::from_utf8_unchecked(state_bytes)) - }; + let message = ByteString::from(message); + let state = ByteString::from_static("HY000"); Self { error_code: code, sql_state: Some(state), error_message: message } } diff --git a/sqlx-mysql/src/protocol/handshake.rs b/sqlx-mysql/src/protocol/handshake.rs index f31db5e7..4f4183f7 100644 --- a/sqlx-mysql/src/protocol/handshake.rs +++ b/sqlx-mysql/src/protocol/handshake.rs @@ -1,10 +1,14 @@ +use std::str::FromStr; + use bytes::buf::Chain; use bytes::{Buf, Bytes}; +use bytestring::ByteString; use memchr::memchr; use sqlx_core::io::{BufExt, Deserialize}; use sqlx_core::Result; -use crate::protocol::{Capabilities, Status}; +use crate::protocol::auth_plugin::NativeAuthPlugin; +use crate::protocol::{AuthPlugin, Capabilities, Status}; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeV10 // https://mariadb.com/kb/en/connection/#initial-handshake-packet @@ -15,7 +19,7 @@ pub(crate) struct Handshake { pub(crate) protocol_version: u8, // human-readable server version - pub(crate) server_version: string::String, + pub(crate) server_version: ByteString, pub(crate) connection_id: u32, @@ -25,10 +29,8 @@ pub(crate) struct Handshake { // default server character set pub(crate) charset: Option, + pub(crate) auth_plugin: Box, pub(crate) auth_plugin_data: Chain, - - // name of the auth_method that the auth_plugin_data belongs to - pub(crate) auth_plugin_name: Option>, } impl Deserialize<'_, Capabilities> for Handshake { @@ -88,6 +90,11 @@ impl Deserialize<'_, Capabilities> for Handshake { auth_plugin_data_2 = buf.split_to(len as usize); + if auth_plugin_data_2.ends_with(&[0]) { + // if this terminates in a NUL; drop the NUL + auth_plugin_data_2.truncate(auth_plugin_data_2.len() - 1); + } + if capabilities.contains(Capabilities::PLUGIN_AUTH) { // due to Bug#59453 the auth-plugin-name is missing the terminating NUL-char // in versions prior to 5.5.10 and 5.6.2 @@ -96,8 +103,7 @@ impl Deserialize<'_, Capabilities> for Handshake { // read to NUL or read to the end if we can't find a NUL - let auth_plugin_name_end = - memchr(b'\0', &buf).unwrap_or(buf.len()); + let auth_plugin_name_end = memchr(b'\0', &buf).unwrap_or(buf.len()); // UNSAFE: auth plugin names are known to be ASCII #[allow(unsafe_code)] @@ -116,7 +122,10 @@ impl Deserialize<'_, Capabilities> for Handshake { capabilities, status, auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2), - auth_plugin_name, + auth_plugin: auth_plugin_name + .map(|name| >::from_str(&name)) + .transpose()? + .unwrap_or_else(|| Box::new(NativeAuthPlugin)), }) } } @@ -167,11 +176,11 @@ mod tests { assert_eq!(h.charset, Some(255)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name.as_deref(), Some("caching_sha2_password")); + assert_eq!(h.auth_plugin.name(), "caching_sha2_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), - &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32, 0] + &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32] ); } @@ -211,14 +220,11 @@ mod tests { assert_eq!(h.charset, Some(8)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password")); + assert_eq!(h.auth_plugin.name(), "mysql_native_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), - &[ - 116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, - 110, 0 - ] + &[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110] ); } @@ -258,14 +264,11 @@ mod tests { assert_eq!(h.charset, Some(45)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password")); + assert_eq!(h.auth_plugin.name(), "mysql_native_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), - &[ - 39, 80, 66, 57, 52, 57, 99, 102, 85, 89, 62, 104, 114, 38, 96, 51, 123, 53, 53, 72, - 0 - ] + &[39, 80, 66, 57, 52, 57, 99, 102, 85, 89, 62, 104, 114, 38, 96, 51, 123, 53, 53, 72,] ); } @@ -305,11 +308,11 @@ mod tests { assert_eq!(h.charset, Some(8)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password")); + assert_eq!(h.auth_plugin.name(), "mysql_native_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), - &[45, 86, 76, 89, 90, 58, 80, 100, 39, 50, 102, 43, 66, 76, 56, 110, 71, 86, 91, 71, 0] + &[45, 86, 76, 89, 90, 58, 80, 100, 39, 50, 102, 43, 66, 76, 56, 110, 71, 86, 91, 71] ); } @@ -334,13 +337,13 @@ mod tests { assert_eq!(h.charset, Some(8)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name, None); + assert_eq!(h.auth_plugin.name(), "mysql_native_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), &[ 98, 115, 61, 115, 78, 105, 71, 101, 73, 122, 77, 80, 41, 121, 76, 76, 120, 59, 91, - 57, 0 + 57 ] ); } @@ -373,14 +376,11 @@ mod tests { assert_eq!(h.charset, Some(8)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name, None); + assert_eq!(h.auth_plugin.name(), "mysql_native_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), - &[ - 60, 102, 108, 108, 90, 92, 66, 115, 60, 113, 69, 67, 95, 56, 55, 74, 79, 47, 57, - 113, 0 - ] + &[60, 102, 108, 108, 90, 92, 66, 115, 60, 113, 69, 67, 95, 56, 55, 74, 79, 47, 57, 113] ); } @@ -416,13 +416,13 @@ mod tests { assert_eq!(h.charset, Some(8)); assert_eq!(h.status, Status::AUTOCOMMIT); - assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password")); + assert_eq!(h.auth_plugin.name(), "mysql_native_password"); assert_eq!( &*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()), &[ 96, 111, 45, 47, 67, 69, 112, 39, 107, 102, 64, 74, 53, 106, 54, 110, 74, 102, 65, - 80, 0 + 80 ] ); } diff --git a/sqlx-mysql/src/protocol/handshake_response.rs b/sqlx-mysql/src/protocol/handshake_response.rs index a7763ed7..4d3af879 100644 --- a/sqlx-mysql/src/protocol/handshake_response.rs +++ b/sqlx-mysql/src/protocol/handshake_response.rs @@ -1,4 +1,3 @@ -use bytes::BufMut; use sqlx_core::io::{Serialize, WriteExt}; use sqlx_core::Result; @@ -14,8 +13,8 @@ pub(crate) struct HandshakeResponse<'a> { pub(crate) max_packet_size: u32, pub(crate) charset: u8, pub(crate) username: Option<&'a str>, - pub(crate) auth_plugin_name: Option<&'a str>, - pub(crate) auth_response: Option>, + pub(crate) auth_plugin_name: &'a str, + pub(crate) auth_response: Vec, } impl Serialize<'_, Capabilities> for HandshakeResponse<'_> { @@ -29,7 +28,7 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> { buf.write_maybe_str_nul(self.username); - let auth_response = self.auth_response.as_deref().unwrap_or_default(); + let auth_response = self.auth_response.as_slice(); if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { buf.write_bytes_lenenc(auth_response); @@ -50,7 +49,7 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> { } if capabilities.contains(Capabilities::PLUGIN_AUTH) { - buf.write_maybe_str_nul(self.auth_plugin_name); + buf.write_str_nul(self.auth_plugin_name); } Ok(()) diff --git a/sqlx-mysql/src/protocol/ok.rs b/sqlx-mysql/src/protocol/ok.rs index 61ad029b..13b348e0 100644 --- a/sqlx-mysql/src/protocol/ok.rs +++ b/sqlx-mysql/src/protocol/ok.rs @@ -42,7 +42,7 @@ impl Deserialize<'_, Capabilities> for OkPacket { #[cfg(test)] mod tests { - use super::{OkPacket, Capabilities, Deserialize, Status}; + use super::{Capabilities, Deserialize, OkPacket, Status}; #[test] fn test_empty_ok_packet() {