Add zero-allocation to MySQL query execution

WIP mysql compiles with types and executor commented out
This commit is contained in:
Ryan Leckey 2020-03-11 01:40:57 -07:00
parent de14a206ff
commit c9df8acc41
39 changed files with 1694 additions and 1366 deletions

View File

@ -74,6 +74,10 @@ required-features = [ "mysql", "macros" ]
name = "mysql"
required-features = [ "mysql" ]
[[test]]
name = "mysql-raw"
required-features = [ "mysql" ]
[[test]]
name = "postgres"
required-features = [ "postgres" ]

View File

@ -7,6 +7,8 @@ pub trait Buf {
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64>;
fn get_i8(&mut self) -> io::Result<i8>;
fn get_u8(&mut self) -> io::Result<u8>;
fn get_u16<T: ByteOrder>(&mut self) -> io::Result<u16>;
@ -17,6 +19,8 @@ pub trait Buf {
fn get_i32<T: ByteOrder>(&mut self) -> io::Result<i32>;
fn get_i64<T: ByteOrder>(&mut self) -> io::Result<i64>;
fn get_u32<T: ByteOrder>(&mut self) -> io::Result<u32>;
fn get_u64<T: ByteOrder>(&mut self) -> io::Result<u64>;
@ -40,6 +44,13 @@ impl<'a> Buf for &'a [u8] {
Ok(val)
}
fn get_i8(&mut self) -> io::Result<i8> {
let val = self[0];
self.advance(1);
Ok(val as i8)
}
fn get_u8(&mut self) -> io::Result<u8> {
let val = self[0];
self.advance(1);
@ -75,6 +86,13 @@ impl<'a> Buf for &'a [u8] {
Ok(val)
}
fn get_i64<T: ByteOrder>(&mut self) -> io::Result<i64> {
let val = T::read_i64(*self);
self.advance(4);
Ok(val)
}
fn get_u32<T: ByteOrder>(&mut self) -> io::Result<u32> {
let val = T::read_u32(*self);
self.advance(4);

View File

@ -1,7 +1,9 @@
//! Core of SQLx, the rust SQL toolkit. Not intended to be used directly.
#![forbid(unsafe_code)]
#![recursion_limit = "512"]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![allow(unused)]
#[macro_use]
pub mod error;

View File

@ -27,10 +27,10 @@ impl Arguments for MySqlArguments {
fn add<T>(&mut self, value: T)
where
Self::Database: Type<T>,
T: Type<Self::Database>,
T: Encode<Self::Database>,
{
let type_id = <MySql as Type<T>>::type_info();
let type_id = <T as Type<MySql>>::type_info();
let index = self.param_types.len();
self.param_types.push(type_id);

View File

@ -1,22 +1,25 @@
use std::collections::HashMap;
use std::convert::TryInto;
use std::io;
use std::sync::Arc;
use byteorder::{ByteOrder, LittleEndian};
use futures_core::future::BoxFuture;
use sha1::Sha1;
use std::net::Shutdown;
use crate::cache::StatementCache;
use crate::connection::{Connect, Connection};
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{
AuthPlugin, AuthSwitch, Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake,
AuthPlugin, AuthSwitch, Capabilities, ComPing, Decode, Encode, EofPacket, ErrPacket, Handshake,
HandshakeResponse, OkPacket, SslRequest,
};
use crate::mysql::rsa;
use crate::mysql::stream::MySqlStream;
use crate::mysql::util::xor_eq;
use crate::mysql::{rsa, tls};
use crate::url::Url;
use std::ops::Range;
// Size before a packet is split
const MAX_PACKET_SIZE: u32 = 1024;
@ -85,521 +88,206 @@ const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
/// against the hostname in the server certificate, so they must be the same for the TLS
/// upgrade to succeed. `ssl-ca` must still be specified.
pub struct MySqlConnection {
pub(super) stream: BufStream<MaybeTlsStream>,
pub(super) stream: MySqlStream,
pub(super) is_ready: bool,
pub(super) cache_statement: HashMap<Box<str>, u32>,
// Active capabilities of the client _&_ the server
pub(super) capabilities: Capabilities,
// Cache of prepared statements
// Query (String) to StatementId to ColumnMap
pub(super) statement_cache: StatementCache<u32>,
// Packets are buffered into a second buffer from the stream
// as we may have compressed or split packets to figure out before
// decoding
pub(super) packet: Vec<u8>,
packet_len: usize,
// Packets in a command sequence have an incrementing sequence number
// This number must be 0 at the start of each command
pub(super) next_seq_no: u8,
// Work buffer for the value ranges of the current row
// This is used as the backing memory for each Row's value indexes
pub(super) current_row_values: Vec<Option<Range<usize>>>,
}
impl MySqlConnection {
/// Write the packet to the stream ( do not send to the server )
pub(crate) fn write(&mut self, packet: impl Encode) {
let buf = self.stream.buffer_mut();
fn to_asciz(s: &str) -> Vec<u8> {
let mut z = String::with_capacity(s.len() + 1);
z.push_str(s);
z.push('\0');
// Allocate room for the header that we write after the packet;
// so, we can get an accurate and cheap measure of packet length
z.into_bytes()
}
let header_offset = buf.len();
buf.advance(4);
async fn rsa_encrypt_with_nonce(
stream: &mut MySqlStream,
public_key_request_id: u8,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
packet.encode(buf, self.capabilities);
// Determine length of encoded packet
// and write to allocated header
let len = buf.len() - header_offset - 4;
let mut header = &mut buf[header_offset..];
LittleEndian::write_u32(&mut header, len as u32); // len
// Take the last sequence number received, if any, and increment by 1
// If there was no sequence number, we only increment if we split packets
header[3] = self.next_seq_no;
self.next_seq_no = self.next_seq_no.wrapping_add(1);
if stream.is_tls() {
// If in a TLS stream, send the password directly in clear text
return Ok(to_asciz(password));
}
/// Send the packet to the database server
pub(crate) async fn send(&mut self, packet: impl Encode) -> crate::Result<()> {
self.write(packet);
self.stream.flush().await?;
// client sends a public key request
stream.send(&[public_key_request_id][..], false).await?;
Ok(())
}
// server sends a public key response
let packet = stream.receive().await?;
let rsa_pub_key = &packet[1..];
/// Send a [HandshakeResponse] packet to the database server
pub(crate) async fn send_handshake_response(
&mut self,
url: &Url,
auth_plugin: &AuthPlugin,
auth_response: &[u8],
) -> crate::Result<()> {
self.send(HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: url.username().unwrap_or("root"),
database: url.database(),
auth_plugin,
auth_response,
})
.await
}
// xor the password with the given nonce
let mut pass = to_asciz(password);
xor_eq(&mut pass, nonce);
/// Try to receive a packet from the database server. Returns `None` if the server has sent
/// no data.
pub(crate) async fn try_receive(&mut self) -> crate::Result<Option<()>> {
self.packet.clear();
// client sends an RSA encrypted password
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
}
// Read the packet header which contains the length and the sequence number
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
let mut header = ret_if_none!(self.stream.peek(4).await?);
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
self.next_seq_no = header.get_u8()?.wrapping_add(1);
self.stream.consume(4);
// Read the packet body and copy it into our internal buf
// We must have a separate buffer around the stream as we can't operate directly
// on bytes returned from the stream. We have various kinds of payload manipulation
// that must be handled before decoding.
let payload = ret_if_none!(self.stream.peek(self.packet_len).await?);
self.packet.extend_from_slice(payload);
self.stream.consume(self.packet_len);
// TODO: Implement packet compression
// TODO: Implement packet joining
Ok(Some(()))
}
/// Receive a complete packet from the database server.
pub(crate) async fn receive(&mut self) -> crate::Result<&mut Self> {
self.try_receive()
.await?
.ok_or(io::ErrorKind::ConnectionAborted)?;
Ok(self)
}
/// Returns a reference to the most recently received packet data
#[inline]
pub(crate) fn packet(&self) -> &[u8] {
&self.packet[..self.packet_len]
}
/// Receive an [EofPacket] if we are supposed to receive them at all.
pub(crate) async fn receive_eof(&mut self) -> crate::Result<()> {
// When (legacy) EOFs are enabled, many things are terminated by an EOF packet
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.receive().await?.packet())?;
async fn make_auth_response(
stream: &mut MySqlStream,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
match plugin {
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
Ok(plugin.scramble(password, nonce))
}
Ok(())
}
/// Receive a [Handshake] packet. When connecting to the database server, this is immediately
/// received from the database server.
pub(crate) async fn receive_handshake(&mut self, url: &Url) -> crate::Result<Handshake> {
let handshake = Handshake::decode(self.receive().await?.packet())?;
let mut client_capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::FOUND_ROWS
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::PLUGIN_AUTH;
if url.database().is_some() {
client_capabilities |= Capabilities::CONNECT_WITH_DB;
}
if cfg!(feature = "tls") {
client_capabilities |= Capabilities::SSL;
}
self.capabilities =
(client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41;
Ok(handshake)
}
/// Receives an [OkPacket] from the database server. This is called at the end of
/// authentication to confirm the established connection.
pub(crate) fn receive_auth_ok<'a>(
&'a mut self,
plugin: &'a AuthPlugin,
password: &'a str,
nonce: &'a [u8],
) -> BoxFuture<'a, crate::Result<()>> {
Box::pin(async move {
self.receive().await?;
match self.packet[0] {
0x00 => self.handle_ok().map(drop),
0xfe => self.handle_auth_switch(password).await,
0xff => self.handle_err(),
_ => self.handle_auth_continue(plugin, password, nonce).await,
}
})
}
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
let ok = OkPacket::decode(self.packet())?;
// An OK signifies the end of the current command sequence
self.next_seq_no = 0;
Ok(ok)
}
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
let err = ErrPacket::decode(self.packet())?;
// An ERR signifies the end of the current command sequence
self.next_seq_no = 0;
Err(MySqlError(err).into())
}
pub(crate) fn handle_unexpected_packet<T>(&self, id: u8) -> crate::Result<T> {
Err(protocol_err!("unexpected packet identifier 0x{:X?}", id).into())
}
pub(crate) async fn handle_auth_continue(
&mut self,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<()> {
match plugin {
AuthPlugin::CachingSha2Password => {
if self.packet[0] == 1 {
match self.packet[1] {
// AUTH_OK
0x03 => {}
// AUTH_CONTINUE
0x04 => {
// client sends an RSA encrypted password
let ct = self.rsa_encrypt(0x02, password, nonce).await?;
self.send(&*ct).await?;
}
auth => {
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", auth).into());
}
}
// ends with server sending either OK_Packet or ERR_Packet
self.receive_auth_ok(plugin, password, nonce)
.await
.map(drop)
} else {
return self.handle_unexpected_packet(self.packet[0]);
}
}
// No other supported auth methods will be called through continue
_ => unreachable!(),
}
}
pub(crate) async fn handle_auth_switch(&mut self, password: &str) -> crate::Result<()> {
let auth = AuthSwitch::decode(self.packet())?;
let auth_response = self
.make_auth_initial_response(&auth.auth_plugin, password, &auth.auth_plugin_data)
.await?;
self.send(&*auth_response).await?;
self.receive_auth_ok(&auth.auth_plugin, password, &auth.auth_plugin_data)
.await
}
pub(crate) async fn make_auth_initial_response(
&mut self,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
match plugin {
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
Ok(plugin.scramble(password, nonce))
}
AuthPlugin::Sha256Password => {
// Full RSA exchange and password encrypt up front with no "cache"
Ok(self.rsa_encrypt(0x01, password, nonce).await?.into_vec())
}
}
}
pub(crate) async fn rsa_encrypt(
&mut self,
public_key_request_id: u8,
password: &str,
nonce: &[u8],
) -> crate::Result<Box<[u8]>> {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
if self.stream.is_tls() {
// If in a TLS stream, send the password directly in clear text
let mut clear_text = String::with_capacity(password.len() + 1);
clear_text.push_str(password);
clear_text.push('\0');
return Ok(clear_text.into_bytes().into_boxed_slice());
}
// client sends a public key request
self.send(&[public_key_request_id][..]).await?;
// server sends a public key response
let packet = self.receive().await?.packet();
let rsa_pub_key = &packet[1..];
// The password string data must be NUL terminated
// Note: This is not in the documentation that I could find
let mut pass = password.as_bytes().to_vec();
pass.push(0);
xor_eq(&mut pass, nonce);
// client sends an RSA encrypted password
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await,
}
}
impl MySqlConnection {
async fn new(url: &Url) -> crate::Result<Self> {
let stream = MaybeTlsStream::connect(url, 3306).await?;
async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
// https://mariadb.com/kb/en/connection/
let mut capabilities = Capabilities::empty();
// Read a [Handshake] packet. When connecting to the database server, this is immediately
// received from the database server.
if cfg!(feature = "tls") {
capabilities |= Capabilities::SSL;
}
let handshake = Handshake::decode(stream.receive().await?)?;
let mut auth_plugin = handshake.auth_plugin;
let mut auth_plugin_data = handshake.auth_plugin_data;
Ok(Self {
stream: BufStream::new(stream),
capabilities,
packet: Vec::with_capacity(8192),
packet_len: 0,
next_seq_no: 0,
statement_cache: StatementCache::new(),
})
}
stream.capabilities &= handshake.server_capabilities;
stream.capabilities |= Capabilities::PROTOCOL_41;
async fn initialize(&mut self) -> crate::Result<()> {
// On connect, we want to establish a modern, Rust-compatible baseline so we
// tweak connection options to enable UTC for TIMESTAMP, UTF-8 for character types, etc.
// Depending on the ssl-mode and capabilities we should upgrade
// our connection to TLS
// TODO: Use batch support when we have it to handle the following in one execution
tls::upgrade_if_needed(stream, url).await?;
// https://mariadb.com/kb/en/sql-mode/
// Send a [HandshakeResponse] packet. This is returned in response to the [Handshake] packet
// that is immediately received.
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
let password = &*url.password().unwrap_or_default();
let auth_response =
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
// not available, a warning is given and the default storage
// engine is used instead.
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
// language=MySQL
self.execute_raw("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))")
.await?;
// This allows us to assume that the output from a TIMESTAMP field is UTC
// language=MySQL
self.execute_raw("SET time_zone = '+00:00'").await?;
// https://mathiasbynens.be/notes/mysql-utf8mb4
// language=MySQL
self.execute_raw("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
.await?;
Ok(())
}
#[cfg(feature = "tls")]
async fn try_ssl(
&mut self,
url: &Url,
ca_file: Option<&str>,
invalid_hostnames: bool,
) -> crate::Result<()> {
use crate::runtime::fs;
use async_native_tls::{Certificate, TlsConnector};
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(ca_file.is_none())
.danger_accept_invalid_hostnames(invalid_hostnames);
if let Some(ca_file) = ca_file {
let root_cert = fs::read(ca_file).await?;
connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?);
}
// send upgrade request and then immediately try TLS handshake
self.send(SslRequest {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
})
stream
.send(
HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: url.username().unwrap_or("root"),
database: url.database(),
auth_plugin: &auth_plugin,
auth_response: &auth_response,
},
false,
)
.await?;
self.stream.stream.upgrade(url, connector).await
loop {
// After sending the handshake response with our assumed auth method the server
// will send OK, fail, or tell us to change auth methods
let capabilities = stream.capabilities;
let packet = stream.receive().await?;
match packet[0] {
// OK
0x00 => {
break;
}
// ERROR
0xFF => {
return stream.handle_err();
}
// AUTH_SWITCH
0xFE => {
let auth = AuthSwitch::decode(packet)?;
auth_plugin = auth.auth_plugin;
auth_plugin_data = auth.auth_plugin_data;
let auth_response =
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
stream.send(&*auth_response, false).await?;
}
0x01 if auth_plugin == AuthPlugin::CachingSha2Password => {
match packet[1] {
// AUTH_OK
0x03 => {}
// AUTH_CONTINUE
0x04 => {
// The specific password is _not_ cached on the server
// We need to send a normal RSA-encrypted password for this
let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data)
.await?;
stream.send(&*enc, false).await?;
}
unk => {
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into());
}
}
}
unk => {
return stream.handle_unexpected();
}
}
}
Ok(())
}
async fn close(mut stream: MySqlStream) -> crate::Result<()> {
// TODO: Actually tell MySQL that we're closing
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
async fn ping(stream: &mut MySqlStream) -> crate::Result<()> {
stream.send(ComPing, true).await?;
match stream.receive().await?[0] {
0x00 | 0xFE => Ok(()),
0xFF => stream.handle_err(),
_ => stream.handle_unexpected(),
}
}
impl MySqlConnection {
pub(super) async fn establish(url: crate::Result<Url>) -> crate::Result<Self> {
pub(super) async fn new(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut self_ = Self::new(&url).await?;
let mut stream = MySqlStream::new(&url).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
// https://mariadb.com/kb/en/connection/
establish(&mut stream, &url).await?;
// On connect, server immediately sends the handshake
let mut handshake = self_.receive_handshake(&url).await?;
let ca_file = url.param("ssl-ca");
let ssl_mode = url.param("ssl-mode").unwrap_or(
if ca_file.is_some() {
"VERIFY_CA"
} else {
"PREFERRED"
}
.into(),
);
let supports_ssl = handshake.server_capabilities.contains(Capabilities::SSL);
match &*ssl_mode {
"DISABLED" => (),
// don't try upgrade
#[cfg(feature = "tls")]
"PREFERRED" if !supports_ssl => {
log::warn!("server does not support TLS; using unencrypted connection")
}
// try to upgrade
#[cfg(feature = "tls")]
"PREFERRED" => {
if let Err(e) = self_.try_ssl(&url, None, true).await {
log::warn!("TLS handshake failed, falling back to insecure: {}", e);
// fallback, redo connection
self_ = Self::new(&url).await?;
handshake = self_.receive_handshake(&url).await?;
}
}
#[cfg(not(feature = "tls"))]
"PREFERRED" => log::info!("compiled without TLS, skipping upgrade"),
#[cfg(feature = "tls")]
"REQUIRED" if !supports_ssl => {
return Err(tls_err!("server does not support TLS").into())
}
#[cfg(feature = "tls")]
"REQUIRED" => self_.try_ssl(&url, None, true).await?,
#[cfg(feature = "tls")]
"VERIFY_CA" | "VERIFY_FULL" if ca_file.is_none() => {
return Err(
tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(),
)
}
#[cfg(feature = "tls")]
"VERIFY_CA" | "VERIFY_FULL" => {
self_
.try_ssl(&url, ca_file.as_deref(), ssl_mode != "VERIFY_FULL")
.await?
}
#[cfg(not(feature = "tls"))]
"REQUIRED" | "VERIFY_CA" | "VERIFY_FULL" => {
return Err(tls_err!("compiled without TLS").into())
}
_ => return Err(tls_err!("unknown `ssl-mode` value: {:?}", ssl_mode).into()),
}
// Pre-generate an auth response by using the auth method in the [Handshake]
let password = url.password().unwrap_or_default();
let auth_response = self_
.make_auth_initial_response(
&handshake.auth_plugin,
&password,
&handshake.auth_plugin_data,
)
.await?;
self_
.send_handshake_response(&url, &handshake.auth_plugin, &auth_response)
.await?;
// After sending the handshake response with our assumed auth method the server
// will send OK, fail, or tell us to change auth methods
self_
.receive_auth_ok(
&handshake.auth_plugin,
&password,
&handshake.auth_plugin_data,
)
.await?;
let mut self_ = Self {
stream,
current_row_values: Vec::with_capacity(10),
is_ready: true,
cache_statement: HashMap::new(),
};
// After the connection is established, we initialize by configuring a few
// connection parameters
self_.initialize().await?;
// initialize().await?;
Ok(self_)
}
async fn close(mut self) -> crate::Result<()> {
// TODO: Actually tell MySQL that we're closing
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;
Ok(())
}
}
impl MySqlConnection {
#[deprecated(note = "please use 'connect' instead")]
pub fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(MySqlConnection::establish(url.try_into()))
}
}
impl Connect for MySqlConnection {
@ -608,12 +296,16 @@ impl Connect for MySqlConnection {
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(MySqlConnection::establish(url.try_into()))
Box::pin(MySqlConnection::new(url.try_into()))
}
}
impl Connection for MySqlConnection {
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(self.close())
Box::pin(close(self.stream))
}
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(ping(&mut self.stream))
}
}

