feat(mysql): impl full connect phase for MySQL with support for:

- mysql_native_password
 - caching_sha2_password
 - sha256_password
 - non-default auth plugin (new)
This commit is contained in:
Ryan Leckey 2021-01-08 15:28:26 -08:00
parent 91fa554063
commit 2557557935
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
28 changed files with 1160 additions and 290 deletions

View File

@ -30,11 +30,21 @@ async = ["futures-util", "sqlx-core/async", "futures-io"]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" }
futures-util = { version = "0.3.8", optional = true }
either = "1.6.1"
log = "0.4.11"
bytestring = "1.0.0"
url = "2.2.0"
percent-encoding = "2.1.0"
futures-io = { version = "0.3", optional = true }
bytes = "1.0"
memchr = "2.3"
bitflags = "1.2"
string = { version = "0.2.1", default-features = false }
sha-1 = "0.9.2"
sha2 = "0.9.2"
rsa = "0.3.0"
base64 = "0.13.0"
rand = "0.7"
[dev-dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }
futures-executor = "0.3.8"
anyhow = "1.0.37"

View File

@ -1,13 +0,0 @@
pub(crate) mod native;
// mod caching_sha2;
// mod sha256;
// XOR(x, y)
// If len(y) < len(x), wrap around inside y
fn xor_eq(x: &mut [u8], y: &[u8]) {
let y_len = y.len();
for i in 0..x.len() {
x[i] ^= y[i % y_len];
}
}

View File

@ -1,35 +0,0 @@
use bytes::{buf::Chain, Bytes};
use sha1::{Digest, Sha1};
use super::xor_eq;
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
pub(crate) fn scramble(nonce: &Chain<Bytes, Bytes>, password: &str) -> Vec<u8> {
// SHA1( password ) ^ SHA1( nonce + SHA1( SHA1( password ) ) )
let mut hasher = Sha1::new();
hasher.update(password);
// SHA1( password )
let mut pw_sha1 = hasher.finalize_reset();
hasher.update(&pw_sha1);
// SHA1( SHA1( password ) )
let pw_sha1_sha1 = hasher.finalize_reset();
// NOTE: use the first 20 bytes of the nonce, we MAY have gotten a nul terminator
hasher.update(nonce.first_ref());
hasher.update(&nonce.last_ref()[..20 - nonce.first_ref().len()]);
hasher.update(&pw_sha1_sha1);
// SHA1( seed + SHA1( SHA1( password ) ) )
let nonce_pw_sha1_sha1 = hasher.finalize();
xor_eq(&mut pw_sha1, &nonce_pw_sha1_sha1);
pw_sha1.to_vec()
}

View File

@ -1,2 +0,0 @@
mod connection;
mod options;

View File

@ -1,17 +0,0 @@
use sqlx_core::blocking::{Connection, Runtime};
use sqlx_core::Result;
use crate::MySqlConnection;
impl<Rt> Connection<Rt> for MySqlConnection<Rt>
where
Rt: Runtime,
{
fn close(self) -> Result<()> {
unimplemented!()
}
fn ping(&mut self) -> Result<()> {
unimplemented!()
}
}

View File

@ -1,17 +0,0 @@
use sqlx_core::blocking::{ConnectOptions, Connection, Runtime};
use sqlx_core::Result;
use crate::{MySqlConnectOptions, MySqlConnection};
impl<Rt> ConnectOptions<Rt> for MySqlConnectOptions<Rt>
where
Rt: Runtime,
Self::Connection: sqlx_core::Connection<Rt, Options = Self> + Connection<Rt>,
{
fn connect(&self) -> Result<MySqlConnection<Rt>> {
// let stream = <Rt as Runtime>::connect_tcp(self.get_host(), self.get_port())?;
//
// Ok(MySqlConnection { stream })
todo!()
}
}

View File

