diff --git a/Cargo.toml b/Cargo.toml index 519daa2a..cd89e536 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,10 @@ required-features = [ "mysql", "macros" ] name = "mysql" required-features = [ "mysql" ] +[[test]] +name = "mysql-raw" +required-features = [ "mysql" ] + [[test]] name = "postgres" required-features = [ "postgres" ] diff --git a/sqlx-core/src/io/buf.rs b/sqlx-core/src/io/buf.rs index e80fe446..96316b65 100644 --- a/sqlx-core/src/io/buf.rs +++ b/sqlx-core/src/io/buf.rs @@ -7,6 +7,8 @@ pub trait Buf { fn get_uint(&mut self, n: usize) -> io::Result; + fn get_i8(&mut self) -> io::Result; + fn get_u8(&mut self) -> io::Result; fn get_u16(&mut self) -> io::Result; @@ -17,6 +19,8 @@ pub trait Buf { fn get_i32(&mut self) -> io::Result; + fn get_i64(&mut self) -> io::Result; + fn get_u32(&mut self) -> io::Result; fn get_u64(&mut self) -> io::Result; @@ -40,6 +44,13 @@ impl<'a> Buf for &'a [u8] { Ok(val) } + fn get_i8(&mut self) -> io::Result { + let val = self[0]; + self.advance(1); + + Ok(val as i8) + } + fn get_u8(&mut self) -> io::Result { let val = self[0]; self.advance(1); @@ -75,6 +86,13 @@ impl<'a> Buf for &'a [u8] { Ok(val) } + fn get_i64(&mut self) -> io::Result { + let val = T::read_i64(*self); + self.advance(4); + + Ok(val) + } + fn get_u32(&mut self) -> io::Result { let val = T::read_u32(*self); self.advance(4); diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index c0c77ad5..f4bdb09f 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -1,7 +1,9 @@ //! Core of SQLx, the rust SQL toolkit. Not intended to be used directly. #![forbid(unsafe_code)] +#![recursion_limit = "512"] #![cfg_attr(docsrs, feature(doc_cfg))] +#![allow(unused)] #[macro_use] pub mod error; diff --git a/sqlx-core/src/mysql/arguments.rs b/sqlx-core/src/mysql/arguments.rs index f7e455a0..b0755ac5 100644 --- a/sqlx-core/src/mysql/arguments.rs +++ b/sqlx-core/src/mysql/arguments.rs @@ -27,10 +27,10 @@ impl Arguments for MySqlArguments { fn add(&mut self, value: T) where - Self::Database: Type, + T: Type, T: Encode, { - let type_id = >::type_info(); + let type_id = >::type_info(); let index = self.param_types.len(); self.param_types.push(type_id); diff --git a/sqlx-core/src/mysql/connection.rs b/sqlx-core/src/mysql/connection.rs index a9b93101..110f341d 100644 --- a/sqlx-core/src/mysql/connection.rs +++ b/sqlx-core/src/mysql/connection.rs @@ -1,22 +1,25 @@ +use std::collections::HashMap; use std::convert::TryInto; use std::io; +use std::sync::Arc; use byteorder::{ByteOrder, LittleEndian}; use futures_core::future::BoxFuture; use sha1::Sha1; use std::net::Shutdown; -use crate::cache::StatementCache; use crate::connection::{Connect, Connection}; use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream}; use crate::mysql::error::MySqlError; use crate::mysql::protocol::{ - AuthPlugin, AuthSwitch, Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake, + AuthPlugin, AuthSwitch, Capabilities, ComPing, Decode, Encode, EofPacket, ErrPacket, Handshake, HandshakeResponse, OkPacket, SslRequest, }; -use crate::mysql::rsa; +use crate::mysql::stream::MySqlStream; use crate::mysql::util::xor_eq; +use crate::mysql::{rsa, tls}; use crate::url::Url; +use std::ops::Range; // Size before a packet is split const MAX_PACKET_SIZE: u32 = 1024; @@ -85,521 +88,206 @@ const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224; /// against the hostname in the server certificate, so they must be the same for the TLS /// upgrade to succeed. `ssl-ca` must still be specified. pub struct MySqlConnection { - pub(super) stream: BufStream, + pub(super) stream: MySqlStream, + pub(super) is_ready: bool, + pub(super) cache_statement: HashMap, u32>, - // Active capabilities of the client _&_ the server - pub(super) capabilities: Capabilities, - - // Cache of prepared statements - // Query (String) to StatementId to ColumnMap - pub(super) statement_cache: StatementCache, - - // Packets are buffered into a second buffer from the stream - // as we may have compressed or split packets to figure out before - // decoding - pub(super) packet: Vec, - packet_len: usize, - - // Packets in a command sequence have an incrementing sequence number - // This number must be 0 at the start of each command - pub(super) next_seq_no: u8, + // Work buffer for the value ranges of the current row + // This is used as the backing memory for each Row's value indexes + pub(super) current_row_values: Vec>>, } -impl MySqlConnection { - /// Write the packet to the stream ( do not send to the server ) - pub(crate) fn write(&mut self, packet: impl Encode) { - let buf = self.stream.buffer_mut(); +fn to_asciz(s: &str) -> Vec { + let mut z = String::with_capacity(s.len() + 1); + z.push_str(s); + z.push('\0'); - // Allocate room for the header that we write after the packet; - // so, we can get an accurate and cheap measure of packet length + z.into_bytes() +} - let header_offset = buf.len(); - buf.advance(4); +async fn rsa_encrypt_with_nonce( + stream: &mut MySqlStream, + public_key_request_id: u8, + password: &str, + nonce: &[u8], +) -> crate::Result> { + // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ - packet.encode(buf, self.capabilities); - - // Determine length of encoded packet - // and write to allocated header - - let len = buf.len() - header_offset - 4; - let mut header = &mut buf[header_offset..]; - - LittleEndian::write_u32(&mut header, len as u32); // len - - // Take the last sequence number received, if any, and increment by 1 - // If there was no sequence number, we only increment if we split packets - header[3] = self.next_seq_no; - self.next_seq_no = self.next_seq_no.wrapping_add(1); + if stream.is_tls() { + // If in a TLS stream, send the password directly in clear text + return Ok(to_asciz(password)); } - /// Send the packet to the database server - pub(crate) async fn send(&mut self, packet: impl Encode) -> crate::Result<()> { - self.write(packet); - self.stream.flush().await?; + // client sends a public key request + stream.send(&[public_key_request_id][..], false).await?; - Ok(()) - } + // server sends a public key response + let packet = stream.receive().await?; + let rsa_pub_key = &packet[1..]; - /// Send a [HandshakeResponse] packet to the database server - pub(crate) async fn send_handshake_response( - &mut self, - url: &Url, - auth_plugin: &AuthPlugin, - auth_response: &[u8], - ) -> crate::Result<()> { - self.send(HandshakeResponse { - client_collation: COLLATE_UTF8MB4_UNICODE_CI, - max_packet_size: MAX_PACKET_SIZE, - username: url.username().unwrap_or("root"), - database: url.database(), - auth_plugin, - auth_response, - }) - .await - } + // xor the password with the given nonce + let mut pass = to_asciz(password); + xor_eq(&mut pass, nonce); - /// Try to receive a packet from the database server. Returns `None` if the server has sent - /// no data. - pub(crate) async fn try_receive(&mut self) -> crate::Result> { - self.packet.clear(); + // client sends an RSA encrypted password + rsa::encrypt::(rsa_pub_key, &pass) +} - // Read the packet header which contains the length and the sequence number - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html - // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header = ret_if_none!(self.stream.peek(4).await?); - self.packet_len = header.get_uint::(3)? as usize; - self.next_seq_no = header.get_u8()?.wrapping_add(1); - self.stream.consume(4); - - // Read the packet body and copy it into our internal buf - // We must have a separate buffer around the stream as we can't operate directly - // on bytes returned from the stream. We have various kinds of payload manipulation - // that must be handled before decoding. - let payload = ret_if_none!(self.stream.peek(self.packet_len).await?); - self.packet.extend_from_slice(payload); - self.stream.consume(self.packet_len); - - // TODO: Implement packet compression - // TODO: Implement packet joining - - Ok(Some(())) - } - - /// Receive a complete packet from the database server. - pub(crate) async fn receive(&mut self) -> crate::Result<&mut Self> { - self.try_receive() - .await? - .ok_or(io::ErrorKind::ConnectionAborted)?; - - Ok(self) - } - - /// Returns a reference to the most recently received packet data - #[inline] - pub(crate) fn packet(&self) -> &[u8] { - &self.packet[..self.packet_len] - } - - /// Receive an [EofPacket] if we are supposed to receive them at all. - pub(crate) async fn receive_eof(&mut self) -> crate::Result<()> { - // When (legacy) EOFs are enabled, many things are terminated by an EOF packet - if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { - let _eof = EofPacket::decode(self.receive().await?.packet())?; +async fn make_auth_response( + stream: &mut MySqlStream, + plugin: &AuthPlugin, + password: &str, + nonce: &[u8], +) -> crate::Result> { + match plugin { + AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => { + Ok(plugin.scramble(password, nonce)) } - Ok(()) - } - - /// Receive a [Handshake] packet. When connecting to the database server, this is immediately - /// received from the database server. - pub(crate) async fn receive_handshake(&mut self, url: &Url) -> crate::Result { - let handshake = Handshake::decode(self.receive().await?.packet())?; - - let mut client_capabilities = Capabilities::PROTOCOL_41 - | Capabilities::IGNORE_SPACE - | Capabilities::FOUND_ROWS - | Capabilities::TRANSACTIONS - | Capabilities::SECURE_CONNECTION - | Capabilities::PLUGIN_AUTH_LENENC_DATA - | Capabilities::PLUGIN_AUTH; - - if url.database().is_some() { - client_capabilities |= Capabilities::CONNECT_WITH_DB; - } - - if cfg!(feature = "tls") { - client_capabilities |= Capabilities::SSL; - } - - self.capabilities = - (client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41; - - Ok(handshake) - } - - /// Receives an [OkPacket] from the database server. This is called at the end of - /// authentication to confirm the established connection. - pub(crate) fn receive_auth_ok<'a>( - &'a mut self, - plugin: &'a AuthPlugin, - password: &'a str, - nonce: &'a [u8], - ) -> BoxFuture<'a, crate::Result<()>> { - Box::pin(async move { - self.receive().await?; - - match self.packet[0] { - 0x00 => self.handle_ok().map(drop), - 0xfe => self.handle_auth_switch(password).await, - 0xff => self.handle_err(), - - _ => self.handle_auth_continue(plugin, password, nonce).await, - } - }) - } - - pub(crate) fn handle_ok(&mut self) -> crate::Result { - let ok = OkPacket::decode(self.packet())?; - - // An OK signifies the end of the current command sequence - self.next_seq_no = 0; - - Ok(ok) - } - - pub(crate) fn handle_err(&mut self) -> crate::Result { - let err = ErrPacket::decode(self.packet())?; - - // An ERR signifies the end of the current command sequence - self.next_seq_no = 0; - - Err(MySqlError(err).into()) - } - - pub(crate) fn handle_unexpected_packet(&self, id: u8) -> crate::Result { - Err(protocol_err!("unexpected packet identifier 0x{:X?}", id).into()) - } - - pub(crate) async fn handle_auth_continue( - &mut self, - plugin: &AuthPlugin, - password: &str, - nonce: &[u8], - ) -> crate::Result<()> { - match plugin { - AuthPlugin::CachingSha2Password => { - if self.packet[0] == 1 { - match self.packet[1] { - // AUTH_OK - 0x03 => {} - - // AUTH_CONTINUE - 0x04 => { - // client sends an RSA encrypted password - let ct = self.rsa_encrypt(0x02, password, nonce).await?; - - self.send(&*ct).await?; - } - - auth => { - return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", auth).into()); - } - } - - // ends with server sending either OK_Packet or ERR_Packet - self.receive_auth_ok(plugin, password, nonce) - .await - .map(drop) - } else { - return self.handle_unexpected_packet(self.packet[0]); - } - } - - // No other supported auth methods will be called through continue - _ => unreachable!(), - } - } - - pub(crate) async fn handle_auth_switch(&mut self, password: &str) -> crate::Result<()> { - let auth = AuthSwitch::decode(self.packet())?; - - let auth_response = self - .make_auth_initial_response(&auth.auth_plugin, password, &auth.auth_plugin_data) - .await?; - - self.send(&*auth_response).await?; - - self.receive_auth_ok(&auth.auth_plugin, password, &auth.auth_plugin_data) - .await - } - - pub(crate) async fn make_auth_initial_response( - &mut self, - plugin: &AuthPlugin, - password: &str, - nonce: &[u8], - ) -> crate::Result> { - match plugin { - AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => { - Ok(plugin.scramble(password, nonce)) - } - - AuthPlugin::Sha256Password => { - // Full RSA exchange and password encrypt up front with no "cache" - Ok(self.rsa_encrypt(0x01, password, nonce).await?.into_vec()) - } - } - } - - pub(crate) async fn rsa_encrypt( - &mut self, - public_key_request_id: u8, - password: &str, - nonce: &[u8], - ) -> crate::Result> { - // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ - - if self.stream.is_tls() { - // If in a TLS stream, send the password directly in clear text - let mut clear_text = String::with_capacity(password.len() + 1); - clear_text.push_str(password); - clear_text.push('\0'); - - return Ok(clear_text.into_bytes().into_boxed_slice()); - } - - // client sends a public key request - self.send(&[public_key_request_id][..]).await?; - - // server sends a public key response - let packet = self.receive().await?.packet(); - let rsa_pub_key = &packet[1..]; - - // The password string data must be NUL terminated - // Note: This is not in the documentation that I could find - let mut pass = password.as_bytes().to_vec(); - pass.push(0); - - xor_eq(&mut pass, nonce); - - // client sends an RSA encrypted password - rsa::encrypt::(rsa_pub_key, &pass) + AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await, } } -impl MySqlConnection { - async fn new(url: &Url) -> crate::Result { - let stream = MaybeTlsStream::connect(url, 3306).await?; +async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> { + // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html + // https://mariadb.com/kb/en/connection/ - let mut capabilities = Capabilities::empty(); + // Read a [Handshake] packet. When connecting to the database server, this is immediately + // received from the database server. - if cfg!(feature = "tls") { - capabilities |= Capabilities::SSL; - } + let handshake = Handshake::decode(stream.receive().await?)?; + let mut auth_plugin = handshake.auth_plugin; + let mut auth_plugin_data = handshake.auth_plugin_data; - Ok(Self { - stream: BufStream::new(stream), - capabilities, - packet: Vec::with_capacity(8192), - packet_len: 0, - next_seq_no: 0, - statement_cache: StatementCache::new(), - }) - } + stream.capabilities &= handshake.server_capabilities; + stream.capabilities |= Capabilities::PROTOCOL_41; - async fn initialize(&mut self) -> crate::Result<()> { - // On connect, we want to establish a modern, Rust-compatible baseline so we - // tweak connection options to enable UTC for TIMESTAMP, UTF-8 for character types, etc. + // Depending on the ssl-mode and capabilities we should upgrade + // our connection to TLS - // TODO: Use batch support when we have it to handle the following in one execution + tls::upgrade_if_needed(stream, url).await?; - // https://mariadb.com/kb/en/sql-mode/ + // Send a [HandshakeResponse] packet. This is returned in response to the [Handshake] packet + // that is immediately received. - // PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator. - // This means that "A" || "B" can be used in place of CONCAT("A", "B"). + let password = &*url.password().unwrap_or_default(); + let auth_response = + make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?; - // NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is - // not available, a warning is given and the default storage - // engine is used instead. - - // NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust. - - // NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust. - - // language=MySQL - self.execute_raw("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))") - .await?; - - // This allows us to assume that the output from a TIMESTAMP field is UTC - - // language=MySQL - self.execute_raw("SET time_zone = '+00:00'").await?; - - // https://mathiasbynens.be/notes/mysql-utf8mb4 - - // language=MySQL - self.execute_raw("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci") - .await?; - - Ok(()) - } - - #[cfg(feature = "tls")] - async fn try_ssl( - &mut self, - url: &Url, - ca_file: Option<&str>, - invalid_hostnames: bool, - ) -> crate::Result<()> { - use crate::runtime::fs; - use async_native_tls::{Certificate, TlsConnector}; - - let mut connector = TlsConnector::new() - .danger_accept_invalid_certs(ca_file.is_none()) - .danger_accept_invalid_hostnames(invalid_hostnames); - - if let Some(ca_file) = ca_file { - let root_cert = fs::read(ca_file).await?; - connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?); - } - - // send upgrade request and then immediately try TLS handshake - self.send(SslRequest { - client_collation: COLLATE_UTF8MB4_UNICODE_CI, - max_packet_size: MAX_PACKET_SIZE, - }) + stream + .send( + HandshakeResponse { + client_collation: COLLATE_UTF8MB4_UNICODE_CI, + max_packet_size: MAX_PACKET_SIZE, + username: url.username().unwrap_or("root"), + database: url.database(), + auth_plugin: &auth_plugin, + auth_response: &auth_response, + }, + false, + ) .await?; - self.stream.stream.upgrade(url, connector).await + loop { + // After sending the handshake response with our assumed auth method the server + // will send OK, fail, or tell us to change auth methods + let capabilities = stream.capabilities; + let packet = stream.receive().await?; + + match packet[0] { + // OK + 0x00 => { + break; + } + + // ERROR + 0xFF => { + return stream.handle_err(); + } + + // AUTH_SWITCH + 0xFE => { + let auth = AuthSwitch::decode(packet)?; + auth_plugin = auth.auth_plugin; + auth_plugin_data = auth.auth_plugin_data; + + let auth_response = + make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?; + + stream.send(&*auth_response, false).await?; + } + + 0x01 if auth_plugin == AuthPlugin::CachingSha2Password => { + match packet[1] { + // AUTH_OK + 0x03 => {} + + // AUTH_CONTINUE + 0x04 => { + // The specific password is _not_ cached on the server + // We need to send a normal RSA-encrypted password for this + let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data) + .await?; + + stream.send(&*enc, false).await?; + } + + unk => { + return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into()); + } + } + } + + unk => { + return stream.handle_unexpected(); + } + } + } + + Ok(()) +} + +async fn close(mut stream: MySqlStream) -> crate::Result<()> { + // TODO: Actually tell MySQL that we're closing + + stream.flush().await?; + stream.shutdown()?; + + Ok(()) +} + +async fn ping(stream: &mut MySqlStream) -> crate::Result<()> { + stream.send(ComPing, true).await?; + + match stream.receive().await?[0] { + 0x00 | 0xFE => Ok(()), + + 0xFF => stream.handle_err(), + + _ => stream.handle_unexpected(), } } impl MySqlConnection { - pub(super) async fn establish(url: crate::Result) -> crate::Result { + pub(super) async fn new(url: crate::Result) -> crate::Result { let url = url?; - let mut self_ = Self::new(&url).await?; + let mut stream = MySqlStream::new(&url).await?; - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html - // https://mariadb.com/kb/en/connection/ + establish(&mut stream, &url).await?; - // On connect, server immediately sends the handshake - let mut handshake = self_.receive_handshake(&url).await?; - - let ca_file = url.param("ssl-ca"); - - let ssl_mode = url.param("ssl-mode").unwrap_or( - if ca_file.is_some() { - "VERIFY_CA" - } else { - "PREFERRED" - } - .into(), - ); - - let supports_ssl = handshake.server_capabilities.contains(Capabilities::SSL); - - match &*ssl_mode { - "DISABLED" => (), - - // don't try upgrade - #[cfg(feature = "tls")] - "PREFERRED" if !supports_ssl => { - log::warn!("server does not support TLS; using unencrypted connection") - } - - // try to upgrade - #[cfg(feature = "tls")] - "PREFERRED" => { - if let Err(e) = self_.try_ssl(&url, None, true).await { - log::warn!("TLS handshake failed, falling back to insecure: {}", e); - // fallback, redo connection - self_ = Self::new(&url).await?; - handshake = self_.receive_handshake(&url).await?; - } - } - - #[cfg(not(feature = "tls"))] - "PREFERRED" => log::info!("compiled without TLS, skipping upgrade"), - - #[cfg(feature = "tls")] - "REQUIRED" if !supports_ssl => { - return Err(tls_err!("server does not support TLS").into()) - } - - #[cfg(feature = "tls")] - "REQUIRED" => self_.try_ssl(&url, None, true).await?, - - #[cfg(feature = "tls")] - "VERIFY_CA" | "VERIFY_FULL" if ca_file.is_none() => { - return Err( - tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(), - ) - } - - #[cfg(feature = "tls")] - "VERIFY_CA" | "VERIFY_FULL" => { - self_ - .try_ssl(&url, ca_file.as_deref(), ssl_mode != "VERIFY_FULL") - .await? - } - - #[cfg(not(feature = "tls"))] - "REQUIRED" | "VERIFY_CA" | "VERIFY_FULL" => { - return Err(tls_err!("compiled without TLS").into()) - } - _ => return Err(tls_err!("unknown `ssl-mode` value: {:?}", ssl_mode).into()), - } - - // Pre-generate an auth response by using the auth method in the [Handshake] - let password = url.password().unwrap_or_default(); - let auth_response = self_ - .make_auth_initial_response( - &handshake.auth_plugin, - &password, - &handshake.auth_plugin_data, - ) - .await?; - - self_ - .send_handshake_response(&url, &handshake.auth_plugin, &auth_response) - .await?; - - // After sending the handshake response with our assumed auth method the server - // will send OK, fail, or tell us to change auth methods - self_ - .receive_auth_ok( - &handshake.auth_plugin, - &password, - &handshake.auth_plugin_data, - ) - .await?; + let mut self_ = Self { + stream, + current_row_values: Vec::with_capacity(10), + is_ready: true, + cache_statement: HashMap::new(), + }; // After the connection is established, we initialize by configuring a few // connection parameters - self_.initialize().await?; + // initialize().await?; Ok(self_) } - - async fn close(mut self) -> crate::Result<()> { - // TODO: Actually tell MySQL that we're closing - - self.stream.flush().await?; - self.stream.stream.shutdown(Shutdown::Both)?; - - Ok(()) - } -} - -impl MySqlConnection { - #[deprecated(note = "please use 'connect' instead")] - pub fn open(url: T) -> BoxFuture<'static, crate::Result> - where - T: TryInto, - Self: Sized, - { - Box::pin(MySqlConnection::establish(url.try_into())) - } } impl Connect for MySqlConnection { @@ -608,12 +296,16 @@ impl Connect for MySqlConnection { T: TryInto, Self: Sized, { - Box::pin(MySqlConnection::establish(url.try_into())) + Box::pin(MySqlConnection::new(url.try_into())) } } impl Connection for MySqlConnection { fn close(self) -> BoxFuture<'static, crate::Result<()>> { - Box::pin(self.close()) + Box::pin(close(self.stream)) + } + + fn ping(&mut self) -> BoxFuture> { + Box::pin(ping(&mut self.stream)) } } diff --git a/sqlx-core/src/mysql/cursor.rs b/sqlx-core/src/mysql/cursor.rs new file mode 100644 index 00000000..902e534b --- /dev/null +++ b/sqlx-core/src/mysql/cursor.rs @@ -0,0 +1,158 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use futures_core::future::BoxFuture; + +use crate::connection::{ConnectionSource, MaybeOwnedConnection}; +use crate::cursor::Cursor; +use crate::executor::Execute; +use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Decode, Row, Status, TypeId}; +use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow}; +use crate::pool::Pool; + +pub struct MySqlCursor<'c, 'q> { + source: ConnectionSource<'c, MySqlConnection>, + query: Option<(&'q str, Option)>, + column_names: Arc, u16>>, + column_types: Vec, + binary: bool, +} + +impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> { + type Database = MySql; + + #[doc(hidden)] + fn from_pool(pool: &Pool, query: E) -> Self + where + Self: Sized, + E: Execute<'q, MySql>, + { + Self { + source: ConnectionSource::Pool(pool.clone()), + column_names: Arc::default(), + column_types: Vec::new(), + binary: true, + query: Some(query.into_parts()), + } + } + + #[doc(hidden)] + fn from_connection(conn: C, query: E) -> Self + where + Self: Sized, + C: Into>, + E: Execute<'q, MySql>, + { + Self { + source: ConnectionSource::Connection(conn.into()), + column_names: Arc::default(), + column_types: Vec::new(), + binary: true, + query: Some(query.into_parts()), + } + } + + fn next(&mut self) -> BoxFuture>>> { + Box::pin(next(self)) + } +} + +async fn next<'a, 'c: 'a, 'q: 'a>( + cursor: &'a mut MySqlCursor<'c, 'q>, +) -> crate::Result>> { + println!("[cursor::next]"); + + let mut conn = cursor.source.resolve_by_ref().await?; + + // The first time [next] is called we need to actually execute our + // contained query. We guard against this happening on _all_ next calls + // by using [Option::take] which replaces the potential value in the Option with `None + let mut initial = if let Some((query, arguments)) = cursor.query.take() { + let statement = conn.run(query, arguments).await?; + + // No statement ID = TEXT mode + cursor.binary = statement.is_some(); + + true + } else { + false + }; + + loop { + let mut packet_id = conn.stream.receive().await?[0]; + println!("[cursor::next/iter] {:x}", packet_id); + match packet_id { + // OK or EOF packet + 0x00 | 0xFE + if conn.stream.packet().len() < 0xFF_FF_FF && (packet_id != 0x00 || initial) => + { + let ok = conn.stream.handle_ok()?; + + if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { + // There is more to this query + initial = true; + } else { + conn.is_ready = true; + return Ok(None); + } + } + + // ERR packet + 0xFF => { + conn.is_ready = true; + return conn.stream.handle_err(); + } + + _ if initial => { + // At the start of the results we expect to see a + // COLUMN_COUNT followed by N COLUMN_DEF + + let cc = ColumnCount::decode(conn.stream.packet())?; + + // We use these definitions to get the actual column types that is critical + // in parsing the rows coming back soon + + cursor.column_types.clear(); + cursor.column_types.reserve(cc.columns as usize); + + let mut column_names = HashMap::with_capacity(cc.columns as usize); + + for i in 0..cc.columns { + let column = ColumnDefinition::decode(conn.stream.receive().await?)?; + + cursor.column_types.push(column.type_id); + + if let Some(name) = column.name() { + column_names.insert(name.to_owned().into_boxed_str(), i as u16); + } + } + + cursor.column_names = Arc::new(column_names); + initial = false; + } + + _ if !cursor.binary || packet_id == 0x00 => { + let row = Row::read( + conn.stream.packet(), + &cursor.column_types, + &mut conn.current_row_values, + // TODO: Text mode + cursor.binary, + )?; + + let row = MySqlRow { + row, + columns: Arc::clone(&cursor.column_names), + // TODO: Text mode + binary: cursor.binary, + }; + + return Ok(Some(row)); + } + + _ => { + return conn.stream.handle_unexpected(); + } + } + } +} diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index cd4bb427..62a22898 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -13,18 +13,18 @@ impl Database for MySql { type TableId = Box; } -impl HasRow for MySql { +impl<'c> HasRow<'c> for MySql { type Database = MySql; - type Row = super::MySqlRow; + type Row = super::MySqlRow<'c>; } -impl<'a> HasCursor<'a> for MySql { +impl<'c, 'q> HasCursor<'c, 'q> for MySql { type Database = MySql; - type Cursor = super::MySqlCursor<'a>; + type Cursor = super::MySqlCursor<'c, 'q>; } -impl<'a> HasRawValue<'a> for MySql { - type RawValue = Option<&'a [u8]>; +impl<'c> HasRawValue<'c> for MySql { + type RawValue = Option>; } diff --git a/sqlx-core/src/mysql/executor.rs b/sqlx-core/src/mysql/executor.rs index de92ac07..b81f297e 100644 --- a/sqlx-core/src/mysql/executor.rs +++ b/sqlx-core/src/mysql/executor.rs @@ -4,341 +4,244 @@ use std::sync::Arc; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use crate::describe::{Column, Describe, Nullability}; -use crate::executor::Executor; +use crate::cursor::Cursor; +use crate::describe::{Column, Describe}; +use crate::executor::{Execute, Executor, RefExecutor}; use crate::mysql::protocol::{ - Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, - ComStmtPrepareOk, Cursor, Decode, EofPacket, FieldFlags, OkPacket, Row, TypeId, + self, Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, + ComStmtPrepareOk, Decode, EofPacket, ErrPacket, FieldFlags, OkPacket, Row, TypeId, +}; +use crate::mysql::{ + MySql, MySqlArguments, MySqlConnection, MySqlCursor, MySqlError, MySqlRow, MySqlTypeInfo, }; -use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo}; -enum Step { - Command(u64), - Row(Row), -} +impl super::MySqlConnection { + async fn wait_until_ready(&mut self) -> crate::Result<()> { + if !self.is_ready { + loop { + let mut packet_id = self.stream.receive().await?[0]; + match packet_id { + 0xFE if self.stream.packet().len() < 0xFF_FF_FF => { + // OK or EOF packet + self.is_ready = true; + break; + } -enum OkOrResultSet { - Ok(OkPacket), - ResultSet(ColumnCount), -} + 0xFF => { + // ERR packet + self.is_ready = true; + return self.stream.handle_err(); + } -impl MySqlConnection { - async fn ignore_columns(&mut self, count: usize) -> crate::Result<()> { - for _ in 0..count { - let _column = ColumnDefinition::decode(self.receive().await?.packet())?; - } - - if count > 0 { - self.receive_eof().await?; - } - - Ok(()) - } - - async fn receive_ok_or_column_count(&mut self) -> crate::Result { - self.receive().await?; - - match self.packet[0] { - 0x00 | 0xfe if self.packet.len() < 0xffffff => self.handle_ok().map(OkOrResultSet::Ok), - 0xff => self.handle_err(), - - _ => Ok(OkOrResultSet::ResultSet(ColumnCount::decode( - self.packet(), - )?)), - } - } - - async fn receive_column_types(&mut self, count: usize) -> crate::Result> { - let mut columns: Vec = Vec::with_capacity(count); - - for _ in 0..count { - let column: ColumnDefinition = - ColumnDefinition::decode(self.receive().await?.packet())?; - - columns.push(column.type_id); - } - - if count > 0 { - self.receive_eof().await?; - } - - Ok(columns.into_boxed_slice()) - } - - async fn wait_for_ready(&mut self) -> crate::Result<()> { - if self.next_seq_no != 0 { - while let Some(_step) = self.step(&[], true).await? { - // Drain steps until we hit the end + _ => { + // Something else; skip + } + } } } Ok(()) } + // Creates a prepared statement for the passed query string async fn prepare(&mut self, query: &str) -> crate::Result { - // Start by sending a COM_STMT_PREPARE - self.send(ComStmtPrepare { query }).await?; + // https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_prepare.html + self.stream.send(ComStmtPrepare { query }, true).await?; - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html + // Should receive a COM_STMT_PREPARE_OK or ERR_PACKET + let packet = self.stream.receive().await?; - // First we should receive a COM_STMT_PREPARE_OK - self.receive().await?; - - if self.packet[0] == 0xff { - // Oops, there was an error in the prepare command - return self.handle_err(); + if packet[0] == 0xFF { + return self.stream.handle_err(); } - ComStmtPrepareOk::decode(self.packet()) + ComStmtPrepareOk::decode(packet) } - async fn prepare_with_cache(&mut self, query: &str) -> crate::Result { - if let Some(&id) = self.statement_cache.get(query) { + async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> { + for _ in 0..count { + let _column = ColumnDefinition::decode(self.stream.receive().await?)?; + } + + if count > 0 { + self.stream.maybe_receive_eof().await?; + } + + Ok(()) + } + + // Gets a cached prepared statement ID _or_ prepares the statement if not in the cache + // At the end we should have [cache_statement] and [cache_statement_columns] filled + async fn get_or_prepare(&mut self, query: &str) -> crate::Result { + if let Some(&id) = self.cache_statement.get(query) { Ok(id) } else { - let prepare_ok = self.prepare(query).await?; + let stmt = self.prepare(query).await?; - // Remember our statement ID, so we do'd do this again the next time - self.statement_cache - .put(query.to_owned(), prepare_ok.statement_id); + self.cache_statement.insert(query.into(), stmt.statement_id); - // Ignore input parameters - self.ignore_columns(prepare_ok.params as usize).await?; + // COM_STMT_PREPARE returns the input columns + // We make no use of that data, so cycle through and drop them + self.drop_column_defs(stmt.params as usize).await?; - // Collect output parameter names - let mut columns = HashMap::with_capacity(prepare_ok.columns as usize); - let mut index = 0_usize; - for _ in 0..prepare_ok.columns { - let column = ColumnDefinition::decode(self.receive().await?.packet())?; + // COM_STMT_PREPARE next returns the output columns + // We just drop these as we get these when we execute the query + self.drop_column_defs(stmt.columns as usize).await?; - if let Some(name) = column.column_alias.or(column.column) { - columns.insert(name, index); + Ok(stmt.statement_id) + } + } + + pub(crate) async fn run( + &mut self, + query: &str, + arguments: Option, + ) -> crate::Result> { + self.wait_until_ready().await?; + self.is_ready = false; + + if let Some(arguments) = arguments { + let statement_id = self.get_or_prepare(query).await?; + + // https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_execute.html + self.stream + .send( + ComStmtExecute { + cursor: protocol::Cursor::NO_CURSOR, + statement_id, + params: &arguments.params, + null_bitmap: &arguments.null_bitmap, + param_types: &arguments.param_types, + }, + true, + ) + .await?; + + Ok(Some(statement_id)) + } else { + // https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_query.html + self.stream.send(ComQuery { query }, true).await?; + + Ok(None) + } + } + + async fn affected_rows(&mut self) -> crate::Result { + let mut rows = 0; + + loop { + let id = self.stream.receive().await?[0]; + + match id { + 0x00 | 0xFE if self.stream.packet().len() < 0xFF_FF_FF => { + // ResultSet row can begin with 0xfe byte (when using text protocol + // with a field length > 0xffffff) + + if !self.stream.maybe_handle_eof()? { + rows += self.stream.handle_ok()?.affected_rows; + } + + // EOF packets do not have affected rows + // So this function is actually useless if the server doesn't support + // proper OK packets + + self.is_ready = true; + break; } - index += 1; - } - - if prepare_ok.columns > 0 { - self.receive_eof().await?; - } - - // At the end of a command, this should go back to 0 - self.next_seq_no = 0; - - // Remember our column map in the statement cache - self.statement_cache - .put_columns(prepare_ok.statement_id, columns); - - Ok(prepare_ok.statement_id) - } - } - - // [COM_STMT_EXECUTE] - async fn execute_statement(&mut self, id: u32, args: MySqlArguments) -> crate::Result<()> { - self.send(ComStmtExecute { - cursor: Cursor::NO_CURSOR, - statement_id: id, - params: &args.params, - null_bitmap: &args.null_bitmap, - param_types: &args.param_types, - }) - .await - } - - async fn step(&mut self, columns: &[TypeId], binary: bool) -> crate::Result> { - let capabilities = self.capabilities; - ret_if_none!(self.try_receive().await?); - - match self.packet[0] { - 0xfe if self.packet.len() < 0xffffff => { - // ResultSet row can begin with 0xfe byte (when using text protocol - // with a field length > 0xffffff) - - if !capabilities.contains(Capabilities::DEPRECATE_EOF) { - let _eof = EofPacket::decode(self.packet())?; - - // An EOF -here- signifies the end of the current command sequence - self.next_seq_no = 0; - - Ok(None) - } else { - self.handle_ok() - .map(|ok| Some(Step::Command(ok.affected_rows))) + 0xFF => { + return self.stream.handle_err(); } - } - 0xff => self.handle_err(), - - _ => Ok(Some(Step::Row(Row::decode( - self.packet(), - columns, - binary, - )?))), - } - } -} - -impl MySqlConnection { - pub(super) async fn execute_raw(&mut self, query: &str) -> crate::Result<()> { - self.wait_for_ready().await?; - - self.send(ComQuery { query }).await?; - - // COM_QUERY can terminate before the result set with an ERR or OK packet - let num_columns = match self.receive_ok_or_column_count().await? { - OkOrResultSet::Ok(_) => { - self.next_seq_no = 0; - return Ok(()); - } - - OkOrResultSet::ResultSet(cc) => cc.columns as usize, - }; - - let columns = self.receive_column_types(num_columns as usize).await?; - - while let Some(_step) = self.step(&columns, false).await? { - // Drop all responses - } - - Ok(()) - } - - async fn execute(&mut self, query: &str, args: MySqlArguments) -> crate::Result { - self.wait_for_ready().await?; - - let statement_id = self.prepare_with_cache(query).await?; - - self.execute_statement(statement_id, args).await?; - - // COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet - let num_columns = match self.receive_ok_or_column_count().await? { - OkOrResultSet::Ok(ok) => { - self.next_seq_no = 0; - - return Ok(ok.affected_rows); - } - - OkOrResultSet::ResultSet(cc) => cc.columns as usize, - }; - - self.ignore_columns(num_columns).await?; - - let mut res = 0; - - while let Some(step) = self.step(&[], true).await? { - if let Step::Command(affected) = step { - res = affected; + _ => {} } } - Ok(res) + Ok(rows) } - async fn describe(&mut self, query: &str) -> crate::Result> { - self.wait_for_ready().await?; + // method is not named describe to work around an intellijrust bug + // otherwise it marks someone trying to describe the connection as "method is private" + async fn do_describe(&mut self, query: &str) -> crate::Result> { + self.wait_until_ready().await?; - let prepare_ok = self.prepare(query).await?; + let stmt = self.prepare(query).await?; - let mut param_types = Vec::with_capacity(prepare_ok.params as usize); - let mut result_columns = Vec::with_capacity(prepare_ok.columns as usize); + let mut param_types = Vec::with_capacity(stmt.params as usize); + let mut result_columns = Vec::with_capacity(stmt.columns as usize); - for _ in 0..prepare_ok.params { - let param = ColumnDefinition::decode(self.receive().await?.packet())?; + for _ in 0..stmt.params { + let param = ColumnDefinition::decode(self.stream.receive().await?)?; param_types.push(MySqlTypeInfo::from_column_def(¶m)); } - if prepare_ok.params > 0 { - self.receive_eof().await?; + if stmt.params > 0 { + self.stream.maybe_receive_eof().await?; } - for _ in 0..prepare_ok.columns { - let column = ColumnDefinition::decode(self.receive().await?.packet())?; + for _ in 0..stmt.columns { + let column = ColumnDefinition::decode(self.stream.receive().await?)?; result_columns.push(Column:: { type_info: MySqlTypeInfo::from_column_def(&column), name: column.column_alias.or(column.column), table_id: column.table_alias.or(column.table), + // TODO(@abonander): Should this be None in some cases? non_null: Some(column.flags.contains(FieldFlags::NOT_NULL)), }); } - if prepare_ok.columns > 0 { - self.receive_eof().await?; + if stmt.columns > 0 { + self.stream.maybe_receive_eof().await?; } - // Command sequence is over - self.next_seq_no = 0; - Ok(Describe { param_types: param_types.into_boxed_slice(), result_columns: result_columns.into_boxed_slice(), }) } +} - fn fetch<'e, 'q: 'e>( - &'e mut self, - query: &'q str, - args: MySqlArguments, - ) -> BoxStream<'e, crate::Result> { - Box::pin(async_stream::try_stream! { - self.wait_for_ready().await?; +impl Executor for super::MySqlConnection { + type Database = MySql; - let statement_id = self.prepare_with_cache(query).await?; + fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result> + where + E: Execute<'q, Self::Database>, + { + Box::pin(async move { + let (query, arguments) = query.into_parts(); - let columns = self.statement_cache.get_columns(statement_id); - - self.execute_statement(statement_id, args).await?; - - // COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet - let num_columns = match self.receive_ok_or_column_count().await? { - OkOrResultSet::Ok(_) => { - self.next_seq_no = 0; - return; - } - - OkOrResultSet::ResultSet(cc) => { - cc.columns as usize - } - }; - - let column_types = self.receive_column_types(num_columns).await?; - - while let Some(Step::Row(row)) = self.step(&column_types, true).await? { - yield MySqlRow { row, columns: Arc::clone(&columns) }; - } + self.run(query, arguments).await?; + self.affected_rows().await }) } -} -impl Executor for MySqlConnection { - type Database = super::MySql; - - fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, crate::Result<()>> { - Box::pin(self.execute_raw(query)) + fn fetch<'q, E>(&mut self, query: E) -> MySqlCursor<'_, 'q> + where + E: Execute<'q, Self::Database>, + { + MySqlCursor::from_connection(self, query) } - fn fetch<'e, 'q: 'e>( + fn describe<'e, 'q, E: 'e>( &'e mut self, - query: &'q str, - args: MySqlArguments, - ) -> BoxFuture<'e, crate::Result> { - Box::pin(self.execute(query, args)) - } - - fn fetch<'e, 'q: 'e>( - &'e mut self, - query: &'q str, - args: MySqlArguments, - ) -> BoxStream<'e, crate::Result> { - self.fetch(query, args) - } - - fn describe<'e, 'q: 'e>( - &'e mut self, - query: &'q str, - ) -> BoxFuture<'e, crate::Result>> { - Box::pin(self.describe(query)) + query: E, + ) -> BoxFuture<'e, crate::Result>> + where + E: Execute<'q, Self::Database>, + { + Box::pin(async move { self.do_describe(query.into_parts().0).await }) } } -impl_execute_for_query!(MySql); +impl<'c> RefExecutor<'c> for &'c mut super::MySqlConnection { + type Database = MySql; + + fn fetch_by_ref<'q, E>(self, query: E) -> MySqlCursor<'c, 'q> + where + E: Execute<'q, Self::Database>, + { + MySqlCursor::from_connection(self, query) + } +} diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 1f037b3f..7efddc38 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -1,7 +1,16 @@ //! **MySQL** database and connection types. +pub use arguments::MySqlArguments; +pub use connection::MySqlConnection; +pub use cursor::MySqlCursor; +pub use database::MySql; +pub use error::MySqlError; +pub use row::{MySqlRow, MySqlValue}; +pub use types::MySqlTypeInfo; + mod arguments; mod connection; +mod cursor; mod database; mod error; mod executor; @@ -9,20 +18,15 @@ mod io; mod protocol; mod row; mod rsa; +mod stream; +mod tls; mod types; mod util; -pub use database::MySql; - -pub use arguments::MySqlArguments; - -pub use connection::MySqlConnection; - -pub use error::MySqlError; - -pub use types::MySqlTypeInfo; - -pub use row::MySqlRow; - /// An alias for [`Pool`], specialized for **MySQL**. -pub type MySqlPool = super::Pool; +pub type MySqlPool = crate::pool::Pool; + +make_query_as!(MySqlQueryAs, MySql, MySqlRow); +impl_map_row_for_row!(MySql, MySqlRow); +impl_column_index_for_row!(MySql); +impl_from_row_for_tuples!(MySql, MySqlRow); diff --git a/sqlx-core/src/mysql/protocol/auth_plugin.rs b/sqlx-core/src/mysql/protocol/auth_plugin.rs index a2b667e6..2ffaa8af 100644 --- a/sqlx-core/src/mysql/protocol/auth_plugin.rs +++ b/sqlx-core/src/mysql/protocol/auth_plugin.rs @@ -6,7 +6,7 @@ use sha2::Sha256; use crate::mysql::util::xor_eq; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum AuthPlugin { MySqlNativePassword, CachingSha2Password, diff --git a/sqlx-core/src/mysql/protocol/column_def.rs b/sqlx-core/src/mysql/protocol/column_def.rs index 2640130e..b3fb6f7c 100644 --- a/sqlx-core/src/mysql/protocol/column_def.rs +++ b/sqlx-core/src/mysql/protocol/column_def.rs @@ -27,6 +27,12 @@ pub struct ColumnDefinition { pub decimals: u8, } +impl ColumnDefinition { + pub fn name(&self) -> Option<&str> { + self.column_alias.as_deref().or(self.column.as_deref()) + } +} + impl Decode for ColumnDefinition { fn decode(mut buf: &[u8]) -> crate::Result { // catalog : string diff --git a/sqlx-core/src/mysql/protocol/com_ping.rs b/sqlx-core/src/mysql/protocol/com_ping.rs new file mode 100644 index 00000000..8ebfed87 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/com_ping.rs @@ -0,0 +1,16 @@ +use byteorder::LittleEndian; + +use crate::io::BufMut; +use crate::mysql::io::BufMutExt; +use crate::mysql::protocol::{Capabilities, Encode}; + +// https://dev.mysql.com/doc/internals/en/com-ping.html +#[derive(Debug)] +pub struct ComPing; + +impl Encode for ComPing { + fn encode(&self, buf: &mut Vec, _: Capabilities) { + // COM_PING : int<1> + buf.put_u8(0x0e); + } +} diff --git a/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs b/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs index 9620c161..d2814582 100644 --- a/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs +++ b/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs @@ -9,7 +9,8 @@ use crate::mysql::protocol::Decode; pub struct ComStmtPrepareOk { pub statement_id: u32, - /// Number of columns in the returned result set (or 0 if statement does not return result set). + /// Number of columns in the returned result set (or 0 if statement + /// does not return result set). pub columns: u16, /// Number of prepared statement parameters ('?' placeholders). diff --git a/sqlx-core/src/mysql/protocol/err.rs b/sqlx-core/src/mysql/protocol/err.rs index 5aa70a28..965d5413 100644 --- a/sqlx-core/src/mysql/protocol/err.rs +++ b/sqlx-core/src/mysql/protocol/err.rs @@ -9,24 +9,34 @@ use crate::mysql::protocol::{Capabilities, Decode, Status}; #[derive(Debug)] pub struct ErrPacket { pub error_code: u16, - pub sql_state: Box, + pub sql_state: Option>, pub error_message: Box, } -impl Decode for ErrPacket { - fn decode(mut buf: &[u8]) -> crate::Result +impl ErrPacket { + pub(crate) fn decode(mut buf: &[u8], capabilities: Capabilities) -> crate::Result where Self: Sized, { let header = buf.get_u8()?; if header != 0xFF { - return Err(protocol_err!("expected 0xFF; received 0x{:X}", header))?; + return Err(protocol_err!( + "expected 0xFF for ERR_PACKET; received 0x{:X}", + header + ))?; } let error_code = buf.get_u16::()?; - let _sql_state_marker: u8 = buf.get_u8()?; - let sql_state = buf.get_str(5)?.into(); + let mut sql_state = None; + + if capabilities.contains(Capabilities::PROTOCOL_41) { + // If the next byte is '#' then we have a SQL STATE + if buf.get(0) == Some(&0x23) { + buf.advance(1); + sql_state = Some(buf.get_str(5)?.into()) + } + } let error_message = buf.get_str(buf.len())?.into(); @@ -42,14 +52,25 @@ impl Decode for ErrPacket { mod tests { use super::{Capabilities, Decode, ErrPacket, Status}; + const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; + const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; + #[test] + fn it_decodes_packets_out_of_order() { + let mut p = ErrPacket::decode(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(&*p.error_message, "Got packets out of order"); + assert_eq!(p.error_code, 1156); + assert_eq!(p.sql_state, None); + } + #[test] fn it_decodes_ok_handshake() { - let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB).unwrap(); + let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap(); assert_eq!(p.error_code, 1049); - assert_eq!(&*p.sql_state, "42000"); + assert_eq!(p.sql_state.as_deref(), Some("42000")); assert_eq!(&*p.error_message, "Unknown database \'unknown\'"); } } diff --git a/sqlx-core/src/mysql/protocol/mod.rs b/sqlx-core/src/mysql/protocol/mod.rs index aebf713d..e6e0afdc 100644 --- a/sqlx-core/src/mysql/protocol/mod.rs +++ b/sqlx-core/src/mysql/protocol/mod.rs @@ -20,12 +20,14 @@ pub use field::FieldFlags; pub use r#type::TypeId; pub use status::Status; +mod com_ping; mod com_query; mod com_set_option; mod com_stmt_execute; mod com_stmt_prepare; mod handshake; +pub use com_ping::ComPing; pub use com_query::ComQuery; pub use com_set_option::{ComSetOption, SetOption}; pub use com_stmt_execute::{ComStmtExecute, Cursor}; diff --git a/sqlx-core/src/mysql/protocol/row.rs b/sqlx-core/src/mysql/protocol/row.rs index d266f5c6..2c275a22 100644 --- a/sqlx-core/src/mysql/protocol/row.rs +++ b/sqlx-core/src/mysql/protocol/row.rs @@ -6,73 +6,84 @@ use crate::io::Buf; use crate::mysql::io::BufExt; use crate::mysql::protocol::{Decode, TypeId}; -pub struct Row { - buffer: Box<[u8]>, - values: Box<[Option>]>, +pub struct Row<'c> { + buffer: &'c [u8], + values: &'c [Option>], binary: bool, } -impl Row { +impl<'c> Row<'c> { pub fn len(&self) -> usize { self.values.len() } - pub fn get(&self, index: usize) -> Option<&[u8]> { + pub fn get(&self, index: usize) -> Option<&'c [u8]> { let range = self.values[index].as_ref()?; Some(&self.buffer[(range.start as usize)..(range.end as usize)]) } } -fn get_lenenc(buf: &[u8]) -> usize { +fn get_lenenc(buf: &[u8]) -> (usize, Option) { match buf[0] { - 0xFB => 1, + 0xFB => (1, None), 0xFC => { let len_size = 1 + 2; let len = LittleEndian::read_u16(&buf[1..]); - len_size + len as usize + (len_size, Some(len as usize)) } 0xFD => { let len_size = 1 + 3; let len = LittleEndian::read_u24(&buf[1..]); - len_size + len as usize + (len_size, Some(len as usize)) } 0xFE => { let len_size = 1 + 8; let len = LittleEndian::read_u64(&buf[1..]); - len_size + len as usize + (len_size, Some(len as usize)) } - value => 1 + value as usize, + len => (1, Some(len as usize)), } } -impl Row { - pub fn decode(mut buf: &[u8], columns: &[TypeId], binary: bool) -> crate::Result { +impl<'c> Row<'c> { + pub fn read( + mut buf: &'c [u8], + columns: &[TypeId], + values: &'c mut Vec>>, + binary: bool, + ) -> crate::Result { + let mut buffer = &*buf; + + values.clear(); + values.reserve(columns.len()); + if !binary { - let buffer: Box<[u8]> = buf.into(); - let mut values = Vec::with_capacity(columns.len()); let mut index = 0; for column_idx in 0..columns.len() { - let size = get_lenenc(&buf[index..]); + let (len_size, size) = get_lenenc(&buf[index..]); - values.push(Some(index..(index + size))); + if let Some(size) = size { + values.push(Some((index + len_size)..(index + len_size + size))); + } else { + values.push(None); + } - index += size; - buf.advance(size); + index += (len_size + size.unwrap_or_default()); } return Ok(Self { buffer, - values: values.into_boxed_slice(), - binary, + values: &*values, + binary: false, }); } @@ -88,7 +99,6 @@ impl Row { buf.advance(null_len); let buffer: Box<[u8]> = buf.into(); - let mut values = Vec::with_capacity(columns.len()); let mut index = 0; for column_idx in 0..columns.len() { @@ -117,7 +127,11 @@ impl Row { | TypeId::LONG_BLOB | TypeId::CHAR | TypeId::TEXT - | TypeId::VAR_CHAR => get_lenenc(&buffer[index..]), + | TypeId::VAR_CHAR => { + let (len_size, len) = get_lenenc(&buffer[index..]); + + len_size + len.unwrap_or_default() + } id => { unimplemented!("encountered unknown field type id: {:?}", id); @@ -130,174 +144,174 @@ impl Row { } Ok(Self { - buffer, - values: values.into_boxed_slice(), + buffer: buf, + values: &*values, binary, }) } } -#[cfg(test)] -mod test { - use super::super::column_count::ColumnCount; - use super::super::column_def::ColumnDefinition; - use super::super::eof::EofPacket; - use super::*; - - #[test] - fn null_bitmap_test() -> crate::Result<()> { - let column_len = ColumnCount::decode(&[26])?; - assert_eq!(column_len.columns, 26); - - let types: Vec = vec![ - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0, - 0, 3, 11, 66, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105, - 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105, - 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105, - 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105, - 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105, - 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105, - 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105, - 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105, - 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102, - 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102, - 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102, - 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102, - 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102, - 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102, - 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102, - 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102, - 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102, - 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102, - 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102, - 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102, - 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102, - 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102, - 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102, - 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102, - 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0, - ])?, - ColumnDefinition::decode(&[ - 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, - 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102, - 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, - ])?, - ] - .into_iter() - .map(|def| def.type_id) - .collect(); - - EofPacket::decode(&[254, 0, 0, 34, 0])?; - - Row::decode( - &[ - 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8, - 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ], - &types, - true, - )?; - - EofPacket::decode(&[254, 0, 0, 34, 0])?; - Ok(()) - } -} +// #[cfg(test)] +// mod test { +// use super::super::column_count::ColumnCount; +// use super::super::column_def::ColumnDefinition; +// use super::super::eof::EofPacket; +// use super::*; +// +// #[test] +// fn null_bitmap_test() -> crate::Result<()> { +// let column_len = ColumnCount::decode(&[26])?; +// assert_eq!(column_len.columns, 26); +// +// let types: Vec = vec![ +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0, +// 0, 3, 11, 66, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105, +// 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105, +// 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105, +// 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105, +// 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105, +// 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105, +// 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105, +// 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105, +// 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102, +// 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102, +// 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102, +// 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102, +// 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102, +// 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102, +// 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102, +// 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102, +// 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102, +// 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102, +// 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102, +// 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102, +// 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102, +// 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102, +// 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102, +// 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102, +// 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0, +// ])?, +// ColumnDefinition::decode(&[ +// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, +// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102, +// 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, +// ])?, +// ] +// .into_iter() +// .map(|def| def.type_id) +// .collect(); +// +// EofPacket::decode(&[254, 0, 0, 34, 0])?; +// +// Row::read( +// &[ +// 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8, +// 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +// 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +// ], +// &types, +// true, +// )?; +// +// EofPacket::decode(&[254, 0, 0, 34, 0])?; +// Ok(()) +// } +// } diff --git a/sqlx-core/src/mysql/row.rs b/sqlx-core/src/mysql/row.rs index 0b1ca2e3..7c4a6465 100644 --- a/sqlx-core/src/mysql/row.rs +++ b/sqlx-core/src/mysql/row.rs @@ -1,59 +1,60 @@ use std::collections::HashMap; +use std::convert::TryFrom; +use std::str::{from_utf8, Utf8Error}; use std::sync::Arc; use crate::decode::Decode; +use crate::error::UnexpectedNullError; +use crate::mysql::io::BufExt; use crate::mysql::protocol; use crate::mysql::MySql; -use crate::row::{Row, RowIndex}; +use crate::row::{ColumnIndex, Row}; use crate::types::Type; +use byteorder::LittleEndian; -pub struct MySqlRow { - pub(super) row: protocol::Row, - pub(super) columns: Arc, usize>>, +#[derive(Debug)] +pub enum MySqlValue<'c> { + Binary(&'c [u8]), + Text(&'c [u8]), } -impl Row for MySqlRow { +impl<'c> TryFrom>> for MySqlValue<'c> { + type Error = crate::Error; + + #[inline] + fn try_from(value: Option>) -> Result { + match value { + Some(value) => Ok(value), + None => Err(crate::Error::decode(UnexpectedNullError)), + } + } +} + +pub struct MySqlRow<'c> { + pub(super) row: protocol::Row<'c>, + pub(super) columns: Arc, u16>>, + pub(super) binary: bool, +} + +impl<'c> Row<'c> for MySqlRow<'c> { type Database = MySql; fn len(&self) -> usize { self.row.len() } - fn get(&self, index: I) -> T + fn get_raw<'r, I>(&'r self, index: I) -> crate::Result>> where - Self::Database: Type, - I: RowIndex, - T: Decode, + I: ColumnIndex, { - index.get(self).unwrap() + let index = index.resolve(self)?; + + Ok(self.row.get(index).map(|mut buf| { + if self.binary { + MySqlValue::Binary(buf) + } else { + MySqlValue::Text(buf) + } + })) } } - -impl RowIndex for usize { - fn get(&self, row: &MySqlRow) -> crate::Result - where - ::Database: Type, - T: Decode<::Database>, - { - Ok(Decode::decode_nullable(row.row.get(*self))?) - } -} - -impl RowIndex for &'_ str { - fn get(&self, row: &MySqlRow) -> crate::Result - where - ::Database: Type, - T: Decode<::Database>, - { - let index = row - .columns - .get(*self) - .ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?; - - let value = Decode::decode_nullable(row.row.get(*index))?; - - Ok(value) - } -} - -impl_from_row_for_row!(MySqlRow); diff --git a/sqlx-core/src/mysql/rsa.rs b/sqlx-core/src/mysql/rsa.rs index 715a499b..a4e242a0 100644 --- a/sqlx-core/src/mysql/rsa.rs +++ b/sqlx-core/src/mysql/rsa.rs @@ -6,7 +6,7 @@ use rand::{thread_rng, Rng}; // For the love of crypto, please delete as much of this as possible and use the RSA crate // directly when that PR is merged -pub fn encrypt(key: &[u8], message: &[u8]) -> crate::Result> { +pub fn encrypt(key: &[u8], message: &[u8]) -> crate::Result> { let key = std::str::from_utf8(key).map_err(|_err| { // TODO(@abonander): protocol_err doesn't like referring to [err] protocol_err!("unexpected error decoding what should be UTF-8") @@ -14,7 +14,7 @@ pub fn encrypt(key: &[u8], message: &[u8]) -> crate::Result let key = parse(key)?; - Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?.into_boxed_slice()) + Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?) } // https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L12 diff --git a/sqlx-core/src/mysql/stream.rs b/sqlx-core/src/mysql/stream.rs new file mode 100644 index 00000000..243b32e3 --- /dev/null +++ b/sqlx-core/src/mysql/stream.rs @@ -0,0 +1,193 @@ +use std::net::Shutdown; + +use byteorder::{ByteOrder, LittleEndian}; + +use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream}; +use crate::mysql::protocol::{Capabilities, Decode, Encode, EofPacket, ErrPacket, OkPacket}; +use crate::mysql::MySqlError; +use crate::url::Url; + +// Size before a packet is split +const MAX_PACKET_SIZE: u32 = 1024; + +pub(crate) struct MySqlStream { + pub(super) stream: BufStream, + + // Active capabilities + pub(super) capabilities: Capabilities, + + // Packets in a command sequence have an incrementing sequence number + // This number must be 0 at the start of each command + pub(super) seq_no: u8, + + // Packets are buffered into a second buffer from the stream + // as we may have compressed or split packets to figure out before + // decoding + packet_buf: Vec, + packet_len: usize, +} + +impl MySqlStream { + pub(super) async fn new(url: &Url) -> crate::Result { + let stream = MaybeTlsStream::connect(&url, 5432).await?; + + let mut capabilities = Capabilities::PROTOCOL_41 + | Capabilities::IGNORE_SPACE + | Capabilities::DEPRECATE_EOF + | Capabilities::FOUND_ROWS + | Capabilities::TRANSACTIONS + | Capabilities::SECURE_CONNECTION + | Capabilities::PLUGIN_AUTH_LENENC_DATA + | Capabilities::MULTI_STATEMENTS + | Capabilities::MULTI_RESULTS + | Capabilities::PLUGIN_AUTH; + + if url.database().is_some() { + capabilities |= Capabilities::CONNECT_WITH_DB; + } + + if cfg!(feature = "tls") { + capabilities |= Capabilities::SSL; + } + + Ok(Self { + capabilities, + stream: BufStream::new(stream), + packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize), + packet_len: 0, + seq_no: 0, + }) + } + + pub(super) fn is_tls(&self) -> bool { + self.stream.is_tls() + } + + pub(super) fn shutdown(&self) -> crate::Result<()> { + Ok(self.stream.shutdown(Shutdown::Both)?) + } + + #[inline] + pub(super) async fn send(&mut self, packet: T, initial: bool) -> crate::Result<()> + where + T: Encode + std::fmt::Debug, + { + if initial { + self.seq_no = 0; + } + + self.write(packet); + self.flush().await + } + + #[inline] + pub(super) async fn flush(&mut self) -> crate::Result<()> { + Ok(self.stream.flush().await?) + } + + /// Write the packet to the buffered stream ( do not send to the server ) + pub(super) fn write(&mut self, packet: T) + where + T: Encode, + { + let buf = self.stream.buffer_mut(); + + // Allocate room for the header that we write after the packet; + // so, we can get an accurate and cheap measure of packet length + + let header_offset = buf.len(); + buf.advance(4); + + packet.encode(buf, self.capabilities); + + // Determine length of encoded packet + // and write to allocated header + + let len = buf.len() - header_offset - 4; + let mut header = &mut buf[header_offset..]; + + LittleEndian::write_u32(&mut header, len as u32); + + // Take the last sequence number received, if any, and increment by 1 + // If there was no sequence number, we only increment if we split packets + header[3] = self.seq_no; + self.seq_no = self.seq_no.wrapping_add(1); + } + + #[inline] + pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> { + self.read().await?; + + Ok(self.packet()) + } + + pub(super) async fn read(&mut self) -> crate::Result<()> { + self.packet_buf.clear(); + self.packet_len = 0; + + // Read the packet header which contains the length and the sequence number + // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html + // https://mariadb.com/kb/en/library/0-packet/#standard-packet + let mut header = self.stream.peek(4_usize).await?; + + self.packet_len = header.get_uint::(3)? as usize; + self.seq_no = header.get_u8()?.wrapping_add(1); + + self.stream.consume(4); + + // Read the packet body and copy it into our internal buf + // We must have a separate buffer around the stream as we can't operate directly + // on bytes returned from the stream. We have various kinds of payload manipulation + // that must be handled before decoding. + let payload = self.stream.peek(self.packet_len).await?; + + self.packet_buf.reserve(payload.len()); + self.packet_buf.extend_from_slice(payload); + + self.stream.consume(self.packet_len); + + // TODO: Implement packet compression + // TODO: Implement packet joining + + Ok(()) + } + + /// Returns a reference to the most recently received packet data. + /// A call to `read` invalidates this buffer. + #[inline] + pub(super) fn packet(&self) -> &[u8] { + &self.packet_buf[..self.packet_len] + } +} + +impl MySqlStream { + pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> { + if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { + let _eof = EofPacket::decode(self.receive().await?)?; + } + + Ok(()) + } + + pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result { + if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { + let _eof = EofPacket::decode(self.packet())?; + + Ok(true) + } else { + Ok(false) + } + } + + pub(crate) fn handle_unexpected(&mut self) -> crate::Result { + Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into()) + } + + pub(crate) fn handle_err(&mut self) -> crate::Result { + Err(MySqlError(ErrPacket::decode(self.packet(), self.capabilities)?).into()) + } + + pub(crate) fn handle_ok(&mut self) -> crate::Result { + OkPacket::decode(self.packet()) + } +} diff --git a/sqlx-core/src/mysql/tls.rs b/sqlx-core/src/mysql/tls.rs new file mode 100644 index 00000000..6af05019 --- /dev/null +++ b/sqlx-core/src/mysql/tls.rs @@ -0,0 +1,115 @@ +use std::borrow::Cow; +use std::str::FromStr; + +use crate::mysql::protocol::{Capabilities, SslRequest}; +use crate::mysql::stream::MySqlStream; +use crate::url::Url; + +pub(super) async fn upgrade_if_needed(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> { + let ca_file = url.param("ssl-ca"); + + let ssl_mode = url.param("ssl-mode"); + + let supports_tls = stream.capabilities.contains(Capabilities::SSL); + + // https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode + match ssl_mode.as_deref() { + Some("DISABLED") => {} + + #[cfg(feature = "tls")] + Some("PREFERRED") | None if !supports_tls => {} + + #[cfg(feature = "tls")] + Some("PREFERRED") => { + if let Err(error) = try_upgrade(stream, &url, None, true).await { + // TLS upgrade failed; fall back to a normal connection + } + } + + #[cfg(feature = "tls")] + Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") + if !supports_tls => + { + return Err(tls_err!("server does not support TLS").into()); + } + + #[cfg(feature = "tls")] + Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") if ca_file.is_none() => { + return Err( + tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(), + ); + } + + #[cfg(feature = "tls")] + Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") => { + try_upgrade( + stream, + url, + // false for both verify-ca and verify-full + ca_file.as_deref(), + // false for only verify-full + mode != "VERIFY_IDENTITY", + ) + .await?; + } + + #[cfg(not(feature = "tls"))] + None => { + // The user neither explicitly enabled TLS in the connection string + // nor did they turn the `tls` feature on + } + + #[cfg(not(feature = "tls"))] + Some(mode @ "PREFERRED") + | Some(mode @ "REQUIRED") + | Some(mode @ "VERIFY_CA") + | Some(mode @ "VERIFY_IDENTITY") => { + return Err(tls_err!( + "ssl-mode {:?} unsupported; SQLx was compiled without `tls` feature", + mode + ) + .into()); + } + + Some(mode) => { + return Err(tls_err!("unknown `ssl-mode` value: {:?}", mode).into()); + } + } + + Ok(()) +} + +#[cfg(feature = "tls")] +async fn try_upgrade( + stream: &mut MySqlStream, + url: &Url, + ca_file: Option<&str>, + accept_invalid_hostnames: bool, +) -> crate::Result<()> { + use crate::runtime::fs; + + use async_native_tls::{Certificate, TlsConnector}; + + let mut connector = TlsConnector::new() + .danger_accept_invalid_certs(ca_file.is_none()) + .danger_accept_invalid_hostnames(accept_invalid_hostnames); + + if let Some(ca_file) = ca_file { + let root_cert = fs::read(ca_file).await?; + + connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?); + } + + // send upgrade request and then immediately try TLS handshake + stream + .send( + SslRequest { + client_collation: COLLATE_UTF8MB4_UNICODE_CI, + max_packet_size: MAX_PACKET_SIZE, + }, + false, + ) + .await?; + + stream.stream.upgrade(url, connector).await +} diff --git a/sqlx-core/src/mysql/types/bool.rs b/sqlx-core/src/mysql/types/bool.rs index 182a371a..c37bbcae 100644 --- a/sqlx-core/src/mysql/types/bool.rs +++ b/sqlx-core/src/mysql/types/bool.rs @@ -1,11 +1,14 @@ -use crate::decode::{Decode, DecodeError}; +use std::convert::TryInto; + +use crate::decode::Decode; use crate::encode::Encode; +use crate::error::UnexpectedNullError; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; -impl Type for MySql { +impl Type for bool { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TINY_INT) } @@ -17,13 +20,18 @@ impl Encode for bool { } } -impl Decode for bool { - fn decode(buf: &[u8]) -> Result { - match buf.len() { - 0 => Err(DecodeError::Message(Box::new( - "Expected minimum 1 byte but received none.", - ))), - _ => Ok(buf[0] != 0), +impl<'de> Decode<'de, MySql> for bool { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()), + + MySqlValue::Text(b"0") => Ok(false), + + MySqlValue::Text(b"1") => Ok(true), + + MySqlValue::Text(s) => Err(crate::Error::Decode( + format!("unexpected value {:?} for boolean", s).into(), + )), } } } diff --git a/sqlx-core/src/mysql/types/bytes.rs b/sqlx-core/src/mysql/types/bytes.rs index ec4429d9..d31991f5 100644 --- a/sqlx-core/src/mysql/types/bytes.rs +++ b/sqlx-core/src/mysql/types/bytes.rs @@ -1,14 +1,16 @@ use byteorder::LittleEndian; -use crate::decode::{Decode, DecodeError}; +use crate::decode::Decode; use crate::encode::Encode; +use crate::error::UnexpectedNullError; use crate::mysql::io::{BufExt, BufMutExt}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; +use std::convert::TryInto; -impl Type<[u8]> for MySql { +impl Type for [u8] { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo { id: TypeId::TEXT, @@ -19,9 +21,9 @@ impl Type<[u8]> for MySql { } } -impl Type> for MySql { +impl Type for Vec { fn type_info() -> MySqlTypeInfo { - >::type_info() + <[u8] as Type>::type_info() } } @@ -37,11 +39,36 @@ impl Encode for Vec { } } -impl Decode for Vec { - fn decode(mut buf: &[u8]) -> Result { - Ok(buf - .get_bytes_lenenc::()? - .unwrap_or_default() - .to_vec()) +impl<'de> Decode<'de, MySql> for Vec { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => { + let len = buf + .get_uint_lenenc::() + .map_err(crate::Error::decode)? + .unwrap_or_default(); + + Ok((&buf[..(len as usize)]).to_vec()) + } + + MySqlValue::Text(s) => Ok(s.to_vec()), + } + } +} + +impl<'de> Decode<'de, MySql> for &'de [u8] { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => { + let len = buf + .get_uint_lenenc::() + .map_err(crate::Error::decode)? + .unwrap_or_default(); + + Ok(&buf[..(len as usize)]) + } + + MySqlValue::Text(s) => Ok(s), + } } } diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index b9ef5446..9c380d0d 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -1,17 +1,19 @@ -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use byteorder::{ByteOrder, LittleEndian}; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; -use crate::decode::{Decode, DecodeError}; +use crate::decode::Decode; use crate::encode::Encode; use crate::io::{Buf, BufMut}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; +use crate::Error; +use bitflags::_core::str::from_utf8; -impl Type> for MySql { +impl Type for DateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TIMESTAMP) } @@ -23,15 +25,15 @@ impl Encode for DateTime { } } -impl Decode for DateTime { - fn decode(buf: &[u8]) -> Result { - let naive: NaiveDateTime = Decode::::decode(buf)?; +impl<'de> Decode<'de, MySql> for DateTime { + fn decode(value: Option>) -> crate::Result { + let naive: NaiveDateTime = Decode::::decode(value)?; Ok(DateTime::from_utc(naive, Utc)) } } -impl Type for MySql { +impl Type for NaiveTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TIME) } @@ -63,24 +65,33 @@ impl Encode for NaiveTime { } } -impl Decode for NaiveTime { - fn decode(mut buf: &[u8]) -> Result { - // data length, expecting 8 or 12 (fractional seconds) - let len = buf.get_u8()?; +impl<'de> Decode<'de, MySql> for NaiveTime { + fn decode(buf: Option>) -> crate::Result { + match buf.try_into()? { + MySqlValue::Binary(mut buf) => { + // data length, expecting 8 or 12 (fractional seconds) + let len = buf.get_u8()?; - // is negative : int<1> - let is_negative = buf.get_u8()?; - assert_eq!(is_negative, 0, "Negative dates/times are not supported"); + // is negative : int<1> + let is_negative = buf.get_u8()?; + assert_eq!(is_negative, 0, "Negative dates/times are not supported"); - // "date on 4 bytes little-endian format" (?) - // https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding - buf.advance(4); + // "date on 4 bytes little-endian format" (?) + // https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding + buf.advance(4); - decode_time(len - 5, buf) + decode_time(len - 5, buf) + } + + MySqlValue::Text(buf) => { + let s = from_utf8(buf).map_err(Error::decode)?; + NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode) + } + } } } -impl Type for MySql { +impl Type for NaiveDate { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::DATE) } @@ -98,13 +109,20 @@ impl Encode for NaiveDate { } } -impl Decode for NaiveDate { - fn decode(buf: &[u8]) -> Result { - Ok(decode_date(&buf[1..])) +impl<'de> Decode<'de, MySql> for NaiveDate { + fn decode(buf: Option>) -> crate::Result { + match buf.try_into()? { + MySqlValue::Binary(buf) => Ok(decode_date(&buf[1..])), + + MySqlValue::Text(buf) => { + let s = from_utf8(buf).map_err(Error::decode)?; + NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode) + } + } } } -impl Type for MySql { +impl Type for NaiveDateTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::DATETIME) } @@ -144,18 +162,27 @@ impl Encode for NaiveDateTime { } } -impl Decode for NaiveDateTime { - fn decode(buf: &[u8]) -> Result { - let len = buf[0]; - let date = decode_date(&buf[1..]); +impl<'de> Decode<'de, MySql> for NaiveDateTime { + fn decode(buf: Option>) -> crate::Result { + match buf.try_into()? { + MySqlValue::Binary(buf) => { + let len = buf[0]; + let date = decode_date(&buf[1..]); - let dt = if len > 4 { - date.and_time(decode_time(len - 4, &buf[5..])?) - } else { - date.and_hms(0, 0, 0) - }; + let dt = if len > 4 { + date.and_time(decode_time(len - 4, &buf[5..])?) + } else { + date.and_hms(0, 0, 0) + }; - Ok(dt) + Ok(dt) + }, + + MySqlValue::Text(buf) => { + let s = from_utf8(buf).map_err(Error::decode)?; + NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Error::decode) + } + } } } @@ -187,7 +214,7 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec) { } } -fn decode_time(len: u8, mut buf: &[u8]) -> Result { +fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result { let hour = buf.get_u8()?; let minute = buf.get_u8()?; let seconds = buf.get_u8()?; diff --git a/sqlx-core/src/mysql/types/float.rs b/sqlx-core/src/mysql/types/float.rs index 62250ff7..f37f5077 100644 --- a/sqlx-core/src/mysql/types/float.rs +++ b/sqlx-core/src/mysql/types/float.rs @@ -1,9 +1,16 @@ -use crate::decode::{Decode, DecodeError}; +use std::convert::TryInto; + +use byteorder::{LittleEndian, ReadBytesExt}; + +use crate::decode::Decode; use crate::encode::Encode; +use crate::error::UnexpectedNullError; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; +use crate::Error; +use std::str::from_utf8; /// The equivalent MySQL type for `f32` is `FLOAT`. /// @@ -18,7 +25,7 @@ use crate::types::Type; /// // (This is expected behavior for floating points and happens both in Rust and in MySQL) /// assert_ne!(10.2f32 as f64, 10.2f64); /// ``` -impl Type for MySql { +impl Type for f32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::FLOAT) } @@ -30,9 +37,19 @@ impl Encode for f32 { } } -impl Decode for f32 { - fn decode(buf: &[u8]) -> Result { - Ok(f32::from_bits(>::decode(buf)? as u32)) +impl<'de> Decode<'de, MySql> for f32 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf + .read_i32::() + .map_err(crate::Error::decode) + .map(|value| f32::from_bits(value as u32)), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } @@ -40,7 +57,7 @@ impl Decode for f32 { /// /// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values /// exactly. -impl Type for MySql { +impl Type for f64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::DOUBLE) } @@ -52,8 +69,18 @@ impl Encode for f64 { } } -impl Decode for f64 { - fn decode(buf: &[u8]) -> Result { - Ok(f64::from_bits(>::decode(buf)? as u64)) +impl<'de> Decode<'de, MySql> for f64 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf + .read_i64::() + .map_err(crate::Error::decode) + .map(|value| f64::from_bits(value as u64)), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } diff --git a/sqlx-core/src/mysql/types/int.rs b/sqlx-core/src/mysql/types/int.rs index a5d63146..cc257fa5 100644 --- a/sqlx-core/src/mysql/types/int.rs +++ b/sqlx-core/src/mysql/types/int.rs @@ -1,14 +1,18 @@ -use byteorder::LittleEndian; +use std::convert::TryInto; +use std::str::from_utf8; -use crate::decode::{Decode, DecodeError}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use crate::decode::Decode; use crate::encode::Encode; -use crate::io::{Buf, BufMut}; +use crate::error::UnexpectedNullError; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; +use crate::Error; -impl Type for MySql { +impl Type for i8 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::TINY_INT) } @@ -16,17 +20,24 @@ impl Type for MySql { impl Encode for i8 { fn encode(&self, buf: &mut Vec) { - buf.push(*self as u8); + buf.write_i8(*self); } } -impl Decode for i8 { - fn decode(buf: &[u8]) -> Result { - Ok(buf[0] as i8) +impl<'de> Decode<'de, MySql> for i8 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_i8().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } -impl Type for MySql { +impl Type for i16 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::SMALL_INT) } @@ -34,17 +45,24 @@ impl Type for MySql { impl Encode for i16 { fn encode(&self, buf: &mut Vec) { - buf.put_i16::(*self); + buf.write_i16::(*self); } } -impl Decode for i16 { - fn decode(mut buf: &[u8]) -> Result { - buf.get_i16::().map_err(Into::into) +impl<'de> Decode<'de, MySql> for i16 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_i16::().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } -impl Type for MySql { +impl Type for i32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::INT) } @@ -52,17 +70,24 @@ impl Type for MySql { impl Encode for i32 { fn encode(&self, buf: &mut Vec) { - buf.put_i32::(*self); + buf.write_i32::(*self); } } -impl Decode for i32 { - fn decode(mut buf: &[u8]) -> Result { - buf.get_i32::().map_err(Into::into) +impl<'de> Decode<'de, MySql> for i32 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_i32::().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } -impl Type for MySql { +impl Type for i64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::new(TypeId::BIG_INT) } @@ -70,14 +95,19 @@ impl Type for MySql { impl Encode for i64 { fn encode(&self, buf: &mut Vec) { - buf.put_u64::(*self as u64); + buf.write_i64::(*self); } } -impl Decode for i64 { - fn decode(mut buf: &[u8]) -> Result { - buf.get_u64::() - .map_err(Into::into) - .map(|val| val as i64) +impl<'de> Decode<'de, MySql> for i64 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_i64::().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 475370ca..3c1311b0 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -10,8 +10,10 @@ mod chrono; use std::fmt::{self, Debug, Display}; +use crate::decode::Decode; use crate::mysql::protocol::TypeId; use crate::mysql::protocol::{ColumnDefinition, FieldFlags}; +use crate::mysql::{MySql, MySqlValue}; use crate::types::TypeInfo; #[derive(Clone, Debug, Default)] @@ -103,3 +105,14 @@ impl TypeInfo for MySqlTypeInfo { } } } + +impl<'de, T> Decode<'de, MySql> for Option +where + T: Decode<'de, MySql>, +{ + fn decode(value: Option>) -> crate::Result { + value + .map(|value| >::decode(Some(value))) + .transpose() + } +} diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index 46087bf9..7a4c98af 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -2,15 +2,18 @@ use std::str; use byteorder::LittleEndian; -use crate::decode::{Decode, DecodeError}; +use crate::decode::Decode; use crate::encode::Encode; +use crate::error::UnexpectedNullError; use crate::mysql::io::{BufExt, BufMutExt}; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; +use std::convert::TryInto; +use std::str::from_utf8; -impl Type for MySql { +impl Type for str { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo { id: TypeId::TEXT, @@ -27,10 +30,9 @@ impl Encode for str { } } -// TODO: Do we need the [HasSqlType] for String -impl Type for MySql { +impl Type for String { fn type_info() -> MySqlTypeInfo { - >::type_info() + >::type_info() } } @@ -40,11 +42,25 @@ impl Encode for String { } } -impl Decode for String { - fn decode(mut buf: &[u8]) -> Result { - Ok(buf - .get_str_lenenc::()? - .unwrap_or_default() - .to_owned()) +impl<'de> Decode<'de, MySql> for &'de str { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => { + let len = buf + .get_uint_lenenc::() + .map_err(crate::Error::decode)? + .unwrap_or_default(); + + from_utf8(&buf[..(len as usize)]).map_err(crate::Error::decode) + } + + MySqlValue::Text(s) => from_utf8(s).map_err(crate::Error::decode), + } + } +} + +impl<'de> Decode<'de, MySql> for String { + fn decode(buf: Option>) -> crate::Result { + <&'de str>::decode(buf).map(ToOwned::to_owned) } } diff --git a/sqlx-core/src/mysql/types/uint.rs b/sqlx-core/src/mysql/types/uint.rs index c6db5a72..d5114809 100644 --- a/sqlx-core/src/mysql/types/uint.rs +++ b/sqlx-core/src/mysql/types/uint.rs @@ -1,14 +1,18 @@ -use byteorder::LittleEndian; +use std::convert::TryInto; +use std::str::from_utf8; -use crate::decode::{Decode, DecodeError}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use crate::decode::Decode; use crate::encode::Encode; -use crate::io::{Buf, BufMut}; +use crate::error::UnexpectedNullError; use crate::mysql::protocol::TypeId; use crate::mysql::types::MySqlTypeInfo; -use crate::mysql::MySql; +use crate::mysql::{MySql, MySqlValue}; use crate::types::Type; +use crate::Error; -impl Type for MySql { +impl Type for u8 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::TINY_INT) } @@ -16,17 +20,24 @@ impl Type for MySql { impl Encode for u8 { fn encode(&self, buf: &mut Vec) { - buf.push(*self); + buf.write_u8(*self); } } -impl Decode for u8 { - fn decode(buf: &[u8]) -> Result { - Ok(buf[0]) +impl<'de> Decode<'de, MySql> for u8 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_u8().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } -impl Type for MySql { +impl Type for u16 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::SMALL_INT) } @@ -34,17 +45,24 @@ impl Type for MySql { impl Encode for u16 { fn encode(&self, buf: &mut Vec) { - buf.put_u16::(*self); + buf.write_u16::(*self); } } -impl Decode for u16 { - fn decode(mut buf: &[u8]) -> Result { - buf.get_u16::().map_err(Into::into) +impl<'de> Decode<'de, MySql> for u16 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_u16::().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } -impl Type for MySql { +impl Type for u32 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::INT) } @@ -52,17 +70,24 @@ impl Type for MySql { impl Encode for u32 { fn encode(&self, buf: &mut Vec) { - buf.put_u32::(*self); + buf.write_u32::(*self); } } -impl Decode for u32 { - fn decode(mut buf: &[u8]) -> Result { - buf.get_u32::().map_err(Into::into) +impl<'de> Decode<'de, MySql> for u32 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_u32::().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } -impl Type for MySql { +impl Type for u64 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::unsigned(TypeId::BIG_INT) } @@ -70,12 +95,19 @@ impl Type for MySql { impl Encode for u64 { fn encode(&self, buf: &mut Vec) { - buf.put_u64::(*self); + buf.write_u64::(*self); } } -impl Decode for u64 { - fn decode(mut buf: &[u8]) -> Result { - buf.get_u64::().map_err(Into::into) +impl<'de> Decode<'de, MySql> for u64 { + fn decode(value: Option>) -> crate::Result { + match value.try_into()? { + MySqlValue::Binary(mut buf) => buf.read_u64::().map_err(Into::into), + + MySqlValue::Text(s) => from_utf8(s) + .map_err(Error::decode)? + .parse() + .map_err(Error::decode), + } } } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 7ae97c9e..3d7d6a73 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -28,3 +28,4 @@ make_query_as!(PgQueryAs, Postgres, PgRow); impl_map_row_for_row!(Postgres, PgRow); impl_column_index_for_row!(Postgres); impl_from_row_for_tuples!(Postgres, PgRow); +impl_execute_for_query!(Postgres); diff --git a/sqlx-core/src/row.rs b/sqlx-core/src/row.rs index dfb88579..26d1eca5 100644 --- a/sqlx-core/src/row.rs +++ b/sqlx-core/src/row.rs @@ -190,7 +190,7 @@ macro_rules! impl_column_index_for_row { row.columns .get(self) .ok_or_else(|| crate::Error::ColumnNotFound((*self).into())) - .map(|&index| index) + .map(|&index| index as usize) } } }; diff --git a/sqlx-macros/src/derives.rs b/sqlx-macros/src/derives.rs index 87bc50bb..1d4a7172 100644 --- a/sqlx-macros/src/derives.rs +++ b/sqlx-macros/src/derives.rs @@ -78,6 +78,25 @@ pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result)); + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + impls.push(quote!( + impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause { + fn decode(value: >::RawValue) -> sqlx::Result { + <#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self) + } + } + )); + } + // panic!("{}", q) Ok(quote!(#(#impls)*)) } diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 9bc3763d..09e3937d 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -65,14 +65,15 @@ macro_rules! test_prepared_type { let mut conn = sqlx_test::new::<$db>().await?; $( - let query = format!("SELECT {} = $1, $1 as _1", $text); + let query = format!($crate::[< $db _query_for_test_prepared_type >]!(), $text); let rec: (bool, $ty) = sqlx::query_as(&query) + .bind($value) .bind($value) .fetch_one(&mut conn) .await?; - assert!(rec.0); + assert!(rec.0, "value returned from server: {:?}", rec.1); assert!($value == rec.1); )+ @@ -81,3 +82,17 @@ macro_rules! test_prepared_type { } } } + +#[macro_export] +macro_rules! MySql_query_for_test_prepared_type { + () => { + "SELECT {} <=> ?, ? as _1" + }; +} + +#[macro_export] +macro_rules! Postgres_query_for_test_prepared_type { + () => { + "SELECT {} is not distinct form $1, $2 as _1" + }; +} diff --git a/src/lib.rs b/src/lib.rs index 1f78550f..81135020 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,4 +70,7 @@ pub mod prelude { #[cfg(feature = "postgres")] pub use super::postgres::PgQueryAs; + + #[cfg(feature = "mysql")] + pub use super::mysql::MySqlQueryAs; } diff --git a/tests/derives.rs b/tests/derives.rs index 7f581e1e..facdbea9 100644 --- a/tests/derives.rs +++ b/tests/derives.rs @@ -58,3 +58,17 @@ where let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap(); assert_eq!(example, decoded); } + +#[cfg(feature = "mysql")] +fn decode_with_db() +where + Foo: for<'de> Decode<'de, sqlx::MySql> + Encode, +{ + let example = Foo(0x1122_3344); + + let mut encoded = Vec::new(); + Encode::::encode(&example, &mut encoded); + + let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap(); + assert_eq!(example, decoded); +} diff --git a/tests/mysql-macros.rs b/tests/mysql-macros.rs index 28b19743..6921a242 100644 --- a/tests/mysql-macros.rs +++ b/tests/mysql-macros.rs @@ -1,9 +1,10 @@ -use sqlx::MySqlConnection; +use sqlx::MySql; +use sqlx_test::new; #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn macro_select_from_cte() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let account = sqlx::query!("select * from (select (1) as id, 'Herp Derpinson' as name) accounts") .fetch_one(&mut conn) @@ -18,7 +19,7 @@ async fn macro_select_from_cte() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn macro_select_from_cte_bind() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let account = sqlx::query!( "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?", 1i32 @@ -41,7 +42,7 @@ struct RawAccount { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_query_as_raw() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let account = sqlx::query_as!( RawAccount, @@ -57,11 +58,3 @@ async fn test_query_as_raw() -> anyhow::Result<()> { Ok(()) } - -fn url() -> anyhow::Result { - Ok(dotenv::var("DATABASE_URL")?) -} - -async fn connect() -> anyhow::Result { - Ok(MySqlConnection::open(url()?).await?) -} diff --git a/tests/mysql-raw.rs b/tests/mysql-raw.rs new file mode 100644 index 00000000..1138714f --- /dev/null +++ b/tests/mysql-raw.rs @@ -0,0 +1,56 @@ +//! Tests for the raw (unprepared) query API for MySql. + +use sqlx::{Cursor, Executor, MySql, Row}; +use sqlx_test::new; + +/// Test a simple select expression. This should return the row. +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_select_expression() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut cursor = conn.fetch("SELECT 5"); + let row = cursor.next().await?.unwrap(); + + assert!(5i32 == row.get::(0)?); + + Ok(()) +} + +/// Test that we can interleave reads and writes to the database +/// in one simple query. Using the `Cursor` API we should be +/// able to fetch from both queries in sequence. +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_multi_read_write() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut cursor = conn.fetch( + " +CREATE TEMPORARY TABLE messages ( + id BIGINT PRIMARY KEY AUTO_INCREMENT, + text TEXT NOT NULL +); + +SELECT 'Hello World' as _1; + +INSERT INTO messages (text) VALUES ('this is a test'); + +SELECT id, text FROM messages; + ", + ); + + let row = cursor.next().await?.unwrap(); + + assert!("Hello World" == row.get::<&str, _>("_1")?); + + let row = cursor.next().await?.unwrap(); + + let id: i64 = row.get("id")?; + let text: &str = row.get("text")?; + + assert_eq!(1_i64, id); + assert_eq!("this is a test", text); + + Ok(()) +} diff --git a/tests/mysql-types-chrono.rs b/tests/mysql-types-chrono.rs deleted file mode 100644 index 72a148a2..00000000 --- a/tests/mysql-types-chrono.rs +++ /dev/null @@ -1,87 +0,0 @@ -use sqlx::types::chrono::{DateTime, NaiveDate, NaiveTime, Utc}; -use sqlx::{mysql::MySqlConnection, Connection, Row}; - -async fn connect() -> anyhow::Result { - Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?) -} - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn mysql_chrono_date() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = NaiveDate::from_ymd(2019, 1, 2); - - let row = sqlx::query!( - "SELECT (DATE '2019-01-02' = ?) as _1, CAST(? AS DATE) as _2", - value, - value - ) - .fetch_one(&mut conn) - .await?; - - assert!(row._1 != 0); - assert_eq!(value, row._2); - - Ok(()) -} - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn mysql_chrono_date_time() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20); - - let row = sqlx::query("SELECT '2019-01-02 05:10:20' = ?, ?") - .bind(&value) - .bind(&value) - .fetch_one(&mut conn) - .await?; - - assert!(row.get::(0)); - assert_eq!(value, row.get(1)); - - Ok(()) -} - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn mysql_chrono_time() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = NaiveTime::from_hms_micro(5, 10, 20, 115100); - - let row = sqlx::query("SELECT TIME '05:10:20.115100' = ?, TIME '05:10:20.115100'") - .bind(&value) - .fetch_one(&mut conn) - .await?; - - assert!(row.get::(0)); - assert_eq!(value, row.get(1)); - - Ok(()) -} - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn mysql_chrono_timestamp() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = DateTime::::from_utc( - NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100), - Utc, - ); - - let row = sqlx::query( - "SELECT TIMESTAMP '2019-01-02 05:10:20.115100' = ?, TIMESTAMP '2019-01-02 05:10:20.115100'", - ) - .bind(&value) - .fetch_one(&mut conn) - .await?; - - assert!(row.get::(0)); - assert_eq!(value, row.get(1)); - - Ok(()) -} diff --git a/tests/mysql-types.rs b/tests/mysql-types.rs index 01d9192a..89cc935c 100644 --- a/tests/mysql-types.rs +++ b/tests/mysql-types.rs @@ -1,94 +1,86 @@ -use sqlx::{mysql::MySqlConnection, Connection, Row}; +use sqlx::MySql; +use sqlx_test::test_type; -async fn connect() -> anyhow::Result { - Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?) -} - -macro_rules! test { - ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => { - #[cfg_attr(feature = "runtime-async-std", async_std::test)] - #[cfg_attr(feature = "runtime-tokio", tokio::test)] - async fn $name () -> anyhow::Result<()> { - let mut conn = connect().await?; - - $( - let row = sqlx::query(&format!("SELECT {} = ?, ? as _1", $text)) - .bind($value) - .bind($value) - .fetch_one(&mut conn) - .await?; - - let value = row.get::<$ty, _>("_1"); - - assert_eq!(row.get::(0), 1, "value returned from server: {:?}", value); - - assert_eq!($value, value); - )+ - - Ok(()) - } - } -} - -test!(mysql_bool: bool: "false" == false, "true" == true); - -test!(mysql_tiny_unsigned: u8: "253" == 253_u8); -test!(mysql_tiny: i8: "5" == 5_i8); - -test!(mysql_medium_unsigned: u16: "21415" == 21415_u16); -test!(mysql_short: i16: "21415" == 21415_i16); - -test!(mysql_long_unsigned: u32: "2141512" == 2141512_u32); -test!(mysql_long: i32: "2141512" == 2141512_i32); - -test!(mysql_longlong_unsigned: u64: "2141512" == 2141512_u64); -test!(mysql_longlong: i64: "2141512" == 2141512_i64); - -// `DOUBLE` can be compared with decimal literals just fine but the same can't be said for `FLOAT` -test!(mysql_double: f64: "3.14159265" == 3.14159265f64); - -test!(mysql_string: String: "'helloworld'" == "helloworld"); - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn mysql_bytes() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = &b"Hello, World"[..]; - - let rec = sqlx::query!( - "SELECT (X'48656c6c6f2c20576f726c64' = ?) as _1, CAST(? as BINARY) as _2", - value, - value - ) - .fetch_one(&mut conn) - .await?; - - assert!(rec._1 != 0); - - let output: Vec = rec._2; - - assert_eq!(&value[..], &*output); - - Ok(()) -} - -#[cfg_attr(feature = "runtime-async-std", async_std::test)] -#[cfg_attr(feature = "runtime-tokio", tokio::test)] -async fn mysql_float() -> anyhow::Result<()> { - let mut conn = connect().await?; - - let value = 10.2f32; - let row = sqlx::query("SELECT ? as _1") - .bind(value) - .fetch_one(&mut conn) - .await?; - - // comparison between FLOAT and literal doesn't work as expected - // we get implicit widening to DOUBLE which gives a slightly different value - // however, round-trip does work as expected - let ret = row.get::("_1"); - assert_eq!(value, ret); - - Ok(()) +test_type!(null( + MySql, + Option, + "NULL" == None:: +)); + +test_type!(bool(MySql, bool, "false" == false, "true" == true)); + +test_type!(u8(MySql, u8, "253" == 253_u8)); +test_type!(i8(MySql, i8, "5" == 5_i8, "0" == 0_i8)); + +test_type!(u16(MySql, u16, "21415" == 21415_u16)); +test_type!(i16(MySql, i16, "21415" == 21415_i16)); + +test_type!(u32(MySql, u32, "2141512" == 2141512_u32)); +test_type!(i32(MySql, i32, "2141512" == 2141512_i32)); + +test_type!(u64(MySql, u64, "2141512" == 2141512_u64)); +test_type!(i64(MySql, i64, "2141512" == 2141512_i64)); + +test_type!(double(MySql, f64, "3.14159265" == 3.14159265f64)); + +// NOTE: This behavior can be very surprising. MySQL implicitly widens FLOAT bind parameters +// to DOUBLE. This results in the weirdness you see below. MySQL generally recommends to stay +// away from FLOATs. +test_type!(float( + MySql, + f32, + "3.1410000324249268" == 3.141f32 as f64 as f32 +)); + +test_type!(string( + MySql, + String, + "'helloworld'" == "helloworld", + "''" == "" +)); + +test_type!(bytes( + MySql, + Vec, + "X'DEADBEEF'" + == vec![0xDE_u8, 0xAD, 0xBE, 0xEF], + "X''" + == Vec::::new(), + "X'0000000052'" + == vec![0_u8, 0, 0, 0, 0x52] +)); + +#[cfg(feature = "chrono")] +mod chrono { + use super::*; + use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; + + test_type!(chrono_date( + MySql, + NaiveDate, + "DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5), + "DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23) + )); + + test_type!(chrono_time( + MySql, + NaiveTime, + "TIME '05:10:20.115100'" == NaiveTime::from_hms_micro(5, 10, 20, 115100) + )); + + test_type!(chrono_date_time( + MySql, + NaiveDateTime, + "'2019-01-02 05:10:20'" == NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20) + )); + + test_type!(chrono_date_time_tz( + MySql, + DateTime::, + "TIMESTAMP '2019-01-02 05:10:20.115100'" + == DateTime::::from_utc( + NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100), + Utc, + ) + )); } diff --git a/tests/mysql.rs b/tests/mysql.rs index 5eb53bec..42e1e636 100644 --- a/tests/mysql.rs +++ b/tests/mysql.rs @@ -1,17 +1,24 @@ use futures::TryStreamExt; -use sqlx::{Connection as _, Executor as _, MySqlConnection, MySqlPool, Row as _}; +use sqlx::{mysql::MySqlQueryAs, Connection, Executor, MySql, MySqlPool}; +use sqlx_test::new; use std::time::Duration; #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_connects() -> anyhow::Result<()> { - let mut conn = connect().await?; + Ok(new::().await?.ping().await?) +} - let row = sqlx::query("select 1 + 1").fetch_one(&mut conn).await?; +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn it_drops_results_in_affected_rows() -> anyhow::Result<()> { + let mut conn = new::().await?; - assert_eq!(2, row.get(0)); + // ~1800 rows should be iterated and dropped + let affected = conn.execute("select * from mysql.time_zone").await?; - conn.close().await?; + // In MySQL, rows being returned isn't enough to flag it as an _affected_ row + assert_eq!(0, affected); Ok(()) } @@ -19,10 +26,10 @@ async fn it_connects() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_executes() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn - .send( + .execute( r#" CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) "#, @@ -38,12 +45,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) assert_eq!(cnt, 1); } - let sum: i32 = sqlx::query("SELECT id FROM users") + let sum: i32 = sqlx::query_as("SELECT id FROM users") .fetch(&mut conn) - .try_fold( - 0_i32, - |acc, x| async move { Ok(acc + x.get::("id")) }, - ) + .try_fold(0_i32, |acc, (x,): (i32,)| async move { Ok(acc + x) }) .await?; assert_eq!(sum, 55); @@ -54,11 +58,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY) #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn it_selects_null() -> anyhow::Result<()> { - let mut conn = connect().await?; + let mut conn = new::().await?; - let row = sqlx::query("SELECT NULL").fetch_one(&mut conn).await?; - - let val: Option = row.get(0); + let (val,): (Option,) = sqlx::query_as("SELECT NULL").fetch_one(&mut conn).await?; assert!(val.is_none()); @@ -68,12 +70,10 @@ async fn it_selects_null() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_describe() -> anyhow::Result<()> { - use sqlx::describe::Nullability::*; - - let mut conn = connect().await?; + let mut conn = new::().await?; let _ = conn - .send( + .execute( r#" CREATE TEMPORARY TABLE describe_test ( id int primary key auto_increment, @@ -88,13 +88,13 @@ async fn test_describe() -> anyhow::Result<()> { .describe("select nt.*, false from describe_test nt") .await?; - assert_eq!(describe.result_columns[0].nullability, NonNull); + assert_eq!(describe.result_columns[0].non_null, Some(true)); assert_eq!(describe.result_columns[0].type_info.type_name(), "INT"); - assert_eq!(describe.result_columns[1].nullability, NonNull); + assert_eq!(describe.result_columns[1].non_null, Some(true)); assert_eq!(describe.result_columns[1].type_info.type_name(), "TEXT"); - assert_eq!(describe.result_columns[2].nullability, Nullable); + assert_eq!(describe.result_columns[2].non_null, Some(false)); assert_eq!(describe.result_columns[2].type_info.type_name(), "TEXT"); - assert_eq!(describe.result_columns[3].nullability, NonNull); + assert_eq!(describe.result_columns[3].non_null, Some(true)); let bool_ty_name = describe.result_columns[3].type_info.type_name(); @@ -112,7 +112,7 @@ async fn test_describe() -> anyhow::Result<()> { #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn pool_immediately_fails_with_db_error() -> anyhow::Result<()> { // Malform the database url by changing the password - let url = url()?.replace("password", "not-the-password"); + let url = dotenv::var("DATABASE_URL")?.replace("password", "not-the-password"); let pool = MySqlPool::new(&url).await?; @@ -152,7 +152,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { let pool = pool.clone(); spawn(async move { loop { - if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&mut &pool).await { + if let Err(e) = sqlx::query("select 1 + 1").execute(&mut &pool).await { eprintln!("pool task {} dying due to {}", i, e); break; } @@ -185,11 +185,3 @@ async fn pool_smoke_test() -> anyhow::Result<()> { Ok(()) } - -fn url() -> anyhow::Result { - Ok(dotenv::var("DATABASE_URL")?) -} - -async fn connect() -> anyhow::Result { - Ok(MySqlConnection::open(url()?).await?) -}