View File

@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use crate::connection::{ConnectionSource, MaybeOwnedConnection};
use crate::cursor::Cursor;
use crate::executor::Execute;
use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Decode, Row, Status, TypeId};
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow};
use crate::pool::Pool;
pub struct MySqlCursor<'c, 'q> {
source: ConnectionSource<'c, MySqlConnection>,
query: Option<(&'q str, Option<MySqlArguments>)>,
column_names: Arc<HashMap<Box<str>, u16>>,
column_types: Vec<TypeId>,
binary: bool,
}
impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> {
type Database = MySql;
#[doc(hidden)]
fn from_pool<E>(pool: &Pool<MySqlConnection>, query: E) -> Self
where
Self: Sized,
E: Execute<'q, MySql>,
{
Self {
source: ConnectionSource::Pool(pool.clone()),
column_names: Arc::default(),
column_types: Vec::new(),
binary: true,
query: Some(query.into_parts()),
}
}
#[doc(hidden)]
fn from_connection<E, C>(conn: C, query: E) -> Self
where
Self: Sized,
C: Into<MaybeOwnedConnection<'c, MySqlConnection>>,
E: Execute<'q, MySql>,
{
Self {
source: ConnectionSource::Connection(conn.into()),
column_names: Arc::default(),
column_types: Vec::new(),
binary: true,
query: Some(query.into_parts()),
}
}
fn next(&mut self) -> BoxFuture<crate::Result<Option<MySqlRow<'_>>>> {
Box::pin(next(self))
}
}
async fn next<'a, 'c: 'a, 'q: 'a>(
cursor: &'a mut MySqlCursor<'c, 'q>,
) -> crate::Result<Option<MySqlRow<'a>>> {
println!("[cursor::next]");
let mut conn = cursor.source.resolve_by_ref().await?;
// The first time [next] is called we need to actually execute our
// contained query. We guard against this happening on _all_ next calls
// by using [Option::take] which replaces the potential value in the Option with `None
let mut initial = if let Some((query, arguments)) = cursor.query.take() {
let statement = conn.run(query, arguments).await?;
// No statement ID = TEXT mode
cursor.binary = statement.is_some();
true
} else {
false
};
loop {
let mut packet_id = conn.stream.receive().await?[0];
println!("[cursor::next/iter] {:x}", packet_id);
match packet_id {
// OK or EOF packet
0x00 | 0xFE
if conn.stream.packet().len() < 0xFF_FF_FF && (packet_id != 0x00 || initial) =>
{
let ok = conn.stream.handle_ok()?;
if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// There is more to this query
initial = true;
} else {
conn.is_ready = true;
return Ok(None);
}
}
// ERR packet
0xFF => {
conn.is_ready = true;
return conn.stream.handle_err();
}
_ if initial => {
// At the start of the results we expect to see a
// COLUMN_COUNT followed by N COLUMN_DEF
let cc = ColumnCount::decode(conn.stream.packet())?;
// We use these definitions to get the actual column types that is critical
// in parsing the rows coming back soon
cursor.column_types.clear();
cursor.column_types.reserve(cc.columns as usize);
let mut column_names = HashMap::with_capacity(cc.columns as usize);
for i in 0..cc.columns {
let column = ColumnDefinition::decode(conn.stream.receive().await?)?;
cursor.column_types.push(column.type_id);
if let Some(name) = column.name() {
column_names.insert(name.to_owned().into_boxed_str(), i as u16);
}
}
cursor.column_names = Arc::new(column_names);
initial = false;
}
_ if !cursor.binary || packet_id == 0x00 => {
let row = Row::read(
conn.stream.packet(),
&cursor.column_types,
&mut conn.current_row_values,
// TODO: Text mode
cursor.binary,
)?;
let row = MySqlRow {
row,
columns: Arc::clone(&cursor.column_names),
// TODO: Text mode
binary: cursor.binary,
};
return Ok(Some(row));
}
_ => {
return conn.stream.handle_unexpected();
}
}
}
}