@ -6,16 +6,24 @@ use sqlx_core::{Connection, DefaultRuntime, Runtime};
use crate::protocol::Capabilities;
use crate::{MySql, MySqlConnectOptions};
#[cfg(feature = "async")]
pub(crate) mod establish;
#[cfg(any(feature = "async", feature = "blocking"))]
mod connect;
#[cfg(any(feature = "async", feature = "blocking"))]
mod stream;
#[allow(clippy::module_name_repetitions)]
pub struct MySqlConnection<Rt = DefaultRuntime>
where
Rt: Runtime,
{
stream: BufStream<Rt::TcpStream>,
connection_id: u32,
// the capability flags are used by the client and server to indicate which
// features they support and want to use.
capabilities: Capabilities,
// the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is
// reset to 0 when a new command begins in the Command Phase.
sequence_id: u8,
@ -25,6 +33,7 @@ impl<Rt> MySqlConnection<Rt>
where
Rt: Runtime,
{
#[cfg(any(feature = "async", feature = "blocking"))]
pub(crate) fn new(stream: Rt::TcpStream) -> Self {
Self {
stream: BufStream::with_capacity(stream, 4096, 1024),
@ -82,3 +91,18 @@ where
unimplemented!()
}
}
#[cfg(feature = "blocking")]
impl<Rt> sqlx_core::blocking::Connection<Rt> for MySqlConnection<Rt>
where
Rt: sqlx_core::blocking::Runtime,
<Rt as Runtime>::TcpStream: std::io::Read + std::io::Write,
{
fn close(self) -> sqlx_core::Result<()> {
unimplemented!()
}
fn ping(&mut self) -> sqlx_core::Result<()> {
unimplemented!()
}
}

View File

@ -0,0 +1,406 @@
//! Implements the connection phase.
//!
//! The connection phase (establish) performs these tasks:
//!
//! - exchange the capabilities of client and server
//! - setup SSL communication channel if requested
//! - authenticate the client against the server
//!
//! The server may immediately send an ERR packet and finish the handshake
//! or send a `Handshake`.
//!
//! https://dev.mysql.com/doc/internals/en/connection-phase.html
//!
use sqlx_core::{Result, Runtime};
use crate::protocol::{Auth, AuthResponse, Handshake, HandshakeResponse};
use crate::{MySqlConnectOptions, MySqlConnection};
macro_rules! connect {
(@blocking @tcp $options:ident) => {
Rt::connect_tcp($options.get_host(), $options.get_port())?;
};
(@tcp $options:ident) => {
Rt::connect_tcp($options.get_host(), $options.get_port()).await?;
};
(@blocking @packet $self:ident) => {
$self.read_packet()?;
};
(@packet $self:ident) => {
$self.read_packet_async().await?;
};
($(@$blocking:ident)? $options:ident) => {{
// open a network stream to the database server
let stream = connect!($(@$blocking)? @tcp $options);
// construct a <MySqlConnection> around the network stream
// wraps the stream in a <BufStream> to buffer read and write
let mut self_ = Self::new(stream);
// immediately the server should emit a <Handshake> packet
let handshake: Handshake = connect!($(@$blocking)? @packet self_);
// & the declared server capabilities with our capabilities to find
// what rules the client should operate under
self_.capabilities &= handshake.capabilities;
// store the connection ID, mainly for debugging
self_.connection_id = handshake.connection_id;
// extract the auth plugin and data from the handshake
// this can get overwritten by an auth switch
let mut auth_plugin = handshake.auth_plugin;
let mut auth_plugin_data = handshake.auth_plugin_data;
let password = $options.get_password().unwrap_or_default();
// create the initial auth response
// this may just be a request for an RSA public key
let initial_auth_response = auth_plugin.invoke(&auth_plugin_data, password);
// the <HandshakeResponse> contains an initial guess at the correct encoding of
// the password and some other metadata like "which database", "which user", etc.
self_.write_packet(&HandshakeResponse {
auth_plugin_name: auth_plugin.name(),
auth_response: initial_auth_response,
charset: 45, // [utf8mb4]
database: $options.get_database(),
max_packet_size: 1024,
username: $options.get_username(),
})?;
loop {
match connect!($(@$blocking)? @packet self_) {
Auth::Ok(_) => {
// successful, simple authentication; good to go
break;
}
Auth::MoreData(data) => {
if let Some(data) = auth_plugin.handle(data, &auth_plugin_data, password)? {
// write the response from the plugin
self_.write_packet(&AuthResponse { data })?;
// let's try again
continue;
}
// all done, the plugin says we check out
break;
}
Auth::Switch(sw) => {
// switch to the new plugin
auth_plugin = sw.plugin;
auth_plugin_data = sw.plugin_data;
// generate an initial response from this plugin
let data = auth_plugin.invoke(&auth_plugin_data, password);
// write the response from the plugin
self_.write_packet(&AuthResponse { data })?;
// let's try again
continue;
}
}
}
Ok(self_)
}};
}
#[cfg(feature = "async")]
impl<Rt> MySqlConnection<Rt>
where
Rt: sqlx_core::AsyncRuntime,
<Rt as Runtime>::TcpStream: Unpin + futures_io::AsyncWrite + futures_io::AsyncRead,
{
pub(crate) async fn connect_async(options: &MySqlConnectOptions<Rt>) -> Result<Self> {
connect!(options)
}
}
#[cfg(feature = "blocking")]
impl<Rt> MySqlConnection<Rt>
where
Rt: sqlx_core::blocking::Runtime,
<Rt as Runtime>::TcpStream: std::io::Write + std::io::Read,
{
pub(crate) fn connect(options: &MySqlConnectOptions<Rt>) -> Result<Self> {
connect!(@blocking options)
}
}
#[cfg(all(test, feature = "async"))]
mod tests {
use futures_executor::block_on;
use sqlx_core::{ConnectOptions, Mock};
use crate::mock::MySqlMockStreamExt;
use crate::MySqlConnectOptions;
const SRV_HANDSHAKE_DEFAULT_OLD_AUTH: &[u8] = b"\n5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal\0)\0\0\04bo+$r4H\0\xfe\xf7-\x02\0\xff\x81\x15\0\0\0\0\0\0\x0f\0\0\0O5X>j}Ur]Y)^\0mysql_old_password\0";
const SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH: &[u8] = b"\n5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal\0)\0\0\04bo+$r4H\0\xfe\xf7-\x02\0\xff\x81\x15\0\0\0\0\0\0\x0f\0\0\0O5X>j}Ur]Y)^\0mysql_native_password\0";
const SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH: &[u8] = b"\n8.0.22\0\x08\0\0\0TIbl}%U#\0\xff\xff\xff\x02\0\xff\xc7\x15\0\0\0\0\0\0\0\0\0\0\x06\x12\x0e`5\x1b\x12\x0b\x13\x06_\x19\0caching_sha2_password\0";
const SRV_HANDSHAKE_DEFAULT_SHA256_AUTH: &[u8] = b"\n8.0.22\0\x0e\0\0\0\x1b\x02O\x04hL8D\0\xff\xff\xff\x02\0\xff\xc7\x15\0\0\0\0\0\0\0\0\0\0^*Nh\x19\x1f*)-\x0c\x07v\0sha256_password\0";
const SRV_PUBLIC_KEY: &[u8] = b"\x01-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwnXi3nr9TmN+NF49A3Y7\nUBnAVhApNJy2cmuf/y6vFM9eHFu5T80Ij1qYc6c79oAGA8nNNCFQL+0j5De88cln\nKrlzq/Ab3U+j5SqgNwk//F6Y3iyjV4L7feSDqjpcheFzkjEslbm/yoRwQ78AAU6s\nqA0hcFuh66mcvnotDrvZAGQ8U2EbbZa6oiR3wrgbzifSKq767g65zIrCpoyxzKMH\nAETSDIaMKpFio4dRATKT5ASQtPoIyxSBmjRtc22sqlhEeiejEMsJzd6Bliuait+A\nkTXL6G1Tbam26Dok/L88CnTAWAkLwTA3bjPcS8Zl9gTsJvoiMuwW1UPEVV/aJ11Z\n/wIDAQAB\n-----END PUBLIC KEY-----\n";
const SRV_AUTH_OK: &[u8] = b"\0\0\0\x02\0\0\0";
const SRV_AUTH_MORE_CONTINUE: &[u8] = b"\x01\x04";
const SRV_AUTH_MORE_OK: &[u8] = b"\x01\x03";
const SRV_SWITCH_CACHING_SHA2_AUTH: &[u8] =
b"\xfecaching_sha2_password\0\x12}Wz?0-M9sO*S\x03\nP\x1c]pe\0";
const SRV_SWITCH_NATIVE_AUTH: &[u8] =
b"\xfemysql_native_password\0\r.89j]CpA3Ov~\x1de\\/\x15,\r\0";
const RES_HANDSHAKE_NATIVE_AUTH: &[u8] = b"P\0\0\x01\x04\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0\x14P\xaf\xf1\x12,\xe9\xad\xea\x7f\xa0\n\xcd\xa2\xb5<\x17\xa5\xc9J\xd0mysql_native_password\0";
const RES_HANDSHAKE_EMPTY_AUTH: &[u8] = b"<\0\0\x01\x04\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0\0mysql_native_password\0";
const RES_HANDSHAKE_CACHING_SHA2_AUTH: &[u8] = b"\\\0\0\x01\x05\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0 \x9d\x85T\x15\xfe\xa9u\x13\x02&\x9dlG\x17\x98\x1b`\x8a\x96\xfcI\x19\x17\xe0(I8\xba\xd7\xfax\xa9caching_sha2_password\0";
const RES_HANDSHAKE_SHA256_AUTH: &[u8] = b"7\0\0\x01\x05\xa3(\x01\0\x04\0\0-\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0root\0\x01\x01sha256_password\0";
const RES_ASK_RSA_KEY: &[u8] = b"\x01\0\0\x03\x02";
const RES_ASK_RSA_KEY_2: &[u8] = b"\x01\0\0\x05\x02";
const RES_RSA_PASSWORD_SHA256: &[u8] = b"\0\x01\0\x03\xc1*\xf5=\xc3\x86\x95U$=\x9c \x946_Rg\xdc\x9d\xa0M\xf2@\xba\xf7\x8f\rE\xdbrI\xac\x05\xfb\xd1\xaa\r0 '\xf2\xec\xb3Xu\x98\x82\xf2\x8d)\x80\xe7\xdcG\\\xde\x87\x0e\x07\x87f\xach\xbb\x0b\xdf\xe0\xd9\xd1N\x9f_\x17xT\xec\xd5\xff\xd3\xa35\x11PO\xca\xf2\x13?=n\xe7\xd5\xbb\xa0\xd0\xca\xc5\x80\xb0\0\xc0\xe9F\x90f\xa0a\xd1\xdb\xe4(\xed2\xd7@\xb8u\x859U\xd6\xa2\xc3\xa2\xbe\x9a\xeeSy\x92\x95\r\xd3\x14\x90\x80\xb1o#\xa6\x7f\x16\x7f\t-'\xf35\xa02zY\xaeP^e\xf9O\xed\x9d\xb5\x8b\x9d\x0cayA\xff\"-\x80\x8c<\xc4\x11e\xdf\x9c\xe2\x9b)\x8f\xb0\xe9\xe1\xbcj\xf9\xa0U\xe6\x95\x9b\x01 \xba\x7f\"\\\x0cF9\\'\xf2\xfcMD\x1a\xd8\xe3\x11\xdfN\xc4\xd3\x9e\xee\x8d\r\xda\x94\xc4\xafR\xf3\x1e8b\x8d$\x84Nj\x18~\xa7\xf1\x8bb&\x90\xc0\xad\xb1O\xec\xfa\x98h\xf0{.\x07R\n";
const RES_RSA_PASSWORD_CACHING_SHA2: &[u8] = b"\0\x01\0\x05#7\x8f\xd6\x8dCi9*\xee\x87\xb3\xb1,@\xdf\x94\xa8g\xbf\xed5\xf3\x1e\x9c\xfe\xda\xe8-6\x9c\x1eO\xb6\x80\x81]h\x0b\xd8\x10xx\xeb\x8b\xe9\x8a\x93\xd7\x83\xf7\x9a\xe1\xb94\xfd\xb0\x81\xeb\x0f\xecU:\xf4\x82\x11\xd3\xee\x8e+\x9e_rm\xb4\xbdM\xa0\x90\xff\xc3\x03V*\xa6|\x16\xdd\xea\xd2\x92\xef\xf5E\xb1t\n\xb7\xd9\x8bU\xbd\x94\xb8\x80|S+z\x1bO\x1e\xdf&\xf7(\xf0~\x97\x8b\xee1\xa4\xbb\x9f6\xc4\x88\xbf\x14$\xb2\xc0\xea\x9f\xdd\xfc\x99\xc8\xfe\x178\xf3X\x90\x01\xcc\xa8\x86\x9d\xe9\x98\xbf\xc2\xdc\xe8\xff\x96\xbd^\xf6\r \xb5\xe8\x0euo\xb5(\x80\xffW7\xf0\xdd\xcc\xaa\x9fYl\xef\xb7y\xf7A\xf4\xcf\x1f\xfc\rS\x7f\x13\xa9b\xadd\x1c\xcf\xf5\x98\x0ei\xc3\x0f\x9c\x8eqeTu\x8b\x17\xe7\xd47\xc5\xe9j=\xfc\x82\x04\x96}V.U?\x85\x14J\xe2\xd3.+:\xc5\xe0'm\x9a3\x85\x1e\xf7\xad\xf9J\xcf\xfc\xa7\xc2\x04@";
const RES_SHA_SCRAMBLE: &[u8] = b" \0\0\x03\xffjg\x06p\x1d\xeawto\xf3\xf6\xa0\x9f7\xa9Z\xb3\xa5\xf9\x0b\x80\x14j8WTb\xf1{f\xf5";
const RES_NATIVE_SCRAMBLE: &[u8] =
b"\x14\0\0\x031.Z\x95JON\x81\x9ak\xc7\xba\xe6{L\x0f\xe8\x03N\xef";
const RES_SWITCH_RSA_PASSWORD_CACHING_SHA2: &[u8] = b"\0\x01\0\x077fS:\x9d3\xec\xe47\xbe\xda\xd8a\x14\x7f\xa8\xa82\x15\xb3\xb8\xa4D\x8f\x8e,,\xc4\x7f\x9ck\x9cI2&\xc2a\xd4\xef\r\x04\xc2\xd1\x89\xb05\xab\xe2YL\xd2hz\xf6y\xb7\xcb\x08\x9a\x1d\xc0A\x7f\x97\xba*\x1e,c\xbcP\xab\xa2\xee\xfa\xcd^=\x1flj\x96\x8fGx\x8e\x9b\xfd\xea\xd05w\xcc\xf2\xfc\xf8\xb4Pm;\xc4\x94}A~=R\xbcr\xbb?\xd1]\r\xb1\xd9{\xf6\x1b%\x14iAe\x04a\x91\x144q\x1e\x92H\xcb\xe7z,+1!6#\x92\x8c\x12o\x8eyb\xe7g\xd2[\x11W\xfeJ\xe3.\x88C\x1a$\xa5\xfa\xfd\xe1\x1e\x0c4\xc5\xbf7\x94\xca$\x0c\xa6\xbc\x07d\x04\x0f\xe4\xfc\xbeZ\x1c7\xce\x0c^8@d; \xf9\xfe\x1dU\x15\x9e\x9f[b\xe6Z\xda\xa9\x17\xcf\xd9\xa8\x0b\x10\xf5\xe3\xa1\xc0\xe2Z\x8b\x9fq\xe9\xe8\x97f\x1bY\xec\xbc\x8b\x89\x9a\xeb\xffU\xe2\xfa#%\xa5d\xfa\xeb\x15\"\x8a\xf4R\x85\xdf\xe3\xcd";
#[test]
fn should_connect_default_native_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?;
mock.write_packet_async(2, SRV_AUTH_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_HANDSHAKE_NATIVE_AUTH);
Ok(())
})
}
#[test]
fn should_connect_default_sha256_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_SHA256_AUTH).await?;
mock.write_packet_async(2, SRV_PUBLIC_KEY).await?;
mock.write_packet_async(4, SRV_AUTH_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_exact_async(RES_HANDSHAKE_SHA256_AUTH.len()).await?;
assert_eq!(&buf, RES_HANDSHAKE_SHA256_AUTH);
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_RSA_PASSWORD_SHA256);
Ok(())
})
}
#[test]
fn should_connect_default_caching_sha2_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH).await?;
mock.write_packet_async(2, SRV_AUTH_MORE_CONTINUE).await?;
mock.write_packet_async(4, SRV_PUBLIC_KEY).await?;
mock.write_packet_async(6, SRV_AUTH_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_exact_async(RES_HANDSHAKE_CACHING_SHA2_AUTH.len()).await?;
assert_eq!(&buf, RES_HANDSHAKE_CACHING_SHA2_AUTH);
let buf = mock.read_exact_async(RES_ASK_RSA_KEY.len()).await?;
assert_eq!(&buf, RES_ASK_RSA_KEY);
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_RSA_PASSWORD_CACHING_SHA2);
Ok(())
})
}
#[test]
fn should_reconnect_default_caching_sha2_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH).await?;
mock.write_packet_async(2, SRV_AUTH_MORE_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_HANDSHAKE_CACHING_SHA2_AUTH);
Ok(())
})
}
#[test]
fn should_connect_switch_native_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_CACHING_SHA2_AUTH).await?;
mock.write_packet_async(2, SRV_SWITCH_NATIVE_AUTH).await?;
mock.write_packet_async(4, SRV_AUTH_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_exact_async(RES_HANDSHAKE_CACHING_SHA2_AUTH.len()).await?;
assert_eq!(&buf, RES_HANDSHAKE_CACHING_SHA2_AUTH);
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_NATIVE_SCRAMBLE);
Ok(())
})
}
#[test]
fn should_connect_switch_caching_sha2_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?;
mock.write_packet_async(2, SRV_SWITCH_CACHING_SHA2_AUTH).await?;
mock.write_packet_async(4, SRV_AUTH_MORE_CONTINUE).await?;
mock.write_packet_async(6, SRV_PUBLIC_KEY).await?;
mock.write_packet_async(8, SRV_AUTH_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_exact_async(RES_HANDSHAKE_NATIVE_AUTH.len()).await?;
assert_eq!(&buf, RES_HANDSHAKE_NATIVE_AUTH);
let buf = mock.read_exact_async(RES_SHA_SCRAMBLE.len()).await?;
assert_eq!(&buf, RES_SHA_SCRAMBLE);
let buf = mock.read_exact_async(RES_ASK_RSA_KEY_2.len()).await?;
assert_eq!(&buf, RES_ASK_RSA_KEY_2);
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_SWITCH_RSA_PASSWORD_CACHING_SHA2);
Ok(())
})
}
#[test]
fn should_reconnect_switch_caching_sha2_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?;
mock.write_packet_async(2, SRV_SWITCH_CACHING_SHA2_AUTH).await?;
mock.write_packet_async(4, SRV_AUTH_MORE_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await?;
let buf = mock.read_exact_async(RES_HANDSHAKE_NATIVE_AUTH.len()).await?;
assert_eq!(&buf, RES_HANDSHAKE_NATIVE_AUTH);
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_SHA_SCRAMBLE);
Ok(())
})
}
#[test]
fn should_connect_empty_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_NATIVE_AUTH).await?;
mock.write_packet_async(2, SRV_AUTH_OK).await?;
let _conn = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.connect()
.await?;
let buf = mock.read_all_async().await?;
assert_eq!(&buf, RES_HANDSHAKE_EMPTY_AUTH);
Ok(())
})
}
#[test]
fn should_not_connect_old_auth() -> anyhow::Result<()> {
block_on(async {
let mut mock = Mock::stream();
mock.write_packet_async(0, SRV_HANDSHAKE_DEFAULT_OLD_AUTH).await?;
let err = MySqlConnectOptions::<Mock>::new()
.port(mock.port())
.username("root")
.password("password")
.connect()
.await
.unwrap_err();
assert_eq!(
err.to_string(),
"2059 (HY000): Authentication plugin 'mysql_old_password' cannot be loaded"
);
Ok(())
})
}
}

