diff --git a/Cargo.lock b/Cargo.lock index 7f1d5749..61491748 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -312,6 +312,11 @@ dependencies = [ "version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "fake-simd" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "fnv" version = "1.0.6" @@ -760,6 +765,16 @@ dependencies = [ "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "num-bigint" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", + "num-integer 0.1.41 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "num-integer" version = "0.1.41" @@ -1022,6 +1037,28 @@ dependencies = [ "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "sha-1" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "block-buffer 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", + "digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", + "fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "sha2" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "block-buffer 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", + "digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", + "fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "slab" version = "0.4.2" @@ -1064,15 +1101,22 @@ version = "0.1.1" dependencies = [ "async-std 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "async-stream 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "base64 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", "bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "chrono 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)", + "digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "futures-core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "futures-util 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", "md-5 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "num-bigint 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "sha-1 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", + "sha2 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1425,6 +1469,7 @@ dependencies = [ "checksum dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)" = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" "checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" "checksum error-chain 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3ab49e9dcb602294bc42f9a7dfc9bc6e936fca4418ea300dbfb84fe16de0b7d9" +"checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" "checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" "checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" "checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" @@ -1475,6 +1520,7 @@ dependencies = [ "checksum mio-uds 0.6.7 (registry+https://github.com/rust-lang/crates.io-index)" = "966257a94e196b11bb43aca423754d87429960a768de9414f3691d6957abf125" "checksum miow 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919" "checksum net2 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88" +"checksum num-bigint 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f9c3f34cdd24f334cb265d9bf8bfa8a241920d026916785747a92f0e55541a1a" "checksum num-integer 0.1.41 (registry+https://github.com/rust-lang/crates.io-index)" = "b85e541ef8255f6cf42bbfe4ef361305c6c135d10919ecc26126c4e5ae94bc09" "checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4" "checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" @@ -1509,6 +1555,8 @@ dependencies = [ "checksum serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "128f9e303a5a29922045a830221b8f78ec74a5f544944f3d5984f8ec3895ef64" "checksum serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)" = "48c575e0cc52bdd09b47f330f646cf59afc586e9c4e3ccd6fc1f625b8ea1dad7" "checksum serde_qs 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "d43eef44996bbe16e99ac720e1577eefa16f7b76b5172165c98ced20ae9903e1" +"checksum sha-1 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "23962131a91661d643c98940b20fcaffe62d776a823247be80a48fcb8b6fce68" +"checksum sha2 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7b4d8bfd0e469f417657573d8451fb33d16cfe0989359b93baf3a1ffc639543d" "checksum slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8" "checksum smallvec 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "f7b0758c52e15a8b5e3691eae6cc559f08eee9406e548a4477ba4e67770a82b6" "checksum smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44e59e0c9fa00817912ae6e4e6e3c4fe04455e75699d06eedc7d85917ed8e8f4" diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index a13731a0..f95e3ad8 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -15,22 +15,29 @@ authors = [ [features] default = [] unstable = [] -postgres = [] -mysql = [] +postgres = [ "md-5" ] +mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] [dependencies] -async-stream = { version = "0.2.0", default-features = false } async-std = { version = "1.4.0", default-features = false, features = [ "unstable" ] } +async-stream = { version = "0.2.0", default-features = false } +base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] } bitflags = { version = "1.2.1", default-features = false } +byteorder = { version = "1.3.2", default-features = false } +chrono = { version = "0.4.10", default-features = false, features = [ "clock" ], optional = true } +digest = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] } futures-core = { version = "0.3.1", default-features = false } futures-util = { version = "0.3.1", default-features = false } +generic-array = { version = "0.12.3", default-features = false, optional = true } log = { version = "0.4.8", default-features = false } -url = { version = "2.1.0", default-features = false } -byteorder = { version = "1.3.2", default-features = false } +md-5 = { version = "0.8.0", default-features = false, optional = true } memchr = { version = "2.2.1", default-features = false } -md-5 = { version = "0.8.0", default-features = false } +num-bigint = { version = "0.2.3", default-features = false, optional = true, features = [ "std" ] } +rand = { version = "0.7.2", default-features = false, optional = true, features = [ "std" ] } +sha-1 = { version = "0.8.1", default-features = false, optional = true } +sha2 = { version = "0.8.0", default-features = false, optional = true } +url = { version = "2.1.0", default-features = false } uuid = { version = "0.8.1", default-features = false, optional = true } -chrono = { version = "0.4.10", default-features = false, features = [ "clock" ], optional = true } [dev-dependencies] matches = "0.1.8" diff --git a/sqlx-core/src/mysql/connection.rs b/sqlx-core/src/mysql/connection.rs index 0592cd1f..7c3e7e9a 100644 --- a/sqlx-core/src/mysql/connection.rs +++ b/sqlx-core/src/mysql/connection.rs @@ -4,6 +4,8 @@ use std::io; use async_std::net::{Shutdown, TcpStream}; use byteorder::{ByteOrder, LittleEndian}; use futures_core::future::BoxFuture; +use sha1::Sha1; +use sha2::{Digest, Sha256}; use crate::cache::StatementCache; use crate::connection::Connection; @@ -11,10 +13,18 @@ use crate::executor::Executor; use crate::io::{Buf, BufMut, BufStream}; use crate::mysql::error::MySqlError; use crate::mysql::protocol::{ - Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake, HandshakeResponse, OkPacket, + AuthPlugin, AuthSwitch, Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake, + HandshakeResponse, OkPacket, }; +use crate::mysql::rsa; +use crate::mysql::util::xor_eq; use crate::url::Url; +// Size before a packet is split +const MAX_PACKET_SIZE: u32 = 1024; + +const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224; + /// An asynchronous connection to a [MySql] database. /// /// The connection string expected by [Connection::open] should be a MySQL connection @@ -23,25 +33,27 @@ use crate::url::Url; pub struct MySqlConnection { pub(super) stream: BufStream, + // Active capabilities of the client _&_ the server pub(super) capabilities: Capabilities, + // Cache of prepared statements + // Query (String) to StatementId to ColumnMap pub(super) statement_cache: StatementCache, - rbuf: Vec, + // Packets are buffered into a second buffer from the stream + // as we may have compressed or split packets to figure out before + // decoding + pub(super) packet: Vec, + packet_len: usize, - next_seq_no: u8, - - pub(super) ready: bool, + // Packets in a command sequence have an incrementing sequence number + // This number must be 0 at the start of each command + pub(super) next_seq_no: u8, } impl MySqlConnection { - pub(super) fn begin_command_phase(&mut self) { - // At the start of the *command phase*, the sequence ID sent from the client - // must be 0 - self.next_seq_no = 0; - } - - pub(super) fn write(&mut self, packet: impl Encode + std::fmt::Debug) { + /// Write the packet to the stream ( do not send to the server ) + pub(crate) fn write(&mut self, packet: impl Encode) { let buf = self.stream.buffer_mut(); // Allocate room for the header that we write after the packet; @@ -66,51 +78,42 @@ impl MySqlConnection { self.next_seq_no = self.next_seq_no.wrapping_add(1); } - async fn receive_ok(&mut self) -> crate::Result { - let packet = self.receive().await?; - Ok(match packet[0] { - 0xfe | 0x00 => OkPacket::decode(packet)?, - - 0xff => { - return Err(MySqlError(ErrPacket::decode(packet)?).into()); - } - - id => { - return Err(protocol_err!( - "unexpected packet identifier 0x{:X?} when expecting 0xFE (OK) or 0xFF \ - (ERR)", - id - ) - .into()); - } - }) - } - - pub(super) async fn receive_eof(&mut self) -> crate::Result<()> { - // When (legacy) EOFs are enabled, the fixed number column definitions are further - // terminated by an EOF packet - if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { - let _eof = EofPacket::decode(self.receive().await?)?; - } + /// Send the packet to the database server + pub(crate) async fn send(&mut self, packet: impl Encode) -> crate::Result<()> { + self.write(packet); + self.stream.flush().await?; Ok(()) } - pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> { - Ok(self - .try_receive() - .await? - .ok_or(io::ErrorKind::UnexpectedEof)?) + /// Send a [HandshakeResponse] packet to the database server + pub(crate) async fn send_handshake_response( + &mut self, + url: &Url, + auth_plugin: &AuthPlugin, + auth_response: &[u8], + ) -> crate::Result<()> { + self.send(HandshakeResponse { + client_collation: COLLATE_UTF8MB4_UNICODE_CI, + max_packet_size: MAX_PACKET_SIZE, + username: url.username().unwrap_or("root"), + database: url.database(), + auth_plugin, + auth_response, + }) + .await } - pub(super) async fn try_receive(&mut self) -> crate::Result> { - self.rbuf.clear(); + /// Try to receive a packet from the database server. Returns `None` if the server has sent + /// no data. + pub(crate) async fn try_receive(&mut self) -> crate::Result> { + self.packet.clear(); // Read the packet header which contains the length and the sequence number // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet let mut header = ret_if_none!(self.stream.peek(4).await?); - let payload_len = header.get_uint::(3)? as usize; + self.packet_len = header.get_uint::(3)? as usize; self.next_seq_no = header.get_u8()?.wrapping_add(1); self.stream.consume(4); @@ -118,66 +121,221 @@ impl MySqlConnection { // We must have a separate buffer around the stream as we can't operate directly // on bytes returned from the stream. We have various kinds of payload manipulation // that must be handled before decoding. - let mut payload = ret_if_none!(self.stream.peek(payload_len).await?); - self.rbuf.extend_from_slice(payload); - self.stream.consume(payload_len); + let mut payload = ret_if_none!(self.stream.peek(self.packet_len).await?); + self.packet.extend_from_slice(payload); + self.stream.consume(self.packet_len); // TODO: Implement packet compression // TODO: Implement packet joining - Ok(Some(&self.rbuf[..payload_len])) + Ok(Some(())) } -} -impl MySqlConnection { - // TODO: Authentication ?! - pub(super) async fn open(url: crate::Result) -> crate::Result { - let url = url?; - let stream = TcpStream::connect((url.host(), url.port(3306))).await?; + /// Receive a complete packet from the database server. + pub(crate) async fn receive(&mut self) -> crate::Result<&mut Self> { + self.try_receive() + .await? + .ok_or(io::ErrorKind::UnexpectedEof)?; - let mut self_ = Self { - stream: BufStream::new(stream), - capabilities: Capabilities::empty(), - rbuf: Vec::with_capacity(8192), - next_seq_no: 0, - statement_cache: StatementCache::new(), - ready: true, - }; + Ok(self) + } - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html - // https://mariadb.com/kb/en/connection/ + /// Returns a reference to the most recently received packet data + #[inline] + pub(crate) fn packet(&self) -> &[u8] { + &self.packet[..self.packet_len] + } - // First, we receive the Handshake + /// Receive an [EofPacket] if we are supposed to receive them at all. + pub(crate) async fn receive_eof(&mut self) -> crate::Result<()> { + // When (legacy) EOFs are enabled, many things are terminated by an EOF packet + if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { + let _eof = EofPacket::decode(self.receive().await?.packet())?; + } - let handshake_packet = self_.receive().await?; - let handshake = Handshake::decode(handshake_packet)?; + Ok(()) + } - let mut client_capabilities = - Capabilities::PROTOCOL_41 | Capabilities::IGNORE_SPACE | Capabilities::FOUND_ROWS; + /// Receive a [Handshake] packet. When connecting to the database server, this is immediately + /// received from the database server. + pub(crate) async fn receive_handshake(&mut self, url: &Url) -> crate::Result { + let handshake = Handshake::decode(self.receive().await?.packet())?; + + let mut client_capabilities = Capabilities::PROTOCOL_41 + | Capabilities::IGNORE_SPACE + | Capabilities::FOUND_ROWS + | Capabilities::PLUGIN_AUTH; if url.database().is_some() { client_capabilities |= Capabilities::CONNECT_WITH_DB; } - // Fails if [Capabilities::PROTOCOL_41] is not in [server_capabilities] - self_.capabilities = + self.capabilities = (client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41; - // Next we send the response + Ok(handshake) + } - self_.write(HandshakeResponse { - client_collation: 192, // utf8_unicode_ci - max_packet_size: 1024, - username: url.username().unwrap_or("root"), - database: url.database(), - auth_plugin_name: handshake.auth_plugin_name.as_deref(), - auth_response: None, - }); + /// Receives an [OkPacket] from the database server. This is called at the end of + /// authentication to confirm the established connection. + pub(crate) fn receive_auth_ok<'a>( + &'a mut self, + plugin: &'a AuthPlugin, + password: &'a str, + nonce: &'a [u8], + ) -> BoxFuture<'a, crate::Result<()>> { + Box::pin(async move { + self.receive().await?; - self_.stream.flush().await?; + match self.packet[0] { + 0x00 => self.handle_ok().map(drop), + 0xfe => self.handle_auth_switch(password).await, + 0xff => self.handle_err(), - let _ok = self_.receive_ok().await?; + _ => self.handle_auth_continue(plugin, password, nonce).await, + } + }) + } +} +impl MySqlConnection { + pub(crate) fn handle_ok(&mut self) -> crate::Result { + let ok = OkPacket::decode(self.packet())?; + + // An OK signifies the end of the current command sequence + self.next_seq_no = 0; + + Ok(ok) + } + + pub(crate) fn handle_err(&mut self) -> crate::Result { + let err = ErrPacket::decode(self.packet())?; + + // An ERR signifies the end of the current command sequence + self.next_seq_no = 0; + + Err(MySqlError(err).into()) + } + + pub(crate) fn handle_unexpected_packet(&self, id: u8) -> crate::Result { + Err(protocol_err!("unexpected packet identifier 0x{:X?}", id).into()) + } + + pub(crate) async fn handle_auth_continue( + &mut self, + plugin: &AuthPlugin, + password: &str, + nonce: &[u8], + ) -> crate::Result<()> { + match plugin { + AuthPlugin::CachingSha2Password => { + if self.packet[0] == 1 { + match self.packet[1] { + // AUTH_OK + 0x03 => {} + + // AUTH_CONTINUE + 0x04 => { + // client sends an RSA encrypted password + let ct = self.rsa_encrypt(0x02, password, nonce).await?; + + self.send(&*ct).await?; + } + + auth => { + return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", auth).into()); + } + } + + // ends with server sending either OK_Packet or ERR_Packet + self.receive_auth_ok(plugin, password, nonce) + .await + .map(drop) + } else { + return self.handle_unexpected_packet(self.packet[0]); + } + } + + // No other supported auth methods will be called through continue + _ => unreachable!(), + } + } + + pub(crate) async fn handle_auth_switch(&mut self, password: &str) -> crate::Result<()> { + let auth = AuthSwitch::decode(self.packet())?; + + let auth_response = self + .make_auth_initial_response(&auth.auth_plugin, password, &auth.auth_plugin_data) + .await?; + + self.send(&*auth_response).await?; + + self.receive_auth_ok(&auth.auth_plugin, password, &auth.auth_plugin_data) + .await + } + + pub(crate) async fn make_auth_initial_response( + &mut self, + plugin: &AuthPlugin, + password: &str, + nonce: &[u8], + ) -> crate::Result> { + match plugin { + AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => { + Ok(plugin.scramble(password, nonce)) + } + + AuthPlugin::Sha256Password => { + // Full RSA exchange and password encrypt up front with no "cache" + Ok(self.rsa_encrypt(0x01, password, nonce).await?.into_vec()) + } + } + } + + pub(crate) async fn rsa_encrypt( + &mut self, + public_key_request_id: u8, + password: &str, + nonce: &[u8], + ) -> crate::Result> { + // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ + + // TODO: Handle SSL + + // client sends a public key request + self.send(&[public_key_request_id][..]).await?; + + // server sends a public key response + let mut packet = self.receive().await?.packet(); + let rsa_pub_key = &packet[1..]; + + // The password string data must be NUL terminated + // Note: This is not in the documentation that I could find + let mut pass = password.as_bytes().to_vec(); + pass.push(0); + + xor_eq(&mut pass, nonce); + + // client sends an RSA encrypted password + rsa::encrypt::(rsa_pub_key, &pass) + } +} + +impl MySqlConnection { + async fn new(url: &Url) -> crate::Result { + let stream = TcpStream::connect((url.host(), url.port(3306))).await?; + + Ok(Self { + stream: BufStream::new(stream), + capabilities: Capabilities::empty(), + packet: Vec::with_capacity(8192), + packet_len: 0, + next_seq_no: 0, + statement_cache: StatementCache::new(), + }) + } + + async fn initialize(&mut self) -> crate::Result<()> { // On connect, we want to establish a modern, Rust-compatible baseline so we // tweak connection options to enable UTC for TIMESTAMP, UTF-8 for character types, etc. @@ -194,25 +352,72 @@ impl MySqlConnection { // NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust. - // NO_ZERO_IN_DATE - Don't allow 'yyyy-00-00'. This is invalid in Rust. + // NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust. - self_.send("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))") + // language=MySQL + self.execute_raw("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))") .await?; // This allows us to assume that the output from a TIMESTAMP field is UTC - self_.send("SET time_zone = 'UTC'").await?; + // language=MySQL + self.execute_raw("SET time_zone = 'UTC'").await?; // https://mathiasbynens.be/notes/mysql-utf8mb4 - self_ - .send("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci") + // language=MySQL + self.execute_raw("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci") .await?; + Ok(()) + } +} + +impl MySqlConnection { + pub(super) async fn open(url: crate::Result) -> crate::Result { + let url = url?; + let mut self_ = Self::new(&url).await?; + + // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html + // https://mariadb.com/kb/en/connection/ + + // On connect, server immediately sends the handshake + let handshake = self_.receive_handshake(&url).await?; + + // Pre-generate an auth response by using the auth method in the [Handshake] + let password = url.password().unwrap_or_default(); + let auth_response = self_ + .make_auth_initial_response( + &handshake.auth_plugin, + password, + &handshake.auth_plugin_data, + ) + .await?; + + self_ + .send_handshake_response(&url, &handshake.auth_plugin, &auth_response) + .await?; + + // After sending the handshake response with our assumed auth method the server + // will send OK, fail, or tell us to change auth methods + self_ + .receive_auth_ok( + &handshake.auth_plugin, + password, + &handshake.auth_plugin_data, + ) + .await?; + + // After the connection is established, we initialize by configuring a few + // connection parameters + self_.initialize().await?; + Ok(self_) } async fn close(mut self) -> crate::Result<()> { + // TODO: Actually tell MySQL that we're closing + self.stream.flush().await?; self.stream.stream.shutdown(Shutdown::Both)?; diff --git a/sqlx-core/src/mysql/executor.rs b/sqlx-core/src/mysql/executor.rs index f4fb1a77..cb681451 100644 --- a/sqlx-core/src/mysql/executor.rs +++ b/sqlx-core/src/mysql/executor.rs @@ -26,7 +26,7 @@ enum OkOrResultSet { impl MySqlConnection { async fn ignore_columns(&mut self, count: usize) -> crate::Result<()> { for _ in 0..count { - let _column = ColumnDefinition::decode(self.receive().await?)?; + let _column = ColumnDefinition::decode(self.receive().await?.packet())?; } if count > 0 { @@ -37,35 +37,15 @@ impl MySqlConnection { } async fn receive_ok_or_column_count(&mut self) -> crate::Result { - let packet = self.receive().await?; + self.receive().await?; - match packet[0] { - 0xfe if packet.len() < 0xffffff => { - let ok = OkPacket::decode(packet)?; - self.ready = true; + match self.packet[0] { + 0x00 | 0xfe if self.packet.len() < 0xffffff => self.handle_ok().map(OkOrResultSet::Ok), + 0xff => self.handle_err(), - Ok(OkOrResultSet::Ok(ok)) - } - - 0x00 => { - let ok = OkPacket::decode(packet)?; - self.ready = true; - - Ok(OkOrResultSet::Ok(ok)) - } - - 0xff => { - let err = ErrPacket::decode(packet)?; - self.ready = true; - - Err(MySqlError(err).into()) - } - - _ => { - let cc = ColumnCount::decode(packet)?; - - Ok(OkOrResultSet::ResultSet(cc)) - } + _ => Ok(OkOrResultSet::ResultSet(ColumnCount::decode( + self.packet(), + )?)), } } @@ -73,8 +53,8 @@ impl MySqlConnection { let mut columns: Vec = Vec::with_capacity(count); for _ in 0..count { - let packet = self.receive().await?; - let column: ColumnDefinition = ColumnDefinition::decode(packet)?; + let column: ColumnDefinition = + ColumnDefinition::decode(self.receive().await?.packet())?; columns.push(column.r#type); } @@ -87,7 +67,7 @@ impl MySqlConnection { } async fn wait_for_ready(&mut self) -> crate::Result<()> { - if !self.ready { + if self.next_seq_no != 0 { while let Some(_step) = self.step(&[], true).await? { // Drain steps until we hit the end } @@ -98,21 +78,19 @@ impl MySqlConnection { async fn prepare(&mut self, query: &str) -> crate::Result { // Start by sending a COM_STMT_PREPARE - self.begin_command_phase(); - self.write(ComStmtPrepare { query }); - self.stream.flush().await?; + self.send(ComStmtPrepare { query }).await?; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html // First we should receive a COM_STMT_PREPARE_OK - let packet = self.receive().await?; + self.receive().await?; - if packet[0] == 0xff { + if self.packet[0] == 0xff { // Oops, there was an error in the prepare command - return Err(MySqlError(ErrPacket::decode(packet)?).into()); + return self.handle_err(); } - ComStmtPrepareOk::decode(packet) + ComStmtPrepareOk::decode(self.packet()) } async fn prepare_with_cache(&mut self, query: &str) -> crate::Result { @@ -132,7 +110,7 @@ impl MySqlConnection { let mut columns = HashMap::with_capacity(prepare_ok.columns as usize); let mut index = 0_usize; for _ in 0..prepare_ok.columns { - let column = ColumnDefinition::decode(self.receive().await?)?; + let column = ColumnDefinition::decode(self.receive().await?.packet())?; if let Some(name) = column.column_alias.or(column.column) { columns.insert(name, index); @@ -145,6 +123,9 @@ impl MySqlConnection { self.receive_eof().await?; } + // At the end of a command, this should go back to 0 + self.next_seq_no = 0; + // Remember our column map in the statement cache self.statement_cache .put_columns(prepare_ok.statement_id, columns); @@ -155,73 +136,59 @@ impl MySqlConnection { // [COM_STMT_EXECUTE] async fn execute_statement(&mut self, id: u32, args: MySqlArguments) -> crate::Result<()> { - self.begin_command_phase(); - self.ready = false; - - self.write(ComStmtExecute { + self.send(ComStmtExecute { cursor: Cursor::NO_CURSOR, statement_id: id, params: &args.params, null_bitmap: &args.null_bitmap, param_types: &args.param_types, - }); - - self.stream.flush().await?; - - Ok(()) + }) + .await } async fn step(&mut self, columns: &[Type], binary: bool) -> crate::Result> { let capabilities = self.capabilities; let packet = ret_if_none!(self.try_receive().await?); - match packet[0] { - 0xfe if packet.len() < 0xffffff => { - // Resultset row can begin with 0xfe byte (when using text protocol + match self.packet[0] { + 0xfe if self.packet.len() < 0xffffff => { + // ResultSet row can begin with 0xfe byte (when using text protocol // with a field length > 0xffffff) if !capabilities.contains(Capabilities::DEPRECATE_EOF) { - let _eof = EofPacket::decode(packet)?; - self.ready = true; + let _eof = EofPacket::decode(self.packet())?; - return Ok(None); + // An EOF -here- signifies the end of the current command sequence + self.next_seq_no = 0; + + Ok(None) } else { - let ok = OkPacket::decode(packet)?; - self.ready = true; - - return Ok(Some(Step::Command(ok.affected_rows))); + self.handle_ok() + .map(|ok| Some(Step::Command(ok.affected_rows))) } } - 0xff => { - let err = ErrPacket::decode(packet)?; - self.ready = true; + 0xff => self.handle_err(), - return Err(MySqlError(err).into()); - } - - _ => { - return Ok(Some(Step::Row(Row::decode(packet, columns, binary)?))); - } + _ => Ok(Some(Step::Row(Row::decode( + self.packet(), + columns, + binary, + )?))), } } } impl MySqlConnection { - async fn send(&mut self, query: &str) -> crate::Result<()> { + pub(super) async fn execute_raw(&mut self, query: &str) -> crate::Result<()> { self.wait_for_ready().await?; - self.begin_command_phase(); - self.ready = false; - - // enable multi-statement only for this query - self.write(ComQuery { query }); - - self.stream.flush().await?; + self.send(ComQuery { query }).await?; // COM_QUERY can terminate before the result set with an ERR or OK packet let num_columns = match self.receive_ok_or_column_count().await? { OkOrResultSet::Ok(_) => { + self.next_seq_no = 0; return Ok(()); } @@ -247,6 +214,8 @@ impl MySqlConnection { // COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet let num_columns = match self.receive_ok_or_column_count().await? { OkOrResultSet::Ok(ok) => { + self.next_seq_no = 0; + return Ok(ok.affected_rows); } @@ -275,7 +244,7 @@ impl MySqlConnection { let mut result_columns = Vec::with_capacity(prepare_ok.columns as usize); for _ in 0..prepare_ok.params { - let param = ColumnDefinition::decode(self.receive().await?)?; + let param = ColumnDefinition::decode(self.receive().await?.packet())?; param_types.push(param.r#type.0); } @@ -284,7 +253,7 @@ impl MySqlConnection { } for _ in 0..prepare_ok.columns { - let column = ColumnDefinition::decode(self.receive().await?)?; + let column = ColumnDefinition::decode(self.receive().await?.packet())?; result_columns.push(Column:: { name: column.column_alias.or(column.column), @@ -298,6 +267,9 @@ impl MySqlConnection { self.receive_eof().await?; } + // Command sequence is over + self.next_seq_no = 0; + Ok(Describe { param_types: param_types.into_boxed_slice(), result_columns: result_columns.into_boxed_slice(), @@ -321,6 +293,7 @@ impl MySqlConnection { // COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet let num_columns = match self.receive_ok_or_column_count().await? { OkOrResultSet::Ok(_) => { + self.next_seq_no = 0; return; } @@ -342,7 +315,7 @@ impl Executor for MySqlConnection { type Database = super::MySql; fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, crate::Result<()>> { - Box::pin(self.send(query)) + Box::pin(self.execute_raw(query)) } fn execute<'e, 'q: 'e>( diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 717864fc..e8585aef 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -8,7 +8,9 @@ mod executor; mod io; mod protocol; mod row; +mod rsa; mod types; +mod util; pub use database::MySql; diff --git a/sqlx-core/src/mysql/protocol/auth_plugin.rs b/sqlx-core/src/mysql/protocol/auth_plugin.rs new file mode 100644 index 00000000..cc54b150 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/auth_plugin.rs @@ -0,0 +1,100 @@ +use digest::{Digest, FixedOutput}; +use generic_array::GenericArray; +use sha1::Sha1; +use sha2::Sha256; + +use crate::mysql::util::xor_eq; + +#[derive(Debug)] +pub enum AuthPlugin { + MySqlNativePassword, + CachingSha2Password, + Sha256Password, +} + +impl AuthPlugin { + pub(crate) fn from_opt_str(s: Option<&str>) -> crate::Result { + match s { + Some("mysql_native_password") | None => Ok(AuthPlugin::MySqlNativePassword), + Some("caching_sha2_password") => Ok(AuthPlugin::CachingSha2Password), + Some("sha256_password") => Ok(AuthPlugin::Sha256Password), + + Some(s) => { + Err(protocol_err!("requires unimplemented authentication plugin: {}", s).into()) + } + } + } + + pub(crate) fn as_str(&self) -> &'static str { + match self { + AuthPlugin::MySqlNativePassword => "mysql_native_password", + AuthPlugin::CachingSha2Password => "caching_sha2_password", + AuthPlugin::Sha256Password => "sha256_password", + } + } + + pub(crate) fn scramble(&self, password: &str, nonce: &[u8]) -> Vec { + match self { + AuthPlugin::MySqlNativePassword => { + // The [nonce] for mysql_native_password is nul terminated + scramble_sha1(password, &nonce[..(nonce.len() - 1)]).to_vec() + } + AuthPlugin::CachingSha2Password => scramble_sha256(password, nonce).to_vec(), + + _ => unimplemented!(), + } + } +} + +fn scramble_sha1( + password: &str, + seed: &[u8], +) -> GenericArray::OutputSize> { + // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) ) + // https://mariadb.com/kb/en/connection/#mysql_native_password-plugin + + let mut ctx = Sha1::new(); + + ctx.input(password); + + let mut pw_hash = ctx.result_reset(); + + ctx.input(&pw_hash); + + let pw_hash_hash = ctx.result_reset(); + + ctx.input(seed); + ctx.input(pw_hash_hash); + + let pw_seed_hash_hash = ctx.result(); + + xor_eq(&mut pw_hash, &pw_seed_hash_hash); + + pw_hash +} + +fn scramble_sha256( + password: &str, + seed: &[u8], +) -> GenericArray::OutputSize> { + // XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password)))) + // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password + let mut ctx = Sha256::new(); + + ctx.input(password); + + let mut pw_hash = ctx.result_reset(); + + ctx.input(&pw_hash); + + let pw_hash_hash = ctx.result_reset(); + + ctx.input(seed); + ctx.input(pw_hash_hash); + + let pw_seed_hash_hash = ctx.result(); + + xor_eq(&mut pw_hash, &pw_seed_hash_hash); + + pw_hash +} diff --git a/sqlx-core/src/mysql/protocol/auth_switch.rs b/sqlx-core/src/mysql/protocol/auth_switch.rs new file mode 100644 index 00000000..8c7315be --- /dev/null +++ b/sqlx-core/src/mysql/protocol/auth_switch.rs @@ -0,0 +1,34 @@ +use byteorder::LittleEndian; + +use crate::io::Buf; +use crate::mysql::protocol::{AuthPlugin, Capabilities, Decode, Status}; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html +#[derive(Debug)] +pub struct AuthSwitch { + pub auth_plugin: AuthPlugin, + pub auth_plugin_data: Box<[u8]>, +} + +impl Decode for AuthSwitch { + fn decode(mut buf: &[u8]) -> crate::Result + where + Self: Sized, + { + let header = buf.get_u8()?; + if header != 0xFE { + return Err(protocol_err!( + "expected AUTH SWITCH (0xFE); received 0x{:X}", + header + ))?; + } + + let auth_plugin = AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?; + let auth_plugin_data = buf.get_bytes(buf.len())?.to_owned().into_boxed_slice(); + + Ok(Self { + auth_plugin_data, + auth_plugin, + }) + } +} diff --git a/sqlx-core/src/mysql/protocol/encode.rs b/sqlx-core/src/mysql/protocol/encode.rs index 1781acbe..154a577c 100644 --- a/sqlx-core/src/mysql/protocol/encode.rs +++ b/sqlx-core/src/mysql/protocol/encode.rs @@ -1,5 +1,12 @@ +use crate::io::BufMut; use crate::mysql::protocol::Capabilities; pub trait Encode { fn encode(&self, buf: &mut Vec, capabilities: Capabilities); } + +impl Encode for &'_ [u8] { + fn encode(&self, buf: &mut Vec, _: Capabilities) { + buf.put_bytes(self); + } +} diff --git a/sqlx-core/src/mysql/protocol/eof.rs b/sqlx-core/src/mysql/protocol/eof.rs index e01c3743..4fe6b2a7 100644 --- a/sqlx-core/src/mysql/protocol/eof.rs +++ b/sqlx-core/src/mysql/protocol/eof.rs @@ -34,19 +34,3 @@ impl Decode for EofPacket { }) } } - -//#[cfg(test)] -//mod tests { -// use super::{Capabilities, Decode, ErrPacket, Status}; -// -// const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; -// -// #[test] -// fn it_decodes_ok_handshake() { -// let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB).unwrap(); -// -// assert_eq!(p.error_code, 1049); -// assert_eq!(&*p.sql_state, "42000"); -// assert_eq!(&*p.error_message, "Unknown database \'unknown\'"); -// } -//} diff --git a/sqlx-core/src/mysql/protocol/handshake.rs b/sqlx-core/src/mysql/protocol/handshake.rs index 459e6fff..9bfe3046 100644 --- a/sqlx-core/src/mysql/protocol/handshake.rs +++ b/sqlx-core/src/mysql/protocol/handshake.rs @@ -1,7 +1,7 @@ use byteorder::LittleEndian; use crate::io::Buf; -use crate::mysql::protocol::{Capabilities, Decode, Status}; +use crate::mysql::protocol::{AuthPlugin, Capabilities, Decode, Status}; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_v10.html // https://mariadb.com/kb/en/connection/#initial-handshake-packet @@ -13,7 +13,7 @@ pub struct Handshake { pub server_capabilities: Capabilities, pub server_default_collation: u8, pub status: Status, - pub auth_plugin_name: Option>, + pub auth_plugin: AuthPlugin, pub auth_plugin_data: Box<[u8]>, } @@ -81,10 +81,10 @@ impl Decode for Handshake { buf.advance(1); } - let auth_plugin_name = if capabilities.contains(Capabilities::PLUGIN_AUTH) { - Some(buf.get_str_nul()?.to_owned().into()) + let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))? } else { - None + AuthPlugin::from_opt_str(None)? }; Ok(Self { @@ -94,7 +94,7 @@ impl Decode for Handshake { server_default_collation: char_set, connection_id, auth_plugin_data: scramble.into_boxed_slice(), - auth_plugin_name, + auth_plugin, status, }) } @@ -102,7 +102,8 @@ impl Decode for Handshake { #[cfg(test)] mod tests { - use super::{Capabilities, Decode, Handshake, Status}; + use super::{AuthPlugin, Capabilities, Decode, Handshake, Status}; + use matches::assert_matches; const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\" { pub client_collation: u8, pub username: &'a str, pub database: Option<&'a str>, - pub auth_plugin_name: Option<&'a str>, - pub auth_response: Option<&'a str>, + pub auth_plugin: &'a AuthPlugin, + pub auth_response: &'a [u8], } impl Encode for HandshakeResponse<'_> { @@ -43,15 +43,15 @@ impl Encode for HandshakeResponse<'_> { if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { // auth_response : string - buf.put_str_lenenc::(self.auth_response.unwrap_or_default()); + buf.put_bytes_lenenc::(self.auth_response); } else { - let auth_response = self.auth_response.unwrap_or_default(); + let auth_response = self.auth_response; // auth_response_length : int<1> buf.put_u8(auth_response.len() as u8); // auth_response : string<{auth_response_length}> - buf.put_str(auth_response); + buf.put_bytes(auth_response); } if capabilities.contains(Capabilities::CONNECT_WITH_DB) { @@ -63,7 +63,7 @@ impl Encode for HandshakeResponse<'_> { if capabilities.contains(Capabilities::PLUGIN_AUTH) { // client_plugin_name : string - buf.put_str_nul(self.auth_plugin_name.unwrap_or_default()); + buf.put_str_nul(self.auth_plugin.as_str()); } } } diff --git a/sqlx-core/src/mysql/protocol/mod.rs b/sqlx-core/src/mysql/protocol/mod.rs index fa5858e3..dd974886 100644 --- a/sqlx-core/src/mysql/protocol/mod.rs +++ b/sqlx-core/src/mysql/protocol/mod.rs @@ -8,11 +8,13 @@ mod encode; pub use decode::Decode; pub use encode::Encode; +mod auth_plugin; mod capabilities; mod field; mod status; mod r#type; +pub use auth_plugin::AuthPlugin; pub use capabilities::Capabilities; pub use field::FieldFlags; pub use r#type::Type; @@ -30,6 +32,7 @@ pub use com_stmt_execute::{ComStmtExecute, Cursor}; pub use com_stmt_prepare::ComStmtPrepare; pub use handshake::Handshake; +mod auth_switch; mod column_count; mod column_def; mod com_stmt_prepare_ok; @@ -39,6 +42,7 @@ mod handshake_response; mod ok; mod row; +pub use auth_switch::AuthSwitch; pub use column_count::ColumnCount; pub use column_def::ColumnDefinition; pub use com_stmt_prepare_ok::ComStmtPrepareOk; diff --git a/sqlx-core/src/mysql/rsa.rs b/sqlx-core/src/mysql/rsa.rs new file mode 100644 index 00000000..4cd4528b --- /dev/null +++ b/sqlx-core/src/mysql/rsa.rs @@ -0,0 +1,265 @@ +use digest::{Digest, DynDigest}; +use num_bigint::BigUint; +use rand::{thread_rng, Rng}; + +// This is mostly taken from https://github.com/RustCrypto/RSA/pull/18 +// For the love of crypto, please delete as much of this as possible and use the RSA crate +// directly when that PR is merged + +pub fn encrypt(key: &[u8], message: &[u8]) -> crate::Result> { + let key = std::str::from_utf8(key).map_err(|_err| { + // TODO(@abonander): protocol_err doesn't like referring to [err] + protocol_err!("unexpected error decoding what should be UTF-8") + })?; + + let key = parse(key)?; + + Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?.into_boxed_slice()) +} + +// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L12 +fn internals_encrypt(key: &PublicKey, m: &BigUint) -> BigUint { + m.modpow(&key.e, &key.n) +} + +// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L184 +fn internals_copy_with_left_pad(dest: &mut [u8], src: &[u8]) { + // left pad with zeros + let padding_bytes = dest.len() - src.len(); + for el in dest.iter_mut().take(padding_bytes) { + *el = 0; + } + dest[padding_bytes..].copy_from_slice(src); +} + +// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/oaep.rs#L13 +fn internals_inc_counter(counter: &mut [u8]) { + if counter[3] == u8::max_value() { + counter[3] = 0; + } else { + counter[3] += 1; + return; + } + + if counter[2] == u8::max_value() { + counter[2] = 0; + } else { + counter[2] += 1; + return; + } + + if counter[1] == u8::max_value() { + counter[1] = 0; + } else { + counter[1] += 1; + return; + } + + if counter[0] == u8::max_value() { + counter[0] = 0u8; + counter[1] = 0u8; + counter[2] = 0u8; + counter[3] = 0u8; + } else { + counter[0] += 1; + } +} + +// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/oaep.rs#L46 +fn oeap_mgf1_xor(out: &mut [u8], digest: &mut D, seed: &[u8]) { + let mut counter = vec![0u8; 4]; + let mut i = 0; + + while i < out.len() { + let mut digest_input = vec![0u8; seed.len() + 4]; + digest_input[0..seed.len()].copy_from_slice(seed); + digest_input[seed.len()..].copy_from_slice(&counter); + + digest.input(digest_input.as_slice()); + let digest_output = &*digest.result_reset(); + let mut j = 0; + loop { + if j >= digest_output.len() || i >= out.len() { + break; + } + + out[i] ^= digest_output[j]; + j += 1; + i += 1; + } + internals_inc_counter(counter.as_mut_slice()); + } +} + +// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/oaep.rs#L75 +fn oaep_encrypt( + rng: &mut R, + pub_key: &PublicKey, + msg: &[u8], +) -> crate::Result> { + // size of [n] in bytes + let k = (pub_key.n.bits() + 7) / 8; + + let mut digest = D::new(); + let h_size = D::output_size(); + + if msg.len() > k - 2 * h_size - 2 { + return Err(protocol_err!("mysql: password too long").into()); + } + + let mut em = vec![0u8; k]; + + let (_, payload) = em.split_at_mut(1); + let (seed, db) = payload.split_at_mut(h_size); + rng.fill(seed); + + // Data block DB = pHash || PS || 01 || M + let db_len = k - h_size - 1; + + let p_hash = digest.result_reset(); + db[0..h_size].copy_from_slice(&*p_hash); + db[db_len - msg.len() - 1] = 1; + db[db_len - msg.len()..].copy_from_slice(msg); + + oeap_mgf1_xor(db, &mut digest, seed); + oeap_mgf1_xor(seed, &mut digest, db); + + { + let mut m = BigUint::from_bytes_be(&em); + let mut c = internals_encrypt(pub_key, &m).to_bytes_be(); + + internals_copy_with_left_pad(&mut em, &c); + } + + Ok(em) +} + +#[derive(Debug)] +struct PublicKey { + n: BigUint, + e: BigUint, +} + +fn parse(key: &str) -> crate::Result { + // This takes advantage of the knowledge that we know + // we are receiving a PKCS#8 RSA Public Key at all + // times from MySQL + + if !key.starts_with("-----BEGIN PUBLIC KEY-----\n") { + return Err(protocol_err!( + "unexpected format for RSA Public Key from MySQL (expected PKCS#8); first line: {:?}", + key.splitn(1, '\n').next() + ) + .into()); + } + + let key_with_trailer = key.trim_start_matches("-----BEGIN PUBLIC KEY-----\n"); + let trailer_pos = key_with_trailer.find('-').unwrap_or(0); + let inner_key = key_with_trailer[..trailer_pos].replace('\n', ""); + + let inner = base64::decode(&inner_key).map_err(|_err| { + // TODO(@abonander): protocol_err doesn't like referring to [err] + protocol_err!("unexpected error decoding what should be base64-encoded data") + })?; + + let len = inner.len(); + + let n_bytes = &inner[(len - 257 - 5)..(len - 5)]; + let e_bytes = &inner[(len - 3)..]; + + let n = BigUint::from_bytes_be(n_bytes); + let e = BigUint::from_bytes_be(e_bytes); + + Ok(PublicKey { n, e }) +} + +#[cfg(test)] +mod tests { + use super::{BigUint, PublicKey}; + use rand::rngs::adapter::ReadRng; + use sha1::Sha1; + use sha2::Sha256; + + const INPUT: &str = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAv9E+l0oFIoGnZmu6bdil\nI3WK79iug/hukj5QrWRrJVVCHL8rRxNsQGYPvQfXgqEnJW0Rqy2BBebNrnSMduny\nCazz1KM1h57hSI1xHGhg/o82Us1j9fUucKo0Pt3vg7xjVVcN0j1bwr96gEbt6B4Q\nt4eKZBhtle1bgoBcqFBhGfU17cnedSzMUCutM+kXTzzOTplKoqXeJpEZDTX8AP9F\nQ9JkoA22yTn8H2GROIAffm1UQS7DXXjI5OnzBJNs72oNSeK8i72xLkoSdfVw3vCu\ni+mpt4LJgAZLvzc2O4nLzu4Bljb+Mrch34HSWyxOfWzt1v9vpJfEVQ2/VZaIng6U\nUQIDAQAB\n-----END PUBLIC KEY-----\n"; + + #[test] + fn it_parses() { + let key = super::parse(INPUT).unwrap(); + + let n = &[ + 0xbf, 0xd1, 0x3e, 0x97, 0x4a, 0x5, 0x22, 0x81, 0xa7, 0x66, 0x6b, 0xba, 0x6d, 0xd8, + 0xa5, 0x23, 0x75, 0x8a, 0xef, 0xd8, 0xae, 0x83, 0xf8, 0x6e, 0x92, 0x3e, 0x50, 0xad, + 0x64, 0x6b, 0x25, 0x55, 0x42, 0x1c, 0xbf, 0x2b, 0x47, 0x13, 0x6c, 0x40, 0x66, 0xf, + 0xbd, 0x7, 0xd7, 0x82, 0xa1, 0x27, 0x25, 0x6d, 0x11, 0xab, 0x2d, 0x81, 0x5, 0xe6, 0xcd, + 0xae, 0x74, 0x8c, 0x76, 0xe9, 0xf2, 0x9, 0xac, 0xf3, 0xd4, 0xa3, 0x35, 0x87, 0x9e, + 0xe1, 0x48, 0x8d, 0x71, 0x1c, 0x68, 0x60, 0xfe, 0x8f, 0x36, 0x52, 0xcd, 0x63, 0xf5, + 0xf5, 0x2e, 0x70, 0xaa, 0x34, 0x3e, 0xdd, 0xef, 0x83, 0xbc, 0x63, 0x55, 0x57, 0xd, + 0xd2, 0x3d, 0x5b, 0xc2, 0xbf, 0x7a, 0x80, 0x46, 0xed, 0xe8, 0x1e, 0x10, 0xb7, 0x87, + 0x8a, 0x64, 0x18, 0x6d, 0x95, 0xed, 0x5b, 0x82, 0x80, 0x5c, 0xa8, 0x50, 0x61, 0x19, + 0xf5, 0x35, 0xed, 0xc9, 0xde, 0x75, 0x2c, 0xcc, 0x50, 0x2b, 0xad, 0x33, 0xe9, 0x17, + 0x4f, 0x3c, 0xce, 0x4e, 0x99, 0x4a, 0xa2, 0xa5, 0xde, 0x26, 0x91, 0x19, 0xd, 0x35, + 0xfc, 0x0, 0xff, 0x45, 0x43, 0xd2, 0x64, 0xa0, 0xd, 0xb6, 0xc9, 0x39, 0xfc, 0x1f, 0x61, + 0x91, 0x38, 0x80, 0x1f, 0x7e, 0x6d, 0x54, 0x41, 0x2e, 0xc3, 0x5d, 0x78, 0xc8, 0xe4, + 0xe9, 0xf3, 0x4, 0x93, 0x6c, 0xef, 0x6a, 0xd, 0x49, 0xe2, 0xbc, 0x8b, 0xbd, 0xb1, 0x2e, + 0x4a, 0x12, 0x75, 0xf5, 0x70, 0xde, 0xf0, 0xae, 0x8b, 0xe9, 0xa9, 0xb7, 0x82, 0xc9, + 0x80, 0x6, 0x4b, 0xbf, 0x37, 0x36, 0x3b, 0x89, 0xcb, 0xce, 0xee, 0x1, 0x96, 0x36, 0xfe, + 0x32, 0xb7, 0x21, 0xdf, 0x81, 0xd2, 0x5b, 0x2c, 0x4e, 0x7d, 0x6c, 0xed, 0xd6, 0xff, + 0x6f, 0xa4, 0x97, 0xc4, 0x55, 0xd, 0xbf, 0x55, 0x96, 0x88, 0x9e, 0xe, 0x94, 0x51, + ][..]; + + let e = &[0x1, 0x0, 0x1][..]; + + assert_eq!(key.n.to_bytes_be(), n); + assert_eq!(key.e.to_bytes_be(), e); + } + + #[test] + fn it_encrypts_sha1() { + // https://github.com/pyca/cryptography/blob/master/vectors/cryptography_vectors/asymmetric/RSA/pkcs-1v2-1d2-vec/oaep-int.txt + + let n = BigUint::from_bytes_be(&[ + 0xbb, 0xf8, 0x2f, 0x09, 0x06, 0x82, 0xce, 0x9c, 0x23, 0x38, 0xac, 0x2b, 0x9d, 0xa8, + 0x71, 0xf7, 0x36, 0x8d, 0x07, 0xee, 0xd4, 0x10, 0x43, 0xa4, 0x40, 0xd6, 0xb6, 0xf0, + 0x74, 0x54, 0xf5, 0x1f, 0xb8, 0xdf, 0xba, 0xaf, 0x03, 0x5c, 0x02, 0xab, 0x61, 0xea, + 0x48, 0xce, 0xeb, 0x6f, 0xcd, 0x48, 0x76, 0xed, 0x52, 0x0d, 0x60, 0xe1, 0xec, 0x46, + 0x19, 0x71, 0x9d, 0x8a, 0x5b, 0x8b, 0x80, 0x7f, 0xaf, 0xb8, 0xe0, 0xa3, 0xdf, 0xc7, + 0x37, 0x72, 0x3e, 0xe6, 0xb4, 0xb7, 0xd9, 0x3a, 0x25, 0x84, 0xee, 0x6a, 0x64, 0x9d, + 0x06, 0x09, 0x53, 0x74, 0x88, 0x34, 0xb2, 0x45, 0x45, 0x98, 0x39, 0x4e, 0xe0, 0xaa, + 0xb1, 0x2d, 0x7b, 0x61, 0xa5, 0x1f, 0x52, 0x7a, 0x9a, 0x41, 0xf6, 0xc1, 0x68, 0x7f, + 0xe2, 0x53, 0x72, 0x98, 0xca, 0x2a, 0x8f, 0x59, 0x46, 0xf8, 0xe5, 0xfd, 0x09, 0x1d, + 0xbd, 0xcb, + ]); + + let e = BigUint::from_bytes_be(&[0x11]); + + let pub_key = PublicKey { n, e }; + + let message = &[ + 0xd4, 0x36, 0xe9, 0x95, 0x69, 0xfd, 0x32, 0xa7, 0xc8, 0xa0, 0x5b, 0xbc, 0x90, 0xd3, + 0x2c, 0x49, + ]; + + let mut seed = &[ + 0xaa, 0xfd, 0x12, 0xf6, 0x59, 0xca, 0xe6, 0x34, 0x89, 0xb4, 0x79, 0xe5, 0x07, 0x6d, + 0xde, 0xc2, 0xf0, 0x6c, 0xb5, 0x8f, + ][..]; + + let mut rng = ReadRng::new(seed); + let cipher_text = super::oaep_encrypt::<_, Sha1>(&mut rng, &pub_key, message).unwrap(); + + let expected_cipher_text = &[ + 0x12, 0x53, 0xe0, 0x4d, 0xc0, 0xa5, 0x39, 0x7b, 0xb4, 0x4a, 0x7a, 0xb8, 0x7e, 0x9b, + 0xf2, 0xa0, 0x39, 0xa3, 0x3d, 0x1e, 0x99, 0x6f, 0xc8, 0x2a, 0x94, 0xcc, 0xd3, 0x00, + 0x74, 0xc9, 0x5d, 0xf7, 0x63, 0x72, 0x20, 0x17, 0x06, 0x9e, 0x52, 0x68, 0xda, 0x5d, + 0x1c, 0x0b, 0x4f, 0x87, 0x2c, 0xf6, 0x53, 0xc1, 0x1d, 0xf8, 0x23, 0x14, 0xa6, 0x79, + 0x68, 0xdf, 0xea, 0xe2, 0x8d, 0xef, 0x04, 0xbb, 0x6d, 0x84, 0xb1, 0xc3, 0x1d, 0x65, + 0x4a, 0x19, 0x70, 0xe5, 0x78, 0x3b, 0xd6, 0xeb, 0x96, 0xa0, 0x24, 0xc2, 0xca, 0x2f, + 0x4a, 0x90, 0xfe, 0x9f, 0x2e, 0xf5, 0xc9, 0xc1, 0x40, 0xe5, 0xbb, 0x48, 0xda, 0x95, + 0x36, 0xad, 0x87, 0x00, 0xc8, 0x4f, 0xc9, 0x13, 0x0a, 0xde, 0xa7, 0x4e, 0x55, 0x8d, + 0x51, 0xa7, 0x4d, 0xdf, 0x85, 0xd8, 0xb5, 0x0d, 0xe9, 0x68, 0x38, 0xd6, 0x06, 0x3e, + 0x09, 0x55, + ][..]; + + assert_eq!(&*expected_cipher_text, &*cipher_text); + } +} diff --git a/sqlx-core/src/mysql/util.rs b/sqlx-core/src/mysql/util.rs new file mode 100644 index 00000000..7feae9a8 --- /dev/null +++ b/sqlx-core/src/mysql/util.rs @@ -0,0 +1,9 @@ +// XOR(x, y) +// If len(y) < len(x), wrap around inside y +pub fn xor_eq(x: &mut [u8], y: &[u8]) { + let y_len = y.len(); + + for i in 0..x.len() { + x[i] ^= y[i % y_len]; + } +}