View File

@ -13,18 +13,18 @@ impl Database for MySql {
type TableId = Box<str>;
}
impl HasRow for MySql {
impl<'c> HasRow<'c> for MySql {
type Database = MySql;
type Row = super::MySqlRow;
type Row = super::MySqlRow<'c>;
}
impl<'a> HasCursor<'a> for MySql {
impl<'c, 'q> HasCursor<'c, 'q> for MySql {
type Database = MySql;
type Cursor = super::MySqlCursor<'a>;
type Cursor = super::MySqlCursor<'c, 'q>;
}
impl<'a> HasRawValue<'a> for MySql {
type RawValue = Option<&'a [u8]>;
impl<'c> HasRawValue<'c> for MySql {
type RawValue = Option<super::MySqlValue<'c>>;
}

View File

@ -4,341 +4,244 @@ use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use crate::describe::{Column, Describe, Nullability};
use crate::executor::Executor;
use crate::cursor::Cursor;
use crate::describe::{Column, Describe};
use crate::executor::{Execute, Executor, RefExecutor};
use crate::mysql::protocol::{
Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
ComStmtPrepareOk, Cursor, Decode, EofPacket, FieldFlags, OkPacket, Row, TypeId,
self, Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
ComStmtPrepareOk, Decode, EofPacket, ErrPacket, FieldFlags, OkPacket, Row, TypeId,
};
use crate::mysql::{
MySql, MySqlArguments, MySqlConnection, MySqlCursor, MySqlError, MySqlRow, MySqlTypeInfo,
};
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};
enum Step {
Command(u64),
Row(Row),
}
impl super::MySqlConnection {
async fn wait_until_ready(&mut self) -> crate::Result<()> {
if !self.is_ready {
loop {
let mut packet_id = self.stream.receive().await?[0];
match packet_id {
0xFE if self.stream.packet().len() < 0xFF_FF_FF => {
// OK or EOF packet
self.is_ready = true;
break;
}
enum OkOrResultSet {
Ok(OkPacket),
ResultSet(ColumnCount),
}
0xFF => {
// ERR packet
self.is_ready = true;
return self.stream.handle_err();
}
impl MySqlConnection {
async fn ignore_columns(&mut self, count: usize) -> crate::Result<()> {
for _ in 0..count {
let _column = ColumnDefinition::decode(self.receive().await?.packet())?;
}
if count > 0 {
self.receive_eof().await?;
}
Ok(())
}
async fn receive_ok_or_column_count(&mut self) -> crate::Result<OkOrResultSet> {
self.receive().await?;
match self.packet[0] {
0x00 | 0xfe if self.packet.len() < 0xffffff => self.handle_ok().map(OkOrResultSet::Ok),
0xff => self.handle_err(),
_ => Ok(OkOrResultSet::ResultSet(ColumnCount::decode(
self.packet(),
)?)),
}
}
async fn receive_column_types(&mut self, count: usize) -> crate::Result<Box<[TypeId]>> {
let mut columns: Vec<TypeId> = Vec::with_capacity(count);
for _ in 0..count {
let column: ColumnDefinition =
ColumnDefinition::decode(self.receive().await?.packet())?;
columns.push(column.type_id);
}
if count > 0 {
self.receive_eof().await?;
}
Ok(columns.into_boxed_slice())
}
async fn wait_for_ready(&mut self) -> crate::Result<()> {
if self.next_seq_no != 0 {
while let Some(_step) = self.step(&[], true).await? {
// Drain steps until we hit the end
_ => {
// Something else; skip
}
}
}
}
Ok(())
}
// Creates a prepared statement for the passed query string
async fn prepare(&mut self, query: &str) -> crate::Result<ComStmtPrepareOk> {
// Start by sending a COM_STMT_PREPARE
self.send(ComStmtPrepare { query }).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_prepare.html
self.stream.send(ComStmtPrepare { query }, true).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html
// Should receive a COM_STMT_PREPARE_OK or ERR_PACKET
let packet = self.stream.receive().await?;
// First we should receive a COM_STMT_PREPARE_OK
self.receive().await?;
if self.packet[0] == 0xff {
// Oops, there was an error in the prepare command
return self.handle_err();
if packet[0] == 0xFF {
return self.stream.handle_err();
}
ComStmtPrepareOk::decode(self.packet())
ComStmtPrepareOk::decode(packet)
}
async fn prepare_with_cache(&mut self, query: &str) -> crate::Result<u32> {
if let Some(&id) = self.statement_cache.get(query) {
async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> {
for _ in 0..count {
let _column = ColumnDefinition::decode(self.stream.receive().await?)?;
}
if count > 0 {
self.stream.maybe_receive_eof().await?;
}
Ok(())
}
// Gets a cached prepared statement ID _or_ prepares the statement if not in the cache
// At the end we should have [cache_statement] and [cache_statement_columns] filled
async fn get_or_prepare(&mut self, query: &str) -> crate::Result<u32> {
if let Some(&id) = self.cache_statement.get(query) {
Ok(id)
} else {
let prepare_ok = self.prepare(query).await?;
let stmt = self.prepare(query).await?;
// Remember our statement ID, so we do'd do this again the next time
self.statement_cache
.put(query.to_owned(), prepare_ok.statement_id);
self.cache_statement.insert(query.into(), stmt.statement_id);
// Ignore input parameters
self.ignore_columns(prepare_ok.params as usize).await?;
// COM_STMT_PREPARE returns the input columns
// We make no use of that data, so cycle through and drop them
self.drop_column_defs(stmt.params as usize).await?;
// Collect output parameter names
let mut columns = HashMap::with_capacity(prepare_ok.columns as usize);
let mut index = 0_usize;
for _ in 0..prepare_ok.columns {
let column = ColumnDefinition::decode(self.receive().await?.packet())?;
// COM_STMT_PREPARE next returns the output columns
// We just drop these as we get these when we execute the query
self.drop_column_defs(stmt.columns as usize).await?;
if let Some(name) = column.column_alias.or(column.column) {
columns.insert(name, index);
Ok(stmt.statement_id)
}
}
pub(crate) async fn run(
&mut self,
query: &str,
arguments: Option<MySqlArguments>,
) -> crate::Result<Option<u32>> {
self.wait_until_ready().await?;
self.is_ready = false;
if let Some(arguments) = arguments {
let statement_id = self.get_or_prepare(query).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_execute.html
self.stream
.send(
ComStmtExecute {
cursor: protocol::Cursor::NO_CURSOR,
statement_id,
params: &arguments.params,
null_bitmap: &arguments.null_bitmap,
param_types: &arguments.param_types,
},
true,
)
.await?;
Ok(Some(statement_id))
} else {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_query.html
self.stream.send(ComQuery { query }, true).await?;
Ok(None)
}
}
async fn affected_rows(&mut self) -> crate::Result<u64> {
let mut rows = 0;
loop {
let id = self.stream.receive().await?[0];
match id {
0x00 | 0xFE if self.stream.packet().len() < 0xFF_FF_FF => {
// ResultSet row can begin with 0xfe byte (when using text protocol
// with a field length > 0xffffff)
if !self.stream.maybe_handle_eof()? {
rows += self.stream.handle_ok()?.affected_rows;
}
// EOF packets do not have affected rows
// So this function is actually useless if the server doesn't support
// proper OK packets
self.is_ready = true;
break;
}
index += 1;
}
if prepare_ok.columns > 0 {
self.receive_eof().await?;
}
// At the end of a command, this should go back to 0
self.next_seq_no = 0;
// Remember our column map in the statement cache
self.statement_cache
.put_columns(prepare_ok.statement_id, columns);
Ok(prepare_ok.statement_id)
}
}
// [COM_STMT_EXECUTE]
async fn execute_statement(&mut self, id: u32, args: MySqlArguments) -> crate::Result<()> {
self.send(ComStmtExecute {
cursor: Cursor::NO_CURSOR,
statement_id: id,
params: &args.params,
null_bitmap: &args.null_bitmap,
param_types: &args.param_types,
})
.await
}
async fn step(&mut self, columns: &[TypeId], binary: bool) -> crate::Result<Option<Step>> {
let capabilities = self.capabilities;
ret_if_none!(self.try_receive().await?);
match self.packet[0] {
0xfe if self.packet.len() < 0xffffff => {
// ResultSet row can begin with 0xfe byte (when using text protocol
// with a field length > 0xffffff)
if !capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.packet())?;
// An EOF -here- signifies the end of the current command sequence
self.next_seq_no = 0;
Ok(None)
} else {
self.handle_ok()
.map(|ok| Some(Step::Command(ok.affected_rows)))
0xFF => {
return self.stream.handle_err();
}
}
0xff => self.handle_err(),
_ => Ok(Some(Step::Row(Row::decode(
self.packet(),
columns,
binary,
)?))),
}
}
}
impl MySqlConnection {
pub(super) async fn execute_raw(&mut self, query: &str) -> crate::Result<()> {
self.wait_for_ready().await?;
self.send(ComQuery { query }).await?;
// COM_QUERY can terminate before the result set with an ERR or OK packet
let num_columns = match self.receive_ok_or_column_count().await? {
OkOrResultSet::Ok(_) => {
self.next_seq_no = 0;
return Ok(());
}
OkOrResultSet::ResultSet(cc) => cc.columns as usize,
};
let columns = self.receive_column_types(num_columns as usize).await?;
while let Some(_step) = self.step(&columns, false).await? {
// Drop all responses
}
Ok(())
}
async fn execute(&mut self, query: &str, args: MySqlArguments) -> crate::Result<u64> {
self.wait_for_ready().await?;
let statement_id = self.prepare_with_cache(query).await?;
self.execute_statement(statement_id, args).await?;
// COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet
let num_columns = match self.receive_ok_or_column_count().await? {
OkOrResultSet::Ok(ok) => {
self.next_seq_no = 0;
return Ok(ok.affected_rows);
}
OkOrResultSet::ResultSet(cc) => cc.columns as usize,
};
self.ignore_columns(num_columns).await?;
let mut res = 0;
while let Some(step) = self.step(&[], true).await? {
if let Step::Command(affected) = step {
res = affected;
_ => {}
}
}
Ok(res)
Ok(rows)
}
async fn describe(&mut self, query: &str) -> crate::Result<Describe<MySql>> {
self.wait_for_ready().await?;
// method is not named describe to work around an intellijrust bug
// otherwise it marks someone trying to describe the connection as "method is private"
async fn do_describe(&mut self, query: &str) -> crate::Result<Describe<MySql>> {
self.wait_until_ready().await?;
let prepare_ok = self.prepare(query).await?;
let stmt = self.prepare(query).await?;
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
let mut result_columns = Vec::with_capacity(prepare_ok.columns as usize);
let mut param_types = Vec::with_capacity(stmt.params as usize);
let mut result_columns = Vec::with_capacity(stmt.columns as usize);
for _ in 0..prepare_ok.params {
let param = ColumnDefinition::decode(self.receive().await?.packet())?;
for _ in 0..stmt.params {
let param = ColumnDefinition::decode(self.stream.receive().await?)?;
param_types.push(MySqlTypeInfo::from_column_def(&param));
}
if prepare_ok.params > 0 {
self.receive_eof().await?;
if stmt.params > 0 {
self.stream.maybe_receive_eof().await?;
}
for _ in 0..prepare_ok.columns {
let column = ColumnDefinition::decode(self.receive().await?.packet())?;
for _ in 0..stmt.columns {
let column = ColumnDefinition::decode(self.stream.receive().await?)?;
result_columns.push(Column::<MySql> {
type_info: MySqlTypeInfo::from_column_def(&column),
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
// TODO(@abonander): Should this be None in some cases?
non_null: Some(column.flags.contains(FieldFlags::NOT_NULL)),
});
}
if prepare_ok.columns > 0 {
self.receive_eof().await?;
if stmt.columns > 0 {
self.stream.maybe_receive_eof().await?;
}
// Command sequence is over
self.next_seq_no = 0;
Ok(Describe {
param_types: param_types.into_boxed_slice(),
result_columns: result_columns.into_boxed_slice(),
})
}
}
fn fetch<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: MySqlArguments,
) -> BoxStream<'e, crate::Result<MySqlRow>> {
Box::pin(async_stream::try_stream! {
self.wait_for_ready().await?;
impl Executor for super::MySqlConnection {
type Database = MySql;
let statement_id = self.prepare_with_cache(query).await?;
fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result<u64>>
where
E: Execute<'q, Self::Database>,
{
Box::pin(async move {
let (query, arguments) = query.into_parts();
let columns = self.statement_cache.get_columns(statement_id);
self.execute_statement(statement_id, args).await?;
// COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet
let num_columns = match self.receive_ok_or_column_count().await? {
OkOrResultSet::Ok(_) => {
self.next_seq_no = 0;
return;
}
OkOrResultSet::ResultSet(cc) => {
cc.columns as usize
}
};
let column_types = self.receive_column_types(num_columns).await?;
while let Some(Step::Row(row)) = self.step(&column_types, true).await? {
yield MySqlRow { row, columns: Arc::clone(&columns) };
}
self.run(query, arguments).await?;
self.affected_rows().await
})
}
}
impl Executor for MySqlConnection {
type Database = super::MySql;
fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, crate::Result<()>> {
Box::pin(self.execute_raw(query))
fn fetch<'q, E>(&mut self, query: E) -> MySqlCursor<'_, 'q>
where
E: Execute<'q, Self::Database>,
{
MySqlCursor::from_connection(self, query)
}
fn fetch<'e, 'q: 'e>(
fn describe<'e, 'q, E: 'e>(
&'e mut self,
query: &'q str,
args: MySqlArguments,
) -> BoxFuture<'e, crate::Result<u64>> {
Box::pin(self.execute(query, args))
}
fn fetch<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: MySqlArguments,
) -> BoxStream<'e, crate::Result<MySqlRow>> {
self.fetch(query, args)
}
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
Box::pin(self.describe(query))
query: E,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>
where
E: Execute<'q, Self::Database>,
{
Box::pin(async move { self.do_describe(query.into_parts().0).await })
}
}
impl_execute_for_query!(MySql);
impl<'c> RefExecutor<'c> for &'c mut super::MySqlConnection {
type Database = MySql;
fn fetch_by_ref<'q, E>(self, query: E) -> MySqlCursor<'c, 'q>
where
E: Execute<'q, Self::Database>,
{
MySqlCursor::from_connection(self, query)
}
}

View File

@ -1,7 +1,16 @@
//! **MySQL** database and connection types.
pub use arguments::MySqlArguments;
pub use connection::MySqlConnection;
pub use cursor::MySqlCursor;
pub use database::MySql;
pub use error::MySqlError;
pub use row::{MySqlRow, MySqlValue};
pub use types::MySqlTypeInfo;
mod arguments;
mod connection;
mod cursor;
mod database;
mod error;
mod executor;
@ -9,20 +18,15 @@ mod io;
mod protocol;
mod row;
mod rsa;
mod stream;
mod tls;
mod types;
mod util;
pub use database::MySql;
pub use arguments::MySqlArguments;
pub use connection::MySqlConnection;
pub use error::MySqlError;
pub use types::MySqlTypeInfo;
pub use row::MySqlRow;
/// An alias for [`Pool`], specialized for **MySQL**.
pub type MySqlPool = super::Pool<MySqlConnection>;
pub type MySqlPool = crate::pool::Pool<MySqlConnection>;
make_query_as!(MySqlQueryAs, MySql, MySqlRow);
impl_map_row_for_row!(MySql, MySqlRow);
impl_column_index_for_row!(MySql);
impl_from_row_for_tuples!(MySql, MySqlRow);

View File

@ -6,7 +6,7 @@ use sha2::Sha256;
use crate::mysql::util::xor_eq;
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum AuthPlugin {
MySqlNativePassword,
CachingSha2Password,

View File

@ -27,6 +27,12 @@ pub struct ColumnDefinition {
pub decimals: u8,
}
impl ColumnDefinition {
pub fn name(&self) -> Option<&str> {
self.column_alias.as_deref().or(self.column.as_deref())
}
}
impl Decode for ColumnDefinition {
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
// catalog : string<lenenc>

View File

@ -0,0 +1,16 @@
use byteorder::LittleEndian;
use crate::io::BufMut;
use crate::mysql::io::BufMutExt;
use crate::mysql::protocol::{Capabilities, Encode};
// https://dev.mysql.com/doc/internals/en/com-ping.html
#[derive(Debug)]
pub struct ComPing;
impl Encode for ComPing {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_PING : int<1>
buf.put_u8(0x0e);
}
}

View File

@ -9,7 +9,8 @@ use crate::mysql::protocol::Decode;
pub struct ComStmtPrepareOk {
pub statement_id: u32,
/// Number of columns in the returned result set (or 0 if statement does not return result set).
/// Number of columns in the returned result set (or 0 if statement
/// does not return result set).
pub columns: u16,
/// Number of prepared statement parameters ('?' placeholders).

View File

@ -9,24 +9,34 @@ use crate::mysql::protocol::{Capabilities, Decode, Status};
#[derive(Debug)]
pub struct ErrPacket {
pub error_code: u16,
pub sql_state: Box<str>,
pub sql_state: Option<Box<str>>,
pub error_message: Box<str>,
}
impl Decode for ErrPacket {
fn decode(mut buf: &[u8]) -> crate::Result<Self>
impl ErrPacket {
pub(crate) fn decode(mut buf: &[u8], capabilities: Capabilities) -> crate::Result<Self>
where
Self: Sized,
{
let header = buf.get_u8()?;
if header != 0xFF {
return Err(protocol_err!("expected 0xFF; received 0x{:X}", header))?;
return Err(protocol_err!(
"expected 0xFF for ERR_PACKET; received 0x{:X}",
header
))?;
}
let error_code = buf.get_u16::<LittleEndian>()?;
let _sql_state_marker: u8 = buf.get_u8()?;
let sql_state = buf.get_str(5)?.into();
let mut sql_state = None;
if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) {
buf.advance(1);
sql_state = Some(buf.get_str(5)?.into())
}
}
let error_message = buf.get_str(buf.len())?.into();
@ -42,14 +52,25 @@ impl Decode for ErrPacket {
mod tests {
use super::{Capabilities, Decode, ErrPacket, Status};
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
#[test]
fn it_decodes_packets_out_of_order() {
let mut p = ErrPacket::decode(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap();
assert_eq!(&*p.error_message, "Got packets out of order");
assert_eq!(p.error_code, 1156);
assert_eq!(p.sql_state, None);
}
#[test]
fn it_decodes_ok_handshake() {
let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB).unwrap();
let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap();
assert_eq!(p.error_code, 1049);
assert_eq!(&*p.sql_state, "42000");
assert_eq!(p.sql_state.as_deref(), Some("42000"));
assert_eq!(&*p.error_message, "Unknown database \'unknown\'");
}
}

View File

@ -20,12 +20,14 @@ pub use field::FieldFlags;
pub use r#type::TypeId;
pub use status::Status;
mod com_ping;
mod com_query;
mod com_set_option;
mod com_stmt_execute;
mod com_stmt_prepare;
mod handshake;
pub use com_ping::ComPing;
pub use com_query::ComQuery;
pub use com_set_option::{ComSetOption, SetOption};
pub use com_stmt_execute::{ComStmtExecute, Cursor};

View File

@ -6,73 +6,84 @@ use crate::io::Buf;
use crate::mysql::io::BufExt;
use crate::mysql::protocol::{Decode, TypeId};
pub struct Row {
buffer: Box<[u8]>,
values: Box<[Option<Range<usize>>]>,
pub struct Row<'c> {
buffer: &'c [u8],
values: &'c [Option<Range<usize>>],
binary: bool,
}
impl Row {
impl<'c> Row<'c> {
pub fn len(&self) -> usize {
self.values.len()
}
pub fn get(&self, index: usize) -> Option<&[u8]> {
pub fn get(&self, index: usize) -> Option<&'c [u8]> {
let range = self.values[index].as_ref()?;
Some(&self.buffer[(range.start as usize)..(range.end as usize)])
}
}
fn get_lenenc(buf: &[u8]) -> usize {
fn get_lenenc(buf: &[u8]) -> (usize, Option<usize>) {
match buf[0] {
0xFB => 1,
0xFB => (1, None),
0xFC => {
let len_size = 1 + 2;
let len = LittleEndian::read_u16(&buf[1..]);
len_size + len as usize
(len_size, Some(len as usize))
}
0xFD => {
let len_size = 1 + 3;
let len = LittleEndian::read_u24(&buf[1..]);
len_size + len as usize
(len_size, Some(len as usize))
}
0xFE => {
let len_size = 1 + 8;
let len = LittleEndian::read_u64(&buf[1..]);
len_size + len as usize
(len_size, Some(len as usize))
}
value => 1 + value as usize,
len => (1, Some(len as usize)),
}
}
impl Row {
pub fn decode(mut buf: &[u8], columns: &[TypeId], binary: bool) -> crate::Result<Self> {
impl<'c> Row<'c> {
pub fn read(
mut buf: &'c [u8],
columns: &[TypeId],
values: &'c mut Vec<Option<Range<usize>>>,
binary: bool,
) -> crate::Result<Self> {
let mut buffer = &*buf;
values.clear();
values.reserve(columns.len());
if !binary {
let buffer: Box<[u8]> = buf.into();
let mut values = Vec::with_capacity(columns.len());
let mut index = 0;
for column_idx in 0..columns.len() {
let size = get_lenenc(&buf[index..]);
let (len_size, size) = get_lenenc(&buf[index..]);
values.push(Some(index..(index + size)));
if let Some(size) = size {
values.push(Some((index + len_size)..(index + len_size + size)));
} else {
values.push(None);
}
index += size;
buf.advance(size);
index += (len_size + size.unwrap_or_default());
}
return Ok(Self {
buffer,
values: values.into_boxed_slice(),
binary,
values: &*values,
binary: false,
});
}
@ -88,7 +99,6 @@ impl Row {
buf.advance(null_len);
let buffer: Box<[u8]> = buf.into();
let mut values = Vec::with_capacity(columns.len());
let mut index = 0;
for column_idx in 0..columns.len() {
@ -117,7 +127,11 @@ impl Row {
| TypeId::LONG_BLOB
| TypeId::CHAR
| TypeId::TEXT
| TypeId::VAR_CHAR => get_lenenc(&buffer[index..]),
| TypeId::VAR_CHAR => {
let (len_size, len) = get_lenenc(&buffer[index..]);
len_size + len.unwrap_or_default()
}
id => {
unimplemented!("encountered unknown field type id: {:?}", id);
@ -130,174 +144,174 @@ impl Row {
}
Ok(Self {
buffer,
values: values.into_boxed_slice(),
buffer: buf,
values: &*values,
binary,
})
}
}
#[cfg(test)]
mod test {
use super::super::column_count::ColumnCount;
use super::super::column_def::ColumnDefinition;
use super::super::eof::EofPacket;
use super::*;
#[test]
fn null_bitmap_test() -> crate::Result<()> {
let column_len = ColumnCount::decode(&[26])?;
assert_eq!(column_len.columns, 26);
let types: Vec<TypeId> = vec![
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0,
0, 3, 11, 66, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105,
101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105,
101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105,
101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105,
101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105,
101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105,
101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105,
101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105,
101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102,
105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102,
105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102,
105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102,
105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102,
105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102,
105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102,
105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102,
105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102,
105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102,
105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102,
105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102,
105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102,
105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102,
105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102,
105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102,
105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0,
])?,
ColumnDefinition::decode(&[
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102,
105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
])?,
]
.into_iter()
.map(|def| def.type_id)
.collect();
EofPacket::decode(&[254, 0, 0, 34, 0])?;
Row::decode(
&[
0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8,
10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
],
&types,
true,
)?;
EofPacket::decode(&[254, 0, 0, 34, 0])?;
Ok(())
}
}
// #[cfg(test)]
// mod test {
// use super::super::column_count::ColumnCount;
// use super::super::column_def::ColumnDefinition;
// use super::super::eof::EofPacket;
// use super::*;
//
// #[test]
// fn null_bitmap_test() -> crate::Result<()> {
// let column_len = ColumnCount::decode(&[26])?;
// assert_eq!(column_len.columns, 26);
//
// let types: Vec<TypeId> = vec![
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0,
// 0, 3, 11, 66, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105,
// 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105,
// 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105,
// 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105,
// 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105,
// 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105,
// 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105,
// 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105,
// 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102,
// 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102,
// 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102,
// 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102,
// 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102,
// 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102,
// 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102,
// 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102,
// 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102,
// 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102,
// 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102,
// 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102,
// 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102,
// 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102,
// 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102,
// 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102,
// 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102,
// 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ]
// .into_iter()
// .map(|def| def.type_id)
// .collect();
//
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
//
// Row::read(
// &[
// 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8,
// 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// ],
// &types,
// true,
// )?;
//
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
// Ok(())
// }
// }

View File

@ -1,59 +1,60 @@
use std::collections::HashMap;
use std::convert::TryFrom;
use std::str::{from_utf8, Utf8Error};
use std::sync::Arc;
use crate::decode::Decode;
use crate::error::UnexpectedNullError;
use crate::mysql::io::BufExt;
use crate::mysql::protocol;
use crate::mysql::MySql;
use crate::row::{Row, RowIndex};
use crate::row::{ColumnIndex, Row};
use crate::types::Type;
use byteorder::LittleEndian;
pub struct MySqlRow {
pub(super) row: protocol::Row,
pub(super) columns: Arc<HashMap<Box<str>, usize>>,
#[derive(Debug)]
pub enum MySqlValue<'c> {
Binary(&'c [u8]),
Text(&'c [u8]),
}
impl Row for MySqlRow {
impl<'c> TryFrom<Option<MySqlValue<'c>>> for MySqlValue<'c> {
type Error = crate::Error;
#[inline]
fn try_from(value: Option<MySqlValue<'c>>) -> Result<Self, Self::Error> {
match value {
Some(value) => Ok(value),
None => Err(crate::Error::decode(UnexpectedNullError)),
}
}
}
pub struct MySqlRow<'c> {
pub(super) row: protocol::Row<'c>,
pub(super) columns: Arc<HashMap<Box<str>, u16>>,
pub(super) binary: bool,
}
impl<'c> Row<'c> for MySqlRow<'c> {
type Database = MySql;
fn len(&self) -> usize {
self.row.len()
}
fn get<T, I>(&self, index: I) -> T
fn get_raw<'r, I>(&'r self, index: I) -> crate::Result<Option<MySqlValue<'c>>>
where
Self::Database: Type<T>,
I: RowIndex<Self>,
T: Decode<Self::Database>,
I: ColumnIndex<Self::Database>,
{
index.get(self).unwrap()
let index = index.resolve(self)?;
Ok(self.row.get(index).map(|mut buf| {
if self.binary {
MySqlValue::Binary(buf)
} else {
MySqlValue::Text(buf)
}
}))
}
}
impl RowIndex<MySqlRow> for usize {
fn get<T>(&self, row: &MySqlRow) -> crate::Result<T>
where
<MySqlRow as Row>::Database: Type<T>,
T: Decode<<MySqlRow as Row>::Database>,
{
Ok(Decode::decode_nullable(row.row.get(*self))?)
}
}
impl RowIndex<MySqlRow> for &'_ str {
fn get<T>(&self, row: &MySqlRow) -> crate::Result<T>
where
<MySqlRow as Row>::Database: Type<T>,
T: Decode<<MySqlRow as Row>::Database>,
{
let index = row
.columns
.get(*self)
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?;
let value = Decode::decode_nullable(row.row.get(*index))?;
Ok(value)
}
}
impl_from_row_for_row!(MySqlRow);

View File

@ -6,7 +6,7 @@ use rand::{thread_rng, Rng};
// For the love of crypto, please delete as much of this as possible and use the RSA crate
// directly when that PR is merged
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Box<[u8]>> {
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Vec<u8>> {
let key = std::str::from_utf8(key).map_err(|_err| {
// TODO(@abonander): protocol_err doesn't like referring to [err]
protocol_err!("unexpected error decoding what should be UTF-8")
@ -14,7 +14,7 @@ pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Box<[u8]>
let key = parse(key)?;
Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?.into_boxed_slice())
Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?)
}
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L12

View File

@ -0,0 +1,193 @@
use std::net::Shutdown;
use byteorder::{ByteOrder, LittleEndian};
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
use crate::mysql::protocol::{Capabilities, Decode, Encode, EofPacket, ErrPacket, OkPacket};
use crate::mysql::MySqlError;
use crate::url::Url;
// Size before a packet is split
const MAX_PACKET_SIZE: u32 = 1024;
pub(crate) struct MySqlStream {
pub(super) stream: BufStream<MaybeTlsStream>,
// Active capabilities
pub(super) capabilities: Capabilities,
// Packets in a command sequence have an incrementing sequence number
// This number must be 0 at the start of each command
pub(super) seq_no: u8,
// Packets are buffered into a second buffer from the stream
// as we may have compressed or split packets to figure out before
// decoding
packet_buf: Vec<u8>,
packet_len: usize,
}
impl MySqlStream {
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
let stream = MaybeTlsStream::connect(&url, 5432).await?;
let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::DEPRECATE_EOF
| Capabilities::FOUND_ROWS
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PLUGIN_AUTH;
if url.database().is_some() {
capabilities |= Capabilities::CONNECT_WITH_DB;
}
if cfg!(feature = "tls") {
capabilities |= Capabilities::SSL;
}
Ok(Self {
capabilities,
stream: BufStream::new(stream),
packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize),
packet_len: 0,
seq_no: 0,
})
}
pub(super) fn is_tls(&self) -> bool {
self.stream.is_tls()
}
pub(super) fn shutdown(&self) -> crate::Result<()> {
Ok(self.stream.shutdown(Shutdown::Both)?)
}
#[inline]
pub(super) async fn send<T>(&mut self, packet: T, initial: bool) -> crate::Result<()>
where
T: Encode + std::fmt::Debug,
{
if initial {
self.seq_no = 0;
}
self.write(packet);
self.flush().await
}
#[inline]
pub(super) async fn flush(&mut self) -> crate::Result<()> {
Ok(self.stream.flush().await?)
}
/// Write the packet to the buffered stream ( do not send to the server )
pub(super) fn write<T>(&mut self, packet: T)
where
T: Encode,
{
let buf = self.stream.buffer_mut();
// Allocate room for the header that we write after the packet;
// so, we can get an accurate and cheap measure of packet length
let header_offset = buf.len();
buf.advance(4);
packet.encode(buf, self.capabilities);
// Determine length of encoded packet
// and write to allocated header
let len = buf.len() - header_offset - 4;
let mut header = &mut buf[header_offset..];
LittleEndian::write_u32(&mut header, len as u32);
// Take the last sequence number received, if any, and increment by 1
// If there was no sequence number, we only increment if we split packets
header[3] = self.seq_no;
self.seq_no = self.seq_no.wrapping_add(1);
}
#[inline]
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
self.read().await?;
Ok(self.packet())
}
pub(super) async fn read(&mut self) -> crate::Result<()> {
self.packet_buf.clear();
self.packet_len = 0;
// Read the packet header which contains the length and the sequence number
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
let mut header = self.stream.peek(4_usize).await?;
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
self.seq_no = header.get_u8()?.wrapping_add(1);
self.stream.consume(4);
// Read the packet body and copy it into our internal buf
// We must have a separate buffer around the stream as we can't operate directly
// on bytes returned from the stream. We have various kinds of payload manipulation
// that must be handled before decoding.
let payload = self.stream.peek(self.packet_len).await?;
self.packet_buf.reserve(payload.len());
self.packet_buf.extend_from_slice(payload);
self.stream.consume(self.packet_len);
// TODO: Implement packet compression
// TODO: Implement packet joining
Ok(())
}
/// Returns a reference to the most recently received packet data.
/// A call to `read` invalidates this buffer.
#[inline]
pub(super) fn packet(&self) -> &[u8] {
&self.packet_buf[..self.packet_len]
}
}
impl MySqlStream {
pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> {
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.receive().await?)?;
}
Ok(())
}
pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result<bool> {
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.packet())?;
Ok(true)
} else {
Ok(false)
}
}
pub(crate) fn handle_unexpected<T>(&mut self) -> crate::Result<T> {
Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into())
}
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
Err(MySqlError(ErrPacket::decode(self.packet(), self.capabilities)?).into())
}
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
OkPacket::decode(self.packet())
}
}

115
sqlx-core/src/mysql/tls.rs Normal file
View File

@ -0,0 +1,115 @@
use std::borrow::Cow;
use std::str::FromStr;
use crate::mysql::protocol::{Capabilities, SslRequest};
use crate::mysql::stream::MySqlStream;
use crate::url::Url;
pub(super) async fn upgrade_if_needed(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
let ca_file = url.param("ssl-ca");
let ssl_mode = url.param("ssl-mode");
let supports_tls = stream.capabilities.contains(Capabilities::SSL);
// https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode
match ssl_mode.as_deref() {
Some("DISABLED") => {}
#[cfg(feature = "tls")]
Some("PREFERRED") | None if !supports_tls => {}
#[cfg(feature = "tls")]
Some("PREFERRED") => {
if let Err(error) = try_upgrade(stream, &url, None, true).await {
// TLS upgrade failed; fall back to a normal connection
}
}
#[cfg(feature = "tls")]
Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY")
if !supports_tls =>
{
return Err(tls_err!("server does not support TLS").into());
}
#[cfg(feature = "tls")]
Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") if ca_file.is_none() => {
return Err(
tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(),
);
}
#[cfg(feature = "tls")]
Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") => {
try_upgrade(
stream,
url,
// false for both verify-ca and verify-full
ca_file.as_deref(),
// false for only verify-full
mode != "VERIFY_IDENTITY",
)
.await?;
}
#[cfg(not(feature = "tls"))]
None => {
// The user neither explicitly enabled TLS in the connection string
// nor did they turn the `tls` feature on
}
#[cfg(not(feature = "tls"))]
Some(mode @ "PREFERRED")
| Some(mode @ "REQUIRED")
| Some(mode @ "VERIFY_CA")
| Some(mode @ "VERIFY_IDENTITY") => {
return Err(tls_err!(
"ssl-mode {:?} unsupported; SQLx was compiled without `tls` feature",
mode
)
.into());
}
Some(mode) => {
return Err(tls_err!("unknown `ssl-mode` value: {:?}", mode).into());
}
}
Ok(())
}
#[cfg(feature = "tls")]
async fn try_upgrade(
stream: &mut MySqlStream,
url: &Url,
ca_file: Option<&str>,
accept_invalid_hostnames: bool,
) -> crate::Result<()> {
use crate::runtime::fs;
use async_native_tls::{Certificate, TlsConnector};
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(ca_file.is_none())
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
if let Some(ca_file) = ca_file {
let root_cert = fs::read(ca_file).await?;
connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?);
}
// send upgrade request and then immediately try TLS handshake
stream
.send(
SslRequest {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
},
false,
)
.await?;
stream.stream.upgrade(url, connector).await
}

View File

@ -1,11 +1,14 @@
use crate::decode::{Decode, DecodeError};
use std::convert::TryInto;
use crate::decode::Decode;
use crate::encode::Encode;
use crate::error::UnexpectedNullError;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
impl Type<bool> for MySql {
impl Type<MySql> for bool {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT)
}
@ -17,13 +20,18 @@ impl Encode<MySql> for bool {
}
}
impl Decode<MySql> for bool {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
match buf.len() {
0 => Err(DecodeError::Message(Box::new(
"Expected minimum 1 byte but received none.",
))),
_ => Ok(buf[0] != 0),
impl<'de> Decode<'de, MySql> for bool {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()),
MySqlValue::Text(b"0") => Ok(false),
MySqlValue::Text(b"1") => Ok(true),
MySqlValue::Text(s) => Err(crate::Error::Decode(
format!("unexpected value {:?} for boolean", s).into(),
)),
}
}
}

View File

@ -1,14 +1,16 @@
use byteorder::LittleEndian;
use crate::decode::{Decode, DecodeError};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::error::UnexpectedNullError;
use crate::mysql::io::{BufExt, BufMutExt};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
use std::convert::TryInto;
impl Type<[u8]> for MySql {
impl Type<MySql> for [u8] {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo {
id: TypeId::TEXT,
@ -19,9 +21,9 @@ impl Type<[u8]> for MySql {
}
}
impl Type<Vec<u8>> for MySql {
impl Type<MySql> for Vec<u8> {
fn type_info() -> MySqlTypeInfo {
<Self as Type<[u8]>>::type_info()
<[u8] as Type<MySql>>::type_info()
}
}
@ -37,11 +39,36 @@ impl Encode<MySql> for Vec<u8> {
}
}
impl Decode<MySql> for Vec<u8> {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
Ok(buf
.get_bytes_lenenc::<LittleEndian>()?
.unwrap_or_default()
.to_vec())
impl<'de> Decode<'de, MySql> for Vec<u8> {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => {
let len = buf
.get_uint_lenenc::<LittleEndian>()
.map_err(crate::Error::decode)?
.unwrap_or_default();
Ok((&buf[..(len as usize)]).to_vec())
}
MySqlValue::Text(s) => Ok(s.to_vec()),
}
}
}
impl<'de> Decode<'de, MySql> for &'de [u8] {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => {
let len = buf
.get_uint_lenenc::<LittleEndian>()
.map_err(crate::Error::decode)?
.unwrap_or_default();
Ok(&buf[..(len as usize)])
}
MySqlValue::Text(s) => Ok(s),
}
}
}

View File

@ -1,17 +1,19 @@
use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use byteorder::{ByteOrder, LittleEndian};
use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc};
use crate::decode::{Decode, DecodeError};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
use crate::Error;
use bitflags::_core::str::from_utf8;
impl Type<DateTime<Utc>> for MySql {
impl Type<MySql> for DateTime<Utc> {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIMESTAMP)
}
@ -23,15 +25,15 @@ impl Encode<MySql> for DateTime<Utc> {
}
}
impl Decode<MySql> for DateTime<Utc> {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
let naive: NaiveDateTime = Decode::<MySql>::decode(buf)?;
impl<'de> Decode<'de, MySql> for DateTime<Utc> {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
let naive: NaiveDateTime = Decode::<MySql>::decode(value)?;
Ok(DateTime::from_utc(naive, Utc))
}
}
impl Type<NaiveTime> for MySql {
impl Type<MySql> for NaiveTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIME)
}
@ -63,24 +65,33 @@ impl Encode<MySql> for NaiveTime {
}
}
impl Decode<MySql> for NaiveTime {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
// data length, expecting 8 or 12 (fractional seconds)
let len = buf.get_u8()?;
impl<'de> Decode<'de, MySql> for NaiveTime {
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match buf.try_into()? {
MySqlValue::Binary(mut buf) => {
// data length, expecting 8 or 12 (fractional seconds)
let len = buf.get_u8()?;
// is negative : int<1>
let is_negative = buf.get_u8()?;
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
// is negative : int<1>
let is_negative = buf.get_u8()?;
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
buf.advance(4);
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
buf.advance(4);
decode_time(len - 5, buf)
decode_time(len - 5, buf)
}
MySqlValue::Text(buf) => {
let s = from_utf8(buf).map_err(Error::decode)?;
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode)
}
}
}
}
impl Type<NaiveDate> for MySql {
impl Type<MySql> for NaiveDate {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATE)
}
@ -98,13 +109,20 @@ impl Encode<MySql> for NaiveDate {
}
}
impl Decode<MySql> for NaiveDate {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
Ok(decode_date(&buf[1..]))
impl<'de> Decode<'de, MySql> for NaiveDate {
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match buf.try_into()? {
MySqlValue::Binary(buf) => Ok(decode_date(&buf[1..])),
MySqlValue::Text(buf) => {
let s = from_utf8(buf).map_err(Error::decode)?;
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode)
}
}
}
}
impl Type<NaiveDateTime> for MySql {
impl Type<MySql> for NaiveDateTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATETIME)
}
@ -144,18 +162,27 @@ impl Encode<MySql> for NaiveDateTime {
}
}
impl Decode<MySql> for NaiveDateTime {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
let len = buf[0];
let date = decode_date(&buf[1..]);
impl<'de> Decode<'de, MySql> for NaiveDateTime {
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match buf.try_into()? {
MySqlValue::Binary(buf) => {
let len = buf[0];
let date = decode_date(&buf[1..]);
let dt = if len > 4 {
date.and_time(decode_time(len - 4, &buf[5..])?)
} else {
date.and_hms(0, 0, 0)
};
let dt = if len > 4 {
date.and_time(decode_time(len - 4, &buf[5..])?)
} else {
date.and_hms(0, 0, 0)
};
Ok(dt)
Ok(dt)
},
MySqlValue::Text(buf) => {
let s = from_utf8(buf).map_err(Error::decode)?;
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Error::decode)
}
}
}
}
@ -187,7 +214,7 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
}
}
fn decode_time(len: u8, mut buf: &[u8]) -> Result<NaiveTime, DecodeError> {
fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result<NaiveTime> {
let hour = buf.get_u8()?;
let minute = buf.get_u8()?;
let seconds = buf.get_u8()?;

View File

@ -1,9 +1,16 @@
use crate::decode::{Decode, DecodeError};
use std::convert::TryInto;
use byteorder::{LittleEndian, ReadBytesExt};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::error::UnexpectedNullError;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
use crate::Error;
use std::str::from_utf8;
/// The equivalent MySQL type for `f32` is `FLOAT`.
///
@ -18,7 +25,7 @@ use crate::types::Type;
/// // (This is expected behavior for floating points and happens both in Rust and in MySQL)
/// assert_ne!(10.2f32 as f64, 10.2f64);
/// ```
impl Type<f32> for MySql {
impl Type<MySql> for f32 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::FLOAT)
}
@ -30,9 +37,19 @@ impl Encode<MySql> for f32 {
}
}
impl Decode<MySql> for f32 {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
Ok(f32::from_bits(<i32 as Decode<MySql>>::decode(buf)? as u32))
impl<'de> Decode<'de, MySql> for f32 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf
.read_i32::<LittleEndian>()
.map_err(crate::Error::decode)
.map(|value| f32::from_bits(value as u32)),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
@ -40,7 +57,7 @@ impl Decode<MySql> for f32 {
///
/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values
/// exactly.
impl Type<f64> for MySql {
impl Type<MySql> for f64 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DOUBLE)
}
@ -52,8 +69,18 @@ impl Encode<MySql> for f64 {
}
}
impl Decode<MySql> for f64 {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
Ok(f64::from_bits(<i64 as Decode<MySql>>::decode(buf)? as u64))
impl<'de> Decode<'de, MySql> for f64 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf
.read_i64::<LittleEndian>()
.map_err(crate::Error::decode)
.map(|value| f64::from_bits(value as u64)),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}