View File

@ -1,139 +0,0 @@
use bytes::{buf::Chain, Buf, Bytes};
use futures_io::{AsyncRead, AsyncWrite};
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{AsyncRuntime, Error, Result, Runtime};
use crate::protocol::{Capabilities, ErrPacket, Handshake, HandshakeResponse, OkPacket};
use crate::{auth, MySqlConnectOptions, MySqlConnection, MySqlDatabaseError};
// https://dev.mysql.com/doc/internals/en/connection-phase.html
// the connection phase (establish) performs these tasks:
// - exchange the capabilities of client and server
// - setup SSL communication channel if requested
// - authenticate the client against the server
// the server may immediately send an ERR packet and finish the handshake
// or send a [InitialHandshake]
fn make_auth_response(
auth_plugin_name: Option<&str>,
username: Option<&str>,
password: Option<&str>,
nonce: &Chain<Bytes, Bytes>,
) -> Result<Option<Vec<u8>>> {
match (auth_plugin_name, password) {
// NOTE: for no authentication plugin, we assume mysql_native_password
// this means we have no support for mysql_old_password (pre mysql 4)
// if you need this, please open an issue
(Some("mysql_native_password"), Some(password)) | (None, Some(password)) => {
Ok(Some(auth::native::scramble(nonce, password)))
}
(_, None) => Ok(None),
// an unsupported plugin error looks like this in the official client:
// ERROR 2059 (HY000): Authentication plugin 'caching_sha2_password' cannot be loaded: /usr/local/mysql/lib/plugin/caching_sha2_password.so: cannot open shared object file: No such file or directory
// and renders like this in SQLx:
// Error: 2059 (HY000): Authentication plugin 'caching_sha2_password' cannot be loaded
(Some(plugin), _) => Err(Error::Connect(Box::new(MySqlDatabaseError(ErrPacket::new(
2059,
&format!("Authentication plugin '{}' cannot be loaded", plugin),
))))),
}
}
fn make_handshake_response<'a, Rt: Runtime>(
handshake: &'a Handshake,
options: &'a MySqlConnectOptions<Rt>,
) -> Result<HandshakeResponse<'a>> {
let auth_response = make_auth_response(
handshake.auth_plugin_name.as_deref(),
options.get_username(),
options.get_password(),
&handshake.auth_plugin_data,
)?;
Ok(HandshakeResponse {
auth_plugin_name: handshake.auth_plugin_name.as_deref(),
auth_response,
charset: 45, // [utf8mb4]
database: options.get_database(),
max_packet_size: 1024,
username: options.get_username(),
})
}
impl<Rt> MySqlConnection<Rt>
where
Rt: AsyncRuntime,
<Rt as Runtime>::TcpStream: Unpin + AsyncWrite + AsyncRead,
{
fn recv_handshake(&mut self, handshake: &Handshake) {
self.capabilities &= handshake.capabilities;
self.connection_id = handshake.connection_id;
}
pub(crate) async fn establish_async(options: &MySqlConnectOptions<Rt>) -> Result<Self> {
let stream = Rt::connect_tcp(options.get_host(), options.get_port()).await?;
let mut self_ = Self::new(stream);
let handshake = self_.read_packet_async().await?;
self_.recv_handshake(&handshake);
self_.write_packet(make_handshake_response(&handshake, options)?)?;
self_.stream.flush_async().await?;
let _ok: OkPacket = self_.read_packet_async().await?;
Ok(self_)
}
fn write_packet<'ser, T>(&'ser mut self, packet: T) -> Result<()>
where
T: Serialize<'ser, Capabilities>,
{
let mut wbuf = Vec::<u8>::with_capacity(1024);
packet.serialize_with(&mut wbuf, self.capabilities)?;
self.sequence_id = self.sequence_id.wrapping_add(1);
self.stream.reserve(wbuf.len() + 4);
self.stream.write(&(wbuf.len() as u32).to_le_bytes()[..3]);
self.stream.write(&[self.sequence_id]);
self.stream.write(&wbuf);
Ok(())
}
async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, Capabilities>,
{
// https://dev.mysql.com/doc/internals/en/mysql-packet.html
self.stream.read_async(4).await?;
let payload_len: usize = self.stream.get(0, 3).get_uint_le(3) as usize;
// FIXME: handle split packets
assert_ne!(payload_len, 0xFF_FF_FF);
self.sequence_id = self.stream.get(3, 1).get_u8();
self.stream.read_async(4 + payload_len).await?;
self.stream.consume(4);
let payload = self.stream.take(payload_len);
if payload[0] == 0xff {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
let err = ErrPacket::deserialize_with(payload, self.capabilities)?;
return Err(Error::Connect(Box::new(MySqlDatabaseError(err))));
}
T::deserialize_with(payload, self.capabilities)
}
}

