mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 15:25:32 +00:00
Add zero-allocation to MySQL query execution
WIP mysql compiles with types and executor commented out
This commit is contained in:
parent
de14a206ff
commit
c9df8acc41
@ -74,6 +74,10 @@ required-features = [ "mysql", "macros" ]
|
||||
name = "mysql"
|
||||
required-features = [ "mysql" ]
|
||||
|
||||
[[test]]
|
||||
name = "mysql-raw"
|
||||
required-features = [ "mysql" ]
|
||||
|
||||
[[test]]
|
||||
name = "postgres"
|
||||
required-features = [ "postgres" ]
|
||||
|
@ -7,6 +7,8 @@ pub trait Buf {
|
||||
|
||||
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64>;
|
||||
|
||||
fn get_i8(&mut self) -> io::Result<i8>;
|
||||
|
||||
fn get_u8(&mut self) -> io::Result<u8>;
|
||||
|
||||
fn get_u16<T: ByteOrder>(&mut self) -> io::Result<u16>;
|
||||
@ -17,6 +19,8 @@ pub trait Buf {
|
||||
|
||||
fn get_i32<T: ByteOrder>(&mut self) -> io::Result<i32>;
|
||||
|
||||
fn get_i64<T: ByteOrder>(&mut self) -> io::Result<i64>;
|
||||
|
||||
fn get_u32<T: ByteOrder>(&mut self) -> io::Result<u32>;
|
||||
|
||||
fn get_u64<T: ByteOrder>(&mut self) -> io::Result<u64>;
|
||||
@ -40,6 +44,13 @@ impl<'a> Buf for &'a [u8] {
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn get_i8(&mut self) -> io::Result<i8> {
|
||||
let val = self[0];
|
||||
self.advance(1);
|
||||
|
||||
Ok(val as i8)
|
||||
}
|
||||
|
||||
fn get_u8(&mut self) -> io::Result<u8> {
|
||||
let val = self[0];
|
||||
self.advance(1);
|
||||
@ -75,6 +86,13 @@ impl<'a> Buf for &'a [u8] {
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn get_i64<T: ByteOrder>(&mut self) -> io::Result<i64> {
|
||||
let val = T::read_i64(*self);
|
||||
self.advance(4);
|
||||
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn get_u32<T: ByteOrder>(&mut self) -> io::Result<u32> {
|
||||
let val = T::read_u32(*self);
|
||||
self.advance(4);
|
||||
|
@ -1,7 +1,9 @@
|
||||
//! Core of SQLx, the rust SQL toolkit. Not intended to be used directly.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![recursion_limit = "512"]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
#![allow(unused)]
|
||||
|
||||
#[macro_use]
|
||||
pub mod error;
|
||||
|
@ -27,10 +27,10 @@ impl Arguments for MySqlArguments {
|
||||
|
||||
fn add<T>(&mut self, value: T)
|
||||
where
|
||||
Self::Database: Type<T>,
|
||||
T: Type<Self::Database>,
|
||||
T: Encode<Self::Database>,
|
||||
{
|
||||
let type_id = <MySql as Type<T>>::type_info();
|
||||
let type_id = <T as Type<MySql>>::type_info();
|
||||
let index = self.param_types.len();
|
||||
|
||||
self.param_types.push(type_id);
|
||||
|
@ -1,22 +1,25 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use futures_core::future::BoxFuture;
|
||||
use sha1::Sha1;
|
||||
use std::net::Shutdown;
|
||||
|
||||
use crate::cache::StatementCache;
|
||||
use crate::connection::{Connect, Connection};
|
||||
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
|
||||
use crate::mysql::error::MySqlError;
|
||||
use crate::mysql::protocol::{
|
||||
AuthPlugin, AuthSwitch, Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake,
|
||||
AuthPlugin, AuthSwitch, Capabilities, ComPing, Decode, Encode, EofPacket, ErrPacket, Handshake,
|
||||
HandshakeResponse, OkPacket, SslRequest,
|
||||
};
|
||||
use crate::mysql::rsa;
|
||||
use crate::mysql::stream::MySqlStream;
|
||||
use crate::mysql::util::xor_eq;
|
||||
use crate::mysql::{rsa, tls};
|
||||
use crate::url::Url;
|
||||
use std::ops::Range;
|
||||
|
||||
// Size before a packet is split
|
||||
const MAX_PACKET_SIZE: u32 = 1024;
|
||||
@ -85,521 +88,206 @@ const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
|
||||
/// against the hostname in the server certificate, so they must be the same for the TLS
|
||||
/// upgrade to succeed. `ssl-ca` must still be specified.
|
||||
pub struct MySqlConnection {
|
||||
pub(super) stream: BufStream<MaybeTlsStream>,
|
||||
pub(super) stream: MySqlStream,
|
||||
pub(super) is_ready: bool,
|
||||
pub(super) cache_statement: HashMap<Box<str>, u32>,
|
||||
|
||||
// Active capabilities of the client _&_ the server
|
||||
pub(super) capabilities: Capabilities,
|
||||
|
||||
// Cache of prepared statements
|
||||
// Query (String) to StatementId to ColumnMap
|
||||
pub(super) statement_cache: StatementCache<u32>,
|
||||
|
||||
// Packets are buffered into a second buffer from the stream
|
||||
// as we may have compressed or split packets to figure out before
|
||||
// decoding
|
||||
pub(super) packet: Vec<u8>,
|
||||
packet_len: usize,
|
||||
|
||||
// Packets in a command sequence have an incrementing sequence number
|
||||
// This number must be 0 at the start of each command
|
||||
pub(super) next_seq_no: u8,
|
||||
// Work buffer for the value ranges of the current row
|
||||
// This is used as the backing memory for each Row's value indexes
|
||||
pub(super) current_row_values: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
/// Write the packet to the stream ( do not send to the server )
|
||||
pub(crate) fn write(&mut self, packet: impl Encode) {
|
||||
let buf = self.stream.buffer_mut();
|
||||
fn to_asciz(s: &str) -> Vec<u8> {
|
||||
let mut z = String::with_capacity(s.len() + 1);
|
||||
z.push_str(s);
|
||||
z.push('\0');
|
||||
|
||||
// Allocate room for the header that we write after the packet;
|
||||
// so, we can get an accurate and cheap measure of packet length
|
||||
z.into_bytes()
|
||||
}
|
||||
|
||||
let header_offset = buf.len();
|
||||
buf.advance(4);
|
||||
async fn rsa_encrypt_with_nonce(
|
||||
stream: &mut MySqlStream,
|
||||
public_key_request_id: u8,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<Vec<u8>> {
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
|
||||
|
||||
packet.encode(buf, self.capabilities);
|
||||
|
||||
// Determine length of encoded packet
|
||||
// and write to allocated header
|
||||
|
||||
let len = buf.len() - header_offset - 4;
|
||||
let mut header = &mut buf[header_offset..];
|
||||
|
||||
LittleEndian::write_u32(&mut header, len as u32); // len
|
||||
|
||||
// Take the last sequence number received, if any, and increment by 1
|
||||
// If there was no sequence number, we only increment if we split packets
|
||||
header[3] = self.next_seq_no;
|
||||
self.next_seq_no = self.next_seq_no.wrapping_add(1);
|
||||
if stream.is_tls() {
|
||||
// If in a TLS stream, send the password directly in clear text
|
||||
return Ok(to_asciz(password));
|
||||
}
|
||||
|
||||
/// Send the packet to the database server
|
||||
pub(crate) async fn send(&mut self, packet: impl Encode) -> crate::Result<()> {
|
||||
self.write(packet);
|
||||
self.stream.flush().await?;
|
||||
// client sends a public key request
|
||||
stream.send(&[public_key_request_id][..], false).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// server sends a public key response
|
||||
let packet = stream.receive().await?;
|
||||
let rsa_pub_key = &packet[1..];
|
||||
|
||||
/// Send a [HandshakeResponse] packet to the database server
|
||||
pub(crate) async fn send_handshake_response(
|
||||
&mut self,
|
||||
url: &Url,
|
||||
auth_plugin: &AuthPlugin,
|
||||
auth_response: &[u8],
|
||||
) -> crate::Result<()> {
|
||||
self.send(HandshakeResponse {
|
||||
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
username: url.username().unwrap_or("root"),
|
||||
database: url.database(),
|
||||
auth_plugin,
|
||||
auth_response,
|
||||
})
|
||||
.await
|
||||
}
|
||||
// xor the password with the given nonce
|
||||
let mut pass = to_asciz(password);
|
||||
xor_eq(&mut pass, nonce);
|
||||
|
||||
/// Try to receive a packet from the database server. Returns `None` if the server has sent
|
||||
/// no data.
|
||||
pub(crate) async fn try_receive(&mut self) -> crate::Result<Option<()>> {
|
||||
self.packet.clear();
|
||||
// client sends an RSA encrypted password
|
||||
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
|
||||
}
|
||||
|
||||
// Read the packet header which contains the length and the sequence number
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
|
||||
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
|
||||
let mut header = ret_if_none!(self.stream.peek(4).await?);
|
||||
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
|
||||
self.next_seq_no = header.get_u8()?.wrapping_add(1);
|
||||
self.stream.consume(4);
|
||||
|
||||
// Read the packet body and copy it into our internal buf
|
||||
// We must have a separate buffer around the stream as we can't operate directly
|
||||
// on bytes returned from the stream. We have various kinds of payload manipulation
|
||||
// that must be handled before decoding.
|
||||
let payload = ret_if_none!(self.stream.peek(self.packet_len).await?);
|
||||
self.packet.extend_from_slice(payload);
|
||||
self.stream.consume(self.packet_len);
|
||||
|
||||
// TODO: Implement packet compression
|
||||
// TODO: Implement packet joining
|
||||
|
||||
Ok(Some(()))
|
||||
}
|
||||
|
||||
/// Receive a complete packet from the database server.
|
||||
pub(crate) async fn receive(&mut self) -> crate::Result<&mut Self> {
|
||||
self.try_receive()
|
||||
.await?
|
||||
.ok_or(io::ErrorKind::ConnectionAborted)?;
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Returns a reference to the most recently received packet data
|
||||
#[inline]
|
||||
pub(crate) fn packet(&self) -> &[u8] {
|
||||
&self.packet[..self.packet_len]
|
||||
}
|
||||
|
||||
/// Receive an [EofPacket] if we are supposed to receive them at all.
|
||||
pub(crate) async fn receive_eof(&mut self) -> crate::Result<()> {
|
||||
// When (legacy) EOFs are enabled, many things are terminated by an EOF packet
|
||||
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
let _eof = EofPacket::decode(self.receive().await?.packet())?;
|
||||
async fn make_auth_response(
|
||||
stream: &mut MySqlStream,
|
||||
plugin: &AuthPlugin,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<Vec<u8>> {
|
||||
match plugin {
|
||||
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
|
||||
Ok(plugin.scramble(password, nonce))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a [Handshake] packet. When connecting to the database server, this is immediately
|
||||
/// received from the database server.
|
||||
pub(crate) async fn receive_handshake(&mut self, url: &Url) -> crate::Result<Handshake> {
|
||||
let handshake = Handshake::decode(self.receive().await?.packet())?;
|
||||
|
||||
let mut client_capabilities = Capabilities::PROTOCOL_41
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::FOUND_ROWS
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::PLUGIN_AUTH;
|
||||
|
||||
if url.database().is_some() {
|
||||
client_capabilities |= Capabilities::CONNECT_WITH_DB;
|
||||
}
|
||||
|
||||
if cfg!(feature = "tls") {
|
||||
client_capabilities |= Capabilities::SSL;
|
||||
}
|
||||
|
||||
self.capabilities =
|
||||
(client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41;
|
||||
|
||||
Ok(handshake)
|
||||
}
|
||||
|
||||
/// Receives an [OkPacket] from the database server. This is called at the end of
|
||||
/// authentication to confirm the established connection.
|
||||
pub(crate) fn receive_auth_ok<'a>(
|
||||
&'a mut self,
|
||||
plugin: &'a AuthPlugin,
|
||||
password: &'a str,
|
||||
nonce: &'a [u8],
|
||||
) -> BoxFuture<'a, crate::Result<()>> {
|
||||
Box::pin(async move {
|
||||
self.receive().await?;
|
||||
|
||||
match self.packet[0] {
|
||||
0x00 => self.handle_ok().map(drop),
|
||||
0xfe => self.handle_auth_switch(password).await,
|
||||
0xff => self.handle_err(),
|
||||
|
||||
_ => self.handle_auth_continue(plugin, password, nonce).await,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
|
||||
let ok = OkPacket::decode(self.packet())?;
|
||||
|
||||
// An OK signifies the end of the current command sequence
|
||||
self.next_seq_no = 0;
|
||||
|
||||
Ok(ok)
|
||||
}
|
||||
|
||||
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
|
||||
let err = ErrPacket::decode(self.packet())?;
|
||||
|
||||
// An ERR signifies the end of the current command sequence
|
||||
self.next_seq_no = 0;
|
||||
|
||||
Err(MySqlError(err).into())
|
||||
}
|
||||
|
||||
pub(crate) fn handle_unexpected_packet<T>(&self, id: u8) -> crate::Result<T> {
|
||||
Err(protocol_err!("unexpected packet identifier 0x{:X?}", id).into())
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_auth_continue(
|
||||
&mut self,
|
||||
plugin: &AuthPlugin,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<()> {
|
||||
match plugin {
|
||||
AuthPlugin::CachingSha2Password => {
|
||||
if self.packet[0] == 1 {
|
||||
match self.packet[1] {
|
||||
// AUTH_OK
|
||||
0x03 => {}
|
||||
|
||||
// AUTH_CONTINUE
|
||||
0x04 => {
|
||||
// client sends an RSA encrypted password
|
||||
let ct = self.rsa_encrypt(0x02, password, nonce).await?;
|
||||
|
||||
self.send(&*ct).await?;
|
||||
}
|
||||
|
||||
auth => {
|
||||
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", auth).into());
|
||||
}
|
||||
}
|
||||
|
||||
// ends with server sending either OK_Packet or ERR_Packet
|
||||
self.receive_auth_ok(plugin, password, nonce)
|
||||
.await
|
||||
.map(drop)
|
||||
} else {
|
||||
return self.handle_unexpected_packet(self.packet[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// No other supported auth methods will be called through continue
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_auth_switch(&mut self, password: &str) -> crate::Result<()> {
|
||||
let auth = AuthSwitch::decode(self.packet())?;
|
||||
|
||||
let auth_response = self
|
||||
.make_auth_initial_response(&auth.auth_plugin, password, &auth.auth_plugin_data)
|
||||
.await?;
|
||||
|
||||
self.send(&*auth_response).await?;
|
||||
|
||||
self.receive_auth_ok(&auth.auth_plugin, password, &auth.auth_plugin_data)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn make_auth_initial_response(
|
||||
&mut self,
|
||||
plugin: &AuthPlugin,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<Vec<u8>> {
|
||||
match plugin {
|
||||
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
|
||||
Ok(plugin.scramble(password, nonce))
|
||||
}
|
||||
|
||||
AuthPlugin::Sha256Password => {
|
||||
// Full RSA exchange and password encrypt up front with no "cache"
|
||||
Ok(self.rsa_encrypt(0x01, password, nonce).await?.into_vec())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn rsa_encrypt(
|
||||
&mut self,
|
||||
public_key_request_id: u8,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<Box<[u8]>> {
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
|
||||
|
||||
if self.stream.is_tls() {
|
||||
// If in a TLS stream, send the password directly in clear text
|
||||
let mut clear_text = String::with_capacity(password.len() + 1);
|
||||
clear_text.push_str(password);
|
||||
clear_text.push('\0');
|
||||
|
||||
return Ok(clear_text.into_bytes().into_boxed_slice());
|
||||
}
|
||||
|
||||
// client sends a public key request
|
||||
self.send(&[public_key_request_id][..]).await?;
|
||||
|
||||
// server sends a public key response
|
||||
let packet = self.receive().await?.packet();
|
||||
let rsa_pub_key = &packet[1..];
|
||||
|
||||
// The password string data must be NUL terminated
|
||||
// Note: This is not in the documentation that I could find
|
||||
let mut pass = password.as_bytes().to_vec();
|
||||
pass.push(0);
|
||||
|
||||
xor_eq(&mut pass, nonce);
|
||||
|
||||
// client sends an RSA encrypted password
|
||||
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
|
||||
AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await,
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
async fn new(url: &Url) -> crate::Result<Self> {
|
||||
let stream = MaybeTlsStream::connect(url, 3306).await?;
|
||||
async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
|
||||
// https://mariadb.com/kb/en/connection/
|
||||
|
||||
let mut capabilities = Capabilities::empty();
|
||||
// Read a [Handshake] packet. When connecting to the database server, this is immediately
|
||||
// received from the database server.
|
||||
|
||||
if cfg!(feature = "tls") {
|
||||
capabilities |= Capabilities::SSL;
|
||||
}
|
||||
let handshake = Handshake::decode(stream.receive().await?)?;
|
||||
let mut auth_plugin = handshake.auth_plugin;
|
||||
let mut auth_plugin_data = handshake.auth_plugin_data;
|
||||
|
||||
Ok(Self {
|
||||
stream: BufStream::new(stream),
|
||||
capabilities,
|
||||
packet: Vec::with_capacity(8192),
|
||||
packet_len: 0,
|
||||
next_seq_no: 0,
|
||||
statement_cache: StatementCache::new(),
|
||||
})
|
||||
}
|
||||
stream.capabilities &= handshake.server_capabilities;
|
||||
stream.capabilities |= Capabilities::PROTOCOL_41;
|
||||
|
||||
async fn initialize(&mut self) -> crate::Result<()> {
|
||||
// On connect, we want to establish a modern, Rust-compatible baseline so we
|
||||
// tweak connection options to enable UTC for TIMESTAMP, UTF-8 for character types, etc.
|
||||
// Depending on the ssl-mode and capabilities we should upgrade
|
||||
// our connection to TLS
|
||||
|
||||
// TODO: Use batch support when we have it to handle the following in one execution
|
||||
tls::upgrade_if_needed(stream, url).await?;
|
||||
|
||||
// https://mariadb.com/kb/en/sql-mode/
|
||||
// Send a [HandshakeResponse] packet. This is returned in response to the [Handshake] packet
|
||||
// that is immediately received.
|
||||
|
||||
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
|
||||
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
|
||||
let password = &*url.password().unwrap_or_default();
|
||||
let auth_response =
|
||||
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
|
||||
|
||||
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
|
||||
// not available, a warning is given and the default storage
|
||||
// engine is used instead.
|
||||
|
||||
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
|
||||
|
||||
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
|
||||
|
||||
// language=MySQL
|
||||
self.execute_raw("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))")
|
||||
.await?;
|
||||
|
||||
// This allows us to assume that the output from a TIMESTAMP field is UTC
|
||||
|
||||
// language=MySQL
|
||||
self.execute_raw("SET time_zone = '+00:00'").await?;
|
||||
|
||||
// https://mathiasbynens.be/notes/mysql-utf8mb4
|
||||
|
||||
// language=MySQL
|
||||
self.execute_raw("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn try_ssl(
|
||||
&mut self,
|
||||
url: &Url,
|
||||
ca_file: Option<&str>,
|
||||
invalid_hostnames: bool,
|
||||
) -> crate::Result<()> {
|
||||
use crate::runtime::fs;
|
||||
use async_native_tls::{Certificate, TlsConnector};
|
||||
|
||||
let mut connector = TlsConnector::new()
|
||||
.danger_accept_invalid_certs(ca_file.is_none())
|
||||
.danger_accept_invalid_hostnames(invalid_hostnames);
|
||||
|
||||
if let Some(ca_file) = ca_file {
|
||||
let root_cert = fs::read(ca_file).await?;
|
||||
connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?);
|
||||
}
|
||||
|
||||
// send upgrade request and then immediately try TLS handshake
|
||||
self.send(SslRequest {
|
||||
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
})
|
||||
stream
|
||||
.send(
|
||||
HandshakeResponse {
|
||||
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
username: url.username().unwrap_or("root"),
|
||||
database: url.database(),
|
||||
auth_plugin: &auth_plugin,
|
||||
auth_response: &auth_response,
|
||||
},
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.stream.stream.upgrade(url, connector).await
|
||||
loop {
|
||||
// After sending the handshake response with our assumed auth method the server
|
||||
// will send OK, fail, or tell us to change auth methods
|
||||
let capabilities = stream.capabilities;
|
||||
let packet = stream.receive().await?;
|
||||
|
||||
match packet[0] {
|
||||
// OK
|
||||
0x00 => {
|
||||
break;
|
||||
}
|
||||
|
||||
// ERROR
|
||||
0xFF => {
|
||||
return stream.handle_err();
|
||||
}
|
||||
|
||||
// AUTH_SWITCH
|
||||
0xFE => {
|
||||
let auth = AuthSwitch::decode(packet)?;
|
||||
auth_plugin = auth.auth_plugin;
|
||||
auth_plugin_data = auth.auth_plugin_data;
|
||||
|
||||
let auth_response =
|
||||
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
|
||||
|
||||
stream.send(&*auth_response, false).await?;
|
||||
}
|
||||
|
||||
0x01 if auth_plugin == AuthPlugin::CachingSha2Password => {
|
||||
match packet[1] {
|
||||
// AUTH_OK
|
||||
0x03 => {}
|
||||
|
||||
// AUTH_CONTINUE
|
||||
0x04 => {
|
||||
// The specific password is _not_ cached on the server
|
||||
// We need to send a normal RSA-encrypted password for this
|
||||
let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data)
|
||||
.await?;
|
||||
|
||||
stream.send(&*enc, false).await?;
|
||||
}
|
||||
|
||||
unk => {
|
||||
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unk => {
|
||||
return stream.handle_unexpected();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close(mut stream: MySqlStream) -> crate::Result<()> {
|
||||
// TODO: Actually tell MySQL that we're closing
|
||||
|
||||
stream.flush().await?;
|
||||
stream.shutdown()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ping(stream: &mut MySqlStream) -> crate::Result<()> {
|
||||
stream.send(ComPing, true).await?;
|
||||
|
||||
match stream.receive().await?[0] {
|
||||
0x00 | 0xFE => Ok(()),
|
||||
|
||||
0xFF => stream.handle_err(),
|
||||
|
||||
_ => stream.handle_unexpected(),
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
pub(super) async fn establish(url: crate::Result<Url>) -> crate::Result<Self> {
|
||||
pub(super) async fn new(url: crate::Result<Url>) -> crate::Result<Self> {
|
||||
let url = url?;
|
||||
let mut self_ = Self::new(&url).await?;
|
||||
let mut stream = MySqlStream::new(&url).await?;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
|
||||
// https://mariadb.com/kb/en/connection/
|
||||
establish(&mut stream, &url).await?;
|
||||
|
||||
// On connect, server immediately sends the handshake
|
||||
let mut handshake = self_.receive_handshake(&url).await?;
|
||||
|
||||
let ca_file = url.param("ssl-ca");
|
||||
|
||||
let ssl_mode = url.param("ssl-mode").unwrap_or(
|
||||
if ca_file.is_some() {
|
||||
"VERIFY_CA"
|
||||
} else {
|
||||
"PREFERRED"
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
|
||||
let supports_ssl = handshake.server_capabilities.contains(Capabilities::SSL);
|
||||
|
||||
match &*ssl_mode {
|
||||
"DISABLED" => (),
|
||||
|
||||
// don't try upgrade
|
||||
#[cfg(feature = "tls")]
|
||||
"PREFERRED" if !supports_ssl => {
|
||||
log::warn!("server does not support TLS; using unencrypted connection")
|
||||
}
|
||||
|
||||
// try to upgrade
|
||||
#[cfg(feature = "tls")]
|
||||
"PREFERRED" => {
|
||||
if let Err(e) = self_.try_ssl(&url, None, true).await {
|
||||
log::warn!("TLS handshake failed, falling back to insecure: {}", e);
|
||||
// fallback, redo connection
|
||||
self_ = Self::new(&url).await?;
|
||||
handshake = self_.receive_handshake(&url).await?;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
"PREFERRED" => log::info!("compiled without TLS, skipping upgrade"),
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
"REQUIRED" if !supports_ssl => {
|
||||
return Err(tls_err!("server does not support TLS").into())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
"REQUIRED" => self_.try_ssl(&url, None, true).await?,
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
"VERIFY_CA" | "VERIFY_FULL" if ca_file.is_none() => {
|
||||
return Err(
|
||||
tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
"VERIFY_CA" | "VERIFY_FULL" => {
|
||||
self_
|
||||
.try_ssl(&url, ca_file.as_deref(), ssl_mode != "VERIFY_FULL")
|
||||
.await?
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
"REQUIRED" | "VERIFY_CA" | "VERIFY_FULL" => {
|
||||
return Err(tls_err!("compiled without TLS").into())
|
||||
}
|
||||
_ => return Err(tls_err!("unknown `ssl-mode` value: {:?}", ssl_mode).into()),
|
||||
}
|
||||
|
||||
// Pre-generate an auth response by using the auth method in the [Handshake]
|
||||
let password = url.password().unwrap_or_default();
|
||||
let auth_response = self_
|
||||
.make_auth_initial_response(
|
||||
&handshake.auth_plugin,
|
||||
&password,
|
||||
&handshake.auth_plugin_data,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self_
|
||||
.send_handshake_response(&url, &handshake.auth_plugin, &auth_response)
|
||||
.await?;
|
||||
|
||||
// After sending the handshake response with our assumed auth method the server
|
||||
// will send OK, fail, or tell us to change auth methods
|
||||
self_
|
||||
.receive_auth_ok(
|
||||
&handshake.auth_plugin,
|
||||
&password,
|
||||
&handshake.auth_plugin_data,
|
||||
)
|
||||
.await?;
|
||||
let mut self_ = Self {
|
||||
stream,
|
||||
current_row_values: Vec::with_capacity(10),
|
||||
is_ready: true,
|
||||
cache_statement: HashMap::new(),
|
||||
};
|
||||
|
||||
// After the connection is established, we initialize by configuring a few
|
||||
// connection parameters
|
||||
self_.initialize().await?;
|
||||
// initialize().await?;
|
||||
|
||||
Ok(self_)
|
||||
}
|
||||
|
||||
async fn close(mut self) -> crate::Result<()> {
|
||||
// TODO: Actually tell MySQL that we're closing
|
||||
|
||||
self.stream.flush().await?;
|
||||
self.stream.stream.shutdown(Shutdown::Both)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
#[deprecated(note = "please use 'connect' instead")]
|
||||
pub fn open<T>(url: T) -> BoxFuture<'static, crate::Result<Self>>
|
||||
where
|
||||
T: TryInto<Url, Error = crate::Error>,
|
||||
Self: Sized,
|
||||
{
|
||||
Box::pin(MySqlConnection::establish(url.try_into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Connect for MySqlConnection {
|
||||
@ -608,12 +296,16 @@ impl Connect for MySqlConnection {
|
||||
T: TryInto<Url, Error = crate::Error>,
|
||||
Self: Sized,
|
||||
{
|
||||
Box::pin(MySqlConnection::establish(url.try_into()))
|
||||
Box::pin(MySqlConnection::new(url.try_into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for MySqlConnection {
|
||||
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
|
||||
Box::pin(self.close())
|
||||
Box::pin(close(self.stream))
|
||||
}
|
||||
|
||||
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
|
||||
Box::pin(ping(&mut self.stream))
|
||||
}
|
||||
}
|
||||
|
158
sqlx-core/src/mysql/cursor.rs
Normal file
158
sqlx-core/src/mysql/cursor.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
use crate::connection::{ConnectionSource, MaybeOwnedConnection};
|
||||
use crate::cursor::Cursor;
|
||||
use crate::executor::Execute;
|
||||
use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Decode, Row, Status, TypeId};
|
||||
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow};
|
||||
use crate::pool::Pool;
|
||||
|
||||
pub struct MySqlCursor<'c, 'q> {
|
||||
source: ConnectionSource<'c, MySqlConnection>,
|
||||
query: Option<(&'q str, Option<MySqlArguments>)>,
|
||||
column_names: Arc<HashMap<Box<str>, u16>>,
|
||||
column_types: Vec<TypeId>,
|
||||
binary: bool,
|
||||
}
|
||||
|
||||
impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> {
|
||||
type Database = MySql;
|
||||
|
||||
#[doc(hidden)]
|
||||
fn from_pool<E>(pool: &Pool<MySqlConnection>, query: E) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
E: Execute<'q, MySql>,
|
||||
{
|
||||
Self {
|
||||
source: ConnectionSource::Pool(pool.clone()),
|
||||
column_names: Arc::default(),
|
||||
column_types: Vec::new(),
|
||||
binary: true,
|
||||
query: Some(query.into_parts()),
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn from_connection<E, C>(conn: C, query: E) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
C: Into<MaybeOwnedConnection<'c, MySqlConnection>>,
|
||||
E: Execute<'q, MySql>,
|
||||
{
|
||||
Self {
|
||||
source: ConnectionSource::Connection(conn.into()),
|
||||
column_names: Arc::default(),
|
||||
column_types: Vec::new(),
|
||||
binary: true,
|
||||
query: Some(query.into_parts()),
|
||||
}
|
||||
}
|
||||
|
||||
fn next(&mut self) -> BoxFuture<crate::Result<Option<MySqlRow<'_>>>> {
|
||||
Box::pin(next(self))
|
||||
}
|
||||
}
|
||||
|
||||
async fn next<'a, 'c: 'a, 'q: 'a>(
|
||||
cursor: &'a mut MySqlCursor<'c, 'q>,
|
||||
) -> crate::Result<Option<MySqlRow<'a>>> {
|
||||
println!("[cursor::next]");
|
||||
|
||||
let mut conn = cursor.source.resolve_by_ref().await?;
|
||||
|
||||
// The first time [next] is called we need to actually execute our
|
||||
// contained query. We guard against this happening on _all_ next calls
|
||||
// by using [Option::take] which replaces the potential value in the Option with `None
|
||||
let mut initial = if let Some((query, arguments)) = cursor.query.take() {
|
||||
let statement = conn.run(query, arguments).await?;
|
||||
|
||||
// No statement ID = TEXT mode
|
||||
cursor.binary = statement.is_some();
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
loop {
|
||||
let mut packet_id = conn.stream.receive().await?[0];
|
||||
println!("[cursor::next/iter] {:x}", packet_id);
|
||||
match packet_id {
|
||||
// OK or EOF packet
|
||||
0x00 | 0xFE
|
||||
if conn.stream.packet().len() < 0xFF_FF_FF && (packet_id != 0x00 || initial) =>
|
||||
{
|
||||
let ok = conn.stream.handle_ok()?;
|
||||
|
||||
if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
|
||||
// There is more to this query
|
||||
initial = true;
|
||||
} else {
|
||||
conn.is_ready = true;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// ERR packet
|
||||
0xFF => {
|
||||
conn.is_ready = true;
|
||||
return conn.stream.handle_err();
|
||||
}
|
||||
|
||||
_ if initial => {
|
||||
// At the start of the results we expect to see a
|
||||
// COLUMN_COUNT followed by N COLUMN_DEF
|
||||
|
||||
let cc = ColumnCount::decode(conn.stream.packet())?;
|
||||
|
||||
// We use these definitions to get the actual column types that is critical
|
||||
// in parsing the rows coming back soon
|
||||
|
||||
cursor.column_types.clear();
|
||||
cursor.column_types.reserve(cc.columns as usize);
|
||||
|
||||
let mut column_names = HashMap::with_capacity(cc.columns as usize);
|
||||
|
||||
for i in 0..cc.columns {
|
||||
let column = ColumnDefinition::decode(conn.stream.receive().await?)?;
|
||||
|
||||
cursor.column_types.push(column.type_id);
|
||||
|
||||
if let Some(name) = column.name() {
|
||||
column_names.insert(name.to_owned().into_boxed_str(), i as u16);
|
||||
}
|
||||
}
|
||||
|
||||
cursor.column_names = Arc::new(column_names);
|
||||
initial = false;
|
||||
}
|
||||
|
||||
_ if !cursor.binary || packet_id == 0x00 => {
|
||||
let row = Row::read(
|
||||
conn.stream.packet(),
|
||||
&cursor.column_types,
|
||||
&mut conn.current_row_values,
|
||||
// TODO: Text mode
|
||||
cursor.binary,
|
||||
)?;
|
||||
|
||||
let row = MySqlRow {
|
||||
row,
|
||||
columns: Arc::clone(&cursor.column_names),
|
||||
// TODO: Text mode
|
||||
binary: cursor.binary,
|
||||
};
|
||||
|
||||
return Ok(Some(row));
|
||||
}
|
||||
|
||||
_ => {
|
||||
return conn.stream.handle_unexpected();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -13,18 +13,18 @@ impl Database for MySql {
|
||||
type TableId = Box<str>;
|
||||
}
|
||||
|
||||
impl HasRow for MySql {
|
||||
impl<'c> HasRow<'c> for MySql {
|
||||
type Database = MySql;
|
||||
|
||||
type Row = super::MySqlRow;
|
||||
type Row = super::MySqlRow<'c>;
|
||||
}
|
||||
|
||||
impl<'a> HasCursor<'a> for MySql {
|
||||
impl<'c, 'q> HasCursor<'c, 'q> for MySql {
|
||||
type Database = MySql;
|
||||
|
||||
type Cursor = super::MySqlCursor<'a>;
|
||||
type Cursor = super::MySqlCursor<'c, 'q>;
|
||||
}
|
||||
|
||||
impl<'a> HasRawValue<'a> for MySql {
|
||||
type RawValue = Option<&'a [u8]>;
|
||||
impl<'c> HasRawValue<'c> for MySql {
|
||||
type RawValue = Option<super::MySqlValue<'c>>;
|
||||
}
|
||||
|
@ -4,341 +4,244 @@ use std::sync::Arc;
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
|
||||
use crate::describe::{Column, Describe, Nullability};
|
||||
use crate::executor::Executor;
|
||||
use crate::cursor::Cursor;
|
||||
use crate::describe::{Column, Describe};
|
||||
use crate::executor::{Execute, Executor, RefExecutor};
|
||||
use crate::mysql::protocol::{
|
||||
Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
|
||||
ComStmtPrepareOk, Cursor, Decode, EofPacket, FieldFlags, OkPacket, Row, TypeId,
|
||||
self, Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
|
||||
ComStmtPrepareOk, Decode, EofPacket, ErrPacket, FieldFlags, OkPacket, Row, TypeId,
|
||||
};
|
||||
use crate::mysql::{
|
||||
MySql, MySqlArguments, MySqlConnection, MySqlCursor, MySqlError, MySqlRow, MySqlTypeInfo,
|
||||
};
|
||||
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};
|
||||
|
||||
enum Step {
|
||||
Command(u64),
|
||||
Row(Row),
|
||||
}
|
||||
impl super::MySqlConnection {
|
||||
async fn wait_until_ready(&mut self) -> crate::Result<()> {
|
||||
if !self.is_ready {
|
||||
loop {
|
||||
let mut packet_id = self.stream.receive().await?[0];
|
||||
match packet_id {
|
||||
0xFE if self.stream.packet().len() < 0xFF_FF_FF => {
|
||||
// OK or EOF packet
|
||||
self.is_ready = true;
|
||||
break;
|
||||
}
|
||||
|
||||
enum OkOrResultSet {
|
||||
Ok(OkPacket),
|
||||
ResultSet(ColumnCount),
|
||||
}
|
||||
0xFF => {
|
||||
// ERR packet
|
||||
self.is_ready = true;
|
||||
return self.stream.handle_err();
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
async fn ignore_columns(&mut self, count: usize) -> crate::Result<()> {
|
||||
for _ in 0..count {
|
||||
let _column = ColumnDefinition::decode(self.receive().await?.packet())?;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
self.receive_eof().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn receive_ok_or_column_count(&mut self) -> crate::Result<OkOrResultSet> {
|
||||
self.receive().await?;
|
||||
|
||||
match self.packet[0] {
|
||||
0x00 | 0xfe if self.packet.len() < 0xffffff => self.handle_ok().map(OkOrResultSet::Ok),
|
||||
0xff => self.handle_err(),
|
||||
|
||||
_ => Ok(OkOrResultSet::ResultSet(ColumnCount::decode(
|
||||
self.packet(),
|
||||
)?)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn receive_column_types(&mut self, count: usize) -> crate::Result<Box<[TypeId]>> {
|
||||
let mut columns: Vec<TypeId> = Vec::with_capacity(count);
|
||||
|
||||
for _ in 0..count {
|
||||
let column: ColumnDefinition =
|
||||
ColumnDefinition::decode(self.receive().await?.packet())?;
|
||||
|
||||
columns.push(column.type_id);
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
self.receive_eof().await?;
|
||||
}
|
||||
|
||||
Ok(columns.into_boxed_slice())
|
||||
}
|
||||
|
||||
async fn wait_for_ready(&mut self) -> crate::Result<()> {
|
||||
if self.next_seq_no != 0 {
|
||||
while let Some(_step) = self.step(&[], true).await? {
|
||||
// Drain steps until we hit the end
|
||||
_ => {
|
||||
// Something else; skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Creates a prepared statement for the passed query string
|
||||
async fn prepare(&mut self, query: &str) -> crate::Result<ComStmtPrepareOk> {
|
||||
// Start by sending a COM_STMT_PREPARE
|
||||
self.send(ComStmtPrepare { query }).await?;
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_prepare.html
|
||||
self.stream.send(ComStmtPrepare { query }, true).await?;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html
|
||||
// Should receive a COM_STMT_PREPARE_OK or ERR_PACKET
|
||||
let packet = self.stream.receive().await?;
|
||||
|
||||
// First we should receive a COM_STMT_PREPARE_OK
|
||||
self.receive().await?;
|
||||
|
||||
if self.packet[0] == 0xff {
|
||||
// Oops, there was an error in the prepare command
|
||||
return self.handle_err();
|
||||
if packet[0] == 0xFF {
|
||||
return self.stream.handle_err();
|
||||
}
|
||||
|
||||
ComStmtPrepareOk::decode(self.packet())
|
||||
ComStmtPrepareOk::decode(packet)
|
||||
}
|
||||
|
||||
async fn prepare_with_cache(&mut self, query: &str) -> crate::Result<u32> {
|
||||
if let Some(&id) = self.statement_cache.get(query) {
|
||||
async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> {
|
||||
for _ in 0..count {
|
||||
let _column = ColumnDefinition::decode(self.stream.receive().await?)?;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
self.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Gets a cached prepared statement ID _or_ prepares the statement if not in the cache
|
||||
// At the end we should have [cache_statement] and [cache_statement_columns] filled
|
||||
async fn get_or_prepare(&mut self, query: &str) -> crate::Result<u32> {
|
||||
if let Some(&id) = self.cache_statement.get(query) {
|
||||
Ok(id)
|
||||
} else {
|
||||
let prepare_ok = self.prepare(query).await?;
|
||||
let stmt = self.prepare(query).await?;
|
||||
|
||||
// Remember our statement ID, so we do'd do this again the next time
|
||||
self.statement_cache
|
||||
.put(query.to_owned(), prepare_ok.statement_id);
|
||||
self.cache_statement.insert(query.into(), stmt.statement_id);
|
||||
|
||||
// Ignore input parameters
|
||||
self.ignore_columns(prepare_ok.params as usize).await?;
|
||||
// COM_STMT_PREPARE returns the input columns
|
||||
// We make no use of that data, so cycle through and drop them
|
||||
self.drop_column_defs(stmt.params as usize).await?;
|
||||
|
||||
// Collect output parameter names
|
||||
let mut columns = HashMap::with_capacity(prepare_ok.columns as usize);
|
||||
let mut index = 0_usize;
|
||||
for _ in 0..prepare_ok.columns {
|
||||
let column = ColumnDefinition::decode(self.receive().await?.packet())?;
|
||||
// COM_STMT_PREPARE next returns the output columns
|
||||
// We just drop these as we get these when we execute the query
|
||||
self.drop_column_defs(stmt.columns as usize).await?;
|
||||
|
||||
if let Some(name) = column.column_alias.or(column.column) {
|
||||
columns.insert(name, index);
|
||||
Ok(stmt.statement_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run(
|
||||
&mut self,
|
||||
query: &str,
|
||||
arguments: Option<MySqlArguments>,
|
||||
) -> crate::Result<Option<u32>> {
|
||||
self.wait_until_ready().await?;
|
||||
self.is_ready = false;
|
||||
|
||||
if let Some(arguments) = arguments {
|
||||
let statement_id = self.get_or_prepare(query).await?;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_execute.html
|
||||
self.stream
|
||||
.send(
|
||||
ComStmtExecute {
|
||||
cursor: protocol::Cursor::NO_CURSOR,
|
||||
statement_id,
|
||||
params: &arguments.params,
|
||||
null_bitmap: &arguments.null_bitmap,
|
||||
param_types: &arguments.param_types,
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Some(statement_id))
|
||||
} else {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_query.html
|
||||
self.stream.send(ComQuery { query }, true).await?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn affected_rows(&mut self) -> crate::Result<u64> {
|
||||
let mut rows = 0;
|
||||
|
||||
loop {
|
||||
let id = self.stream.receive().await?[0];
|
||||
|
||||
match id {
|
||||
0x00 | 0xFE if self.stream.packet().len() < 0xFF_FF_FF => {
|
||||
// ResultSet row can begin with 0xfe byte (when using text protocol
|
||||
// with a field length > 0xffffff)
|
||||
|
||||
if !self.stream.maybe_handle_eof()? {
|
||||
rows += self.stream.handle_ok()?.affected_rows;
|
||||
}
|
||||
|
||||
// EOF packets do not have affected rows
|
||||
// So this function is actually useless if the server doesn't support
|
||||
// proper OK packets
|
||||
|
||||
self.is_ready = true;
|
||||
break;
|
||||
}
|
||||
|
||||
index += 1;
|
||||
}
|
||||
|
||||
if prepare_ok.columns > 0 {
|
||||
self.receive_eof().await?;
|
||||
}
|
||||
|
||||
// At the end of a command, this should go back to 0
|
||||
self.next_seq_no = 0;
|
||||
|
||||
// Remember our column map in the statement cache
|
||||
self.statement_cache
|
||||
.put_columns(prepare_ok.statement_id, columns);
|
||||
|
||||
Ok(prepare_ok.statement_id)
|
||||
}
|
||||
}
|
||||
|
||||
// [COM_STMT_EXECUTE]
|
||||
async fn execute_statement(&mut self, id: u32, args: MySqlArguments) -> crate::Result<()> {
|
||||
self.send(ComStmtExecute {
|
||||
cursor: Cursor::NO_CURSOR,
|
||||
statement_id: id,
|
||||
params: &args.params,
|
||||
null_bitmap: &args.null_bitmap,
|
||||
param_types: &args.param_types,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn step(&mut self, columns: &[TypeId], binary: bool) -> crate::Result<Option<Step>> {
|
||||
let capabilities = self.capabilities;
|
||||
ret_if_none!(self.try_receive().await?);
|
||||
|
||||
match self.packet[0] {
|
||||
0xfe if self.packet.len() < 0xffffff => {
|
||||
// ResultSet row can begin with 0xfe byte (when using text protocol
|
||||
// with a field length > 0xffffff)
|
||||
|
||||
if !capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
let _eof = EofPacket::decode(self.packet())?;
|
||||
|
||||
// An EOF -here- signifies the end of the current command sequence
|
||||
self.next_seq_no = 0;
|
||||
|
||||
Ok(None)
|
||||
} else {
|
||||
self.handle_ok()
|
||||
.map(|ok| Some(Step::Command(ok.affected_rows)))
|
||||
0xFF => {
|
||||
return self.stream.handle_err();
|
||||
}
|
||||
}
|
||||
|
||||
0xff => self.handle_err(),
|
||||
|
||||
_ => Ok(Some(Step::Row(Row::decode(
|
||||
self.packet(),
|
||||
columns,
|
||||
binary,
|
||||
)?))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
pub(super) async fn execute_raw(&mut self, query: &str) -> crate::Result<()> {
|
||||
self.wait_for_ready().await?;
|
||||
|
||||
self.send(ComQuery { query }).await?;
|
||||
|
||||
// COM_QUERY can terminate before the result set with an ERR or OK packet
|
||||
let num_columns = match self.receive_ok_or_column_count().await? {
|
||||
OkOrResultSet::Ok(_) => {
|
||||
self.next_seq_no = 0;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
OkOrResultSet::ResultSet(cc) => cc.columns as usize,
|
||||
};
|
||||
|
||||
let columns = self.receive_column_types(num_columns as usize).await?;
|
||||
|
||||
while let Some(_step) = self.step(&columns, false).await? {
|
||||
// Drop all responses
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute(&mut self, query: &str, args: MySqlArguments) -> crate::Result<u64> {
|
||||
self.wait_for_ready().await?;
|
||||
|
||||
let statement_id = self.prepare_with_cache(query).await?;
|
||||
|
||||
self.execute_statement(statement_id, args).await?;
|
||||
|
||||
// COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet
|
||||
let num_columns = match self.receive_ok_or_column_count().await? {
|
||||
OkOrResultSet::Ok(ok) => {
|
||||
self.next_seq_no = 0;
|
||||
|
||||
return Ok(ok.affected_rows);
|
||||
}
|
||||
|
||||
OkOrResultSet::ResultSet(cc) => cc.columns as usize,
|
||||
};
|
||||
|
||||
self.ignore_columns(num_columns).await?;
|
||||
|
||||
let mut res = 0;
|
||||
|
||||
while let Some(step) = self.step(&[], true).await? {
|
||||
if let Step::Command(affected) = step {
|
||||
res = affected;
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
async fn describe(&mut self, query: &str) -> crate::Result<Describe<MySql>> {
|
||||
self.wait_for_ready().await?;
|
||||
// method is not named describe to work around an intellijrust bug
|
||||
// otherwise it marks someone trying to describe the connection as "method is private"
|
||||
async fn do_describe(&mut self, query: &str) -> crate::Result<Describe<MySql>> {
|
||||
self.wait_until_ready().await?;
|
||||
|
||||
let prepare_ok = self.prepare(query).await?;
|
||||
let stmt = self.prepare(query).await?;
|
||||
|
||||
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
|
||||
let mut result_columns = Vec::with_capacity(prepare_ok.columns as usize);
|
||||
let mut param_types = Vec::with_capacity(stmt.params as usize);
|
||||
let mut result_columns = Vec::with_capacity(stmt.columns as usize);
|
||||
|
||||
for _ in 0..prepare_ok.params {
|
||||
let param = ColumnDefinition::decode(self.receive().await?.packet())?;
|
||||
for _ in 0..stmt.params {
|
||||
let param = ColumnDefinition::decode(self.stream.receive().await?)?;
|
||||
param_types.push(MySqlTypeInfo::from_column_def(¶m));
|
||||
}
|
||||
|
||||
if prepare_ok.params > 0 {
|
||||
self.receive_eof().await?;
|
||||
if stmt.params > 0 {
|
||||
self.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
for _ in 0..prepare_ok.columns {
|
||||
let column = ColumnDefinition::decode(self.receive().await?.packet())?;
|
||||
for _ in 0..stmt.columns {
|
||||
let column = ColumnDefinition::decode(self.stream.receive().await?)?;
|
||||
|
||||
result_columns.push(Column::<MySql> {
|
||||
type_info: MySqlTypeInfo::from_column_def(&column),
|
||||
name: column.column_alias.or(column.column),
|
||||
table_id: column.table_alias.or(column.table),
|
||||
// TODO(@abonander): Should this be None in some cases?
|
||||
non_null: Some(column.flags.contains(FieldFlags::NOT_NULL)),
|
||||
});
|
||||
}
|
||||
|
||||
if prepare_ok.columns > 0 {
|
||||
self.receive_eof().await?;
|
||||
if stmt.columns > 0 {
|
||||
self.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
// Command sequence is over
|
||||
self.next_seq_no = 0;
|
||||
|
||||
Ok(Describe {
|
||||
param_types: param_types.into_boxed_slice(),
|
||||
result_columns: result_columns.into_boxed_slice(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
args: MySqlArguments,
|
||||
) -> BoxStream<'e, crate::Result<MySqlRow>> {
|
||||
Box::pin(async_stream::try_stream! {
|
||||
self.wait_for_ready().await?;
|
||||
impl Executor for super::MySqlConnection {
|
||||
type Database = MySql;
|
||||
|
||||
let statement_id = self.prepare_with_cache(query).await?;
|
||||
fn execute<'e, 'q, E: 'e>(&'e mut self, query: E) -> BoxFuture<'e, crate::Result<u64>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
Box::pin(async move {
|
||||
let (query, arguments) = query.into_parts();
|
||||
|
||||
let columns = self.statement_cache.get_columns(statement_id);
|
||||
|
||||
self.execute_statement(statement_id, args).await?;
|
||||
|
||||
// COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet
|
||||
let num_columns = match self.receive_ok_or_column_count().await? {
|
||||
OkOrResultSet::Ok(_) => {
|
||||
self.next_seq_no = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
OkOrResultSet::ResultSet(cc) => {
|
||||
cc.columns as usize
|
||||
}
|
||||
};
|
||||
|
||||
let column_types = self.receive_column_types(num_columns).await?;
|
||||
|
||||
while let Some(Step::Row(row)) = self.step(&column_types, true).await? {
|
||||
yield MySqlRow { row, columns: Arc::clone(&columns) };
|
||||
}
|
||||
self.run(query, arguments).await?;
|
||||
self.affected_rows().await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Executor for MySqlConnection {
|
||||
type Database = super::MySql;
|
||||
|
||||
fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, crate::Result<()>> {
|
||||
Box::pin(self.execute_raw(query))
|
||||
fn fetch<'q, E>(&mut self, query: E) -> MySqlCursor<'_, 'q>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
MySqlCursor::from_connection(self, query)
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e>(
|
||||
fn describe<'e, 'q, E: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
args: MySqlArguments,
|
||||
) -> BoxFuture<'e, crate::Result<u64>> {
|
||||
Box::pin(self.execute(query, args))
|
||||
}
|
||||
|
||||
fn fetch<'e, 'q: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
args: MySqlArguments,
|
||||
) -> BoxStream<'e, crate::Result<MySqlRow>> {
|
||||
self.fetch(query, args)
|
||||
}
|
||||
|
||||
fn describe<'e, 'q: 'e>(
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
|
||||
Box::pin(self.describe(query))
|
||||
query: E,
|
||||
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
Box::pin(async move { self.do_describe(query.into_parts().0).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl_execute_for_query!(MySql);
|
||||
impl<'c> RefExecutor<'c> for &'c mut super::MySqlConnection {
|
||||
type Database = MySql;
|
||||
|
||||
fn fetch_by_ref<'q, E>(self, query: E) -> MySqlCursor<'c, 'q>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
MySqlCursor::from_connection(self, query)
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,16 @@
|
||||
//! **MySQL** database and connection types.
|
||||
|
||||
pub use arguments::MySqlArguments;
|
||||
pub use connection::MySqlConnection;
|
||||
pub use cursor::MySqlCursor;
|
||||
pub use database::MySql;
|
||||
pub use error::MySqlError;
|
||||
pub use row::{MySqlRow, MySqlValue};
|
||||
pub use types::MySqlTypeInfo;
|
||||
|
||||
mod arguments;
|
||||
mod connection;
|
||||
mod cursor;
|
||||
mod database;
|
||||
mod error;
|
||||
mod executor;
|
||||
@ -9,20 +18,15 @@ mod io;
|
||||
mod protocol;
|
||||
mod row;
|
||||
mod rsa;
|
||||
mod stream;
|
||||
mod tls;
|
||||
mod types;
|
||||
mod util;
|
||||
|
||||
pub use database::MySql;
|
||||
|
||||
pub use arguments::MySqlArguments;
|
||||
|
||||
pub use connection::MySqlConnection;
|
||||
|
||||
pub use error::MySqlError;
|
||||
|
||||
pub use types::MySqlTypeInfo;
|
||||
|
||||
pub use row::MySqlRow;
|
||||
|
||||
/// An alias for [`Pool`], specialized for **MySQL**.
|
||||
pub type MySqlPool = super::Pool<MySqlConnection>;
|
||||
pub type MySqlPool = crate::pool::Pool<MySqlConnection>;
|
||||
|
||||
make_query_as!(MySqlQueryAs, MySql, MySqlRow);
|
||||
impl_map_row_for_row!(MySql, MySqlRow);
|
||||
impl_column_index_for_row!(MySql);
|
||||
impl_from_row_for_tuples!(MySql, MySqlRow);
|
||||
|
@ -6,7 +6,7 @@ use sha2::Sha256;
|
||||
|
||||
use crate::mysql::util::xor_eq;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum AuthPlugin {
|
||||
MySqlNativePassword,
|
||||
CachingSha2Password,
|
||||
|
@ -27,6 +27,12 @@ pub struct ColumnDefinition {
|
||||
pub decimals: u8,
|
||||
}
|
||||
|
||||
impl ColumnDefinition {
|
||||
pub fn name(&self) -> Option<&str> {
|
||||
self.column_alias.as_deref().or(self.column.as_deref())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode for ColumnDefinition {
|
||||
fn decode(mut buf: &[u8]) -> crate::Result<Self> {
|
||||
// catalog : string<lenenc>
|
||||
|
16
sqlx-core/src/mysql/protocol/com_ping.rs
Normal file
16
sqlx-core/src/mysql/protocol/com_ping.rs
Normal file
@ -0,0 +1,16 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::io::BufMutExt;
|
||||
use crate::mysql::protocol::{Capabilities, Encode};
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-ping.html
|
||||
#[derive(Debug)]
|
||||
pub struct ComPing;
|
||||
|
||||
impl Encode for ComPing {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
// COM_PING : int<1>
|
||||
buf.put_u8(0x0e);
|
||||
}
|
||||
}
|
@ -9,7 +9,8 @@ use crate::mysql::protocol::Decode;
|
||||
pub struct ComStmtPrepareOk {
|
||||
pub statement_id: u32,
|
||||
|
||||
/// Number of columns in the returned result set (or 0 if statement does not return result set).
|
||||
/// Number of columns in the returned result set (or 0 if statement
|
||||
/// does not return result set).
|
||||
pub columns: u16,
|
||||
|
||||
/// Number of prepared statement parameters ('?' placeholders).
|
||||
|
@ -9,24 +9,34 @@ use crate::mysql::protocol::{Capabilities, Decode, Status};
|
||||
#[derive(Debug)]
|
||||
pub struct ErrPacket {
|
||||
pub error_code: u16,
|
||||
pub sql_state: Box<str>,
|
||||
pub sql_state: Option<Box<str>>,
|
||||
pub error_message: Box<str>,
|
||||
}
|
||||
|
||||
impl Decode for ErrPacket {
|
||||
fn decode(mut buf: &[u8]) -> crate::Result<Self>
|
||||
impl ErrPacket {
|
||||
pub(crate) fn decode(mut buf: &[u8], capabilities: Capabilities) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0xFF {
|
||||
return Err(protocol_err!("expected 0xFF; received 0x{:X}", header))?;
|
||||
return Err(protocol_err!(
|
||||
"expected 0xFF for ERR_PACKET; received 0x{:X}",
|
||||
header
|
||||
))?;
|
||||
}
|
||||
|
||||
let error_code = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
let _sql_state_marker: u8 = buf.get_u8()?;
|
||||
let sql_state = buf.get_str(5)?.into();
|
||||
let mut sql_state = None;
|
||||
|
||||
if capabilities.contains(Capabilities::PROTOCOL_41) {
|
||||
// If the next byte is '#' then we have a SQL STATE
|
||||
if buf.get(0) == Some(&0x23) {
|
||||
buf.advance(1);
|
||||
sql_state = Some(buf.get_str(5)?.into())
|
||||
}
|
||||
}
|
||||
|
||||
let error_message = buf.get_str(buf.len())?.into();
|
||||
|
||||
@ -42,14 +52,25 @@ impl Decode for ErrPacket {
|
||||
mod tests {
|
||||
use super::{Capabilities, Decode, ErrPacket, Status};
|
||||
|
||||
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
|
||||
|
||||
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
|
||||
|
||||
#[test]
|
||||
fn it_decodes_packets_out_of_order() {
|
||||
let mut p = ErrPacket::decode(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap();
|
||||
|
||||
assert_eq!(&*p.error_message, "Got packets out of order");
|
||||
assert_eq!(p.error_code, 1156);
|
||||
assert_eq!(p.sql_state, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_decodes_ok_handshake() {
|
||||
let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB).unwrap();
|
||||
let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap();
|
||||
|
||||
assert_eq!(p.error_code, 1049);
|
||||
assert_eq!(&*p.sql_state, "42000");
|
||||
assert_eq!(p.sql_state.as_deref(), Some("42000"));
|
||||
assert_eq!(&*p.error_message, "Unknown database \'unknown\'");
|
||||
}
|
||||
}
|
||||
|
@ -20,12 +20,14 @@ pub use field::FieldFlags;
|
||||
pub use r#type::TypeId;
|
||||
pub use status::Status;
|
||||
|
||||
mod com_ping;
|
||||
mod com_query;
|
||||
mod com_set_option;
|
||||
mod com_stmt_execute;
|
||||
mod com_stmt_prepare;
|
||||
mod handshake;
|
||||
|
||||
pub use com_ping::ComPing;
|
||||
pub use com_query::ComQuery;
|
||||
pub use com_set_option::{ComSetOption, SetOption};
|
||||
pub use com_stmt_execute::{ComStmtExecute, Cursor};
|
||||
|
@ -6,73 +6,84 @@ use crate::io::Buf;
|
||||
use crate::mysql::io::BufExt;
|
||||
use crate::mysql::protocol::{Decode, TypeId};
|
||||
|
||||
pub struct Row {
|
||||
buffer: Box<[u8]>,
|
||||
values: Box<[Option<Range<usize>>]>,
|
||||
pub struct Row<'c> {
|
||||
buffer: &'c [u8],
|
||||
values: &'c [Option<Range<usize>>],
|
||||
binary: bool,
|
||||
}
|
||||
|
||||
impl Row {
|
||||
impl<'c> Row<'c> {
|
||||
pub fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn get(&self, index: usize) -> Option<&[u8]> {
|
||||
pub fn get(&self, index: usize) -> Option<&'c [u8]> {
|
||||
let range = self.values[index].as_ref()?;
|
||||
|
||||
Some(&self.buffer[(range.start as usize)..(range.end as usize)])
|
||||
}
|
||||
}
|
||||
|
||||
fn get_lenenc(buf: &[u8]) -> usize {
|
||||
fn get_lenenc(buf: &[u8]) -> (usize, Option<usize>) {
|
||||
match buf[0] {
|
||||
0xFB => 1,
|
||||
0xFB => (1, None),
|
||||
|
||||
0xFC => {
|
||||
let len_size = 1 + 2;
|
||||
let len = LittleEndian::read_u16(&buf[1..]);
|
||||
|
||||
len_size + len as usize
|
||||
(len_size, Some(len as usize))
|
||||
}
|
||||
|
||||
0xFD => {
|
||||
let len_size = 1 + 3;
|
||||
let len = LittleEndian::read_u24(&buf[1..]);
|
||||
|
||||
len_size + len as usize
|
||||
(len_size, Some(len as usize))
|
||||
}
|
||||
|
||||
0xFE => {
|
||||
let len_size = 1 + 8;
|
||||
let len = LittleEndian::read_u64(&buf[1..]);
|
||||
|
||||
len_size + len as usize
|
||||
(len_size, Some(len as usize))
|
||||
}
|
||||
|
||||
value => 1 + value as usize,
|
||||
len => (1, Some(len as usize)),
|
||||
}
|
||||
}
|
||||
|
||||
impl Row {
|
||||
pub fn decode(mut buf: &[u8], columns: &[TypeId], binary: bool) -> crate::Result<Self> {
|
||||
impl<'c> Row<'c> {
|
||||
pub fn read(
|
||||
mut buf: &'c [u8],
|
||||
columns: &[TypeId],
|
||||
values: &'c mut Vec<Option<Range<usize>>>,
|
||||
binary: bool,
|
||||
) -> crate::Result<Self> {
|
||||
let mut buffer = &*buf;
|
||||
|
||||
values.clear();
|
||||
values.reserve(columns.len());
|
||||
|
||||
if !binary {
|
||||
let buffer: Box<[u8]> = buf.into();
|
||||
let mut values = Vec::with_capacity(columns.len());
|
||||
let mut index = 0;
|
||||
|
||||
for column_idx in 0..columns.len() {
|
||||
let size = get_lenenc(&buf[index..]);
|
||||
let (len_size, size) = get_lenenc(&buf[index..]);
|
||||
|
||||
values.push(Some(index..(index + size)));
|
||||
if let Some(size) = size {
|
||||
values.push(Some((index + len_size)..(index + len_size + size)));
|
||||
} else {
|
||||
values.push(None);
|
||||
}
|
||||
|
||||
index += size;
|
||||
buf.advance(size);
|
||||
index += (len_size + size.unwrap_or_default());
|
||||
}
|
||||
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
values: values.into_boxed_slice(),
|
||||
binary,
|
||||
values: &*values,
|
||||
binary: false,
|
||||
});
|
||||
}
|
||||
|
||||
@ -88,7 +99,6 @@ impl Row {
|
||||
buf.advance(null_len);
|
||||
|
||||
let buffer: Box<[u8]> = buf.into();
|
||||
let mut values = Vec::with_capacity(columns.len());
|
||||
let mut index = 0;
|
||||
|
||||
for column_idx in 0..columns.len() {
|
||||
@ -117,7 +127,11 @@ impl Row {
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::VAR_CHAR => get_lenenc(&buffer[index..]),
|
||||
| TypeId::VAR_CHAR => {
|
||||
let (len_size, len) = get_lenenc(&buffer[index..]);
|
||||
|
||||
len_size + len.unwrap_or_default()
|
||||
}
|
||||
|
||||
id => {
|
||||
unimplemented!("encountered unknown field type id: {:?}", id);
|
||||
@ -130,174 +144,174 @@ impl Row {
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
values: values.into_boxed_slice(),
|
||||
buffer: buf,
|
||||
values: &*values,
|
||||
binary,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::super::column_count::ColumnCount;
|
||||
use super::super::column_def::ColumnDefinition;
|
||||
use super::super::eof::EofPacket;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn null_bitmap_test() -> crate::Result<()> {
|
||||
let column_len = ColumnCount::decode(&[26])?;
|
||||
assert_eq!(column_len.columns, 26);
|
||||
|
||||
let types: Vec<TypeId> = vec![
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0,
|
||||
0, 3, 11, 66, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105,
|
||||
101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105,
|
||||
101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105,
|
||||
101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105,
|
||||
101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105,
|
||||
101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105,
|
||||
101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105,
|
||||
101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105,
|
||||
101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102,
|
||||
105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102,
|
||||
105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102,
|
||||
105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102,
|
||||
105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102,
|
||||
105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102,
|
||||
105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102,
|
||||
105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102,
|
||||
105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102,
|
||||
105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102,
|
||||
105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102,
|
||||
105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102,
|
||||
105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102,
|
||||
105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102,
|
||||
105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102,
|
||||
105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102,
|
||||
105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0,
|
||||
])?,
|
||||
ColumnDefinition::decode(&[
|
||||
3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102,
|
||||
105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
])?,
|
||||
]
|
||||
.into_iter()
|
||||
.map(|def| def.type_id)
|
||||
.collect();
|
||||
|
||||
EofPacket::decode(&[254, 0, 0, 34, 0])?;
|
||||
|
||||
Row::decode(
|
||||
&[
|
||||
0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8,
|
||||
10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
],
|
||||
&types,
|
||||
true,
|
||||
)?;
|
||||
|
||||
EofPacket::decode(&[254, 0, 0, 34, 0])?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
// #[cfg(test)]
|
||||
// mod test {
|
||||
// use super::super::column_count::ColumnCount;
|
||||
// use super::super::column_def::ColumnDefinition;
|
||||
// use super::super::eof::EofPacket;
|
||||
// use super::*;
|
||||
//
|
||||
// #[test]
|
||||
// fn null_bitmap_test() -> crate::Result<()> {
|
||||
// let column_len = ColumnCount::decode(&[26])?;
|
||||
// assert_eq!(column_len.columns, 26);
|
||||
//
|
||||
// let types: Vec<TypeId> = vec![
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0,
|
||||
// 0, 3, 11, 66, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105,
|
||||
// 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105,
|
||||
// 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105,
|
||||
// 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105,
|
||||
// 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105,
|
||||
// 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105,
|
||||
// 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105,
|
||||
// 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105,
|
||||
// 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ]
|
||||
// .into_iter()
|
||||
// .map(|def| def.type_id)
|
||||
// .collect();
|
||||
//
|
||||
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
|
||||
//
|
||||
// Row::read(
|
||||
// &[
|
||||
// 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8,
|
||||
// 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// ],
|
||||
// &types,
|
||||
// true,
|
||||
// )?;
|
||||
//
|
||||
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
|
@ -1,59 +1,60 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::str::{from_utf8, Utf8Error};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::io::BufExt;
|
||||
use crate::mysql::protocol;
|
||||
use crate::mysql::MySql;
|
||||
use crate::row::{Row, RowIndex};
|
||||
use crate::row::{ColumnIndex, Row};
|
||||
use crate::types::Type;
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
pub struct MySqlRow {
|
||||
pub(super) row: protocol::Row,
|
||||
pub(super) columns: Arc<HashMap<Box<str>, usize>>,
|
||||
#[derive(Debug)]
|
||||
pub enum MySqlValue<'c> {
|
||||
Binary(&'c [u8]),
|
||||
Text(&'c [u8]),
|
||||
}
|
||||
|
||||
impl Row for MySqlRow {
|
||||
impl<'c> TryFrom<Option<MySqlValue<'c>>> for MySqlValue<'c> {
|
||||
type Error = crate::Error;
|
||||
|
||||
#[inline]
|
||||
fn try_from(value: Option<MySqlValue<'c>>) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
Some(value) => Ok(value),
|
||||
None => Err(crate::Error::decode(UnexpectedNullError)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MySqlRow<'c> {
|
||||
pub(super) row: protocol::Row<'c>,
|
||||
pub(super) columns: Arc<HashMap<Box<str>, u16>>,
|
||||
pub(super) binary: bool,
|
||||
}
|
||||
|
||||
impl<'c> Row<'c> for MySqlRow<'c> {
|
||||
type Database = MySql;
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.row.len()
|
||||
}
|
||||
|
||||
fn get<T, I>(&self, index: I) -> T
|
||||
fn get_raw<'r, I>(&'r self, index: I) -> crate::Result<Option<MySqlValue<'c>>>
|
||||
where
|
||||
Self::Database: Type<T>,
|
||||
I: RowIndex<Self>,
|
||||
T: Decode<Self::Database>,
|
||||
I: ColumnIndex<Self::Database>,
|
||||
{
|
||||
index.get(self).unwrap()
|
||||
let index = index.resolve(self)?;
|
||||
|
||||
Ok(self.row.get(index).map(|mut buf| {
|
||||
if self.binary {
|
||||
MySqlValue::Binary(buf)
|
||||
} else {
|
||||
MySqlValue::Text(buf)
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl RowIndex<MySqlRow> for usize {
|
||||
fn get<T>(&self, row: &MySqlRow) -> crate::Result<T>
|
||||
where
|
||||
<MySqlRow as Row>::Database: Type<T>,
|
||||
T: Decode<<MySqlRow as Row>::Database>,
|
||||
{
|
||||
Ok(Decode::decode_nullable(row.row.get(*self))?)
|
||||
}
|
||||
}
|
||||
|
||||
impl RowIndex<MySqlRow> for &'_ str {
|
||||
fn get<T>(&self, row: &MySqlRow) -> crate::Result<T>
|
||||
where
|
||||
<MySqlRow as Row>::Database: Type<T>,
|
||||
T: Decode<<MySqlRow as Row>::Database>,
|
||||
{
|
||||
let index = row
|
||||
.columns
|
||||
.get(*self)
|
||||
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))?;
|
||||
|
||||
let value = Decode::decode_nullable(row.row.get(*index))?;
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl_from_row_for_row!(MySqlRow);
|
||||
|
@ -6,7 +6,7 @@ use rand::{thread_rng, Rng};
|
||||
// For the love of crypto, please delete as much of this as possible and use the RSA crate
|
||||
// directly when that PR is merged
|
||||
|
||||
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Box<[u8]>> {
|
||||
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Vec<u8>> {
|
||||
let key = std::str::from_utf8(key).map_err(|_err| {
|
||||
// TODO(@abonander): protocol_err doesn't like referring to [err]
|
||||
protocol_err!("unexpected error decoding what should be UTF-8")
|
||||
@ -14,7 +14,7 @@ pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Box<[u8]>
|
||||
|
||||
let key = parse(key)?;
|
||||
|
||||
Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?.into_boxed_slice())
|
||||
Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?)
|
||||
}
|
||||
|
||||
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L12
|
||||
|
193
sqlx-core/src/mysql/stream.rs
Normal file
193
sqlx-core/src/mysql/stream.rs
Normal file
@ -0,0 +1,193 @@
|
||||
use std::net::Shutdown;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
|
||||
use crate::mysql::protocol::{Capabilities, Decode, Encode, EofPacket, ErrPacket, OkPacket};
|
||||
use crate::mysql::MySqlError;
|
||||
use crate::url::Url;
|
||||
|
||||
// Size before a packet is split
|
||||
const MAX_PACKET_SIZE: u32 = 1024;
|
||||
|
||||
pub(crate) struct MySqlStream {
|
||||
pub(super) stream: BufStream<MaybeTlsStream>,
|
||||
|
||||
// Active capabilities
|
||||
pub(super) capabilities: Capabilities,
|
||||
|
||||
// Packets in a command sequence have an incrementing sequence number
|
||||
// This number must be 0 at the start of each command
|
||||
pub(super) seq_no: u8,
|
||||
|
||||
// Packets are buffered into a second buffer from the stream
|
||||
// as we may have compressed or split packets to figure out before
|
||||
// decoding
|
||||
packet_buf: Vec<u8>,
|
||||
packet_len: usize,
|
||||
}
|
||||
|
||||
impl MySqlStream {
|
||||
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
|
||||
let stream = MaybeTlsStream::connect(&url, 5432).await?;
|
||||
|
||||
let mut capabilities = Capabilities::PROTOCOL_41
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::FOUND_ROWS
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH;
|
||||
|
||||
if url.database().is_some() {
|
||||
capabilities |= Capabilities::CONNECT_WITH_DB;
|
||||
}
|
||||
|
||||
if cfg!(feature = "tls") {
|
||||
capabilities |= Capabilities::SSL;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
capabilities,
|
||||
stream: BufStream::new(stream),
|
||||
packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize),
|
||||
packet_len: 0,
|
||||
seq_no: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn is_tls(&self) -> bool {
|
||||
self.stream.is_tls()
|
||||
}
|
||||
|
||||
pub(super) fn shutdown(&self) -> crate::Result<()> {
|
||||
Ok(self.stream.shutdown(Shutdown::Both)?)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) async fn send<T>(&mut self, packet: T, initial: bool) -> crate::Result<()>
|
||||
where
|
||||
T: Encode + std::fmt::Debug,
|
||||
{
|
||||
if initial {
|
||||
self.seq_no = 0;
|
||||
}
|
||||
|
||||
self.write(packet);
|
||||
self.flush().await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) async fn flush(&mut self) -> crate::Result<()> {
|
||||
Ok(self.stream.flush().await?)
|
||||
}
|
||||
|
||||
/// Write the packet to the buffered stream ( do not send to the server )
|
||||
pub(super) fn write<T>(&mut self, packet: T)
|
||||
where
|
||||
T: Encode,
|
||||
{
|
||||
let buf = self.stream.buffer_mut();
|
||||
|
||||
// Allocate room for the header that we write after the packet;
|
||||
// so, we can get an accurate and cheap measure of packet length
|
||||
|
||||
let header_offset = buf.len();
|
||||
buf.advance(4);
|
||||
|
||||
packet.encode(buf, self.capabilities);
|
||||
|
||||
// Determine length of encoded packet
|
||||
// and write to allocated header
|
||||
|
||||
let len = buf.len() - header_offset - 4;
|
||||
let mut header = &mut buf[header_offset..];
|
||||
|
||||
LittleEndian::write_u32(&mut header, len as u32);
|
||||
|
||||
// Take the last sequence number received, if any, and increment by 1
|
||||
// If there was no sequence number, we only increment if we split packets
|
||||
header[3] = self.seq_no;
|
||||
self.seq_no = self.seq_no.wrapping_add(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
|
||||
self.read().await?;
|
||||
|
||||
Ok(self.packet())
|
||||
}
|
||||
|
||||
pub(super) async fn read(&mut self) -> crate::Result<()> {
|
||||
self.packet_buf.clear();
|
||||
self.packet_len = 0;
|
||||
|
||||
// Read the packet header which contains the length and the sequence number
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
|
||||
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
|
||||
let mut header = self.stream.peek(4_usize).await?;
|
||||
|
||||
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
|
||||
self.seq_no = header.get_u8()?.wrapping_add(1);
|
||||
|
||||
self.stream.consume(4);
|
||||
|
||||
// Read the packet body and copy it into our internal buf
|
||||
// We must have a separate buffer around the stream as we can't operate directly
|
||||
// on bytes returned from the stream. We have various kinds of payload manipulation
|
||||
// that must be handled before decoding.
|
||||
let payload = self.stream.peek(self.packet_len).await?;
|
||||
|
||||
self.packet_buf.reserve(payload.len());
|
||||
self.packet_buf.extend_from_slice(payload);
|
||||
|
||||
self.stream.consume(self.packet_len);
|
||||
|
||||
// TODO: Implement packet compression
|
||||
// TODO: Implement packet joining
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a reference to the most recently received packet data.
|
||||
/// A call to `read` invalidates this buffer.
|
||||
#[inline]
|
||||
pub(super) fn packet(&self) -> &[u8] {
|
||||
&self.packet_buf[..self.packet_len]
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlStream {
|
||||
pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> {
|
||||
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
let _eof = EofPacket::decode(self.receive().await?)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result<bool> {
|
||||
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
let _eof = EofPacket::decode(self.packet())?;
|
||||
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_unexpected<T>(&mut self) -> crate::Result<T> {
|
||||
Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into())
|
||||
}
|
||||
|
||||
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
|
||||
Err(MySqlError(ErrPacket::decode(self.packet(), self.capabilities)?).into())
|
||||
}
|
||||
|
||||
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
|
||||
OkPacket::decode(self.packet())
|
||||
}
|
||||
}
|
115
sqlx-core/src/mysql/tls.rs
Normal file
115
sqlx-core/src/mysql/tls.rs
Normal file
@ -0,0 +1,115 @@
|
||||
use std::borrow::Cow;
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::mysql::protocol::{Capabilities, SslRequest};
|
||||
use crate::mysql::stream::MySqlStream;
|
||||
use crate::url::Url;
|
||||
|
||||
pub(super) async fn upgrade_if_needed(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
|
||||
let ca_file = url.param("ssl-ca");
|
||||
|
||||
let ssl_mode = url.param("ssl-mode");
|
||||
|
||||
let supports_tls = stream.capabilities.contains(Capabilities::SSL);
|
||||
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode
|
||||
match ssl_mode.as_deref() {
|
||||
Some("DISABLED") => {}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some("PREFERRED") | None if !supports_tls => {}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some("PREFERRED") => {
|
||||
if let Err(error) = try_upgrade(stream, &url, None, true).await {
|
||||
// TLS upgrade failed; fall back to a normal connection
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY")
|
||||
if !supports_tls =>
|
||||
{
|
||||
return Err(tls_err!("server does not support TLS").into());
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") if ca_file.is_none() => {
|
||||
return Err(
|
||||
tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") => {
|
||||
try_upgrade(
|
||||
stream,
|
||||
url,
|
||||
// false for both verify-ca and verify-full
|
||||
ca_file.as_deref(),
|
||||
// false for only verify-full
|
||||
mode != "VERIFY_IDENTITY",
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
None => {
|
||||
// The user neither explicitly enabled TLS in the connection string
|
||||
// nor did they turn the `tls` feature on
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Some(mode @ "PREFERRED")
|
||||
| Some(mode @ "REQUIRED")
|
||||
| Some(mode @ "VERIFY_CA")
|
||||
| Some(mode @ "VERIFY_IDENTITY") => {
|
||||
return Err(tls_err!(
|
||||
"ssl-mode {:?} unsupported; SQLx was compiled without `tls` feature",
|
||||
mode
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
Some(mode) => {
|
||||
return Err(tls_err!("unknown `ssl-mode` value: {:?}", mode).into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn try_upgrade(
|
||||
stream: &mut MySqlStream,
|
||||
url: &Url,
|
||||
ca_file: Option<&str>,
|
||||
accept_invalid_hostnames: bool,
|
||||
) -> crate::Result<()> {
|
||||
use crate::runtime::fs;
|
||||
|
||||
use async_native_tls::{Certificate, TlsConnector};
|
||||
|
||||
let mut connector = TlsConnector::new()
|
||||
.danger_accept_invalid_certs(ca_file.is_none())
|
||||
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
|
||||
|
||||
if let Some(ca_file) = ca_file {
|
||||
let root_cert = fs::read(ca_file).await?;
|
||||
|
||||
connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?);
|
||||
}
|
||||
|
||||
// send upgrade request and then immediately try TLS handshake
|
||||
stream
|
||||
.send(
|
||||
SslRequest {
|
||||
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
},
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
|
||||
stream.stream.upgrade(url, connector).await
|
||||
}
|
@ -1,11 +1,14 @@
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use std::convert::TryInto;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
|
||||
impl Type<bool> for MySql {
|
||||
impl Type<MySql> for bool {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TINY_INT)
|
||||
}
|
||||
@ -17,13 +20,18 @@ impl Encode<MySql> for bool {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for bool {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
match buf.len() {
|
||||
0 => Err(DecodeError::Message(Box::new(
|
||||
"Expected minimum 1 byte but received none.",
|
||||
))),
|
||||
_ => Ok(buf[0] != 0),
|
||||
impl<'de> Decode<'de, MySql> for bool {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()),
|
||||
|
||||
MySqlValue::Text(b"0") => Ok(false),
|
||||
|
||||
MySqlValue::Text(b"1") => Ok(true),
|
||||
|
||||
MySqlValue::Text(s) => Err(crate::Error::Decode(
|
||||
format!("unexpected value {:?} for boolean", s).into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,16 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::io::{BufExt, BufMutExt};
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
use std::convert::TryInto;
|
||||
|
||||
impl Type<[u8]> for MySql {
|
||||
impl Type<MySql> for [u8] {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo {
|
||||
id: TypeId::TEXT,
|
||||
@ -19,9 +21,9 @@ impl Type<[u8]> for MySql {
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Vec<u8>> for MySql {
|
||||
impl Type<MySql> for Vec<u8> {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
<Self as Type<[u8]>>::type_info()
|
||||
<[u8] as Type<MySql>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,11 +39,36 @@ impl Encode<MySql> for Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for Vec<u8> {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(buf
|
||||
.get_bytes_lenenc::<LittleEndian>()?
|
||||
.unwrap_or_default()
|
||||
.to_vec())
|
||||
impl<'de> Decode<'de, MySql> for Vec<u8> {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => {
|
||||
let len = buf
|
||||
.get_uint_lenenc::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)?
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok((&buf[..(len as usize)]).to_vec())
|
||||
}
|
||||
|
||||
MySqlValue::Text(s) => Ok(s.to_vec()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for &'de [u8] {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => {
|
||||
let len = buf
|
||||
.get_uint_lenenc::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)?
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(&buf[..(len as usize)])
|
||||
}
|
||||
|
||||
MySqlValue::Text(s) => Ok(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,17 +1,19 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc};
|
||||
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::io::{Buf, BufMut};
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
use bitflags::_core::str::from_utf8;
|
||||
|
||||
impl Type<DateTime<Utc>> for MySql {
|
||||
impl Type<MySql> for DateTime<Utc> {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TIMESTAMP)
|
||||
}
|
||||
@ -23,15 +25,15 @@ impl Encode<MySql> for DateTime<Utc> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for DateTime<Utc> {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
let naive: NaiveDateTime = Decode::<MySql>::decode(buf)?;
|
||||
impl<'de> Decode<'de, MySql> for DateTime<Utc> {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
let naive: NaiveDateTime = Decode::<MySql>::decode(value)?;
|
||||
|
||||
Ok(DateTime::from_utc(naive, Utc))
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<NaiveTime> for MySql {
|
||||
impl Type<MySql> for NaiveTime {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TIME)
|
||||
}
|
||||
@ -63,24 +65,33 @@ impl Encode<MySql> for NaiveTime {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for NaiveTime {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
// data length, expecting 8 or 12 (fractional seconds)
|
||||
let len = buf.get_u8()?;
|
||||
impl<'de> Decode<'de, MySql> for NaiveTime {
|
||||
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match buf.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => {
|
||||
// data length, expecting 8 or 12 (fractional seconds)
|
||||
let len = buf.get_u8()?;
|
||||
|
||||
// is negative : int<1>
|
||||
let is_negative = buf.get_u8()?;
|
||||
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
|
||||
// is negative : int<1>
|
||||
let is_negative = buf.get_u8()?;
|
||||
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
|
||||
|
||||
// "date on 4 bytes little-endian format" (?)
|
||||
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
|
||||
buf.advance(4);
|
||||
// "date on 4 bytes little-endian format" (?)
|
||||
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
|
||||
buf.advance(4);
|
||||
|
||||
decode_time(len - 5, buf)
|
||||
decode_time(len - 5, buf)
|
||||
}
|
||||
|
||||
MySqlValue::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(Error::decode)?;
|
||||
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<NaiveDate> for MySql {
|
||||
impl Type<MySql> for NaiveDate {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DATE)
|
||||
}
|
||||
@ -98,13 +109,20 @@ impl Encode<MySql> for NaiveDate {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for NaiveDate {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(decode_date(&buf[1..]))
|
||||
impl<'de> Decode<'de, MySql> for NaiveDate {
|
||||
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match buf.try_into()? {
|
||||
MySqlValue::Binary(buf) => Ok(decode_date(&buf[1..])),
|
||||
|
||||
MySqlValue::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(Error::decode)?;
|
||||
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<NaiveDateTime> for MySql {
|
||||
impl Type<MySql> for NaiveDateTime {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DATETIME)
|
||||
}
|
||||
@ -144,18 +162,27 @@ impl Encode<MySql> for NaiveDateTime {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for NaiveDateTime {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
let len = buf[0];
|
||||
let date = decode_date(&buf[1..]);
|
||||
impl<'de> Decode<'de, MySql> for NaiveDateTime {
|
||||
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match buf.try_into()? {
|
||||
MySqlValue::Binary(buf) => {
|
||||
let len = buf[0];
|
||||
let date = decode_date(&buf[1..]);
|
||||
|
||||
let dt = if len > 4 {
|
||||
date.and_time(decode_time(len - 4, &buf[5..])?)
|
||||
} else {
|
||||
date.and_hms(0, 0, 0)
|
||||
};
|
||||
let dt = if len > 4 {
|
||||
date.and_time(decode_time(len - 4, &buf[5..])?)
|
||||
} else {
|
||||
date.and_hms(0, 0, 0)
|
||||
};
|
||||
|
||||
Ok(dt)
|
||||
Ok(dt)
|
||||
},
|
||||
|
||||
MySqlValue::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(Error::decode)?;
|
||||
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Error::decode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,7 +214,7 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_time(len: u8, mut buf: &[u8]) -> Result<NaiveTime, DecodeError> {
|
||||
fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result<NaiveTime> {
|
||||
let hour = buf.get_u8()?;
|
||||
let minute = buf.get_u8()?;
|
||||
let seconds = buf.get_u8()?;
|
||||
|
@ -1,9 +1,16 @@
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use std::convert::TryInto;
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
use std::str::from_utf8;
|
||||
|
||||
/// The equivalent MySQL type for `f32` is `FLOAT`.
|
||||
///
|
||||
@ -18,7 +25,7 @@ use crate::types::Type;
|
||||
/// // (This is expected behavior for floating points and happens both in Rust and in MySQL)
|
||||
/// assert_ne!(10.2f32 as f64, 10.2f64);
|
||||
/// ```
|
||||
impl Type<f32> for MySql {
|
||||
impl Type<MySql> for f32 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::FLOAT)
|
||||
}
|
||||
@ -30,9 +37,19 @@ impl Encode<MySql> for f32 {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for f32 {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(f32::from_bits(<i32 as Decode<MySql>>::decode(buf)? as u32))
|
||||
impl<'de> Decode<'de, MySql> for f32 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf
|
||||
.read_i32::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)
|
||||
.map(|value| f32::from_bits(value as u32)),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,7 +57,7 @@ impl Decode<MySql> for f32 {
|
||||
///
|
||||
/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values
|
||||
/// exactly.
|
||||
impl Type<f64> for MySql {
|
||||
impl Type<MySql> for f64 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DOUBLE)
|
||||
}
|
||||
@ -52,8 +69,18 @@ impl Encode<MySql> for f64 {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for f64 {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(f64::from_bits(<i64 as Decode<MySql>>::decode(buf)? as u64))
|
||||
impl<'de> Decode<'de, MySql> for f64 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf
|
||||
.read_i64::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)
|
||||
.map(|value| f64::from_bits(value as u64)),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,18 @@
|
||||
use byteorder::LittleEndian;
|
||||
use std::convert::TryInto;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::io::{Buf, BufMut};
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
|
||||
impl Type<i8> for MySql {
|
||||
impl Type<MySql> for i8 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TINY_INT)
|
||||
}
|
||||
@ -16,17 +20,24 @@ impl Type<i8> for MySql {
|
||||
|
||||
impl Encode<MySql> for i8 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(*self as u8);
|
||||
buf.write_i8(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for i8 {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(buf[0] as i8)
|
||||
impl<'de> Decode<'de, MySql> for i8 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_i8().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<i16> for MySql {
|
||||
impl Type<MySql> for i16 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::SMALL_INT)
|
||||
}
|
||||
@ -34,17 +45,24 @@ impl Type<i16> for MySql {
|
||||
|
||||
impl Encode<MySql> for i16 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_i16::<LittleEndian>(*self);
|
||||
buf.write_i16::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for i16 {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
buf.get_i16::<LittleEndian>().map_err(Into::into)
|
||||
impl<'de> Decode<'de, MySql> for i16 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_i16::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<i32> for MySql {
|
||||
impl Type<MySql> for i32 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::INT)
|
||||
}
|
||||
@ -52,17 +70,24 @@ impl Type<i32> for MySql {
|
||||
|
||||
impl Encode<MySql> for i32 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_i32::<LittleEndian>(*self);
|
||||
buf.write_i32::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for i32 {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
buf.get_i32::<LittleEndian>().map_err(Into::into)
|
||||
impl<'de> Decode<'de, MySql> for i32 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_i32::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<i64> for MySql {
|
||||
impl Type<MySql> for i64 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::BIG_INT)
|
||||
}
|
||||
@ -70,14 +95,19 @@ impl Type<i64> for MySql {
|
||||
|
||||
impl Encode<MySql> for i64 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_u64::<LittleEndian>(*self as u64);
|
||||
buf.write_i64::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for i64 {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
buf.get_u64::<LittleEndian>()
|
||||
.map_err(Into::into)
|
||||
.map(|val| val as i64)
|
||||
impl<'de> Decode<'de, MySql> for i64 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_i64::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10,8 +10,10 @@ mod chrono;
|
||||
|
||||
use std::fmt::{self, Debug, Display};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::protocol::{ColumnDefinition, FieldFlags};
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::TypeInfo;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@ -103,3 +105,14 @@ impl TypeInfo for MySqlTypeInfo {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Decode<'de, MySql> for Option<T>
|
||||
where
|
||||
T: Decode<'de, MySql>,
|
||||
{
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
value
|
||||
.map(|value| <T as Decode<MySql>>::decode(Some(value)))
|
||||
.transpose()
|
||||
}
|
||||
}
|
||||
|
@ -2,15 +2,18 @@ use std::str;
|
||||
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::io::{BufExt, BufMutExt};
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
use std::convert::TryInto;
|
||||
use std::str::from_utf8;
|
||||
|
||||
impl Type<str> for MySql {
|
||||
impl Type<MySql> for str {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo {
|
||||
id: TypeId::TEXT,
|
||||
@ -27,10 +30,9 @@ impl Encode<MySql> for str {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Do we need the [HasSqlType] for String
|
||||
impl Type<String> for MySql {
|
||||
impl Type<MySql> for String {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
<Self as Type<&str>>::type_info()
|
||||
<str as Type<MySql>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,11 +42,25 @@ impl Encode<MySql> for String {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for String {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(buf
|
||||
.get_str_lenenc::<LittleEndian>()?
|
||||
.unwrap_or_default()
|
||||
.to_owned())
|
||||
impl<'de> Decode<'de, MySql> for &'de str {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => {
|
||||
let len = buf
|
||||
.get_uint_lenenc::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)?
|
||||
.unwrap_or_default();
|
||||
|
||||
from_utf8(&buf[..(len as usize)]).map_err(crate::Error::decode)
|
||||
}
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s).map_err(crate::Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for String {
|
||||
fn decode(buf: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
<&'de str>::decode(buf).map(ToOwned::to_owned)
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,18 @@
|
||||
use byteorder::LittleEndian;
|
||||
use std::convert::TryInto;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use crate::decode::{Decode, DecodeError};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::io::{Buf, BufMut};
|
||||
use crate::error::UnexpectedNullError;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
|
||||
impl Type<u8> for MySql {
|
||||
impl Type<MySql> for u8 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::TINY_INT)
|
||||
}
|
||||
@ -16,17 +20,24 @@ impl Type<u8> for MySql {
|
||||
|
||||
impl Encode<MySql> for u8 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(*self);
|
||||
buf.write_u8(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for u8 {
|
||||
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
Ok(buf[0])
|
||||
impl<'de> Decode<'de, MySql> for u8 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_u8().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<u16> for MySql {
|
||||
impl Type<MySql> for u16 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::SMALL_INT)
|
||||
}
|
||||
@ -34,17 +45,24 @@ impl Type<u16> for MySql {
|
||||
|
||||
impl Encode<MySql> for u16 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_u16::<LittleEndian>(*self);
|
||||
buf.write_u16::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for u16 {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
buf.get_u16::<LittleEndian>().map_err(Into::into)
|
||||
impl<'de> Decode<'de, MySql> for u16 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_u16::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<u32> for MySql {
|
||||
impl Type<MySql> for u32 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::INT)
|
||||
}
|
||||
@ -52,17 +70,24 @@ impl Type<u32> for MySql {
|
||||
|
||||
impl Encode<MySql> for u32 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_u32::<LittleEndian>(*self);
|
||||
buf.write_u32::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for u32 {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
buf.get_u32::<LittleEndian>().map_err(Into::into)
|
||||
impl<'de> Decode<'de, MySql> for u32 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_u32::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<u64> for MySql {
|
||||
impl Type<MySql> for u64 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::BIG_INT)
|
||||
}
|
||||
@ -70,12 +95,19 @@ impl Type<u64> for MySql {
|
||||
|
||||
impl Encode<MySql> for u64 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_u64::<LittleEndian>(*self);
|
||||
buf.write_u64::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<MySql> for u64 {
|
||||
fn decode(mut buf: &[u8]) -> Result<Self, DecodeError> {
|
||||
buf.get_u64::<LittleEndian>().map_err(Into::into)
|
||||
impl<'de> Decode<'de, MySql> for u64 {
|
||||
fn decode(value: Option<MySqlValue<'de>>) -> crate::Result<Self> {
|
||||
match value.try_into()? {
|
||||
MySqlValue::Binary(mut buf) => buf.read_u64::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlValue::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -28,3 +28,4 @@ make_query_as!(PgQueryAs, Postgres, PgRow);
|
||||
impl_map_row_for_row!(Postgres, PgRow);
|
||||
impl_column_index_for_row!(Postgres);
|
||||
impl_from_row_for_tuples!(Postgres, PgRow);
|
||||
impl_execute_for_query!(Postgres);
|
||||
|
@ -190,7 +190,7 @@ macro_rules! impl_column_index_for_row {
|
||||
row.columns
|
||||
.get(self)
|
||||
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))
|
||||
.map(|&index| index)
|
||||
.map(|&index| index as usize)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -78,6 +78,25 @@ pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!('de));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::MySql>));
|
||||
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
impls.push(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<'de, sqlx::MySql> for #ident #ty_generics #where_clause {
|
||||
fn decode(value: <sqlx::MySql as sqlx::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
|
||||
<#ty as sqlx::decode::Decode<'de, sqlx::MySql>>::decode(value).map(Self)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
// panic!("{}", q)
|
||||
Ok(quote!(#(#impls)*))
|
||||
}
|
||||
|
@ -65,14 +65,15 @@ macro_rules! test_prepared_type {
|
||||
let mut conn = sqlx_test::new::<$db>().await?;
|
||||
|
||||
$(
|
||||
let query = format!("SELECT {} = $1, $1 as _1", $text);
|
||||
let query = format!($crate::[< $db _query_for_test_prepared_type >]!(), $text);
|
||||
|
||||
let rec: (bool, $ty) = sqlx::query_as(&query)
|
||||
.bind($value)
|
||||
.bind($value)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(rec.0);
|
||||
assert!(rec.0, "value returned from server: {:?}", rec.1);
|
||||
assert!($value == rec.1);
|
||||
)+
|
||||
|
||||
@ -81,3 +82,17 @@ macro_rules! test_prepared_type {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! MySql_query_for_test_prepared_type {
|
||||
() => {
|
||||
"SELECT {} <=> ?, ? as _1"
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! Postgres_query_for_test_prepared_type {
|
||||
() => {
|
||||
"SELECT {} is not distinct form $1, $2 as _1"
|
||||
};
|
||||
}
|
||||
|
@ -70,4 +70,7 @@ pub mod prelude {
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
pub use super::postgres::PgQueryAs;
|
||||
|
||||
#[cfg(feature = "mysql")]
|
||||
pub use super::mysql::MySqlQueryAs;
|
||||
}
|
||||
|
@ -58,3 +58,17 @@ where
|
||||
let decoded = Foo::decode(Some(sqlx::postgres::PgValue::Binary(&encoded))).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
}
|
||||
|
||||
#[cfg(feature = "mysql")]
|
||||
fn decode_with_db()
|
||||
where
|
||||
Foo: for<'de> Decode<'de, sqlx::MySql> + Encode<sqlx::MySql>,
|
||||
{
|
||||
let example = Foo(0x1122_3344);
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
Encode::<sqlx::MySql>::encode(&example, &mut encoded);
|
||||
|
||||
let decoded = Foo::decode(Some(sqlx::mysql::MySqlValue::Binary(&encoded))).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
use sqlx::MySqlConnection;
|
||||
use sqlx::MySql;
|
||||
use sqlx_test::new;
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn macro_select_from_cte() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
let mut conn = new::<MySql>().await?;
|
||||
let account =
|
||||
sqlx::query!("select * from (select (1) as id, 'Herp Derpinson' as name) accounts")
|
||||
.fetch_one(&mut conn)
|
||||
@ -18,7 +19,7 @@ async fn macro_select_from_cte() -> anyhow::Result<()> {
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn macro_select_from_cte_bind() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
let mut conn = new::<MySql>().await?;
|
||||
let account = sqlx::query!(
|
||||
"select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?",
|
||||
1i32
|
||||
@ -41,7 +42,7 @@ struct RawAccount {
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn test_query_as_raw() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let account = sqlx::query_as!(
|
||||
RawAccount,
|
||||
@ -57,11 +58,3 @@ async fn test_query_as_raw() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn url() -> anyhow::Result<String> {
|
||||
Ok(dotenv::var("DATABASE_URL")?)
|
||||
}
|
||||
|
||||
async fn connect() -> anyhow::Result<MySqlConnection> {
|
||||
Ok(MySqlConnection::open(url()?).await?)
|
||||
}
|
||||
|
56
tests/mysql-raw.rs
Normal file
56
tests/mysql-raw.rs
Normal file
@ -0,0 +1,56 @@
|
||||
//! Tests for the raw (unprepared) query API for MySql.
|
||||
|
||||
use sqlx::{Cursor, Executor, MySql, Row};
|
||||
use sqlx_test::new;
|
||||
|
||||
/// Test a simple select expression. This should return the row.
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn test_select_expression() -> anyhow::Result<()> {
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let mut cursor = conn.fetch("SELECT 5");
|
||||
let row = cursor.next().await?.unwrap();
|
||||
|
||||
assert!(5i32 == row.get::<i32, _>(0)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Test that we can interleave reads and writes to the database
|
||||
/// in one simple query. Using the `Cursor` API we should be
|
||||
/// able to fetch from both queries in sequence.
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn test_multi_read_write() -> anyhow::Result<()> {
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let mut cursor = conn.fetch(
|
||||
"
|
||||
CREATE TEMPORARY TABLE messages (
|
||||
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
text TEXT NOT NULL
|
||||
);
|
||||
|
||||
SELECT 'Hello World' as _1;
|
||||
|
||||
INSERT INTO messages (text) VALUES ('this is a test');
|
||||
|
||||
SELECT id, text FROM messages;
|
||||
",
|
||||
);
|
||||
|
||||
let row = cursor.next().await?.unwrap();
|
||||
|
||||
assert!("Hello World" == row.get::<&str, _>("_1")?);
|
||||
|
||||
let row = cursor.next().await?.unwrap();
|
||||
|
||||
let id: i64 = row.get("id")?;
|
||||
let text: &str = row.get("text")?;
|
||||
|
||||
assert_eq!(1_i64, id);
|
||||
assert_eq!("this is a test", text);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,87 +0,0 @@
|
||||
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveTime, Utc};
|
||||
use sqlx::{mysql::MySqlConnection, Connection, Row};
|
||||
|
||||
async fn connect() -> anyhow::Result<MySqlConnection> {
|
||||
Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?)
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn mysql_chrono_date() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
let value = NaiveDate::from_ymd(2019, 1, 2);
|
||||
|
||||
let row = sqlx::query!(
|
||||
"SELECT (DATE '2019-01-02' = ?) as _1, CAST(? AS DATE) as _2",
|
||||
value,
|
||||
value
|
||||
)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(row._1 != 0);
|
||||
assert_eq!(value, row._2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn mysql_chrono_date_time() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
let value = NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20);
|
||||
|
||||
let row = sqlx::query("SELECT '2019-01-02 05:10:20' = ?, ?")
|
||||
.bind(&value)
|
||||
.bind(&value)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(row.get::<bool, _>(0));
|
||||
assert_eq!(value, row.get(1));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn mysql_chrono_time() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
let value = NaiveTime::from_hms_micro(5, 10, 20, 115100);
|
||||
|
||||
let row = sqlx::query("SELECT TIME '05:10:20.115100' = ?, TIME '05:10:20.115100'")
|
||||
.bind(&value)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(row.get::<bool, _>(0));
|
||||
assert_eq!(value, row.get(1));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn mysql_chrono_timestamp() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
let value = DateTime::<Utc>::from_utc(
|
||||
NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100),
|
||||
Utc,
|
||||
);
|
||||
|
||||
let row = sqlx::query(
|
||||
"SELECT TIMESTAMP '2019-01-02 05:10:20.115100' = ?, TIMESTAMP '2019-01-02 05:10:20.115100'",
|
||||
)
|
||||
.bind(&value)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(row.get::<bool, _>(0));
|
||||
assert_eq!(value, row.get(1));
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,94 +1,86 @@
|
||||
use sqlx::{mysql::MySqlConnection, Connection, Row};
|
||||
use sqlx::MySql;
|
||||
use sqlx_test::test_type;
|
||||
|
||||
async fn connect() -> anyhow::Result<MySqlConnection> {
|
||||
Ok(MySqlConnection::open(dotenv::var("DATABASE_URL")?).await?)
|
||||
}
|
||||
|
||||
macro_rules! test {
|
||||
($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn $name () -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
$(
|
||||
let row = sqlx::query(&format!("SELECT {} = ?, ? as _1", $text))
|
||||
.bind($value)
|
||||
.bind($value)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
let value = row.get::<$ty, _>("_1");
|
||||
|
||||
assert_eq!(row.get::<i32, _>(0), 1, "value returned from server: {:?}", value);
|
||||
|
||||
assert_eq!($value, value);
|
||||
)+
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test!(mysql_bool: bool: "false" == false, "true" == true);
|
||||
|
||||
test!(mysql_tiny_unsigned: u8: "253" == 253_u8);
|
||||
test!(mysql_tiny: i8: "5" == 5_i8);
|
||||
|
||||
test!(mysql_medium_unsigned: u16: "21415" == 21415_u16);
|
||||
test!(mysql_short: i16: "21415" == 21415_i16);
|
||||
|
||||
test!(mysql_long_unsigned: u32: "2141512" == 2141512_u32);
|
||||
test!(mysql_long: i32: "2141512" == 2141512_i32);
|
||||
|
||||
test!(mysql_longlong_unsigned: u64: "2141512" == 2141512_u64);
|
||||
test!(mysql_longlong: i64: "2141512" == 2141512_i64);
|
||||
|
||||
// `DOUBLE` can be compared with decimal literals just fine but the same can't be said for `FLOAT`
|
||||
test!(mysql_double: f64: "3.14159265" == 3.14159265f64);
|
||||
|
||||
test!(mysql_string: String: "'helloworld'" == "helloworld");
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn mysql_bytes() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
let value = &b"Hello, World"[..];
|
||||
|
||||
let rec = sqlx::query!(
|
||||
"SELECT (X'48656c6c6f2c20576f726c64' = ?) as _1, CAST(? as BINARY) as _2",
|
||||
value,
|
||||
value
|
||||
)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(rec._1 != 0);
|
||||
|
||||
let output: Vec<u8> = rec._2;
|
||||
|
||||
assert_eq!(&value[..], &*output);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn mysql_float() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
|
||||
let value = 10.2f32;
|
||||
let row = sqlx::query("SELECT ? as _1")
|
||||
.bind(value)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
// comparison between FLOAT and literal doesn't work as expected
|
||||
// we get implicit widening to DOUBLE which gives a slightly different value
|
||||
// however, round-trip does work as expected
|
||||
let ret = row.get::<f32, _>("_1");
|
||||
assert_eq!(value, ret);
|
||||
|
||||
Ok(())
|
||||
test_type!(null(
|
||||
MySql,
|
||||
Option<i16>,
|
||||
"NULL" == None::<i16>
|
||||
));
|
||||
|
||||
test_type!(bool(MySql, bool, "false" == false, "true" == true));
|
||||
|
||||
test_type!(u8(MySql, u8, "253" == 253_u8));
|
||||
test_type!(i8(MySql, i8, "5" == 5_i8, "0" == 0_i8));
|
||||
|
||||
test_type!(u16(MySql, u16, "21415" == 21415_u16));
|
||||
test_type!(i16(MySql, i16, "21415" == 21415_i16));
|
||||
|
||||
test_type!(u32(MySql, u32, "2141512" == 2141512_u32));
|
||||
test_type!(i32(MySql, i32, "2141512" == 2141512_i32));
|
||||
|
||||
test_type!(u64(MySql, u64, "2141512" == 2141512_u64));
|
||||
test_type!(i64(MySql, i64, "2141512" == 2141512_i64));
|
||||
|
||||
test_type!(double(MySql, f64, "3.14159265" == 3.14159265f64));
|
||||
|
||||
// NOTE: This behavior can be very surprising. MySQL implicitly widens FLOAT bind parameters
|
||||
// to DOUBLE. This results in the weirdness you see below. MySQL generally recommends to stay
|
||||
// away from FLOATs.
|
||||
test_type!(float(
|
||||
MySql,
|
||||
f32,
|
||||
"3.1410000324249268" == 3.141f32 as f64 as f32
|
||||
));
|
||||
|
||||
test_type!(string(
|
||||
MySql,
|
||||
String,
|
||||
"'helloworld'" == "helloworld",
|
||||
"''" == ""
|
||||
));
|
||||
|
||||
test_type!(bytes(
|
||||
MySql,
|
||||
Vec<u8>,
|
||||
"X'DEADBEEF'"
|
||||
== vec![0xDE_u8, 0xAD, 0xBE, 0xEF],
|
||||
"X''"
|
||||
== Vec::<u8>::new(),
|
||||
"X'0000000052'"
|
||||
== vec![0_u8, 0, 0, 0, 0x52]
|
||||
));
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
mod chrono {
|
||||
use super::*;
|
||||
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
|
||||
|
||||
test_type!(chrono_date(
|
||||
MySql,
|
||||
NaiveDate,
|
||||
"DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5),
|
||||
"DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23)
|
||||
));
|
||||
|
||||
test_type!(chrono_time(
|
||||
MySql,
|
||||
NaiveTime,
|
||||
"TIME '05:10:20.115100'" == NaiveTime::from_hms_micro(5, 10, 20, 115100)
|
||||
));
|
||||
|
||||
test_type!(chrono_date_time(
|
||||
MySql,
|
||||
NaiveDateTime,
|
||||
"'2019-01-02 05:10:20'" == NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20)
|
||||
));
|
||||
|
||||
test_type!(chrono_date_time_tz(
|
||||
MySql,
|
||||
DateTime::<Utc>,
|
||||
"TIMESTAMP '2019-01-02 05:10:20.115100'"
|
||||
== DateTime::<Utc>::from_utc(
|
||||
NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100),
|
||||
Utc,
|
||||
)
|
||||
));
|
||||
}
|
||||
|
@ -1,17 +1,24 @@
|
||||
use futures::TryStreamExt;
|
||||
use sqlx::{Connection as _, Executor as _, MySqlConnection, MySqlPool, Row as _};
|
||||
use sqlx::{mysql::MySqlQueryAs, Connection, Executor, MySql, MySqlPool};
|
||||
use sqlx_test::new;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn it_connects() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
Ok(new::<MySql>().await?.ping().await?)
|
||||
}
|
||||
|
||||
let row = sqlx::query("select 1 + 1").fetch_one(&mut conn).await?;
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn it_drops_results_in_affected_rows() -> anyhow::Result<()> {
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
assert_eq!(2, row.get(0));
|
||||
// ~1800 rows should be iterated and dropped
|
||||
let affected = conn.execute("select * from mysql.time_zone").await?;
|
||||
|
||||
conn.close().await?;
|
||||
// In MySQL, rows being returned isn't enough to flag it as an _affected_ row
|
||||
assert_eq!(0, affected);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -19,10 +26,10 @@ async fn it_connects() -> anyhow::Result<()> {
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn it_executes() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let _ = conn
|
||||
.send(
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
|
||||
"#,
|
||||
@ -38,12 +45,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
|
||||
assert_eq!(cnt, 1);
|
||||
}
|
||||
|
||||
let sum: i32 = sqlx::query("SELECT id FROM users")
|
||||
let sum: i32 = sqlx::query_as("SELECT id FROM users")
|
||||
.fetch(&mut conn)
|
||||
.try_fold(
|
||||
0_i32,
|
||||
|acc, x| async move { Ok(acc + x.get::<i32, _>("id")) },
|
||||
)
|
||||
.try_fold(0_i32, |acc, (x,): (i32,)| async move { Ok(acc + x) })
|
||||
.await?;
|
||||
|
||||
assert_eq!(sum, 55);
|
||||
@ -54,11 +58,9 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY)
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn it_selects_null() -> anyhow::Result<()> {
|
||||
let mut conn = connect().await?;
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let row = sqlx::query("SELECT NULL").fetch_one(&mut conn).await?;
|
||||
|
||||
let val: Option<i32> = row.get(0);
|
||||
let (val,): (Option<i32>,) = sqlx::query_as("SELECT NULL").fetch_one(&mut conn).await?;
|
||||
|
||||
assert!(val.is_none());
|
||||
|
||||
@ -68,12 +70,10 @@ async fn it_selects_null() -> anyhow::Result<()> {
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn test_describe() -> anyhow::Result<()> {
|
||||
use sqlx::describe::Nullability::*;
|
||||
|
||||
let mut conn = connect().await?;
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
let _ = conn
|
||||
.send(
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TEMPORARY TABLE describe_test (
|
||||
id int primary key auto_increment,
|
||||
@ -88,13 +88,13 @@ async fn test_describe() -> anyhow::Result<()> {
|
||||
.describe("select nt.*, false from describe_test nt")
|
||||
.await?;
|
||||
|
||||
assert_eq!(describe.result_columns[0].nullability, NonNull);
|
||||
assert_eq!(describe.result_columns[0].non_null, Some(true));
|
||||
assert_eq!(describe.result_columns[0].type_info.type_name(), "INT");
|
||||
assert_eq!(describe.result_columns[1].nullability, NonNull);
|
||||
assert_eq!(describe.result_columns[1].non_null, Some(true));
|
||||
assert_eq!(describe.result_columns[1].type_info.type_name(), "TEXT");
|
||||
assert_eq!(describe.result_columns[2].nullability, Nullable);
|
||||
assert_eq!(describe.result_columns[2].non_null, Some(false));
|
||||
assert_eq!(describe.result_columns[2].type_info.type_name(), "TEXT");
|
||||
assert_eq!(describe.result_columns[3].nullability, NonNull);
|
||||
assert_eq!(describe.result_columns[3].non_null, Some(true));
|
||||
|
||||
let bool_ty_name = describe.result_columns[3].type_info.type_name();
|
||||
|
||||
@ -112,7 +112,7 @@ async fn test_describe() -> anyhow::Result<()> {
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn pool_immediately_fails_with_db_error() -> anyhow::Result<()> {
|
||||
// Malform the database url by changing the password
|
||||
let url = url()?.replace("password", "not-the-password");
|
||||
let url = dotenv::var("DATABASE_URL")?.replace("password", "not-the-password");
|
||||
|
||||
let pool = MySqlPool::new(&url).await?;
|
||||
|
||||
@ -152,7 +152,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
|
||||
let pool = pool.clone();
|
||||
spawn(async move {
|
||||
loop {
|
||||
if let Err(e) = sqlx::query("select 1 + 1").fetch_one(&mut &pool).await {
|
||||
if let Err(e) = sqlx::query("select 1 + 1").execute(&mut &pool).await {
|
||||
eprintln!("pool task {} dying due to {}", i, e);
|
||||
break;
|
||||
}
|
||||
@ -185,11 +185,3 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn url() -> anyhow::Result<String> {
|
||||
Ok(dotenv::var("DATABASE_URL")?)
|
||||
}
|
||||
|
||||
async fn connect() -> anyhow::Result<MySqlConnection> {
|
||||
Ok(MySqlConnection::open(url()?).await?)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user