View File

@ -1,14 +1,18 @@
use byteorder::LittleEndian;
use std::convert::TryInto;
use std::str::from_utf8;
use crate::decode::{Decode, DecodeError};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::error::UnexpectedNullError;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
use crate::Error;
impl Type<i8> for MySql {
impl Type<MySql> for i8 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT)
}
@ -16,17 +20,24 @@ impl Type<i8> for MySql {
impl Encode<MySql> for i8 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.push(*self as u8);
buf.write_i8(*self);
}
}
impl Decode<MySql> for i8 {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
Ok(buf[0] as i8)
impl<'de> Decode<'de, MySql> for i8 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_i8().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
impl Type<i16> for MySql {
impl Type<MySql> for i16 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::SMALL_INT)
}
@ -34,17 +45,24 @@ impl Type<i16> for MySql {
impl Encode<MySql> for i16 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_i16::<LittleEndian>(*self);
buf.write_i16::<LittleEndian>(*self);
}
}
impl Decode<MySql> for i16 {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
buf.get_i16::<LittleEndian>().map_err(Into::into)
impl<'de> Decode<'de, MySql> for i16 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_i16::<LittleEndian>().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
impl Type<i32> for MySql {
impl Type<MySql> for i32 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::INT)
}
@ -52,17 +70,24 @@ impl Type<i32> for MySql {
impl Encode<MySql> for i32 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_i32::<LittleEndian>(*self);
buf.write_i32::<LittleEndian>(*self);
}
}
impl Decode<MySql> for i32 {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
buf.get_i32::<LittleEndian>().map_err(Into::into)
impl<'de> Decode<'de, MySql> for i32 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_i32::<LittleEndian>().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
impl Type<i64> for MySql {
impl Type<MySql> for i64 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::BIG_INT)
}
@ -70,14 +95,19 @@ impl Type<i64> for MySql {
impl Encode<MySql> for i64 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u64::<LittleEndian>(*self as u64);
buf.write_i64::<LittleEndian>(*self);
}
}
impl Decode<MySql> for i64 {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
buf.get_u64::<LittleEndian>()
.map_err(Into::into)
.map(|val| val as i64)
impl<'de> Decode<'de, MySql> for i64 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_i64::<LittleEndian>().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}