View File

@ -0,0 +1,154 @@
//! Reads and writes packets to and from the MySQL database server.
//!
//! The logic for serializing data structures into the packets is found
//! mostly in `protocol/`.
//!
//! Packets in MySQL are prefixed by 4 bytes.
//! 3 for length (in LE) and a sequence id.
//!
//! Packets may only be as large as the communicated size in the initial
//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
//! a larger payload is simply sending completely "full" packets, one after the
//! other, with an increasing sequence id.
//!
//! In other words, when we sent data, we:
//!
//! - Split the data into "packets" of size `2 ** 24 - 1` bytes.
//!
//! - Prepend each packet with a **packet header**, consisting of the length of that packet,
//! and the sequence number.
//!
//! https://dev.mysql.com/doc/internals/en/mysql-packet.html
//!
use bytes::{Buf, BufMut};
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{Error, Result, Runtime};
use crate::protocol::{Capabilities, ErrPacket};
use crate::{MySqlConnection, MySqlDatabaseError};
impl<Rt> MySqlConnection<Rt>
where
Rt: Runtime,
{
pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
where
T: Serialize<'ser, Capabilities>,
{
// the sequence-id is incremented with each packet and may
// wrap around. it starts at 0 and is reset to 0 when a new command
// begins in the Command Phase
self.sequence_id = self.sequence_id.wrapping_add(1);
// optimize for <16M packet sizes, in the case of >= 16M we would
// swap out the write buffer for a fresh buffer and then split it into
// 16M chunks separated by packet headers
let buf = self.stream.buffer();
let pos = buf.len();
// leave room for the length of the packet header at the start
buf.reserve(4);
buf.extend_from_slice(&[0_u8; 3]);
buf.push(self.sequence_id);
// serialize the passed packet structure directly into the write buffer
packet.serialize_with(buf, self.capabilities)?;
let payload_len = buf.len() - pos - 4;
// FIXME: handle split packets
assert!(payload_len < 0xFF_FF_FF);
// write back the length of the packet
#[allow(clippy::cast_possible_truncation)]
(&mut buf[pos..]).put_uint_le(payload_len as u64, 3);
Ok(())
}
fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result<T>
where
T: Deserialize<'de, Capabilities>,
{
// FIXME: handle split packets
assert_ne!(len, 0xFF_FF_FF);
// We store the sequence id here. To respond to a packet, it should use a
// sequence id of n+1. It only "resets" at the start of a new command.
self.sequence_id = self.stream.get(3, 1).get_u8();
// tell the stream that we are done with the 4-byte header
self.stream.consume(4);
// and remove the remainder of the packet from the stream, the payload
let payload = self.stream.take(len);
if payload[0] == 0xff {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
let err = ErrPacket::deserialize_with(payload, self.capabilities)?;
return Err(Error::connect(MySqlDatabaseError(err)));
}
T::deserialize_with(payload, self.capabilities)
}
}
macro_rules! read_packet {
($(@$blocking:ident)? $self:ident) => {{
// reads at least 4 bytes from the IO stream into the read buffer
read_packet!($(@$blocking)? @stream $self, 0, 4);
// the first 3 bytes will be the payload length of the packet (in LE)
// ALLOW: the max this len will be is 16M
#[allow(clippy::cast_possible_truncation)]
let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize;
// read <payload_len> bytes _after_ the 4 byte packet header
// note that we have not yet told the stream we are done with any of
// these bytes yet. if this next read invocation were to never return (eg., the
// outer future was dropped), then the next time read_packet_async was called
// it will re-read the parsed-above packet header. Note that we have NOT
// mutated `self` _yet_. This is important.
read_packet!($(@$blocking)? @stream $self, 4, payload_len);
$self.recv_packet(payload_len)
}};
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read($offset, $n)?;
};
(@stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read_async($offset, $n).await?;
};
}
#[cfg(feature = "async")]
impl<Rt> MySqlConnection<Rt>
where
Rt: sqlx_core::AsyncRuntime,
<Rt as Runtime>::TcpStream: Unpin + futures_io::AsyncWrite + futures_io::AsyncRead,
{
pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, Capabilities>,
{
read_packet!(self)
}
}
#[cfg(feature = "blocking")]
impl<Rt> MySqlConnection<Rt>
where
Rt: Runtime,
<Rt as Runtime>::TcpStream: std::io::Write + std::io::Read,
{
pub(super) fn read_packet<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, Capabilities>,
{
read_packet!(@blocking self)
}
}