View File

@ -10,8 +10,10 @@ mod chrono;
use std::fmt::{self, Debug, Display};
use crate::decode::Decode;
use crate::mysql::protocol::TypeId;
use crate::mysql::protocol::{ColumnDefinition, FieldFlags};
use crate::mysql::{MySql, MySqlValue};
use crate::types::TypeInfo;
#[derive(Clone, Debug, Default)]
@ -103,3 +105,14 @@ impl TypeInfo for MySqlTypeInfo {
}
}
}
impl<'de, T> Decode<'de, MySql> for Option<T>
where
T: Decode<'de, MySql>,
{
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
value
.map(|value| <T as Decode<MySql>>::decode(Some(value)))
.transpose()
}
}

View File

@ -2,15 +2,18 @@ use std::str;
use byteorder::LittleEndian;
use crate::decode::{Decode, DecodeError};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::error::UnexpectedNullError;
use crate::mysql::io::{BufExt, BufMutExt};
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
use std::convert::TryInto;
use std::str::from_utf8;
impl Type<str> for MySql {
impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo {
id: TypeId::TEXT,
@ -27,10 +30,9 @@ impl Encode<MySql> for str {
}
}
// TODO: Do we need the [HasSqlType] for String
impl Type<String> for MySql {
impl Type<MySql> for String {
fn type_info() -> MySqlTypeInfo {
<Self as Type<&str>>::type_info()
<str as Type<MySql>>::type_info()
}
}
@ -40,11 +42,25 @@ impl Encode<MySql> for String {
}
}
impl Decode<MySql> for String {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
Ok(buf
.get_str_lenenc::<LittleEndian>()?
.unwrap_or_default()
.to_owned())
impl<'de> Decode<'de, MySql> for &'de str {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => {
let len = buf
.get_uint_lenenc::<LittleEndian>()
.map_err(crate::Error::decode)?
.unwrap_or_default();
from_utf8(&buf[..(len as usize)]).map_err(crate::Error::decode)
}
MySqlValue::Text(s) => from_utf8(s).map_err(crate::Error::decode),
}
}
}
impl<'de> Decode<'de, MySql> for String {
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
<&'de str>::decode(buf).map(ToOwned::to_owned)
}
}