View File

@ -10,6 +10,12 @@ use crate::protocol::ErrPacket;
#[derive(Debug)]
pub struct MySqlDatabaseError(pub(crate) ErrPacket);
impl MySqlDatabaseError {
pub(crate) fn new(code: u16, message: &str) -> Self {
Self(ErrPacket::new(code, message))
}
}
impl DatabaseError for MySqlDatabaseError {
fn message(&self) -> &str {
&self.0.error_message

View File

@ -1,6 +1,6 @@
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use sqlx_core::io::BufExt;
use string::String;
// UNSAFE: _unchecked string methods
// intended for use when the protocol is *known* to always produce
@ -10,10 +10,10 @@ pub(crate) trait MySqlBufExt: BufExt {
fn get_uint_lenenc(&mut self) -> u64;
#[allow(unsafe_code)]
unsafe fn get_str_lenenc_unchecked(&mut self) -> String<Bytes>;
unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString;
#[allow(unsafe_code)]
unsafe fn get_str_eof_unchecked(&mut self) -> String<Bytes>;
unsafe fn get_str_eof_unchecked(&mut self) -> ByteString;
fn get_bytes_lenenc(&mut self) -> Bytes;
}
@ -38,14 +38,14 @@ impl MySqlBufExt for Bytes {
}
#[allow(unsafe_code)]
unsafe fn get_str_lenenc_unchecked(&mut self) -> String<Bytes> {
unsafe fn get_str_lenenc_unchecked(&mut self) -> ByteString {
let len = self.get_uint_lenenc() as usize;
self.get_str_unchecked(len)
}
#[allow(unsafe_code)]
unsafe fn get_str_eof_unchecked(&mut self) -> String<Bytes> {
unsafe fn get_str_eof_unchecked(&mut self) -> ByteString {
self.get_str_unchecked(self.len())
}

View File

@ -3,6 +3,7 @@
//! [MySQL]: https://www.mysql.com/
//!
#![cfg_attr(doc_cfg, feature(doc_cfg))]
#![cfg_attr(not(any(feature = "async", feature = "blocking")), allow(unused))]
#![deny(unsafe_code)]
#![warn(rust_2018_idioms)]
#![warn(future_incompatible)]
@ -21,16 +22,15 @@
mod connection;
mod database;
mod error;
mod io;
mod options;
mod protocol;
mod error;
mod auth;
#[cfg(feature = "blocking")]
mod blocking;
#[cfg(test)]
mod mock;
pub use connection::MySqlConnection;
pub use database::MySql;
pub use options::MySqlConnectOptions;
pub use error::MySqlDatabaseError;
pub use options::MySqlConnectOptions;

67
sqlx-mysql/src/mock.rs Normal file
View File

@ -0,0 +1,67 @@
use std::io;
use sqlx_core::mock::MockStream;
pub(crate) trait MySqlMockStreamExt {
#[cfg(feature = "async")]
fn write_packet_async<'x>(
&'x mut self,
seq: u8,
packet: &'x [u8],
) -> futures_util::future::BoxFuture<'x, io::Result<()>>;
#[cfg(feature = "async")]
fn read_exact_async(
&mut self,
n: usize,
) -> futures_util::future::BoxFuture<'_, io::Result<Vec<u8>>>;
#[cfg(feature = "async")]
fn read_all_async(&mut self) -> futures_util::future::BoxFuture<'_, io::Result<Vec<u8>>>;
}
impl MySqlMockStreamExt for MockStream {
#[cfg(feature = "async")]
fn write_packet_async<'x>(
&'x mut self,
seq: u8,
packet: &'x [u8],
) -> futures_util::future::BoxFuture<'x, io::Result<()>> {
use futures_util::AsyncWriteExt;
Box::pin(async move {
self.write_all(&packet.len().to_le_bytes()[..3]).await?;
self.write_all(&[seq]).await?;
self.write_all(packet).await
})
}
#[cfg(feature = "async")]
fn read_exact_async(
&mut self,
n: usize,
) -> futures_util::future::BoxFuture<'_, io::Result<Vec<u8>>> {
use futures_util::AsyncReadExt;
Box::pin(async move {
let mut buf = vec![0; n];
let read = self.read(&mut buf).await?;
buf.truncate(read);
Ok(buf)
})
}
#[cfg(feature = "async")]
fn read_all_async(&mut self) -> futures_util::future::BoxFuture<'_, io::Result<Vec<u8>>> {
use futures_util::AsyncReadExt;
Box::pin(async move {
let mut buf = vec![0; 1024];
let read = self.read(&mut buf).await?;
buf.truncate(read);
Ok(buf)
})
}
}

View File

@ -11,6 +11,8 @@ mod builder;
mod default;
mod parse;
// TODO: RSA Public Key (to avoid the key exchange for caching_sha2 and sha256 plugins)
/// Options which can be used to configure how a MySQL connection is opened.
///
/// A value of `MySqlConnectOptions` can be parsed from a connection URL,
@ -135,6 +137,20 @@ where
Rt: sqlx_core::AsyncRuntime,
<Rt as Runtime>::TcpStream: futures_io::AsyncRead + futures_io::AsyncWrite + Unpin,
{
futures_util::FutureExt::boxed(MySqlConnection::establish_async(self))
Box::pin(MySqlConnection::connect_async(self))
}
}
#[cfg(feature = "blocking")]
impl<Rt> sqlx_core::blocking::ConnectOptions<Rt> for MySqlConnectOptions<Rt>
where
Rt: sqlx_core::blocking::Runtime,
<Rt as Runtime>::TcpStream: std::io::Read + std::io::Write,
{
fn connect(&self) -> sqlx_core::Result<Self::Connection>
where
Self::Connection: Sized,
{
<MySqlConnection<Rt>>::connect(self)
}
}

View File

@ -1,5 +1,5 @@
use std::str::FromStr;
use std::borrow::Cow;
use std::str::FromStr;
use percent_encoding::percent_decode_str;
use sqlx_core::{Error, Runtime};

View File

@ -1,13 +1,19 @@
mod auth;
mod auth_plugin;
mod auth_switch;
mod capabilities;
mod err;
mod handshake;
mod handshake_response;
mod ok;
mod status;
mod err;
pub(crate) use err::ErrPacket;
pub(crate) use ok::OkPacket;
pub(crate) use auth::{Auth, AuthResponse};
pub(crate) use auth_plugin::AuthPlugin;
pub(crate) use auth_switch::AuthSwitch;
pub(crate) use capabilities::Capabilities;
pub(crate) use err::ErrPacket;
pub(crate) use handshake::Handshake;
pub(crate) use handshake_response::HandshakeResponse;
pub(crate) use ok::OkPacket;
pub(crate) use status::Status;

View File

@ -0,0 +1,46 @@
use std::fmt::Debug;
use bytes::Bytes;
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{Error, Result};
use crate::protocol::{AuthSwitch, Capabilities, OkPacket};
use crate::MySqlDatabaseError;
#[derive(Debug)]
pub(crate) enum Auth {
Ok(OkPacket),
MoreData(Bytes),
Switch(AuthSwitch),
}
impl Deserialize<'_, Capabilities> for Auth {
fn deserialize_with(buf: Bytes, capabilities: Capabilities) -> Result<Self> {
match buf[0] {
0x00 => OkPacket::deserialize_with(buf, capabilities).map(Self::Ok),
0x01 => Ok(Self::MoreData(buf.slice(1..))),
0xfe => AuthSwitch::deserialize_with(buf, capabilities).map(Self::Switch),
tag => Err(Error::connect(MySqlDatabaseError::new(
2027,
&format!(
"Malformed packet: Received 0x{:x} but expected one of: 0x0, 0x1, or 0xfe",
tag
),
))),
}
}
}
#[derive(Debug)]
pub(crate) struct AuthResponse {
pub(crate) data: Vec<u8>,
}
impl Serialize<'_, Capabilities> for AuthResponse {
fn serialize_with(&self, buf: &mut Vec<u8>, _context: Capabilities) -> Result<()> {
buf.extend_from_slice(&self.data);
Ok(())
}
}

View File

@ -0,0 +1,76 @@
use std::error::Error as StdError;
use std::fmt::Debug;
use std::str::FromStr;
use bytes::buf::Chain;
use bytes::Bytes;
use sqlx_core::{Error, Result};
use crate::MySqlDatabaseError;
mod caching_sha2;
mod native;
mod rsa;
mod sha256;
pub(crate) use self::caching_sha2::CachingSha2AuthPlugin;
pub(crate) use self::native::NativeAuthPlugin;
pub(crate) use self::sha256::Sha256AuthPlugin;
pub(crate) trait AuthPlugin: 'static + Debug + Send + Sync {
fn name(&self) -> &'static str;
// Invoke the auth plugin and return the auth response
fn invoke(&self, nonce: &Chain<Bytes, Bytes>, password: &str) -> Vec<u8>;
// Handle "more data" from the MySQL server
// which tells the plugin some plugin-specific information
// if the plugin returns Some(_) that is sent back to MySQL
fn handle(
&self,
data: Bytes,
nonce: &Chain<Bytes, Bytes>,
password: &str,
) -> Result<Option<Vec<u8>>>;
}
impl FromStr for Box<dyn AuthPlugin> {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
match s {
_ if s == CachingSha2AuthPlugin.name() => Ok(Box::new(CachingSha2AuthPlugin)),
_ if s == Sha256AuthPlugin.name() => Ok(Box::new(Sha256AuthPlugin)),
_ if s == NativeAuthPlugin.name() => Ok(Box::new(NativeAuthPlugin)),
_ => Err(Error::connect(MySqlDatabaseError::new(
2059,
&format!("Authentication plugin '{}' cannot be loaded", s),
))),
}
}
}
// XOR(x, y)
// If len(y) < len(x), wrap around inside y
fn xor_eq(x: &mut [u8], y: &[u8]) {
let y_len = y.len();
for i in 0..x.len() {
x[i] ^= y[i % y_len];
}
}
fn err_msg(plugin: &'static str, message: &str) -> Error {
Error::connect(MySqlDatabaseError::new(
2061,
&format!("Authentication plugin '{}' reported error: {}", plugin, message),
))
}
fn err<E>(plugin: &'static str, error: E) -> Error
where
E: StdError,
{
err_msg(plugin, &error.to_string())
}

View File

@ -0,0 +1,79 @@
use bytes::buf::Chain;
use bytes::Bytes;
use sha2::{Digest, Sha256};
use sqlx_core::Result;
/// Implements SHA-256 authentication but uses caching on the server-side for better performance.
/// After the first authentication, a fast path is used that doesn't involve the RSA key exchange.
///
/// https://dev.mysql.com/doc/refman/8.0/en/caching-sha2-pluggable-authentication.html
/// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
///
#[derive(Debug)]
pub(crate) struct CachingSha2AuthPlugin;
impl super::AuthPlugin for CachingSha2AuthPlugin {
fn name(&self) -> &'static str {
"caching_sha2_password"
}
fn invoke(&self, nonce: &Chain<Bytes, Bytes>, password: &str) -> Vec<u8> {
if password.is_empty() {
// empty password => no scramble
return vec![];
}
// SHA256( password ) ^ SHA256( nonce + SHA256( SHA256( password ) ) )
let mut hasher = Sha256::new();
hasher.update(password);
// SHA256( password )
let mut pw_sha2 = hasher.finalize_reset();
hasher.update(&pw_sha2);
// SHA256( SHA256( password ) )
let pw_sha2_sha2 = hasher.finalize_reset();
hasher.update(pw_sha2_sha2);
hasher.update(nonce.first_ref());
hasher.update(nonce.last_ref());
// SHA256( nonce + SHA256( SHA256( password ) ) )
let nonce_pw_sha1_sha1 = hasher.finalize();
super::xor_eq(&mut pw_sha2, &nonce_pw_sha1_sha1);
pw_sha2.to_vec()
}
fn handle(
&self,
data: Bytes,
nonce: &Chain<Bytes, Bytes>,
password: &str,
) -> Result<Option<Vec<u8>>> {
const AUTH_SUCCESS: u8 = 0x3;
const AUTH_CONTINUE: u8 = 0x4;
match data[0] {
// good to go, return nothing
AUTH_SUCCESS => Ok(None),
AUTH_CONTINUE => {
// ruh roh, we need to ask for the RSA public key, so we can
// encrypt our password directly and send it
return Ok(Some(vec![0x2_u8]));
}
_ => {
let rsa_pub_key = data;
let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?;
Ok(Some(encrypted))
}
}
}
}