View File

@ -1,14 +1,18 @@
use byteorder::LittleEndian;
use std::convert::TryInto;
use std::str::from_utf8;
use crate::decode::{Decode, DecodeError};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::error::UnexpectedNullError;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::mysql::{MySql, MySqlValue};
use crate::types::Type;
use crate::Error;
impl Type<u8> for MySql {
impl Type<MySql> for u8 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::TINY_INT)
}
@ -16,17 +20,24 @@ impl Type<u8> for MySql {
impl Encode<MySql> for u8 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.push(*self);
buf.write_u8(*self);
}
}
impl Decode<MySql> for u8 {
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
Ok(buf[0])
impl<'de> Decode<'de, MySql> for u8 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_u8().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
impl Type<u16> for MySql {
impl Type<MySql> for u16 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::SMALL_INT)
}
@ -34,17 +45,24 @@ impl Type<u16> for MySql {
impl Encode<MySql> for u16 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u16::<LittleEndian>(*self);
buf.write_u16::<LittleEndian>(*self);
}
}
impl Decode<MySql> for u16 {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
buf.get_u16::<LittleEndian>().map_err(Into::into)
impl<'de> Decode<'de, MySql> for u16 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_u16::<LittleEndian>().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
impl Type<u32> for MySql {
impl Type<MySql> for u32 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::INT)
}
@ -52,17 +70,24 @@ impl Type<u32> for MySql {
impl Encode<MySql> for u32 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u32::<LittleEndian>(*self);
buf.write_u32::<LittleEndian>(*self);
}
}
impl Decode<MySql> for u32 {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
buf.get_u32::<LittleEndian>().map_err(Into::into)
impl<'de> Decode<'de, MySql> for u32 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_u32::<LittleEndian>().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
impl Type<u64> for MySql {
impl Type<MySql> for u64 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::BIG_INT)
}
@ -70,12 +95,19 @@ impl Type<u64> for MySql {
impl Encode<MySql> for u64 {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u64::<LittleEndian>(*self);
buf.write_u64::<LittleEndian>(*self);
}
}
impl Decode<MySql> for u64 {
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
buf.get_u64::<LittleEndian>().map_err(Into::into)
impl<'de> Decode<'de, MySql> for u64 {
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {
MySqlValue::Binary(mut buf) => buf.read_u64::<LittleEndian>().map_err(Into::into),
MySqlValue::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}

View File

@ -28,3 +28,4 @@ make_query_as!(PgQueryAs, Postgres, PgRow);
impl_map_row_for_row!(Postgres, PgRow);
impl_column_index_for_row!(Postgres);
impl_from_row_for_tuples!(Postgres, PgRow);
impl_execute_for_query!(Postgres);

View File

@ -190,7 +190,7 @@ macro_rules! impl_column_index_for_row {
row.columns
.get(self)
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))
.map(|&index| index)
.map(|&index| index as usize)
}
}
};