View File

@ -0,0 +1,60 @@
use bytes::{buf::Chain, Bytes};
use sha1::{Digest, Sha1};
use sqlx_core::Result;
use super::xor_eq;
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
#[derive(Debug)]
pub(crate) struct NativeAuthPlugin;
impl super::AuthPlugin for NativeAuthPlugin {
fn name(&self) -> &'static str {
"mysql_native_password"
}
fn invoke(&self, nonce: &Chain<Bytes, Bytes>, password: &str) -> Vec<u8> {
if password.is_empty() {
// no password => empty scramble
return vec![];
}
// SHA1( password ) ^ SHA1( nonce + SHA1( SHA1( password ) ) )
let mut hasher = Sha1::new();
hasher.update(password);
// SHA1( password )
let mut pw_sha1 = hasher.finalize_reset();
hasher.update(&pw_sha1);
// SHA1( SHA1( password ) )
let pw_sha1_sha1 = hasher.finalize_reset();
hasher.update(nonce.first_ref());
hasher.update(&nonce.last_ref());
hasher.update(&pw_sha1_sha1);
// SHA1( seed + SHA1( SHA1( password ) ) )
let nonce_pw_sha1_sha1 = hasher.finalize();
xor_eq(&mut pw_sha1, &nonce_pw_sha1_sha1);
pw_sha1.to_vec()
}
fn handle(
&self,
_data: Bytes,
_nonce: &Chain<Bytes, Bytes>,
_password: &str,
) -> Result<Option<Vec<u8>>> {
// MySQL should not be returning any additional data for
// the native mysql auth plugin
unreachable!()
}
}

View File

@ -0,0 +1,70 @@
use std::str::from_utf8;
use bytes::buf::Chain;
use bytes::Bytes;
use rsa::{PaddingScheme, PublicKey, RSAPublicKey};
use sqlx_core::Result;
pub(crate) fn encrypt(
plugin: &'static str,
key: &[u8],
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> Result<Vec<u8>> {
// xor the password with the given nonce
let mut pass = to_asciz(password);
let (a, b) = (nonce.first_ref(), nonce.last_ref());
let mut nonce = Vec::with_capacity(a.len() + b.len());
nonce.extend_from_slice(&*a);
nonce.extend_from_slice(&*b);
super::xor_eq(&mut pass, &*nonce);
// client sends an RSA encrypted password
let pkey = parse_rsa_pub_key(plugin, key)?;
let padding = PaddingScheme::new_oaep::<sha1::Sha1>();
pkey.encrypt(&mut rng(), padding, &pass[..]).map_err(|err| super::err(plugin, err))
}
// https://docs.rs/rsa/0.3.0/rsa/struct.RSAPublicKey.html?search=#example-1
fn parse_rsa_pub_key(plugin: &'static str, key: &[u8]) -> Result<RSAPublicKey> {
let key = from_utf8(key).map_err(|err| super::err(plugin, err))?;
// Takes advantage of the knowledge that we know
// we are receiving a PKCS#8 RSA Public Key at all
// times from MySQL
let encoded =
key.lines().filter(|line| !line.starts_with("-")).fold(String::new(), |mut data, line| {
data.push_str(&line);
data
});
let der = base64::decode(&encoded).map_err(|err| super::err(plugin, err))?;
RSAPublicKey::from_pkcs8(&der).map_err(|err| super::err(plugin, err))
}
fn to_asciz(s: &str) -> Vec<u8> {
let mut z = String::with_capacity(s.len() + 1);
z.push_str(s);
z.push('\0');
z.into_bytes()
}
// use a stable stream of numbers for encryption
// during tests to assert the result of [encrypt]
#[cfg(not(test))]
fn rng() -> rand::rngs::ThreadRng {
rand::thread_rng()
}
#[cfg(test)]
fn rng() -> rand::rngs::mock::StepRng {
rand::rngs::mock::StepRng::new(0, 1)
}

View File

@ -0,0 +1,42 @@
use bytes::buf::Chain;
use bytes::Bytes;
use sqlx_core::Result;
/// Implements SHA-256 authentication.
///
/// Each time we connect we have to do an RSA key exchange.
/// This slows down auth quite a bit.
///
/// https://dev.mysql.com/doc/refman/8.0/en/sha256-pluggable-authentication.html
/// https://mariadb.com/kb/en/sha256_password-plugin/
///
#[derive(Debug)]
pub(crate) struct Sha256AuthPlugin;
impl super::AuthPlugin for Sha256AuthPlugin {
fn name(&self) -> &'static str {
"sha256_password"
}
fn invoke(&self, _nonce: &Chain<Bytes, Bytes>, password: &str) -> Vec<u8> {
if password.is_empty() {
// no password => do not ask for RSA key
return vec![];
}
// ask for the RSA key
vec![0x01]
}
fn handle(
&self,
data: Bytes,
nonce: &Chain<Bytes, Bytes>,
password: &str,
) -> Result<Option<Vec<u8>>> {
let rsa_pub_key = data;
let encrypted = super::rsa::encrypt(self.name(), &rsa_pub_key, password, nonce)?;
Ok(Some(encrypted))
}
}

View File

@ -0,0 +1,39 @@
use std::str::FromStr;
use bytes::{buf::Chain, Buf, Bytes};
use sqlx_core::io::{BufExt, Deserialize};
use sqlx_core::Result;
use super::Capabilities;
use crate::protocol::AuthPlugin;
// https://dev.mysql.com/doc/internals/en/authentication-method-change.html
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
#[derive(Debug)]
pub(crate) struct AuthSwitch {
pub(crate) plugin: Box<dyn AuthPlugin>,
pub(crate) plugin_data: Chain<Bytes, Bytes>,
}
impl Deserialize<'_, Capabilities> for AuthSwitch {
fn deserialize_with(mut buf: Bytes, _capabilities: Capabilities) -> Result<Self> {
let tag = buf.get_u8();
debug_assert_eq!(tag, 0xfe);
// SAFE: auth plugins are ASCII only
#[allow(unsafe_code)]
let name = unsafe { buf.get_str_nul_unchecked()? };
if buf.ends_with(&[0]) {
// if this terminates in a NUL; drop the NUL
buf.truncate(buf.len() - 1);
}
let plugin_data = buf.chain(Bytes::new());
let plugin = <Box<dyn AuthPlugin>>::from_str(&*name)?;
Ok(Self { plugin, plugin_data })
}
}

View File