View File

@ -78,6 +78,25 @@ pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro
));
}
if cfg!(feature = "mysql") {
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!('de));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::MySql>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
impls.push(quote!(
impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause {
fn decode(value: <sqlx::MySql as sqlx::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
<#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self)
}
}
));
}
// panic!("{}", q)
Ok(quote!(#(#impls)*))
}

View File

@ -65,14 +65,15 @@ macro_rules! test_prepared_type {
let mut conn = sqlx_test::new::<$db>().await?;
$(
let query = format!("SELECT {} = $1, $1 as _1", $text);
let query = format!($crate::[< $db _query_for_test_prepared_type >]!(), $text);
let rec: (bool, $ty) = sqlx::query_as(&query)
.bind($value)
.bind($value)
.fetch_one(&mut conn)
.await?;
assert!(rec.0);
assert!(rec.0, "value returned from server: {:?}", rec.1);
assert!($value == rec.1);
)+
@ -81,3 +82,17 @@ macro_rules! test_prepared_type {
}
}
}
#[macro_export]
macro_rules! MySql_query_for_test_prepared_type {
() => {
"SELECT {} <=> ?, ? as _1"
};
}
#[macro_export]
macro_rules! Postgres_query_for_test_prepared_type {
() => {
"SELECT {} is not distinct form $1, $2 as _1"
};
}

View File

@ -70,4 +70,7 @@ pub mod prelude {
#[cfg(feature = "postgres")]
pub use super::postgres::PgQueryAs;
#[cfg(feature = "mysql")]
pub use super::mysql::MySqlQueryAs;
}

View File

@ -58,3 +58,17 @@ where
let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap();
assert_eq!(example, decoded);
}
#[cfg(feature = "mysql")]
fn decode_with_db()
where
Foo: for<'de> Decode<'de, sqlx::MySql> + Encode<sqlx::MySql>,
{
let example = Foo(0x1122_3344);
let mut encoded = Vec::new();
Encode::<sqlx::MySql>::encode(&example, &mut encoded);
let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap();
assert_eq!(example, decoded);
}

View File

@ -1,9 +1,10 @@
use sqlx::MySqlConnection;
use sqlx::MySql;
use sqlx_test::new;
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn macro_select_from_cte() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<MySql>().await?;
let account =
sqlx::query!("select * from (select (1) as id, 'Herp Derpinson' as name) accounts")
.fetch_one(&mut conn)
@ -18,7 +19,7 @@ async fn macro_select_from_cte() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn macro_select_from_cte_bind() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<MySql>().await?;
let account = sqlx::query!(
"select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?",
1i32
@ -41,7 +42,7 @@ struct RawAccount {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_query_as_raw() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<MySql>().await?;
let account = sqlx::query_as!(
RawAccount,
@ -57,11 +58,3 @@ async fn test_query_as_raw() -> anyhow::Result<()> {
Ok(())
}
fn url() -> anyhow::Result<String> {
Ok(dotenv::var("DATABASE_URL")?)
}
async fn connect() -> anyhow::Result<MySqlConnection> {
Ok(MySqlConnection::open(url()?).await?)
}

56
tests/mysql-raw.rs Normal file
View File

@ -0,0 +1,56 @@
//! Tests for the raw (unprepared) query API for MySql.
use sqlx::{Cursor, Executor, MySql, Row};
use sqlx_test::new;
/// Test a simple select expression. This should return the row.
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_select_expression() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut cursor = conn.fetch("SELECT 5");
let row = cursor.next().await?.unwrap();
assert!(5i32 == row.get::<i32, _>(0)?);
Ok(())
}
/// Test that we can interleave reads and writes to the database
/// in one simple query. Using the `Cursor` API we should be
/// able to fetch from both queries in sequence.
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_multi_read_write() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut cursor = conn.fetch(
"
CREATE TEMPORARY TABLE messages (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
text TEXT NOT NULL
);
SELECT 'Hello World' as _1;
INSERT INTO messages (text) VALUES ('this is a test');
SELECT id, text FROM messages;
",
);
let row = cursor.next().await?.unwrap();
assert!("Hello World" == row.get::<&str, _>("_1")?);
let row = cursor.next().await?.unwrap();
let id: i64 = row.get("id")?;
let text: &str = row.get("text")?;
assert_eq!(1_i64, id);
assert_eq!("this is a test", text);
Ok(())
}

View File

@ -1,87 +0,0 @@
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveTime, Utc};
use sqlx::{mysql::MySqlConnection, Connection, Row};
async fn connect() -> anyhow::Result<MySqlConnection> {
Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?)
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn mysql_chrono_date() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = NaiveDate::from_ymd(2019, 1, 2);
let row = sqlx::query!(
"SELECT (DATE '2019-01-02' = ?) as _1, CAST(? AS DATE) as _2",
value,
value
)
.fetch_one(&mut conn)
.await?;
assert!(row._1 != 0);
assert_eq!(value, row._2);
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn mysql_chrono_date_time() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20);
let row = sqlx::query("SELECT '2019-01-02 05:10:20' = ?, ?")
.bind(&value)
.bind(&value)
.fetch_one(&mut conn)
.await?;
assert!(row.get::<bool, _>(0));
assert_eq!(value, row.get(1));
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn mysql_chrono_time() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = NaiveTime::from_hms_micro(5, 10, 20, 115100);
let row = sqlx::query("SELECT TIME '05:10:20.115100' = ?, TIME '05:10:20.115100'")
.bind(&value)
.fetch_one(&mut conn)
.await?;
assert!(row.get::<bool, _>(0));
assert_eq!(value, row.get(1));
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn mysql_chrono_timestamp() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = DateTime::<Utc>::from_utc(
NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100),
Utc,
);
let row = sqlx::query(
"SELECT TIMESTAMP '2019-01-02 05:10:20.115100' = ?, TIMESTAMP '2019-01-02 05:10:20.115100'",
)
.bind(&value)
.fetch_one(&mut conn)
.await?;
assert!(row.get::<bool, _>(0));
assert_eq!(value, row.get(1));
Ok(())
}

View File

@ -1,94 +1,86 @@
use sqlx::{mysql::MySqlConnection, Connection, Row};
use sqlx::MySql;
use sqlx_test::test_type;
async fn connect() -> anyhow::Result<MySqlConnection> {
Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?)
}
macro_rules! test {
($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn $name () -> anyhow::Result<()> {
let mut conn = connect().await?;
$(
let row = sqlx::query(&format!("SELECT {} = ?, ? as _1", $text))
.bind($value)
.bind($value)
.fetch_one(&mut conn)
.await?;
let value = row.get::<$ty, _>("_1");
assert_eq!(row.get::<i32, _>(0), 1, "value returned from server: {:?}", value);
assert_eq!($value, value);
)+
Ok(())
}
}
}
test!(mysql_bool: bool: "false" == false, "true" == true);
test!(mysql_tiny_unsigned: u8: "253" == 253_u8);
test!(mysql_tiny: i8: "5" == 5_i8);
test!(mysql_medium_unsigned: u16: "21415" == 21415_u16);
test!(mysql_short: i16: "21415" == 21415_i16);
test!(mysql_long_unsigned: u32: "2141512" == 2141512_u32);
test!(mysql_long: i32: "2141512" == 2141512_i32);
test!(mysql_longlong_unsigned: u64: "2141512" == 2141512_u64);
test!(mysql_longlong: i64: "2141512" == 2141512_i64);
// `DOUBLE` can be compared with decimal literals just fine but the same can't be said for `FLOAT`
test!(mysql_double: f64: "3.14159265" == 3.14159265f64);
test!(mysql_string: String: "'helloworld'" == "helloworld");
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn mysql_bytes() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = &b"Hello, World"[..];
let rec = sqlx::query!(
"SELECT (X'48656c6c6f2c20576f726c64' = ?) as _1, CAST(? as BINARY) as _2",
value,
value
)
.fetch_one(&mut conn)
.await?;
assert!(rec._1 != 0);
let output: Vec<u8> = rec._2;
assert_eq!(&value[..], &*output);
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn mysql_float() -> anyhow::Result<()> {
let mut conn = connect().await?;
let value = 10.2f32;
let row = sqlx::query("SELECT ? as _1")
.bind(value)
.fetch_one(&mut conn)
.await?;
// comparison between FLOAT and literal doesn't work as expected
// we get implicit widening to DOUBLE which gives a slightly different value
// however, round-trip does work as expected
let ret = row.get::<f32, _>("_1");
assert_eq!(value, ret);
Ok(())
test_type!(null(
MySql,
Option<i16>,
"NULL" == None::<i16>
));
test_type!(bool(MySql, bool, "false" == false, "true" == true));
test_type!(u8(MySql, u8, "253" == 253_u8));
test_type!(i8(MySql, i8, "5" == 5_i8, "0" == 0_i8));
test_type!(u16(MySql, u16, "21415" == 21415_u16));
test_type!(i16(MySql, i16, "21415" == 21415_i16));
test_type!(u32(MySql, u32, "2141512" == 2141512_u32));
test_type!(i32(MySql, i32, "2141512" == 2141512_i32));
test_type!(u64(MySql, u64, "2141512" == 2141512_u64));
test_type!(i64(MySql, i64, "2141512" == 2141512_i64));
test_type!(double(MySql, f64, "3.14159265" == 3.14159265f64));
// NOTE: This behavior can be very surprising. MySQL implicitly widens FLOAT bind parameters
// to DOUBLE. This results in the weirdness you see below. MySQL generally recommends to stay
// away from FLOATs.
test_type!(float(
MySql,
f32,
"3.1410000324249268" == 3.141f32 as f64 as f32
));
test_type!(string(
MySql,
String,
"'helloworld'" == "helloworld",
"''" == ""
));
test_type!(bytes(
MySql,
Vec<u8>,
"X'DEADBEEF'"
== vec![0xDE_u8, 0xAD, 0xBE, 0xEF],
"X''"
== Vec::<u8>::new(),
"X'0000000052'"
== vec![0_u8, 0, 0, 0, 0x52]
));
#[cfg(feature = "chrono")]
mod chrono {
use super::*;
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
test_type!(chrono_date(
MySql,
NaiveDate,
"DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5),
"DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23)
));
test_type!(chrono_time(
MySql,
NaiveTime,
"TIME '05:10:20.115100'" == NaiveTime::from_hms_micro(5, 10, 20, 115100)
));
test_type!(chrono_date_time(
MySql,
NaiveDateTime,
"'2019-01-02 05:10:20'" == NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20)
));
test_type!(chrono_date_time_tz(
MySql,
DateTime::<Utc>,
"TIMESTAMP '2019-01-02 05:10:20.115100'"
== DateTime::<Utc>::from_utc(
NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100),
Utc,
)
));
}

View File

@ -1,17 +1,24 @@
use futures::TryStreamExt;
use sqlx::{Connection as _, Executor as _, MySqlConnection, MySqlPool, Row as _};
use sqlx::{mysql::MySqlQueryAs, Connection, Executor, MySql, MySqlPool};
use sqlx_test::new;
use std::time::Duration;
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_connects() -> anyhow::Result<()> {
let mut conn = connect().await?;
Ok(new::<MySql>().await?.ping().await?)
}
let row = sqlx::query("select 1 + 1").fetch_one(&mut conn).await?;
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_drops_results_in_affected_rows() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
assert_eq!(2, row.get(0));
// ~1800 rows should be iterated and dropped
let affected = conn.execute("select * from mysql.time_zone").await?;
conn.close().await?;
// In MySQL, rows being returned isn't enough to flag it as an _affected_ row
assert_eq!(0, affected);
Ok(())
}
@ -19,10 +26,10 @@ async fn it_connects() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_executes() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<MySql>().await?;
let _ = conn
.send(
.execute(
r#"
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
"#,
@ -38,12 +45,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
assert_eq!(cnt, 1);
}
let sum: i32 = sqlx::query("SELECT id FROM users")
let sum: i32 = sqlx::query_as("SELECT id FROM users")
.fetch(&mut conn)
.try_fold(
0_i32,
|acc, x| async move { Ok(acc + x.get::<i32, _>("id")) },
)
.try_fold(0_i32, |acc, (x,): (i32,)| async move { Ok(acc + x) })
.await?;
assert_eq!(sum, 55);
@ -54,11 +58,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_selects_null() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<MySql>().await?;
let row = sqlx::query("SELECT NULL").fetch_one(&mut conn).await?;
let val: Option<i32> = row.get(0);
let (val,): (Option<i32>,) = sqlx::query_as("SELECT NULL").fetch_one(&mut conn).await?;
assert!(val.is_none());
@ -68,12 +70,10 @@ async fn it_selects_null() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_describe() -> anyhow::Result<()> {
use sqlx::describe::Nullability::*;
let mut conn = connect().await?;
let mut conn = new::<MySql>().await?;
let _ = conn
.send(
.execute(
r#"
CREATE TEMPORARY TABLE describe_test (
id int primary key auto_increment,
@ -88,13 +88,13 @@ async fn test_describe() -> anyhow::Result<()> {
.describe("select nt.*, false from describe_test nt")
.await?;
assert_eq!(describe.result_columns[0].nullability, NonNull);
assert_eq!(describe.result_columns[0].non_null, Some(true));
assert_eq!(describe.result_columns[0].type_info.type_name(), "INT");
assert_eq!(describe.result_columns[1].nullability, NonNull);
assert_eq!(describe.result_columns[1].non_null, Some(true));
assert_eq!(describe.result_columns[1].type_info.type_name(), "TEXT");
assert_eq!(describe.result_columns[2].nullability, Nullable);
assert_eq!(describe.result_columns[2].non_null, Some(false));
assert_eq!(describe.result_columns[2].type_info.type_name(), "TEXT");
assert_eq!(describe.result_columns[3].nullability, NonNull);
assert_eq!(describe.result_columns[3].non_null, Some(true));
let bool_ty_name = describe.result_columns[3].type_info.type_name();
@ -112,7 +112,7 @@ async fn test_describe() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn pool_immediately_fails_with_db_error() -> anyhow::Result<()> {
// Malform the database url by changing the password
let url = url()?.replace("password", "not-the-password");
let url = dotenv::var("DATABASE_URL")?.replace("password", "not-the-password");
let pool = MySqlPool::new(&url).await?;
@ -152,7 +152,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
let pool = pool.clone();
spawn(async move {
loop {
if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&mut &pool).await {
if let Err(e) = sqlx::query("select 1 + 1").execute(&mut &pool).await {
eprintln!("pool task {} dying due to {}", i, e);
break;
}
@ -185,11 +185,3 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
Ok(())
}
fn url() -> anyhow::Result<String> {
Ok(dotenv::var("DATABASE_URL")?)
}
async fn connect() -> anyhow::Result<MySqlConnection> {
Ok(MySqlConnection::open(url()?).await?)
}