@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use sqlx_core::io::{BufExt, Deserialize};
use sqlx_core::Result;
use string::String;
use crate::io::MySqlBufExt;
use crate::protocol::Capabilities;
@ -14,21 +14,14 @@ use crate::protocol::Capabilities;
#[derive(Debug)]
pub(crate) struct ErrPacket {
pub(crate) error_code: u16,
pub(crate) sql_state: Option<String<Bytes>>,
pub(crate) error_message: String<Bytes>,
pub(crate) sql_state: Option<ByteString>,
pub(crate) error_message: ByteString,
}
impl ErrPacket {
pub(crate) fn new(code: u16, message: &str) -> Self {
let message_bytes = Bytes::copy_from_slice(message.as_bytes());
let state_bytes = Bytes::from_static(b"HY000");
// UNSAFE: the UTF-8 string is converted to bytes right above. The string crate has a
// safe method for creation from Rust str but it pulls in an old version of Bytes
#[allow(unsafe_code)]
let (message, state) = unsafe {
(String::from_utf8_unchecked(message_bytes), String::from_utf8_unchecked(state_bytes))
};
let message = ByteString::from(message);
let state = ByteString::from_static("HY000");
Self { error_code: code, sql_state: Some(state), error_message: message }
}

View File

@ -1,10 +1,14 @@
use std::str::FromStr;
use bytes::buf::Chain;
use bytes::{Buf, Bytes};
use bytestring::ByteString;
use memchr::memchr;
use sqlx_core::io::{BufExt, Deserialize};
use sqlx_core::Result;
use crate::protocol::{Capabilities, Status};
use crate::protocol::auth_plugin::NativeAuthPlugin;
use crate::protocol::{AuthPlugin, Capabilities, Status};
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeV10
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
@ -15,7 +19,7 @@ pub(crate) struct Handshake {
pub(crate) protocol_version: u8,
// human-readable server version
pub(crate) server_version: string::String<Bytes>,
pub(crate) server_version: ByteString,
pub(crate) connection_id: u32,
@ -25,10 +29,8 @@ pub(crate) struct Handshake {
// default server character set
pub(crate) charset: Option<u8>,
pub(crate) auth_plugin: Box<dyn AuthPlugin>,
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
// name of the auth_method that the auth_plugin_data belongs to
pub(crate) auth_plugin_name: Option<string::String<Bytes>>,
}
impl Deserialize<'_, Capabilities> for Handshake {
@ -88,6 +90,11 @@ impl Deserialize<'_, Capabilities> for Handshake {
auth_plugin_data_2 = buf.split_to(len as usize);
if auth_plugin_data_2.ends_with(&[0]) {
// if this terminates in a NUL; drop the NUL
auth_plugin_data_2.truncate(auth_plugin_data_2.len() - 1);
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// due to Bug#59453 the auth-plugin-name is missing the terminating NUL-char
// in versions prior to 5.5.10 and 5.6.2
@ -96,8 +103,7 @@ impl Deserialize<'_, Capabilities> for Handshake {
// read to NUL or read to the end if we can't find a NUL
let auth_plugin_name_end =
memchr(b'\0', &buf).unwrap_or(buf.len());
let auth_plugin_name_end = memchr(b'\0', &buf).unwrap_or(buf.len());
// UNSAFE: auth plugin names are known to be ASCII
#[allow(unsafe_code)]
@ -116,7 +122,10 @@ impl Deserialize<'_, Capabilities> for Handshake {
capabilities,
status,
auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
auth_plugin_name,
auth_plugin: auth_plugin_name
.map(|name| <Box<dyn AuthPlugin>>::from_str(&name))
.transpose()?
.unwrap_or_else(|| Box::new(NativeAuthPlugin)),
})
}
}
@ -167,11 +176,11 @@ mod tests {
assert_eq!(h.charset, Some(255));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("caching_sha2_password"));
assert_eq!(h.auth_plugin.name(), "caching_sha2_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32, 0]
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32]
);
}
@ -211,14 +220,11 @@ mod tests {
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(h.auth_plugin.name(), "mysql_native_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53,
110, 0
]
&[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110]
);
}
@ -258,14 +264,11 @@ mod tests {
assert_eq!(h.charset, Some(45));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(h.auth_plugin.name(), "mysql_native_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
39, 80, 66, 57, 52, 57, 99, 102, 85, 89, 62, 104, 114, 38, 96, 51, 123, 53, 53, 72,
0
]
&[39, 80, 66, 57, 52, 57, 99, 102, 85, 89, 62, 104, 114, 38, 96, 51, 123, 53, 53, 72,]
);
}
@ -305,11 +308,11 @@ mod tests {
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(h.auth_plugin.name(), "mysql_native_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[45, 86, 76, 89, 90, 58, 80, 100, 39, 50, 102, 43, 66, 76, 56, 110, 71, 86, 91, 71, 0]
&[45, 86, 76, 89, 90, 58, 80, 100, 39, 50, 102, 43, 66, 76, 56, 110, 71, 86, 91, 71]
);
}
@ -334,13 +337,13 @@ mod tests {
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name, None);
assert_eq!(h.auth_plugin.name(), "mysql_native_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
98, 115, 61, 115, 78, 105, 71, 101, 73, 122, 77, 80, 41, 121, 76, 76, 120, 59, 91,
57, 0
57
]
);
}
@ -373,14 +376,11 @@ mod tests {
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name, None);
assert_eq!(h.auth_plugin.name(), "mysql_native_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
60, 102, 108, 108, 90, 92, 66, 115, 60, 113, 69, 67, 95, 56, 55, 74, 79, 47, 57,
113, 0
]
&[60, 102, 108, 108, 90, 92, 66, 115, 60, 113, 69, 67, 95, 56, 55, 74, 79, 47, 57, 113]
);
}
@ -416,13 +416,13 @@ mod tests {
assert_eq!(h.charset, Some(8));
assert_eq!(h.status, Status::AUTOCOMMIT);
assert_eq!(h.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_eq!(h.auth_plugin.name(), "mysql_native_password");
assert_eq!(
&*h.auth_plugin_data.copy_to_bytes(h.auth_plugin_data.remaining()),
&[
96, 111, 45, 47, 67, 69, 112, 39, 107, 102, 64, 74, 53, 106, 54, 110, 74, 102, 65,
80, 0
80
]
);
}

View File

@ -1,4 +1,3 @@
use bytes::BufMut;
use sqlx_core::io::{Serialize, WriteExt};
use sqlx_core::Result;
@ -14,8 +13,8 @@ pub(crate) struct HandshakeResponse<'a> {
pub(crate) max_packet_size: u32,
pub(crate) charset: u8,
pub(crate) username: Option<&'a str>,
pub(crate) auth_plugin_name: Option<&'a str>,
pub(crate) auth_response: Option<Vec<u8>>,
pub(crate) auth_plugin_name: &'a str,
pub(crate) auth_response: Vec<u8>,
}
impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
@ -29,7 +28,7 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
buf.write_maybe_str_nul(self.username);
let auth_response = self.auth_response.as_deref().unwrap_or_default();
let auth_response = self.auth_response.as_slice();
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
buf.write_bytes_lenenc(auth_response);
@ -50,7 +49,7 @@ impl Serialize<'_, Capabilities> for HandshakeResponse<'_> {
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.write_maybe_str_nul(self.auth_plugin_name);
buf.write_str_nul(self.auth_plugin_name);
}
Ok(())

View File

@ -42,7 +42,7 @@ impl Deserialize<'_, Capabilities> for OkPacket {
#[cfg(test)]
mod tests {
use super::{OkPacket, Capabilities, Deserialize, Status};
use super::{Capabilities, Deserialize, OkPacket, Status};
#[test]
fn test_empty_ok_packet() {