mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-03-27 13:46:32 +00:00
refactor(mysql): adapt to the 0.4.x core refactor
This commit is contained in:
@@ -1,43 +1,51 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use crate::arguments::Arguments;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::MySql;
|
||||
use crate::types::Type;
|
||||
use crate::mysql::{MySql, MySqlTypeInfo};
|
||||
|
||||
#[derive(Default)]
|
||||
/// Implementation of [`Arguments`] for MySQL.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MySqlArguments {
|
||||
pub(crate) param_types: Vec<MySqlTypeInfo>,
|
||||
pub(crate) params: Vec<u8>,
|
||||
pub(crate) values: Vec<u8>,
|
||||
pub(crate) types: Vec<MySqlTypeInfo>,
|
||||
pub(crate) null_bitmap: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Arguments for MySqlArguments {
|
||||
impl<'q> Arguments<'q> for MySqlArguments {
|
||||
type Database = MySql;
|
||||
|
||||
fn reserve(&mut self, len: usize, size: usize) {
|
||||
self.param_types.reserve(len);
|
||||
self.params.reserve(size);
|
||||
|
||||
// ensure we have enough size in the bitmap to hold at least `len` extra bits
|
||||
// the second `& 7` gives us 0 spare bits when param_types.len() is a multiple of 8
|
||||
let spare_bits = (8 - (self.param_types.len()) & 7) & 7;
|
||||
// ensure that if there are no spare bits left, `len = 1` reserves another byte
|
||||
self.null_bitmap.reserve((len + 7 - spare_bits) / 8);
|
||||
self.types.reserve(len);
|
||||
self.values.reserve(size);
|
||||
}
|
||||
|
||||
fn add<T>(&mut self, value: T)
|
||||
where
|
||||
T: Type<Self::Database>,
|
||||
T: Encode<Self::Database>,
|
||||
T: Encode<'q, Self::Database>,
|
||||
{
|
||||
let type_id = <T as Type<MySql>>::type_info();
|
||||
let index = self.param_types.len();
|
||||
let ty = value.produces();
|
||||
let index = self.types.len();
|
||||
|
||||
self.param_types.push(type_id);
|
||||
self.types.push(ty);
|
||||
self.null_bitmap.resize((index / 8) + 1, 0);
|
||||
|
||||
if let IsNull::Yes = value.encode_nullable(&mut self.params) {
|
||||
if let IsNull::Yes = value.encode(self) {
|
||||
self.null_bitmap[index / 8] |= (1 << index % 8) as u8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MySqlArguments {
|
||||
type Target = Vec<u8>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.values
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for MySqlArguments {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.values
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::ops::Range;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
use sha1::Sha1;
|
||||
|
||||
use crate::connection::{Connect, Connection};
|
||||
use crate::executor::Executor;
|
||||
use crate::mysql::protocol::{
|
||||
AuthPlugin, AuthSwitch, Capabilities, ComPing, Handshake, HandshakeResponse,
|
||||
};
|
||||
use crate::mysql::stream::MySqlStream;
|
||||
use crate::mysql::util::xor_eq;
|
||||
|
||||
use crate::mysql::{rsa, tls};
|
||||
use crate::url::Url;
|
||||
|
||||
// Size before a packet is split
|
||||
pub(super) const MAX_PACKET_SIZE: u32 = 1024;
|
||||
|
||||
pub(super) const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
|
||||
|
||||
/// An asynchronous connection to a [`MySql`] database.
|
||||
///
|
||||
/// The connection string expected by `MySqlConnection` should be a MySQL connection
|
||||
/// string, as documented at
|
||||
/// <https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html#connecting-using-uri>
|
||||
///
|
||||
/// ### TLS Support (requires `tls` feature)
|
||||
/// This connection type supports some of the same flags as the `mysql` CLI application for SSL
|
||||
/// connections, but they must be specified via the query segment of the connection string
|
||||
/// rather than as program arguments.
|
||||
///
|
||||
/// The same options for `--ssl-mode` are supported as the `ssl-mode` query parameter:
|
||||
/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
|
||||
///
|
||||
/// ```text
|
||||
/// mysql://<user>[:<password>]@<host>[:<port>]/<database>[?ssl-mode=<ssl-mode>[&ssl-ca=<path>]]
|
||||
/// ```
|
||||
/// where
|
||||
/// ```text
|
||||
/// ssl-mode = DISABLED | PREFERRED | REQUIRED | VERIFY_CA | VERIFY_IDENTITY
|
||||
/// path = percent (URL) encoded path on the local machine
|
||||
/// ```
|
||||
///
|
||||
/// If the `tls` feature is not enabled, `ssl-mode=DISABLED` and `ssl-mode=PREFERRED` are no-ops and
|
||||
/// `ssl-mode=REQUIRED`, `ssl-mode=VERIFY_CA` and `ssl-mode=VERIFY_IDENTITY` are forbidden
|
||||
/// (attempting to connect with these will return an error).
|
||||
///
|
||||
/// If the `tls` feature is enabled, an upgrade to TLS is attempted on every connection by default
|
||||
/// (equivalent to `ssl-mode=PREFERRED`). If the server does not support TLS (because `--ssl=0` was
|
||||
/// passed to the server or an invalid certificate or key was used:
|
||||
/// <https://dev.mysql.com/doc/refman/8.0/en/using-encrypted-connections.html>)
|
||||
/// then it falls back to an unsecured connection and logs a warning.
|
||||
///
|
||||
/// Add `ssl-mode=REQUIRED` to your connection string to emit an error if the TLS upgrade fails.
|
||||
///
|
||||
/// However, like with `mysql` the server certificate is **not** checked for validity by default.
|
||||
///
|
||||
/// Specifying `ssl-mode=VERIFY_CA` will cause the TLS upgrade to verify the server's SSL
|
||||
/// certificate against a local CA root certificate; this is not the system root certificate
|
||||
/// but is instead expected to be specified as a local path with the `ssl-ca` query parameter
|
||||
/// (percent-encoded so the URL remains valid).
|
||||
///
|
||||
/// If you're running MySQL locally it might look something like this (for `VERIFY_CA`):
|
||||
/// ```text
|
||||
/// mysql://root:password@localhost/my_database?ssl-mode=VERIFY_CA&ssl-ca=%2Fvar%2Flib%2Fmysql%2Fca.pem
|
||||
/// ```
|
||||
///
|
||||
/// `%2F` is the percent-encoding for forward slash (`/`). In the example we give `/var/lib/mysql/ca.pem`
|
||||
/// as the CA certificate path, which is generated by the MySQL server automatically if
|
||||
/// no certificate is manually specified. Note that the path may vary based on the default `my.cnf`
|
||||
/// packaged with MySQL for your Linux distribution. Also note that unlike MySQL, MariaDB does *not*
|
||||
/// generate certificates automatically and they must always be passed in to enable TLS.
|
||||
///
|
||||
/// If `ssl-ca` is not specified or the file cannot be read, then an error is returned.
|
||||
/// `ssl-ca` implies `ssl-mode=VERIFY_CA` so you only actually need to specify the former
|
||||
/// but you may prefer having both to be more explicit.
|
||||
///
|
||||
/// If `ssl-mode=VERIFY_IDENTITY` is specified, in addition to checking the certificate as with
|
||||
/// `ssl-mode=VERIFY_CA`, the hostname in the connection string will be verified
|
||||
/// against the hostname in the server certificate, so they must be the same for the TLS
|
||||
/// upgrade to succeed. `ssl-ca` must still be specified.
|
||||
pub struct MySqlConnection {
|
||||
pub(super) stream: MySqlStream,
|
||||
pub(super) is_ready: bool,
|
||||
pub(super) cache_statement: HashMap<Box<str>, u32>,
|
||||
|
||||
// Work buffer for the value ranges of the current row
|
||||
// This is used as the backing memory for each Row's value indexes
|
||||
pub(super) current_row_values: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
fn to_asciz(s: &str) -> Vec<u8> {
|
||||
let mut z = String::with_capacity(s.len() + 1);
|
||||
z.push_str(s);
|
||||
z.push('\0');
|
||||
|
||||
z.into_bytes()
|
||||
}
|
||||
|
||||
async fn rsa_encrypt_with_nonce(
|
||||
stream: &mut MySqlStream,
|
||||
public_key_request_id: u8,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<Vec<u8>> {
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
|
||||
|
||||
if stream.is_tls() {
|
||||
// If in a TLS stream, send the password directly in clear text
|
||||
return Ok(to_asciz(password));
|
||||
}
|
||||
|
||||
// client sends a public key request
|
||||
stream.send(&[public_key_request_id][..], false).await?;
|
||||
|
||||
// server sends a public key response
|
||||
let packet = stream.receive().await?;
|
||||
let rsa_pub_key = &packet[1..];
|
||||
|
||||
// xor the password with the given nonce
|
||||
let mut pass = to_asciz(password);
|
||||
xor_eq(&mut pass, nonce);
|
||||
|
||||
// client sends an RSA encrypted password
|
||||
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
|
||||
}
|
||||
|
||||
async fn make_auth_response(
|
||||
stream: &mut MySqlStream,
|
||||
plugin: &AuthPlugin,
|
||||
password: &str,
|
||||
nonce: &[u8],
|
||||
) -> crate::Result<Vec<u8>> {
|
||||
if password.is_empty() {
|
||||
// Empty password should not be sent
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
match plugin {
|
||||
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
|
||||
Ok(plugin.scramble(password, nonce))
|
||||
}
|
||||
|
||||
AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
|
||||
// https://mariadb.com/kb/en/connection/
|
||||
|
||||
// Read a [Handshake] packet. When connecting to the database server, this is immediately
|
||||
// received from the database server.
|
||||
|
||||
let handshake = Handshake::read(stream.receive().await?)?;
|
||||
let mut auth_plugin = handshake.auth_plugin;
|
||||
let mut auth_plugin_data = handshake.auth_plugin_data;
|
||||
|
||||
stream.capabilities &= handshake.server_capabilities;
|
||||
stream.capabilities |= Capabilities::PROTOCOL_41;
|
||||
|
||||
log::trace!("using capability flags: {:?}", stream.capabilities);
|
||||
|
||||
// Depending on the ssl-mode and capabilities we should upgrade
|
||||
// our connection to TLS
|
||||
|
||||
tls::upgrade_if_needed(stream, url).await?;
|
||||
|
||||
// Send a [HandshakeResponse] packet. This is returned in response to the [Handshake] packet
|
||||
// that is immediately received.
|
||||
|
||||
let password = &*url.password().unwrap_or_default();
|
||||
let auth_response =
|
||||
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
|
||||
|
||||
stream
|
||||
.send(
|
||||
HandshakeResponse {
|
||||
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
username: &url.username().unwrap_or(Cow::Borrowed("root")),
|
||||
database: url.database(),
|
||||
auth_plugin: &auth_plugin,
|
||||
auth_response: &auth_response,
|
||||
},
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
|
||||
loop {
|
||||
// After sending the handshake response with our assumed auth method the server
|
||||
// will send OK, fail, or tell us to change auth methods
|
||||
let packet = stream.receive().await?;
|
||||
|
||||
match packet[0] {
|
||||
// OK
|
||||
0x00 => {
|
||||
break;
|
||||
}
|
||||
|
||||
// ERROR
|
||||
0xFF => {
|
||||
return stream.handle_err();
|
||||
}
|
||||
|
||||
// AUTH_SWITCH
|
||||
0xFE => {
|
||||
let auth = AuthSwitch::read(packet)?;
|
||||
auth_plugin = auth.auth_plugin;
|
||||
auth_plugin_data = auth.auth_plugin_data;
|
||||
|
||||
let auth_response =
|
||||
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
|
||||
|
||||
stream.send(&*auth_response, false).await?;
|
||||
}
|
||||
|
||||
0x01 if auth_plugin == AuthPlugin::CachingSha2Password => {
|
||||
match packet[1] {
|
||||
// AUTH_OK
|
||||
0x03 => {}
|
||||
|
||||
// AUTH_CONTINUE
|
||||
0x04 => {
|
||||
// The specific password is _not_ cached on the server
|
||||
// We need to send a normal RSA-encrypted password for this
|
||||
let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data)
|
||||
.await?;
|
||||
|
||||
stream.send(&*enc, false).await?;
|
||||
}
|
||||
|
||||
unk => {
|
||||
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
return stream.handle_unexpected();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close(mut stream: MySqlStream) -> crate::Result<()> {
|
||||
// TODO: Actually tell MySQL that we're closing
|
||||
|
||||
stream.flush().await?;
|
||||
stream.shutdown()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ping(stream: &mut MySqlStream) -> crate::Result<()> {
|
||||
stream.wait_until_ready().await?;
|
||||
stream.is_ready = false;
|
||||
|
||||
stream.send(ComPing, true).await?;
|
||||
|
||||
match stream.receive().await?[0] {
|
||||
0x00 | 0xFE => stream.handle_ok().map(drop),
|
||||
|
||||
0xFF => stream.handle_err(),
|
||||
|
||||
_ => stream.handle_unexpected(),
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlConnection {
|
||||
pub(super) async fn new(url: std::result::Result<Url, url::ParseError>) -> crate::Result<Self> {
|
||||
let url = url?;
|
||||
let mut stream = MySqlStream::new(&url).await?;
|
||||
|
||||
establish(&mut stream, &url).await?;
|
||||
|
||||
let mut self_ = Self {
|
||||
stream,
|
||||
current_row_values: Vec::with_capacity(10),
|
||||
is_ready: true,
|
||||
cache_statement: HashMap::new(),
|
||||
};
|
||||
|
||||
// After the connection is established, we initialize by configuring a few
|
||||
// connection parameters
|
||||
|
||||
// https://mariadb.com/kb/en/sql-mode/
|
||||
|
||||
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
|
||||
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
|
||||
|
||||
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
|
||||
// not available, a warning is given and the default storage
|
||||
// engine is used instead.
|
||||
|
||||
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
|
||||
|
||||
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
|
||||
|
||||
// --
|
||||
|
||||
// Setting the time zone allows us to assume that the output
|
||||
// from a TIMESTAMP field is UTC
|
||||
|
||||
// --
|
||||
|
||||
// https://mathiasbynens.be/notes/mysql-utf8mb4
|
||||
|
||||
self_.execute(r#"
|
||||
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
|
||||
SET time_zone = '+00:00';
|
||||
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"#).await?;
|
||||
|
||||
Ok(self_)
|
||||
}
|
||||
}
|
||||
|
||||
impl Connect for MySqlConnection {
|
||||
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<MySqlConnection>>
|
||||
where
|
||||
T: TryInto<Url, Error = url::ParseError>,
|
||||
Self: Sized,
|
||||
{
|
||||
Box::pin(MySqlConnection::new(url.try_into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for MySqlConnection {
|
||||
#[inline]
|
||||
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
|
||||
Box::pin(close(self.stream))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
|
||||
Box::pin(ping(&mut self.stream))
|
||||
}
|
||||
}
|
||||
176
sqlx-core/src/mysql/connection/auth.rs
Normal file
176
sqlx-core/src/mysql/connection/auth.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use bytes::buf::ext::Chain;
|
||||
use bytes::Bytes;
|
||||
use digest::{Digest, FixedOutput};
|
||||
use generic_array::GenericArray;
|
||||
use sha1::Sha1;
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::mysql::connection::stream::MySqlStream;
|
||||
use crate::mysql::protocol::auth::AuthPlugin;
|
||||
use crate::mysql::protocol::rsa;
|
||||
use crate::mysql::protocol::Packet;
|
||||
|
||||
impl AuthPlugin {
|
||||
pub(super) async fn scramble(
|
||||
self,
|
||||
stream: &mut MySqlStream,
|
||||
password: &str,
|
||||
nonce: &Chain<Bytes, Bytes>,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
match self {
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
|
||||
AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()),
|
||||
|
||||
AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()),
|
||||
|
||||
// https://mariadb.com/kb/en/sha256_password-plugin/
|
||||
AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle(
|
||||
self,
|
||||
stream: &mut MySqlStream,
|
||||
packet: Packet<Bytes>,
|
||||
password: &str,
|
||||
nonce: &Chain<Bytes, Bytes>,
|
||||
) -> Result<bool, Error> {
|
||||
match self {
|
||||
AuthPlugin::CachingSha2Password if packet[0] == 0x01 => {
|
||||
match packet[1] {
|
||||
// AUTH_OK
|
||||
0x03 => Ok(true),
|
||||
|
||||
// AUTH_CONTINUE
|
||||
0x04 => {
|
||||
let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;
|
||||
|
||||
stream.write_packet(&*payload);
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
v => {
|
||||
Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (AUTH_OK) or 0x04 (AUTH_CONTINUE)", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => Err(err_protocol!(
|
||||
"unexpected packet 0x{:02x} for auth plugin '{}' during authentication",
|
||||
packet[0],
|
||||
self.name()
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scramble_sha1(
|
||||
password: &str,
|
||||
nonce: &Chain<Bytes, Bytes>,
|
||||
) -> GenericArray<u8, <Sha1 as FixedOutput>::OutputSize> {
|
||||
// SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
|
||||
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
|
||||
|
||||
let mut ctx = Sha1::new();
|
||||
|
||||
ctx.input(password);
|
||||
|
||||
let mut pw_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(&pw_hash);
|
||||
|
||||
let pw_hash_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(nonce.first_ref());
|
||||
ctx.input(nonce.last_ref());
|
||||
ctx.input(pw_hash_hash);
|
||||
|
||||
let pw_seed_hash_hash = ctx.result();
|
||||
|
||||
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
|
||||
|
||||
pw_hash
|
||||
}
|
||||
|
||||
fn scramble_sha256(
|
||||
password: &str,
|
||||
nonce: &Chain<Bytes, Bytes>,
|
||||
) -> GenericArray<u8, <Sha256 as FixedOutput>::OutputSize> {
|
||||
// XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password))))
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password
|
||||
let mut ctx = Sha256::new();
|
||||
|
||||
ctx.input(password);
|
||||
|
||||
let mut pw_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(&pw_hash);
|
||||
|
||||
let pw_hash_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(nonce.first_ref());
|
||||
ctx.input(nonce.last_ref());
|
||||
ctx.input(pw_hash_hash);
|
||||
|
||||
let pw_seed_hash_hash = ctx.result();
|
||||
|
||||
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
|
||||
|
||||
pw_hash
|
||||
}
|
||||
|
||||
async fn encrypt_rsa<'s>(
|
||||
stream: &'s mut MySqlStream,
|
||||
public_key_request_id: u8,
|
||||
password: &'s str,
|
||||
nonce: &'s Chain<Bytes, Bytes>,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
|
||||
|
||||
if stream.is_tls() {
|
||||
// If in a TLS stream, send the password directly in clear text
|
||||
return Ok(to_asciz(password));
|
||||
}
|
||||
|
||||
// client sends a public key request
|
||||
stream.write_packet(&[public_key_request_id][..]);
|
||||
stream.flush().await?;
|
||||
|
||||
// server sends a public key response
|
||||
let packet = stream.recv_packet().await?;
|
||||
let rsa_pub_key = &packet[1..];
|
||||
|
||||
// xor the password with the given nonce
|
||||
let mut pass = to_asciz(password);
|
||||
|
||||
let (a, b) = (nonce.first_ref(), nonce.last_ref());
|
||||
let mut nonce = Vec::with_capacity(a.len() + b.len());
|
||||
nonce.extend_from_slice(&*a);
|
||||
nonce.extend_from_slice(&*b);
|
||||
|
||||
xor_eq(&mut pass, &*nonce);
|
||||
|
||||
// client sends an RSA encrypted password
|
||||
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
|
||||
}
|
||||
|
||||
// XOR(x, y)
|
||||
// If len(y) < len(x), wrap around inside y
|
||||
fn xor_eq(x: &mut [u8], y: &[u8]) {
|
||||
let y_len = y.len();
|
||||
|
||||
for i in 0..x.len() {
|
||||
x[i] ^= y[i % y_len];
|
||||
}
|
||||
}
|
||||
|
||||
fn to_asciz(s: &str) -> Vec<u8> {
|
||||
let mut z = String::with_capacity(s.len() + 1);
|
||||
z.push_str(s);
|
||||
z.push('\0');
|
||||
|
||||
z.into_bytes()
|
||||
}
|
||||
101
sqlx-core/src/mysql/connection/establish.rs
Normal file
101
sqlx-core/src/mysql/connection/establish.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use bytes::Bytes;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::mysql::connection::{tls, MySqlStream, COLLATE_UTF8MB4_UNICODE_CI, MAX_PACKET_SIZE};
|
||||
use crate::mysql::protocol::connect::{
|
||||
AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse,
|
||||
};
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
use crate::mysql::{MySqlConnectOptions, MySqlConnection};
|
||||
use bytes::buf::BufExt;
|
||||
|
||||
impl MySqlConnection {
|
||||
pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result<Self, Error> {
|
||||
let mut stream: MySqlStream = MySqlStream::connect(options).await?;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
|
||||
// https://mariadb.com/kb/en/connection/
|
||||
|
||||
let handshake: Handshake = stream.recv_packet().await?.decode()?;
|
||||
|
||||
let mut plugin = handshake.auth_plugin;
|
||||
let mut nonce = handshake.auth_plugin_data;
|
||||
|
||||
stream.capabilities &= handshake.server_capabilities;
|
||||
stream.capabilities |= Capabilities::PROTOCOL_41;
|
||||
|
||||
// Upgrade to TLS if we were asked to and the server supports it
|
||||
tls::maybe_upgrade(&mut stream, options).await?;
|
||||
|
||||
let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) {
|
||||
Some(plugin.scramble(&mut stream, password, &nonce).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
stream.write_packet(HandshakeResponse {
|
||||
char_set: COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: MAX_PACKET_SIZE,
|
||||
username: &options.username,
|
||||
database: options.database.as_deref(),
|
||||
auth_plugin: plugin,
|
||||
auth_response: auth_response.as_deref(),
|
||||
});
|
||||
|
||||
stream.flush().await?;
|
||||
|
||||
loop {
|
||||
let packet = stream.recv_packet().await?;
|
||||
match packet[0] {
|
||||
0x00 => {
|
||||
let _ok = packet.ok()?;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
0xfe => {
|
||||
let switch: AuthSwitchRequest = packet.decode()?;
|
||||
|
||||
plugin = Some(switch.plugin);
|
||||
nonce = switch.data.chain(Bytes::new());
|
||||
|
||||
let response = switch
|
||||
.plugin
|
||||
.scramble(
|
||||
&mut stream,
|
||||
options.password.as_deref().unwrap_or_default(),
|
||||
&nonce,
|
||||
)
|
||||
.await?;
|
||||
|
||||
stream.write_packet(AuthSwitchResponse(response));
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
id => {
|
||||
if let (Some(plugin), Some(password)) = (plugin, &options.password) {
|
||||
if plugin.handle(&mut stream, packet, password, &nonce).await? {
|
||||
// plugin signaled authentication is ok
|
||||
break;
|
||||
}
|
||||
|
||||
// plugin signaled to continue authentication
|
||||
} else {
|
||||
return Err(err_protocol!(
|
||||
"unexpected packet 0x{:02x} during authentication",
|
||||
id
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream,
|
||||
cache_statement: HashMap::new(),
|
||||
scratch_row_columns: Default::default(),
|
||||
scratch_row_column_names: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
285
sqlx-core/src/mysql/connection/executor.rs
Normal file
285
sqlx-core/src/mysql/connection/executor.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::try_stream;
|
||||
use bytes::Bytes;
|
||||
use either::Either;
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
use futures_core::Stream;
|
||||
use futures_util::{pin_mut, TryStreamExt};
|
||||
|
||||
use crate::describe::{Column, Describe};
|
||||
use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::mysql::io::MySqlBufExt;
|
||||
use crate::mysql::protocol::response::Status;
|
||||
use crate::mysql::protocol::statement::{
|
||||
BinaryRow, Execute as StatementExecute, Prepare, PrepareOk,
|
||||
};
|
||||
use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow};
|
||||
use crate::mysql::protocol::Packet;
|
||||
use crate::mysql::row::MySqlColumn;
|
||||
use crate::mysql::{
|
||||
MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo, MySqlValueFormat,
|
||||
};
|
||||
|
||||
impl MySqlConnection {
|
||||
async fn prepare(&mut self, query: &str) -> Result<u32, Error> {
|
||||
if let Some(&statement) = self.cache_statement.get(query) {
|
||||
return Ok(statement);
|
||||
}
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
|
||||
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
|
||||
|
||||
self.stream.send_packet(Prepare { query }).await?;
|
||||
|
||||
let ok: PrepareOk = self.stream.recv().await?;
|
||||
|
||||
// the parameter definitions are very unreliable so we skip over them
|
||||
// as we have little use
|
||||
|
||||
if ok.params > 0 {
|
||||
for _ in 0..ok.params {
|
||||
let _def: ColumnDefinition = self.stream.recv().await?;
|
||||
}
|
||||
|
||||
self.stream.maybe_recv_eof().await?;
|
||||
}
|
||||
|
||||
// the column definitions are berefit the type information from the
|
||||
// to-be-bound parameters; we will receive the output column definitions
|
||||
// once more on execute so we wait for that
|
||||
|
||||
if ok.columns > 0 {
|
||||
for _ in 0..(ok.columns as usize) {
|
||||
let _def: ColumnDefinition = self.stream.recv().await?;
|
||||
}
|
||||
|
||||
self.stream.maybe_recv_eof().await?;
|
||||
}
|
||||
|
||||
self.cache_statement
|
||||
.insert(query.to_owned(), ok.statement_id);
|
||||
|
||||
Ok(ok.statement_id)
|
||||
}
|
||||
|
||||
async fn recv_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
|
||||
let num_columns: u64 = packet.get_uint_lenenc(); // column count
|
||||
|
||||
// the result-set metadata is primarily a listing of each output
|
||||
// column in the result-set
|
||||
|
||||
let column_names = Arc::make_mut(&mut self.scratch_row_column_names);
|
||||
let columns = Arc::make_mut(&mut self.scratch_row_columns);
|
||||
|
||||
columns.clear();
|
||||
column_names.clear();
|
||||
|
||||
for i in 0..num_columns {
|
||||
let def: ColumnDefinition = self.stream.recv().await?;
|
||||
|
||||
let name = (match (def.name()?, def.alias()?) {
|
||||
(_, alias) if !alias.is_empty() => Some(alias),
|
||||
|
||||
(name, _) if !name.is_empty() => Some(name),
|
||||
|
||||
_ => None,
|
||||
})
|
||||
.map(UStr::new);
|
||||
|
||||
if let Some(name) = &name {
|
||||
column_names.insert(name.clone(), i as usize);
|
||||
}
|
||||
|
||||
let type_info = MySqlTypeInfo::from_column(&def);
|
||||
|
||||
columns.push(MySqlColumn { name, type_info });
|
||||
}
|
||||
|
||||
self.stream.maybe_recv_eof().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run<'c>(
|
||||
&'c mut self,
|
||||
query: &str,
|
||||
arguments: Option<MySqlArguments>,
|
||||
) -> Result<impl Stream<Item = Result<Either<u64, MySqlRow>, Error>> + 'c, Error> {
|
||||
self.stream.wait_until_ready().await?;
|
||||
self.stream.busy = true;
|
||||
|
||||
let format = if let Some(arguments) = arguments {
|
||||
let statement = self.prepare(query).await?;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
|
||||
self.stream
|
||||
.send_packet(StatementExecute {
|
||||
statement,
|
||||
arguments: &arguments,
|
||||
})
|
||||
.await?;
|
||||
|
||||
MySqlValueFormat::Binary
|
||||
} else {
|
||||
// https://dev.mysql.com/doc/internals/en/com-query.html
|
||||
self.stream.send_packet(Query(query)).await?;
|
||||
|
||||
MySqlValueFormat::Text
|
||||
};
|
||||
|
||||
Ok(try_stream! {
|
||||
loop {
|
||||
// query response is a meta-packet which may be one of:
|
||||
// Ok, Err, ResultSet, or (unhandled) LocalInfileRequest
|
||||
let mut packet = self.stream.recv_packet().await?;
|
||||
|
||||
if packet[0] == 0x00 || packet[0] == 0xff {
|
||||
// first packet in a query response is OK or ERR
|
||||
// this indicates either a successful query with no rows at all or a failed query
|
||||
let ok = packet.ok()?;
|
||||
let v = Either::Left(ok.affected_rows);
|
||||
|
||||
yield v;
|
||||
|
||||
if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
|
||||
// more result sets exist, continue to the next one
|
||||
continue;
|
||||
}
|
||||
|
||||
self.stream.busy = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise, this first packet is the start of the result-set metadata,
|
||||
self.recv_result_metadata(packet).await?;
|
||||
|
||||
// finally, there will be none or many result-rows
|
||||
loop {
|
||||
let packet = self.stream.recv_packet().await?;
|
||||
|
||||
if packet[0] == 0xfe && packet.len() < 9 {
|
||||
let eof = packet.eof(self.stream.capabilities)?;
|
||||
let v = Either::Left(0);
|
||||
|
||||
yield v;
|
||||
|
||||
if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
|
||||
// more result sets exist, continue to the next one
|
||||
break;
|
||||
}
|
||||
|
||||
self.stream.busy = false;
|
||||
return;
|
||||
}
|
||||
|
||||
let row = match format {
|
||||
MySqlValueFormat::Binary => packet.decode_with::<BinaryRow, _>(&self.scratch_row_columns)?.0,
|
||||
MySqlValueFormat::Text => packet.decode_with::<TextRow, _>(&self.scratch_row_columns)?.0,
|
||||
};
|
||||
|
||||
let v = Either::Right(MySqlRow {
|
||||
row,
|
||||
format,
|
||||
columns: Arc::clone(&self.scratch_row_columns),
|
||||
column_names: Arc::clone(&self.scratch_row_column_names),
|
||||
});
|
||||
|
||||
yield v;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> Executor<'c> for &'c mut MySqlConnection {
|
||||
type Database = MySql;
|
||||
|
||||
fn fetch_many<'q: 'c, E>(
|
||||
self,
|
||||
mut query: E,
|
||||
) -> BoxStream<'c, Result<Either<u64, MySqlRow>, Error>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
let s = query.query();
|
||||
let arguments = query.take_arguments();
|
||||
|
||||
Box::pin(try_stream! {
|
||||
let s = self.run(s, arguments).await?;
|
||||
pin_mut!(s);
|
||||
|
||||
while let Some(v) = s.try_next().await? {
|
||||
yield v;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_optional<'q: 'c, E>(self, query: E) -> BoxFuture<'c, Result<Option<MySqlRow>, Error>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
let mut s = self.fetch_many(query);
|
||||
|
||||
Box::pin(async move {
|
||||
while let Some(v) = s.try_next().await? {
|
||||
if let Either::Right(r) = v {
|
||||
return Ok(Some(r));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
})
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn describe<'q: 'c, E>(self, query: E) -> BoxFuture<'c, Result<Describe<MySql>, Error>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
let query = query.query();
|
||||
|
||||
Box::pin(async move {
|
||||
self.stream.send_packet(Prepare { query }).await?;
|
||||
|
||||
let ok: PrepareOk = self.stream.recv().await?;
|
||||
|
||||
let mut params = Vec::with_capacity(ok.params as usize);
|
||||
let mut columns = Vec::with_capacity(ok.columns as usize);
|
||||
|
||||
if ok.params > 0 {
|
||||
for _ in 0..ok.params {
|
||||
let def: ColumnDefinition = self.stream.recv().await?;
|
||||
|
||||
params.push(MySqlTypeInfo::from_column(&def));
|
||||
}
|
||||
|
||||
self.stream.maybe_recv_eof().await?;
|
||||
}
|
||||
|
||||
// the column definitions are berefit the type information from the
|
||||
// to-be-bound parameters; we will receive the output column definitions
|
||||
// once more on execute so we wait for that
|
||||
|
||||
if ok.columns > 0 {
|
||||
for _ in 0..(ok.columns as usize) {
|
||||
let def: ColumnDefinition = self.stream.recv().await?;
|
||||
let ty = MySqlTypeInfo::from_column(&def);
|
||||
|
||||
columns.push(Column {
|
||||
name: def.name()?.to_owned(),
|
||||
type_info: ty,
|
||||
not_null: Some(def.flags.contains(ColumnFlags::NOT_NULL)),
|
||||
})
|
||||
}
|
||||
|
||||
self.stream.maybe_recv_eof().await?;
|
||||
}
|
||||
|
||||
Ok(Describe { params, columns })
|
||||
})
|
||||
}
|
||||
}
|
||||
116
sqlx-core/src/mysql/connection/mod.rs
Normal file
116
sqlx-core/src/mysql/connection/mod.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::net::Shutdown;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use crate::connection::{Connect, Connection};
|
||||
use crate::error::Error;
|
||||
use crate::executor::Executor;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::mysql::protocol::text::{Ping, Quit};
|
||||
use crate::mysql::row::MySqlColumn;
|
||||
use crate::mysql::{MySql, MySqlConnectOptions};
|
||||
|
||||
mod auth;
|
||||
mod establish;
|
||||
mod executor;
|
||||
mod stream;
|
||||
mod tls;
|
||||
|
||||
pub(crate) use stream::MySqlStream;
|
||||
|
||||
const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
|
||||
|
||||
const MAX_PACKET_SIZE: u32 = 1024;
|
||||
|
||||
/// A connection to a MySQL database.
|
||||
pub struct MySqlConnection {
|
||||
// underlying TCP stream,
|
||||
// wrapped in a potentially TLS stream,
|
||||
// wrapped in a buffered stream
|
||||
stream: MySqlStream,
|
||||
|
||||
// cache by query string to the statement id
|
||||
cache_statement: HashMap<String, u32>,
|
||||
|
||||
// working memory for the active row's column information
|
||||
// this allows us to re-use these allocations unless the user is persisting the
|
||||
// Row type past a stream iteration (clone-on-write)
|
||||
scratch_row_columns: Arc<Vec<MySqlColumn>>,
|
||||
scratch_row_column_names: Arc<HashMap<UStr, usize>>,
|
||||
}
|
||||
|
||||
impl Debug for MySqlConnection {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MySqlConnection").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for MySqlConnection {
|
||||
type Database = MySql;
|
||||
|
||||
fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
|
||||
Box::pin(async move {
|
||||
self.stream.send_packet(Quit).await?;
|
||||
self.stream.shutdown(Shutdown::Both)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
|
||||
Box::pin(async move {
|
||||
self.stream.wait_until_ready().await?;
|
||||
self.stream.send_packet(Ping).await?;
|
||||
self.stream.recv_ok().await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Connect for MySqlConnection {
|
||||
type Options = MySqlConnectOptions;
|
||||
|
||||
#[inline]
|
||||
fn connect_with(options: &Self::Options) -> BoxFuture<'_, Result<Self, Error>> {
|
||||
Box::pin(async move {
|
||||
let mut conn = MySqlConnection::establish(options).await?;
|
||||
|
||||
// After the connection is established, we initialize by configuring a few
|
||||
// connection parameters
|
||||
|
||||
// https://mariadb.com/kb/en/sql-mode/
|
||||
|
||||
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
|
||||
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
|
||||
|
||||
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
|
||||
// not available, a warning is given and the default storage
|
||||
// engine is used instead.
|
||||
|
||||
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
|
||||
|
||||
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
|
||||
|
||||
// --
|
||||
|
||||
// Setting the time zone allows us to assume that the output
|
||||
// from a TIMESTAMP field is UTC
|
||||
|
||||
// --
|
||||
|
||||
// https://mathiasbynens.be/notes/mysql-utf8mb4
|
||||
|
||||
conn.execute(r#"
|
||||
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
|
||||
SET time_zone = '+00:00';
|
||||
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"#).await?;
|
||||
|
||||
Ok(conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
150
sqlx-core/src/mysql/connection/stream.rs
Normal file
150
sqlx-core/src/mysql/connection/stream.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use sqlx_rt::TcpStream;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufStream, Decode, Encode};
|
||||
use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket};
|
||||
use crate::mysql::protocol::{Capabilities, Packet};
|
||||
use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
|
||||
use crate::net::MaybeTlsStream;
|
||||
|
||||
pub struct MySqlStream {
|
||||
stream: BufStream<MaybeTlsStream<TcpStream>>,
|
||||
pub(super) capabilities: Capabilities,
|
||||
pub(super) sequence_id: u8,
|
||||
pub(crate) busy: bool,
|
||||
}
|
||||
|
||||
impl MySqlStream {
|
||||
pub(super) async fn connect(options: &MySqlConnectOptions) -> Result<Self, Error> {
|
||||
let stream = TcpStream::connect((&*options.host, options.port)).await?;
|
||||
|
||||
let mut capabilities = Capabilities::PROTOCOL_41
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::FOUND_ROWS
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::SSL;
|
||||
|
||||
if options.database.is_some() {
|
||||
capabilities |= Capabilities::CONNECT_WITH_DB;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
busy: false,
|
||||
capabilities,
|
||||
sequence_id: 0,
|
||||
stream: BufStream::new(MaybeTlsStream::Raw(stream)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
|
||||
if self.busy {
|
||||
loop {
|
||||
let packet = self.recv_packet().await?;
|
||||
match packet[0] {
|
||||
0xfe if packet.len() < 9 => {
|
||||
// OK or EOF packet
|
||||
self.busy = false;
|
||||
break;
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Something else; skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
|
||||
where
|
||||
T: Encode<'en, Capabilities>,
|
||||
{
|
||||
self.sequence_id = 0;
|
||||
self.write_packet(payload);
|
||||
self.flush().await
|
||||
}
|
||||
|
||||
pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
|
||||
where
|
||||
T: Encode<'en, Capabilities>,
|
||||
{
|
||||
self.stream
|
||||
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
|
||||
}
|
||||
|
||||
// receive the next packet from the database server
|
||||
// may block (async) on more data from the server
|
||||
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
|
||||
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
|
||||
|
||||
let mut header: Bytes = self.stream.read(4).await?;
|
||||
|
||||
let packet_size = header.get_uint_le(3) as usize;
|
||||
let sequence_id = header.get_u8();
|
||||
|
||||
self.sequence_id = sequence_id.wrapping_add(1);
|
||||
|
||||
let payload: Bytes = self.stream.read(packet_size).await?;
|
||||
|
||||
// TODO: packet compression
|
||||
// TODO: packet joining
|
||||
|
||||
if payload[0] == 0xff {
|
||||
self.busy = false;
|
||||
|
||||
// instead of letting this packet be looked at everywhere, we check here
|
||||
// and emit a proper Error
|
||||
return Err(
|
||||
MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Packet(payload))
|
||||
}
|
||||
|
||||
pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
|
||||
where
|
||||
T: Decode<'de, Capabilities>,
|
||||
{
|
||||
self.recv_packet().await?.decode_with(self.capabilities)
|
||||
}
|
||||
|
||||
pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
|
||||
self.recv_packet().await?.ok()
|
||||
}
|
||||
|
||||
pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
|
||||
if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
Ok(None)
|
||||
} else {
|
||||
self.recv().await.map(Some)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for MySqlStream {
|
||||
type Target = BufStream<MaybeTlsStream<TcpStream>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for MySqlStream {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.stream
|
||||
}
|
||||
}
|
||||
79
sqlx-core/src/mysql/connection/tls.rs
Normal file
79
sqlx-core/src/mysql/connection/tls.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use sqlx_rt::{
|
||||
fs,
|
||||
native_tls::{Certificate, TlsConnector},
|
||||
};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::mysql::connection::MySqlStream;
|
||||
use crate::mysql::protocol::connect::SslRequest;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
use crate::mysql::{MySqlConnectOptions, MySqlSslMode};
|
||||
|
||||
pub(super) async fn maybe_upgrade(
|
||||
stream: &mut MySqlStream,
|
||||
options: &MySqlConnectOptions,
|
||||
) -> Result<(), Error> {
|
||||
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
|
||||
match options.ssl_mode {
|
||||
MySqlSslMode::Disabled => {}
|
||||
|
||||
MySqlSslMode::Preferred => {
|
||||
// try upgrade, but its okay if we fail
|
||||
upgrade(stream, options).await?;
|
||||
}
|
||||
|
||||
MySqlSslMode::Required | MySqlSslMode::VerifyIdentity | MySqlSslMode::VerifyCa => {
|
||||
if !upgrade(stream, options).await? {
|
||||
// upgrade failed, die
|
||||
return Err(Error::Tls("server does not support TLS".into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Result<bool, Error> {
|
||||
if !stream.capabilities.contains(Capabilities::SSL) {
|
||||
// server does not support TLS
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
stream.write_packet(SslRequest {
|
||||
max_packet_size: super::MAX_PACKET_SIZE,
|
||||
char_set: super::COLLATE_UTF8MB4_UNICODE_CI,
|
||||
});
|
||||
|
||||
stream.flush().await?;
|
||||
|
||||
// FIXME: de-duplicate with postgres/connection/tls.rs
|
||||
|
||||
let accept_invalid_certs = !matches!(
|
||||
options.ssl_mode,
|
||||
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
|
||||
);
|
||||
|
||||
let mut builder = TlsConnector::builder();
|
||||
builder
|
||||
.danger_accept_invalid_certs(accept_invalid_certs)
|
||||
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity));
|
||||
|
||||
if !accept_invalid_certs {
|
||||
if let Some(ca) = &options.ssl_ca {
|
||||
let data = fs::read(ca).await?;
|
||||
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
|
||||
|
||||
builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "runtime-async-std"))]
|
||||
let connector = builder.build().map_err(Error::tls)?;
|
||||
|
||||
#[cfg(feature = "runtime-async-std")]
|
||||
let connector = builder;
|
||||
|
||||
stream.upgrade(&options.host, connector.into()).await?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
use crate::connection::ConnectionSource;
|
||||
use crate::cursor::Cursor;
|
||||
use crate::executor::Execute;
|
||||
use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Row, Status};
|
||||
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};
|
||||
use crate::pool::Pool;
|
||||
|
||||
pub struct MySqlCursor<'c, 'q> {
|
||||
source: ConnectionSource<'c, MySqlConnection>,
|
||||
query: Option<(&'q str, Option<MySqlArguments>)>,
|
||||
column_names: Arc<HashMap<Box<str>, u16>>,
|
||||
column_types: Vec<MySqlTypeInfo>,
|
||||
binary: bool,
|
||||
}
|
||||
|
||||
impl crate::cursor::private::Sealed for MySqlCursor<'_, '_> {}
|
||||
|
||||
impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> {
|
||||
type Database = MySql;
|
||||
|
||||
#[doc(hidden)]
|
||||
fn from_pool<E>(pool: &Pool<MySqlConnection>, query: E) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
E: Execute<'q, MySql>,
|
||||
{
|
||||
Self {
|
||||
source: ConnectionSource::Pool(pool.clone()),
|
||||
column_names: Arc::default(),
|
||||
column_types: Vec::new(),
|
||||
binary: true,
|
||||
query: Some(query.into_parts()),
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn from_connection<E>(conn: &'c mut MySqlConnection, query: E) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
E: Execute<'q, MySql>,
|
||||
{
|
||||
Self {
|
||||
source: ConnectionSource::ConnectionRef(conn),
|
||||
column_names: Arc::default(),
|
||||
column_types: Vec::new(),
|
||||
binary: true,
|
||||
query: Some(query.into_parts()),
|
||||
}
|
||||
}
|
||||
|
||||
fn next(&mut self) -> BoxFuture<crate::Result<Option<MySqlRow<'_>>>> {
|
||||
Box::pin(next(self))
|
||||
}
|
||||
}
|
||||
|
||||
async fn next<'a, 'c: 'a, 'q: 'a>(
|
||||
cursor: &'a mut MySqlCursor<'c, 'q>,
|
||||
) -> crate::Result<Option<MySqlRow<'a>>> {
|
||||
let mut conn = cursor.source.resolve().await?;
|
||||
|
||||
// The first time [next] is called we need to actually execute our
|
||||
// contained query. We guard against this happening on _all_ next calls
|
||||
// by using [Option::take] which replaces the potential value in the Option with `None
|
||||
let mut initial = if let Some((query, arguments)) = cursor.query.take() {
|
||||
let statement = conn.run(query, arguments).await?;
|
||||
|
||||
// No statement ID = TEXT mode
|
||||
cursor.binary = statement.is_some();
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
loop {
|
||||
let packet_id = conn.stream.receive().await?[0];
|
||||
|
||||
match packet_id {
|
||||
// OK or EOF packet
|
||||
0x00 | 0xFE
|
||||
if conn.stream.packet().len() < 0xFF_FF_FF && (packet_id != 0x00 || initial) =>
|
||||
{
|
||||
let status = if let Some(eof) = conn.stream.maybe_handle_eof()? {
|
||||
eof.status
|
||||
} else {
|
||||
conn.stream.handle_ok()?.status
|
||||
};
|
||||
|
||||
if status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
|
||||
// There is more to this query
|
||||
initial = true;
|
||||
} else {
|
||||
conn.is_ready = true;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// ERR packet
|
||||
0xFF => {
|
||||
conn.is_ready = true;
|
||||
return conn.stream.handle_err();
|
||||
}
|
||||
|
||||
_ if initial => {
|
||||
// At the start of the results we expect to see a
|
||||
// COLUMN_COUNT followed by N COLUMN_DEF
|
||||
|
||||
let cc = ColumnCount::read(conn.stream.packet())?;
|
||||
|
||||
// We use these definitions to get the actual column types that is critical
|
||||
// in parsing the rows coming back soon
|
||||
|
||||
cursor.column_types.clear();
|
||||
cursor.column_types.reserve(cc.columns as usize);
|
||||
|
||||
let mut column_names = HashMap::with_capacity(cc.columns as usize);
|
||||
|
||||
for i in 0..cc.columns {
|
||||
let column = ColumnDefinition::read(conn.stream.receive().await?)?;
|
||||
|
||||
cursor
|
||||
.column_types
|
||||
.push(MySqlTypeInfo::from_nullable_column_def(&column));
|
||||
|
||||
if let Some(name) = column.name() {
|
||||
column_names.insert(name.to_owned().into_boxed_str(), i as u16);
|
||||
}
|
||||
}
|
||||
|
||||
if cc.columns > 0 {
|
||||
conn.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
cursor.column_names = Arc::new(column_names);
|
||||
initial = false;
|
||||
}
|
||||
|
||||
_ if !cursor.binary || packet_id == 0x00 => {
|
||||
let row = Row::read(
|
||||
conn.stream.packet(),
|
||||
&cursor.column_types,
|
||||
&mut conn.current_row_values,
|
||||
cursor.binary,
|
||||
)?;
|
||||
|
||||
let row = MySqlRow {
|
||||
row,
|
||||
names: Arc::clone(&cursor.column_names),
|
||||
};
|
||||
|
||||
return Ok(Some(row));
|
||||
}
|
||||
|
||||
_ => {
|
||||
return conn.stream.handle_unexpected();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,41 +1,31 @@
|
||||
use crate::cursor::HasCursor;
|
||||
use crate::database::Database;
|
||||
use crate::mysql::error::MySqlError;
|
||||
use crate::row::HasRow;
|
||||
use crate::value::HasRawValue;
|
||||
use crate::database::{Database, HasArguments, HasValueRef};
|
||||
use crate::mysql::value::{MySqlValue, MySqlValueRef};
|
||||
use crate::mysql::{MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};
|
||||
|
||||
/// **MySQL** database driver.
|
||||
/// PostgreSQL database driver.
|
||||
#[derive(Debug)]
|
||||
pub struct MySql;
|
||||
|
||||
impl Database for MySql {
|
||||
type Connection = super::MySqlConnection;
|
||||
type Connection = MySqlConnection;
|
||||
|
||||
type Arguments = super::MySqlArguments;
|
||||
type Row = MySqlRow;
|
||||
|
||||
type TypeInfo = super::MySqlTypeInfo;
|
||||
type TypeInfo = MySqlTypeInfo;
|
||||
|
||||
type TableId = Box<str>;
|
||||
|
||||
type RawBuffer = Vec<u8>;
|
||||
|
||||
type Error = MySqlError;
|
||||
type Value = MySqlValue;
|
||||
}
|
||||
|
||||
impl<'c> HasRow<'c> for MySql {
|
||||
impl<'r> HasValueRef<'r> for MySql {
|
||||
type Database = MySql;
|
||||
|
||||
type Row = super::MySqlRow<'c>;
|
||||
type ValueRef = MySqlValueRef<'r>;
|
||||
}
|
||||
|
||||
impl<'c, 'q> HasCursor<'c, 'q> for MySql {
|
||||
impl HasArguments<'_> for MySql {
|
||||
type Database = MySql;
|
||||
|
||||
type Cursor = super::MySqlCursor<'c, 'q>;
|
||||
}
|
||||
type Arguments = MySqlArguments;
|
||||
|
||||
impl<'c> HasRawValue<'c> for MySql {
|
||||
type Database = MySql;
|
||||
|
||||
type RawValue = super::MySqlValue<'c>;
|
||||
type ArgumentBuffer = Vec<u8>;
|
||||
}
|
||||
|
||||
@@ -1,63 +1,64 @@
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt::{self, Display};
|
||||
use std::error::Error;
|
||||
use std::fmt::{self, Debug, Display, Formatter};
|
||||
|
||||
use crate::error::DatabaseError;
|
||||
use crate::mysql::protocol::ErrPacket;
|
||||
use crate::mysql::protocol::response::ErrPacket;
|
||||
use smallvec::alloc::borrow::Cow;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MySqlError(pub(super) ErrPacket);
|
||||
/// An error returned from the MySQL database.
|
||||
pub struct MySqlDatabaseError(pub(super) ErrPacket);
|
||||
|
||||
impl Display for MySqlError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.pad(self.message())
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseError for MySqlError {
|
||||
fn message(&self) -> &str {
|
||||
&*self.0.error_message
|
||||
}
|
||||
|
||||
fn code(&self) -> Option<&str> {
|
||||
impl MySqlDatabaseError {
|
||||
/// The [SQLSTATE](https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html) code for this error.
|
||||
pub fn code(&self) -> Option<&str> {
|
||||
self.0.sql_state.as_deref()
|
||||
}
|
||||
|
||||
fn as_ref_err(&self) -> &(dyn StdError + Send + Sync + 'static) {
|
||||
self
|
||||
/// The [number](https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html)
|
||||
/// for this error.
|
||||
///
|
||||
/// MySQL tends to use SQLSTATE as a general error category, and the error number as a more
|
||||
/// granular indication of the error.
|
||||
pub fn number(&self) -> u16 {
|
||||
self.0.error_code
|
||||
}
|
||||
|
||||
fn as_mut_err(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_box_err(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static> {
|
||||
self
|
||||
/// The human-readable error message.
|
||||
pub fn message(&self) -> &str {
|
||||
&self.0.error_message
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for MySqlError {}
|
||||
|
||||
impl From<MySqlError> for crate::Error {
|
||||
fn from(err: MySqlError) -> Self {
|
||||
crate::Error::Database(Box::new(err))
|
||||
impl Debug for MySqlDatabaseError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MySqlDatabaseError")
|
||||
.field("code", &self.code())
|
||||
.field("number", &self.number())
|
||||
.field("message", &self.message())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_downcasting() {
|
||||
let error = MySqlError(ErrPacket {
|
||||
error_code: 0xABCD,
|
||||
sql_state: None,
|
||||
error_message: "".into(),
|
||||
});
|
||||
|
||||
let error = crate::Error::from(error);
|
||||
|
||||
let db_err = match error {
|
||||
crate::Error::Database(db_err) => db_err,
|
||||
e => panic!("expected crate::Error::Database, got {:?}", e),
|
||||
};
|
||||
|
||||
assert_eq!(db_err.downcast_ref::<MySqlError>().0.error_code, 0xABCD);
|
||||
assert_eq!(db_err.downcast::<MySqlError>().0.error_code, 0xABCD);
|
||||
impl Display for MySqlDatabaseError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
if let Some(code) = &self.code() {
|
||||
write!(f, "{} ({}): {}", self.number(), code, self.message())
|
||||
} else {
|
||||
write!(f, "{}: {}", self.number(), self.message())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for MySqlDatabaseError {}
|
||||
|
||||
impl DatabaseError for MySqlDatabaseError {
|
||||
#[inline]
|
||||
fn message(&self) -> &str {
|
||||
self.message()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn code(&self) -> Option<Cow<str>> {
|
||||
self.code().map(Cow::Borrowed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
use futures_core::future::BoxFuture;
|
||||
|
||||
use crate::cursor::Cursor;
|
||||
use crate::describe::{Column, Describe};
|
||||
use crate::executor::{Execute, Executor, RefExecutor};
|
||||
use crate::mysql::protocol::{
|
||||
self, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, FieldFlags,
|
||||
Status,
|
||||
};
|
||||
use crate::mysql::{MySql, MySqlArguments, MySqlCursor, MySqlTypeInfo};
|
||||
|
||||
impl super::MySqlConnection {
|
||||
// Creates a prepared statement for the passed query string
|
||||
async fn prepare(&mut self, query: &str) -> crate::Result<ComStmtPrepareOk> {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_prepare.html
|
||||
self.stream.send(ComStmtPrepare { query }, true).await?;
|
||||
|
||||
// Should receive a COM_STMT_PREPARE_OK or ERR_PACKET
|
||||
let packet = self.stream.receive().await?;
|
||||
|
||||
if packet[0] == 0xFF {
|
||||
return self.stream.handle_err();
|
||||
}
|
||||
|
||||
ComStmtPrepareOk::read(packet)
|
||||
}
|
||||
|
||||
async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> {
|
||||
for _ in 0..count {
|
||||
let _column = ColumnDefinition::read(self.stream.receive().await?)?;
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
self.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Gets a cached prepared statement ID _or_ prepares the statement if not in the cache
|
||||
// At the end we should have [cache_statement] and [cache_statement_columns] filled
|
||||
async fn get_or_prepare(&mut self, query: &str) -> crate::Result<u32> {
|
||||
if let Some(&id) = self.cache_statement.get(query) {
|
||||
Ok(id)
|
||||
} else {
|
||||
let stmt = self.prepare(query).await?;
|
||||
|
||||
self.cache_statement.insert(query.into(), stmt.statement_id);
|
||||
|
||||
// COM_STMT_PREPARE returns the input columns
|
||||
// We make no use of that data, so cycle through and drop them
|
||||
self.drop_column_defs(stmt.params as usize).await?;
|
||||
|
||||
// COM_STMT_PREPARE next returns the output columns
|
||||
// We just drop these as we get these when we execute the query
|
||||
self.drop_column_defs(stmt.columns as usize).await?;
|
||||
|
||||
Ok(stmt.statement_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run(
|
||||
&mut self,
|
||||
query: &str,
|
||||
arguments: Option<MySqlArguments>,
|
||||
) -> crate::Result<Option<u32>> {
|
||||
self.stream.wait_until_ready().await?;
|
||||
self.stream.is_ready = false;
|
||||
|
||||
if let Some(arguments) = arguments {
|
||||
let statement_id = self.get_or_prepare(query).await?;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_execute.html
|
||||
self.stream
|
||||
.send(
|
||||
ComStmtExecute {
|
||||
cursor: protocol::Cursor::NO_CURSOR,
|
||||
statement_id,
|
||||
params: &arguments.params,
|
||||
null_bitmap: &arguments.null_bitmap,
|
||||
param_types: &arguments.param_types,
|
||||
},
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Some(statement_id))
|
||||
} else {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_query.html
|
||||
self.stream.send(ComQuery { query }, true).await?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn affected_rows(&mut self) -> crate::Result<u64> {
|
||||
let mut rows = 0;
|
||||
|
||||
loop {
|
||||
let id = self.stream.receive().await?[0];
|
||||
|
||||
match id {
|
||||
0x00 | 0xFE if self.stream.packet().len() < 0xFF_FF_FF => {
|
||||
// ResultSet row can begin with 0xfe byte (when using text protocol
|
||||
// with a field length > 0xffffff)
|
||||
|
||||
let status = if let Some(eof) = self.stream.maybe_handle_eof()? {
|
||||
eof.status
|
||||
} else {
|
||||
let ok = self.stream.handle_ok()?;
|
||||
|
||||
rows += ok.affected_rows;
|
||||
ok.status
|
||||
};
|
||||
|
||||
if !status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
|
||||
self.is_ready = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
0xFF => {
|
||||
return self.stream.handle_err();
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
// method is not named describe to work around an intellijrust bug
|
||||
// otherwise it marks someone trying to describe the connection as "method is private"
|
||||
async fn do_describe(&mut self, query: &str) -> crate::Result<Describe<MySql>> {
|
||||
self.stream.wait_until_ready().await?;
|
||||
|
||||
let stmt = self.prepare(query).await?;
|
||||
|
||||
let mut param_types = Vec::with_capacity(stmt.params as usize);
|
||||
let mut result_columns = Vec::with_capacity(stmt.columns as usize);
|
||||
|
||||
for _ in 0..stmt.params {
|
||||
let param = ColumnDefinition::read(self.stream.receive().await?)?;
|
||||
param_types.push(MySqlTypeInfo::from_column_def(¶m));
|
||||
}
|
||||
|
||||
if stmt.params > 0 {
|
||||
self.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
for _ in 0..stmt.columns {
|
||||
let column = ColumnDefinition::read(self.stream.receive().await?)?;
|
||||
|
||||
result_columns.push(Column::<MySql> {
|
||||
type_info: MySqlTypeInfo::from_column_def(&column),
|
||||
name: column.column_alias.or(column.column),
|
||||
table_id: column.table_alias.or(column.table),
|
||||
// TODO(@abonander): Should this be None in some cases?
|
||||
non_null: Some(column.flags.contains(FieldFlags::NOT_NULL)),
|
||||
});
|
||||
}
|
||||
|
||||
if stmt.columns > 0 {
|
||||
self.stream.maybe_receive_eof().await?;
|
||||
}
|
||||
|
||||
Ok(Describe {
|
||||
param_types: param_types.into_boxed_slice(),
|
||||
result_columns: result_columns.into_boxed_slice(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Executor for super::MySqlConnection {
|
||||
type Database = MySql;
|
||||
|
||||
fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>(
|
||||
&'c mut self,
|
||||
query: E,
|
||||
) -> BoxFuture<'e, crate::Result<u64>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
log_execution!(query, {
|
||||
Box::pin(async move {
|
||||
let (query, arguments) = query.into_parts();
|
||||
|
||||
self.run(query, arguments).await?;
|
||||
self.affected_rows().await
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch<'q, E>(&mut self, query: E) -> MySqlCursor<'_, 'q>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
log_execution!(query, { MySqlCursor::from_connection(self, query) })
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn describe<'e, 'q, E: 'e>(
|
||||
&'e mut self,
|
||||
query: E,
|
||||
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
Box::pin(async move { self.do_describe(query.into_parts().0).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> RefExecutor<'c> for &'c mut super::MySqlConnection {
|
||||
type Database = MySql;
|
||||
|
||||
fn fetch_by_ref<'q, E>(self, query: E) -> MySqlCursor<'c, 'q>
|
||||
where
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
log_execution!(query, { MySqlCursor::from_connection(self, query) })
|
||||
}
|
||||
}
|
||||
40
sqlx-core/src/mysql/io/buf.rs
Normal file
40
sqlx-core/src/mysql/io/buf.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::BufExt;
|
||||
|
||||
pub trait MySqlBufExt: Buf {
|
||||
// Read a length-encoded integer.
|
||||
// NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL.
|
||||
// NOTE: 0xff is only returned during a result set to indicate ERR.
|
||||
// <https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger>
|
||||
fn get_uint_lenenc(&mut self) -> u64;
|
||||
|
||||
// Read a length-encoded string.
|
||||
fn get_str_lenenc(&mut self) -> Result<String, Error>;
|
||||
|
||||
// Read a length-encoded byte sequence.
|
||||
fn get_bytes_lenenc(&mut self) -> Bytes;
|
||||
}
|
||||
|
||||
impl MySqlBufExt for Bytes {
|
||||
fn get_uint_lenenc(&mut self) -> u64 {
|
||||
match self.get_u8() {
|
||||
0xfc => u64::from(self.get_u16_le()),
|
||||
0xfd => u64::from(self.get_uint_le(3)),
|
||||
0xfe => u64::from(self.get_u64_le()),
|
||||
|
||||
v => u64::from(v),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_str_lenenc(&mut self) -> Result<String, Error> {
|
||||
let size = self.get_uint_lenenc();
|
||||
self.get_str(size as usize)
|
||||
}
|
||||
|
||||
fn get_bytes_lenenc(&mut self) -> Bytes {
|
||||
let size = self.get_uint_lenenc();
|
||||
self.split_to(size as usize)
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
use std::io;
|
||||
|
||||
use byteorder::ByteOrder;
|
||||
|
||||
use crate::io::Buf;
|
||||
|
||||
pub trait BufExt {
|
||||
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>>;
|
||||
|
||||
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&str>>;
|
||||
|
||||
fn get_bytes_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&[u8]>>;
|
||||
}
|
||||
|
||||
impl BufExt for &'_ [u8] {
|
||||
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>> {
|
||||
Ok(match self.get_u8()? {
|
||||
0xFB => None,
|
||||
0xFC => Some(u64::from(self.get_u16::<T>()?)),
|
||||
0xFD => Some(u64::from(self.get_u24::<T>()?)),
|
||||
0xFE => Some(self.get_u64::<T>()?),
|
||||
|
||||
value => Some(u64::from(value)),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&str>> {
|
||||
self.get_uint_lenenc::<T>()?
|
||||
.map(move |len| self.get_str(len as usize))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_bytes_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&[u8]>> {
|
||||
self.get_uint_lenenc::<T>()?
|
||||
.map(move |len| self.get_bytes(len as usize))
|
||||
.transpose()
|
||||
}
|
||||
}
|
||||
134
sqlx-core/src/mysql/io/buf_mut.rs
Normal file
134
sqlx-core/src/mysql/io/buf_mut.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use bytes::BufMut;
|
||||
|
||||
pub trait MySqlBufMutExt: BufMut {
|
||||
fn put_uint_lenenc(&mut self, v: u64);
|
||||
|
||||
fn put_str_lenenc(&mut self, v: &str);
|
||||
|
||||
fn put_bytes_lenenc(&mut self, v: &[u8]);
|
||||
}
|
||||
|
||||
impl MySqlBufMutExt for Vec<u8> {
|
||||
fn put_uint_lenenc(&mut self, v: u64) {
|
||||
// https://dev.mysql.com/doc/internals/en/integer.html
|
||||
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
|
||||
|
||||
if v < 251 {
|
||||
self.push(v as u8);
|
||||
} else if v < 0x1_00_00 {
|
||||
self.push(0xfc);
|
||||
self.extend(&(v as u16).to_le_bytes());
|
||||
} else if v < 0x1_00_00_00 {
|
||||
self.push(0xfd);
|
||||
self.extend(&(v as u32).to_le_bytes()[..3]);
|
||||
} else {
|
||||
self.push(0xfe);
|
||||
self.extend(&v.to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
fn put_str_lenenc(&mut self, v: &str) {
|
||||
self.put_bytes_lenenc(v.as_bytes());
|
||||
}
|
||||
|
||||
fn put_bytes_lenenc(&mut self, v: &[u8]) {
|
||||
self.put_uint_lenenc(v.len() as u64);
|
||||
self.extend(v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFA as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFA");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(std::u16::MAX as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFF_FF_FF as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(std::u64::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_fb() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFB as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFB\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_fc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFC as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFC\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_fd() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFD as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFD\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_fe() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFE as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFE\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_int_lenenc_ff() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc(0xFF as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_string_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_lenenc("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_string_null() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_nul("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string\0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encodes_byte_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_bytes_lenenc(b"random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
use std::{u16, u32, u64, u8};
|
||||
|
||||
use byteorder::ByteOrder;
|
||||
|
||||
use crate::io::BufMut;
|
||||
|
||||
pub trait BufMutExt {
|
||||
fn put_uint_lenenc<T: ByteOrder, U: Into<Option<u64>>>(&mut self, val: U);
|
||||
|
||||
fn put_str_lenenc<T: ByteOrder>(&mut self, val: &str);
|
||||
|
||||
fn put_bytes_lenenc<T: ByteOrder>(&mut self, val: &[u8]);
|
||||
}
|
||||
|
||||
impl BufMutExt for Vec<u8> {
|
||||
fn put_uint_lenenc<T: ByteOrder, U: Into<Option<u64>>>(&mut self, value: U) {
|
||||
if let Some(value) = value.into() {
|
||||
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
|
||||
if value > 0xFF_FF_FF {
|
||||
// Integer value is encoded in the next 8 bytes (9 bytes total)
|
||||
self.push(0xFE);
|
||||
self.put_u64::<T>(value);
|
||||
} else if value > u64::from(u16::MAX) {
|
||||
// Integer value is encoded in the next 3 bytes (4 bytes total)
|
||||
self.push(0xFD);
|
||||
self.put_u24::<T>(value as u32);
|
||||
} else if value > u64::from(u8::MAX) {
|
||||
// Integer value is encoded in the next 2 bytes (3 bytes total)
|
||||
self.push(0xFC);
|
||||
self.put_u16::<T>(value as u16);
|
||||
} else {
|
||||
match value {
|
||||
// If the value is of size u8 and one of the key bytes used in length encoding
|
||||
// we must put that single byte as a u16
|
||||
0xFB | 0xFC | 0xFD | 0xFE | 0xFF => {
|
||||
self.push(0xFC);
|
||||
self.put_u16::<T>(value as u16);
|
||||
}
|
||||
|
||||
_ => {
|
||||
self.push(value as u8);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
self.push(0xFB);
|
||||
}
|
||||
}
|
||||
|
||||
fn put_str_lenenc<T: ByteOrder>(&mut self, val: &str) {
|
||||
self.put_uint_lenenc::<T, _>(val.len() as u64);
|
||||
self.extend_from_slice(val.as_bytes());
|
||||
}
|
||||
|
||||
fn put_bytes_lenenc<T: ByteOrder>(&mut self, val: &[u8]) {
|
||||
self.put_uint_lenenc::<T, _>(val.len() as u64);
|
||||
self.extend_from_slice(val);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{BufMut, BufMutExt};
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_none() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(None);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFB");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFA as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFA");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(std::u16::MAX as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFF_FF_FF as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(std::u64::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fb() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFB as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFB\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFC as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFC\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fd() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFD as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFD\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fe() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFE as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFE\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_ff() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFF as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u64::<LittleEndian>(std::u64::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u32() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u32::<LittleEndian>(std::u32::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u24::<LittleEndian>(0xFF_FF_FF as u32);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u16::<LittleEndian>(std::u16::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u8(std::u8::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_lenenc::<LittleEndian>("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_fix() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_null() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_nul("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string\0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_byte_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_bytes_lenenc::<LittleEndian>(b"random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
mod buf_ext;
|
||||
mod buf_mut_ext;
|
||||
mod buf;
|
||||
mod buf_mut;
|
||||
|
||||
pub use buf_ext::BufExt;
|
||||
pub use buf_mut_ext::BufMutExt;
|
||||
pub use buf::MySqlBufExt;
|
||||
pub use buf_mut::MySqlBufMutExt;
|
||||
|
||||
@@ -1,35 +1,25 @@
|
||||
//! **MySQL** database and connection types.
|
||||
|
||||
pub use arguments::MySqlArguments;
|
||||
pub use connection::MySqlConnection;
|
||||
pub use cursor::MySqlCursor;
|
||||
pub use database::MySql;
|
||||
pub use error::MySqlError;
|
||||
pub use row::MySqlRow;
|
||||
pub use type_info::MySqlTypeInfo;
|
||||
pub use value::{MySqlData, MySqlValue};
|
||||
//! **MySQL** database driver.
|
||||
|
||||
mod arguments;
|
||||
mod connection;
|
||||
mod cursor;
|
||||
mod database;
|
||||
mod error;
|
||||
mod executor;
|
||||
mod io;
|
||||
mod options;
|
||||
mod protocol;
|
||||
mod row;
|
||||
mod rsa;
|
||||
mod stream;
|
||||
mod tls;
|
||||
mod type_info;
|
||||
pub mod types;
|
||||
mod util;
|
||||
mod value;
|
||||
|
||||
/// An alias for [`crate::pool::Pool`], specialized for **MySQL**.
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
|
||||
pub type MySqlPool = crate::pool::Pool<MySqlConnection>;
|
||||
pub use arguments::MySqlArguments;
|
||||
pub use connection::MySqlConnection;
|
||||
pub use database::MySql;
|
||||
pub use error::MySqlDatabaseError;
|
||||
pub use options::{MySqlConnectOptions, MySqlSslMode};
|
||||
pub use row::MySqlRow;
|
||||
pub use type_info::MySqlTypeInfo;
|
||||
pub use value::{MySqlValue, MySqlValueFormat, MySqlValueRef};
|
||||
|
||||
make_query_as!(MySqlQueryAs, MySql, MySqlRow);
|
||||
impl_map_row_for_row!(MySql, MySqlRow);
|
||||
impl_from_row_for_tuples!(MySql, MySqlRow);
|
||||
/// An alias for [`Pool`][crate::pool::Pool], specialized for MySQL.
|
||||
pub type MySqlPool = crate::pool::Pool<MySqlConnection>;
|
||||
|
||||
232
sqlx-core/src/mysql/options.rs
Normal file
232
sqlx-core/src/mysql/options.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
use url::Url;
|
||||
|
||||
use crate::error::{BoxDynError, Error};
|
||||
|
||||
/// Options for controlling the desired security state of the connection to the MySQL server.
|
||||
///
|
||||
/// It is used by the [`ssl_mode`](MySqlConnectOptions::ssl_mode) method.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum MySqlSslMode {
|
||||
/// Establish an unencrypted connection.
|
||||
Disabled,
|
||||
|
||||
/// Establish an encrypted connection if the server supports encrypted connections, falling
|
||||
/// back to an unencrypted connection if an encrypted connection cannot be established.
|
||||
///
|
||||
/// This is the default if `ssl_mode` is not specified.
|
||||
Preferred,
|
||||
|
||||
/// Establish an encrypted connection if the server supports encrypted connections.
|
||||
/// The connection attempt fails if an encrypted connection cannot be established.
|
||||
Required,
|
||||
|
||||
/// Like `Required`, but additionally verify the server Certificate Authority (CA)
|
||||
/// certificate against the configured CA certificates. The connection attempt fails
|
||||
/// if no valid matching CA certificates are found.
|
||||
VerifyCa,
|
||||
|
||||
/// Like `VerifyCa`, but additionally perform host name identity verification by
|
||||
/// checking the host name the client uses for connecting to the server against the
|
||||
/// identity in the certificate that the server sends to the client.
|
||||
VerifyIdentity,
|
||||
}
|
||||
|
||||
impl Default for MySqlSslMode {
|
||||
fn default() -> Self {
|
||||
MySqlSslMode::Preferred
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for MySqlSslMode {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Error> {
|
||||
Ok(match s {
|
||||
"DISABLED" => MySqlSslMode::Disabled,
|
||||
"PREFERRED" => MySqlSslMode::Preferred,
|
||||
"REQUIRED" => MySqlSslMode::Required,
|
||||
"VERIFY_CA" => MySqlSslMode::VerifyCa,
|
||||
"VERIFY_IDENTITY" => MySqlSslMode::VerifyIdentity,
|
||||
|
||||
_ => {
|
||||
return Err(err_protocol!("unknown SSL mode value: {:?}", s));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Options and flags which can be used to configure a MySQL connection.
|
||||
///
|
||||
/// A value of `PgConnectOptions` can be parsed from a connection URI,
|
||||
/// as described by [MySQL](https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-reference-jdbc-url-format.html).
|
||||
///
|
||||
/// The generic format of the connection URL:
|
||||
///
|
||||
/// ```text
|
||||
/// mysql://[host][/database][?properties]
|
||||
/// ```
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// # use sqlx_core::error::Error;
|
||||
/// # use sqlx_core::connection::Connect;
|
||||
/// # use sqlx_core::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
|
||||
/// #
|
||||
/// # #[sqlx_rt::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// // URI connection string
|
||||
/// let conn = MySqlConnection::connect("mysql://root:password@localhost/db").await?;
|
||||
///
|
||||
/// // Manually-constructed options
|
||||
/// let conn = MySqlConnection::connect_with(&MySqlConnectOptions::new()
|
||||
/// .host("localhost")
|
||||
/// .username("root")
|
||||
/// .password("password")
|
||||
/// .database("db")
|
||||
/// ).await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MySqlConnectOptions {
|
||||
pub(crate) host: String,
|
||||
pub(crate) port: u16,
|
||||
pub(crate) username: String,
|
||||
pub(crate) password: Option<String>,
|
||||
pub(crate) database: Option<String>,
|
||||
pub(crate) ssl_mode: MySqlSslMode,
|
||||
pub(crate) ssl_ca: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl MySqlConnectOptions {
|
||||
/// Creates a new, default set of options ready for configuration
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
port: 3306,
|
||||
host: String::from("localhost"),
|
||||
username: String::from("root"),
|
||||
password: None,
|
||||
database: None,
|
||||
ssl_mode: MySqlSslMode::Preferred,
|
||||
ssl_ca: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the name of the host to connect to.
|
||||
///
|
||||
/// The default behavior when the host is not specified,
|
||||
/// is to connect to localhost.
|
||||
pub fn host(mut self, host: &str) -> Self {
|
||||
self.host = host.to_owned();
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the port to connect to at the server host.
|
||||
///
|
||||
/// The default port for MySQL is `3306`.
|
||||
pub fn port(mut self, port: u16) -> Self {
|
||||
self.port = port;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the username to connect as.
|
||||
pub fn username(mut self, username: &str) -> Self {
|
||||
self.username = username.to_owned();
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the password to connect with.
|
||||
pub fn password(mut self, password: &str) -> Self {
|
||||
self.password = Some(password.to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the database name.
|
||||
pub fn database(mut self, database: &str) -> Self {
|
||||
self.database = Some(database.to_owned());
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated
|
||||
/// with the server.
|
||||
///
|
||||
/// By default, the SSL mode is [`Preferred`](MySqlSslMode::Preferred), and the client will
|
||||
/// first attempt an SSL connection but fallback to a non-SSL connection on failure.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
|
||||
/// let options = MySqlConnectOptions::new()
|
||||
/// .ssl_mode(MySqlSslMode::Required);
|
||||
/// ```
|
||||
pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self {
|
||||
self.ssl_mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the name of a file containing a list of trusted SSL Certificate Authorities.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
|
||||
/// let options = MySqlConnectOptions::new()
|
||||
/// .ssl_mode(MySqlSslMode::VerifyCa)
|
||||
/// .ssl_ca("path/to/ca.crt");
|
||||
/// ```
|
||||
pub fn ssl_ca(mut self, file_name: impl AsRef<Path>) -> Self {
|
||||
self.ssl_ca = Some(file_name.as_ref().to_owned());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for MySqlConnectOptions {
|
||||
type Err = BoxDynError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, BoxDynError> {
|
||||
let url: Url = s.parse()?;
|
||||
let mut options = Self::new();
|
||||
|
||||
if let Some(host) = url.host_str() {
|
||||
options = options.host(host);
|
||||
}
|
||||
|
||||
if let Some(port) = url.port() {
|
||||
options = options.port(port);
|
||||
}
|
||||
|
||||
let username = url.username();
|
||||
if !username.is_empty() {
|
||||
options = options.username(username);
|
||||
}
|
||||
|
||||
if let Some(password) = url.password() {
|
||||
options = options.password(password);
|
||||
}
|
||||
|
||||
let path = url.path().trim_start_matches('/');
|
||||
if !path.is_empty() {
|
||||
options = options.database(path);
|
||||
}
|
||||
|
||||
for (key, value) in url.query_pairs().into_iter() {
|
||||
match &*key {
|
||||
"ssl-mode" => {
|
||||
options = options.ssl_mode(value.parse()?);
|
||||
}
|
||||
|
||||
"ssl-ca" => {
|
||||
options = options.ssl_ca(&*value);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
}
|
||||
34
sqlx-core/src/mysql/protocol/auth.rs
Normal file
34
sqlx-core/src/mysql/protocol/auth.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum AuthPlugin {
|
||||
MySqlNativePassword,
|
||||
CachingSha2Password,
|
||||
Sha256Password,
|
||||
}
|
||||
|
||||
impl AuthPlugin {
|
||||
pub(crate) fn name(self) -> &'static str {
|
||||
match self {
|
||||
AuthPlugin::MySqlNativePassword => "mysql_native_password",
|
||||
AuthPlugin::CachingSha2Password => "caching_sha2_password",
|
||||
AuthPlugin::Sha256Password => "sha256_password",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AuthPlugin {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword),
|
||||
"caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password),
|
||||
"sha256_password" => Ok(AuthPlugin::Sha256Password),
|
||||
|
||||
_ => Err(err_protocol!("unknown authentication plugin: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
use digest::{Digest, FixedOutput};
|
||||
use generic_array::GenericArray;
|
||||
use memchr::memchr;
|
||||
use sha1::Sha1;
|
||||
use sha2::Sha256;
|
||||
|
||||
use crate::mysql::util::xor_eq;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum AuthPlugin {
|
||||
MySqlNativePassword,
|
||||
CachingSha2Password,
|
||||
Sha256Password,
|
||||
}
|
||||
|
||||
impl AuthPlugin {
|
||||
pub(crate) fn from_opt_str(s: Option<&str>) -> crate::Result<AuthPlugin> {
|
||||
match s {
|
||||
Some("mysql_native_password") | None => Ok(AuthPlugin::MySqlNativePassword),
|
||||
Some("caching_sha2_password") => Ok(AuthPlugin::CachingSha2Password),
|
||||
Some("sha256_password") => Ok(AuthPlugin::Sha256Password),
|
||||
|
||||
Some(s) => {
|
||||
Err(protocol_err!("requires unimplemented authentication plugin: {}", s).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
AuthPlugin::MySqlNativePassword => "mysql_native_password",
|
||||
AuthPlugin::CachingSha2Password => "caching_sha2_password",
|
||||
AuthPlugin::Sha256Password => "sha256_password",
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scramble(&self, password: &str, nonce: &[u8]) -> Vec<u8> {
|
||||
match self {
|
||||
AuthPlugin::MySqlNativePassword => {
|
||||
// The [nonce] for mysql_native_password is (optionally) nul terminated
|
||||
let end = memchr(b'\0', nonce).unwrap_or(nonce.len());
|
||||
|
||||
scramble_sha1(password, &nonce[..end]).to_vec()
|
||||
}
|
||||
AuthPlugin::CachingSha2Password => scramble_sha256(password, nonce).to_vec(),
|
||||
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scramble_sha1(
|
||||
password: &str,
|
||||
seed: &[u8],
|
||||
) -> GenericArray<u8, <Sha1 as FixedOutput>::OutputSize> {
|
||||
// SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
|
||||
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
|
||||
|
||||
let mut ctx = Sha1::new();
|
||||
|
||||
ctx.input(password);
|
||||
|
||||
let mut pw_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(&pw_hash);
|
||||
|
||||
let pw_hash_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(seed);
|
||||
ctx.input(pw_hash_hash);
|
||||
|
||||
let pw_seed_hash_hash = ctx.result();
|
||||
|
||||
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
|
||||
|
||||
pw_hash
|
||||
}
|
||||
|
||||
fn scramble_sha256(
|
||||
password: &str,
|
||||
seed: &[u8],
|
||||
) -> GenericArray<u8, <Sha256 as FixedOutput>::OutputSize> {
|
||||
// XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password))))
|
||||
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password
|
||||
let mut ctx = Sha256::new();
|
||||
|
||||
ctx.input(password);
|
||||
|
||||
let mut pw_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(&pw_hash);
|
||||
|
||||
let pw_hash_hash = ctx.result_reset();
|
||||
|
||||
ctx.input(seed);
|
||||
ctx.input(pw_hash_hash);
|
||||
|
||||
let pw_seed_hash_hash = ctx.result();
|
||||
|
||||
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
|
||||
|
||||
pw_hash
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::protocol::AuthPlugin;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AuthSwitch {
|
||||
pub(crate) auth_plugin: AuthPlugin,
|
||||
pub(crate) auth_plugin_data: Box<[u8]>,
|
||||
}
|
||||
|
||||
impl AuthSwitch {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0xFE {
|
||||
return Err(protocol_err!(
|
||||
"expected AUTH SWITCH (0xFE); received 0x{:X}",
|
||||
header
|
||||
))?;
|
||||
}
|
||||
|
||||
let auth_plugin = AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?;
|
||||
let auth_plugin_data = buf.get_bytes(buf.len())?.to_owned().into_boxed_slice();
|
||||
|
||||
Ok(Self {
|
||||
auth_plugin_data,
|
||||
auth_plugin,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::mysql::io::BufExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ColumnCount {
|
||||
pub columns: u64,
|
||||
}
|
||||
|
||||
impl ColumnCount {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
|
||||
let columns = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
|
||||
|
||||
Ok(Self { columns })
|
||||
}
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::io::BufExt;
|
||||
use crate::mysql::protocol::{FieldFlags, TypeId};
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html
|
||||
// https://mariadb.com/kb/en/resultset/#column-definition-packet
|
||||
#[derive(Debug)]
|
||||
pub struct ColumnDefinition {
|
||||
pub schema: Option<Box<str>>,
|
||||
|
||||
pub table_alias: Option<Box<str>>,
|
||||
pub table: Option<Box<str>>,
|
||||
|
||||
pub column_alias: Option<Box<str>>,
|
||||
pub column: Option<Box<str>>,
|
||||
|
||||
pub char_set: u16,
|
||||
|
||||
pub max_size: u32,
|
||||
|
||||
pub type_id: TypeId,
|
||||
|
||||
pub flags: FieldFlags,
|
||||
|
||||
pub decimals: u8,
|
||||
}
|
||||
|
||||
impl ColumnDefinition {
|
||||
pub fn name(&self) -> Option<&str> {
|
||||
self.column_alias.as_deref().or(self.column.as_deref())
|
||||
}
|
||||
}
|
||||
|
||||
impl ColumnDefinition {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
|
||||
// catalog : string<lenenc>
|
||||
let catalog = buf.get_str_lenenc::<LittleEndian>()?;
|
||||
|
||||
if catalog != Some("def") {
|
||||
return Err(protocol_err!(
|
||||
"expected ColumnDefinition (\"def\"); received {:?}",
|
||||
catalog
|
||||
))?;
|
||||
}
|
||||
|
||||
let schema = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
let table_alias = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
let table = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
let column_alias = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
let column = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
|
||||
let len_fixed_fields = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
|
||||
|
||||
if len_fixed_fields != 0x0c {
|
||||
return Err(protocol_err!(
|
||||
"expected ColumnDefinition (0x0c); received {:?}",
|
||||
len_fixed_fields
|
||||
))?;
|
||||
}
|
||||
|
||||
let char_set = buf.get_u16::<LittleEndian>()?;
|
||||
let max_size = buf.get_u32::<LittleEndian>()?;
|
||||
|
||||
let type_id = buf.get_u8()?;
|
||||
let flags = buf.get_u16::<LittleEndian>()?;
|
||||
let decimals = buf.get_u8()?;
|
||||
|
||||
Ok(Self {
|
||||
schema,
|
||||
table,
|
||||
table_alias,
|
||||
column,
|
||||
column_alias,
|
||||
char_set,
|
||||
max_size,
|
||||
type_id: TypeId(type_id),
|
||||
flags: FieldFlags::from_bits_truncate(flags),
|
||||
decimals,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::protocol::{Capabilities, Encode};
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-ping.html
|
||||
#[derive(Debug)]
|
||||
pub struct ComPing;
|
||||
|
||||
impl Encode for ComPing {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
// COM_PING : int<1>
|
||||
buf.put_u8(0x0e);
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::protocol::{Capabilities, Encode};
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query.html
|
||||
#[derive(Debug)]
|
||||
pub struct ComQuery<'a> {
|
||||
pub query: &'a str,
|
||||
}
|
||||
|
||||
impl Encode for ComQuery<'_> {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
// COM_QUERY : int<1>
|
||||
buf.put_u8(0x03);
|
||||
|
||||
// query : string<EOF>
|
||||
buf.put_str(self.query);
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::protocol::{Capabilities, Encode};
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
|
||||
bitflags::bitflags! {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a3e5e9e744ff6f7b989a604fd669977da
|
||||
// https://mariadb.com/kb/en/library/com_stmt_execute/#flag
|
||||
pub struct Cursor: u8 {
|
||||
const NO_CURSOR = 0;
|
||||
const READ_ONLY = 1;
|
||||
const FOR_UPDATE = 2;
|
||||
const SCROLLABLE = 4;
|
||||
}
|
||||
}
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html
|
||||
#[derive(Debug)]
|
||||
pub struct ComStmtExecute<'a> {
|
||||
pub statement_id: u32,
|
||||
pub cursor: Cursor,
|
||||
pub params: &'a [u8],
|
||||
pub null_bitmap: &'a [u8],
|
||||
pub param_types: &'a [MySqlTypeInfo],
|
||||
}
|
||||
|
||||
impl Encode for ComStmtExecute<'_> {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
// COM_STMT_EXECUTE : int<1>
|
||||
buf.put_u8(0x17);
|
||||
|
||||
// statement_id : int<4>
|
||||
buf.put_u32::<LittleEndian>(self.statement_id);
|
||||
|
||||
// cursor : int<1>
|
||||
buf.put_u8(self.cursor.bits());
|
||||
|
||||
// iterations (always 1) : int<4>
|
||||
buf.put_u32::<LittleEndian>(1);
|
||||
|
||||
if !self.param_types.is_empty() {
|
||||
// null bitmap : byte<(param_count + 7)/8>
|
||||
buf.put_bytes(self.null_bitmap);
|
||||
|
||||
// send type to server (0 / 1) : byte<1>
|
||||
buf.put_u8(1);
|
||||
|
||||
for ty in self.param_types {
|
||||
// field type : byte<1>
|
||||
buf.put_u8(ty.id.0);
|
||||
|
||||
// parameter flag : byte<1>
|
||||
buf.put_u8(if ty.is_unsigned { 0x80 } else { 0 });
|
||||
}
|
||||
|
||||
// byte<n> binary parameter value
|
||||
buf.put_bytes(self.params);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::protocol::{Capabilities, Encode};
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html
|
||||
#[derive(Debug)]
|
||||
pub struct ComStmtPrepare<'a> {
|
||||
pub query: &'a str,
|
||||
}
|
||||
|
||||
impl Encode for ComStmtPrepare<'_> {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
// COM_STMT_PREPARE : int<1>
|
||||
buf.put_u8(0x16);
|
||||
|
||||
// query : string<EOF>
|
||||
buf.put_str(self.query);
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::Buf;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response_ok
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ComStmtPrepareOk {
|
||||
pub(crate) statement_id: u32,
|
||||
|
||||
/// Number of columns in the returned result set (or 0 if statement
|
||||
/// does not return result set).
|
||||
pub(crate) columns: u16,
|
||||
|
||||
/// Number of prepared statement parameters ('?' placeholders).
|
||||
pub(crate) params: u16,
|
||||
|
||||
/// Number of warnings.
|
||||
pub(crate) warnings: u16,
|
||||
}
|
||||
|
||||
impl ComStmtPrepareOk {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
|
||||
let header = buf.get_u8()?;
|
||||
|
||||
if header != 0x00 {
|
||||
return Err(protocol_err!(
|
||||
"expected COM_STMT_PREPARE_OK (0x00); received 0x{:X}",
|
||||
header
|
||||
))?;
|
||||
}
|
||||
|
||||
let statement_id = buf.get_u32::<LittleEndian>()?;
|
||||
let columns = buf.get_u16::<LittleEndian>()?;
|
||||
let params = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
// -not used- : string<1>
|
||||
buf.advance(1);
|
||||
|
||||
let warnings = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
Ok(Self {
|
||||
statement_id,
|
||||
columns,
|
||||
params,
|
||||
warnings,
|
||||
})
|
||||
}
|
||||
}
|
||||
41
sqlx-core/src/mysql/protocol/connect/auth_switch.rs
Normal file
41
sqlx-core/src/mysql/protocol/connect/auth_switch.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Encode;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::mysql::protocol::auth::AuthPlugin;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthSwitchRequest {
|
||||
pub plugin: AuthPlugin,
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
impl Decode<'_> for AuthSwitchRequest {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
let header = buf.get_u8();
|
||||
if header != 0xfe {
|
||||
return Err(err_protocol!(
|
||||
"expected 0xfe (AUTH_SWITCH) but found 0x{:x}",
|
||||
header
|
||||
));
|
||||
}
|
||||
|
||||
let plugin = buf.get_str_nul()?.parse()?;
|
||||
let data = buf.get_bytes(buf.len());
|
||||
|
||||
Ok(Self { plugin, data })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthSwitchResponse(pub Vec<u8>);
|
||||
|
||||
impl Encode<'_, Capabilities> for AuthSwitchResponse {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.extend_from_slice(&self.0);
|
||||
}
|
||||
}
|
||||
194
sqlx-core/src/mysql/protocol/connect/handshake.rs
Normal file
194
sqlx-core/src/mysql/protocol/connect/handshake.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use bytes::buf::ext::Chain;
|
||||
use bytes::buf::BufExt as _;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::mysql::protocol::auth::AuthPlugin;
|
||||
use crate::mysql::protocol::response::Status;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Handshake {
|
||||
pub(crate) protocol_version: u8,
|
||||
pub(crate) server_version: String,
|
||||
pub(crate) connection_id: u32,
|
||||
pub(crate) server_capabilities: Capabilities,
|
||||
pub(crate) server_default_collation: u8,
|
||||
pub(crate) status: Status,
|
||||
pub(crate) auth_plugin: Option<AuthPlugin>,
|
||||
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
|
||||
}
|
||||
|
||||
impl Decode<'_> for Handshake {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
let protocol_version = buf.get_u8(); // int<1>
|
||||
let server_version = buf.get_str_nul()?; // string<NUL>
|
||||
let connection_id = buf.get_u32_le(); // int<4>
|
||||
let auth_plugin_data_1 = buf.get_bytes(8); // string<8>
|
||||
|
||||
buf.advance(1); // reserved: string<1>
|
||||
|
||||
let capabilities_1 = buf.get_u16_le(); // int<2>
|
||||
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
|
||||
|
||||
let collation = buf.get_u8(); // int<1>
|
||||
let status = Status::from_bits_truncate(buf.get_u16_le());
|
||||
|
||||
let capabilities_2 = buf.get_u16_le(); // int<2>
|
||||
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
|
||||
|
||||
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
buf.get_u8()
|
||||
} else {
|
||||
buf.advance(1); // int<1>
|
||||
0
|
||||
};
|
||||
|
||||
buf.advance(6); // reserved: string<6>
|
||||
|
||||
if capabilities.contains(Capabilities::MYSQL) {
|
||||
buf.advance(4); // reserved: string<4>
|
||||
} else {
|
||||
let capabilities_3 = buf.get_u32_le(); // int<4>
|
||||
capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
|
||||
}
|
||||
|
||||
let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
|
||||
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize;
|
||||
let v = buf.get_bytes(len);
|
||||
buf.advance(1); // NUL-terminator
|
||||
|
||||
v
|
||||
} else {
|
||||
Bytes::new()
|
||||
};
|
||||
|
||||
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
Some(buf.get_str_nul()?.parse()?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
protocol_version,
|
||||
server_version,
|
||||
connection_id,
|
||||
server_default_collation: collation,
|
||||
status,
|
||||
server_capabilities: capabilities,
|
||||
auth_plugin,
|
||||
auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_handshake_mysql_8_0_18() {
|
||||
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
|
||||
|
||||
let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
|
||||
|
||||
assert_eq!(p.protocol_version, 10);
|
||||
|
||||
p.server_capabilities.toggle(
|
||||
Capabilities::MYSQL
|
||||
| Capabilities::FOUND_ROWS
|
||||
| Capabilities::LONG_FLAG
|
||||
| Capabilities::CONNECT_WITH_DB
|
||||
| Capabilities::NO_SCHEMA
|
||||
| Capabilities::COMPRESS
|
||||
| Capabilities::ODBC
|
||||
| Capabilities::LOCAL_FILES
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::PROTOCOL_41
|
||||
| Capabilities::INTERACTIVE
|
||||
| Capabilities::SSL
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::CONNECT_ATTRS
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
|
||||
| Capabilities::SESSION_TRACK
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
|
||||
| Capabilities::SSL_VERIFY_SERVER_CERT
|
||||
| Capabilities::OPTIONAL_RESULTSET_METADATA
|
||||
| Capabilities::REMEMBER_OPTIONS,
|
||||
);
|
||||
|
||||
assert!(p.server_capabilities.is_empty());
|
||||
|
||||
assert_eq!(p.server_default_collation, 255);
|
||||
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
|
||||
|
||||
assert!(matches!(
|
||||
p.auth_plugin,
|
||||
Some(AuthPlugin::CachingSha2Password)
|
||||
));
|
||||
|
||||
assert_eq!(
|
||||
&*p.auth_plugin_data.to_bytes(),
|
||||
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_handshake_mariadb_10_4_7() {
|
||||
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
|
||||
|
||||
let mut p = Handshake::decode(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
|
||||
|
||||
assert_eq!(p.protocol_version, 10);
|
||||
|
||||
assert_eq!(
|
||||
&*p.server_version,
|
||||
"5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
|
||||
);
|
||||
|
||||
p.server_capabilities.toggle(
|
||||
Capabilities::FOUND_ROWS
|
||||
| Capabilities::LONG_FLAG
|
||||
| Capabilities::CONNECT_WITH_DB
|
||||
| Capabilities::NO_SCHEMA
|
||||
| Capabilities::COMPRESS
|
||||
| Capabilities::ODBC
|
||||
| Capabilities::LOCAL_FILES
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::PROTOCOL_41
|
||||
| Capabilities::INTERACTIVE
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::CONNECT_ATTRS
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
|
||||
| Capabilities::SESSION_TRACK
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::REMEMBER_OPTIONS,
|
||||
);
|
||||
|
||||
assert!(p.server_capabilities.is_empty());
|
||||
|
||||
assert_eq!(p.server_default_collation, 8);
|
||||
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
|
||||
assert!(matches!(
|
||||
p.auth_plugin,
|
||||
Some(AuthPlugin::MySqlNativePassword)
|
||||
));
|
||||
|
||||
assert_eq!(
|
||||
&*p.auth_plugin_data.to_bytes(),
|
||||
&[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,]
|
||||
);
|
||||
}
|
||||
73
sqlx-core/src/mysql/protocol/connect/handshake_response.rs
Normal file
73
sqlx-core/src/mysql/protocol/connect/handshake_response.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use crate::io::{BufMutExt, Encode};
|
||||
use crate::mysql::io::MySqlBufMutExt;
|
||||
use crate::mysql::protocol::auth::AuthPlugin;
|
||||
use crate::mysql::protocol::connect::ssl_request::SslRequest;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
// https://mariadb.com/kb/en/connection/#client-handshake-response
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HandshakeResponse<'a> {
|
||||
pub database: Option<&'a str>,
|
||||
|
||||
/// Max size of a command packet that the client wants to send to the server
|
||||
pub max_packet_size: u32,
|
||||
|
||||
/// Default character set for the connection
|
||||
pub char_set: u8,
|
||||
|
||||
/// Name of the SQL account which client wants to log in
|
||||
pub username: &'a str,
|
||||
|
||||
/// Authentication method used by the client
|
||||
pub auth_plugin: Option<AuthPlugin>,
|
||||
|
||||
/// Opaque authentication response
|
||||
pub auth_response: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, mut capabilities: Capabilities) {
|
||||
if self.auth_plugin.is_none() {
|
||||
// ensure PLUGIN_AUTH is set *only* if we have a defined plugin
|
||||
capabilities.remove(Capabilities::PLUGIN_AUTH);
|
||||
}
|
||||
|
||||
// NOTE: Half of this packet is identical to the SSL Request packet
|
||||
SslRequest {
|
||||
max_packet_size: self.max_packet_size,
|
||||
char_set: self.char_set,
|
||||
}
|
||||
.encode_with(buf, capabilities);
|
||||
|
||||
buf.put_str_nul(self.username);
|
||||
|
||||
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
|
||||
buf.put_bytes_lenenc(self.auth_response.unwrap_or_default());
|
||||
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
|
||||
let response = self.auth_response.unwrap_or_default();
|
||||
|
||||
buf.push(response.len() as u8);
|
||||
buf.extend(response);
|
||||
} else {
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
|
||||
if let Some(database) = &self.database {
|
||||
buf.put_str_nul(database);
|
||||
} else {
|
||||
buf.push(0);
|
||||
}
|
||||
}
|
||||
|
||||
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
if let Some(plugin) = &self.auth_plugin {
|
||||
buf.put_str_nul(plugin.name());
|
||||
} else {
|
||||
buf.push(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
13
sqlx-core/src/mysql/protocol/connect/mod.rs
Normal file
13
sqlx-core/src/mysql/protocol/connect/mod.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! Connection Phase
|
||||
//!
|
||||
//! <https://dev.mysql.com/doc/internals/en/connection-phase.html>
|
||||
|
||||
mod auth_switch;
|
||||
mod handshake;
|
||||
mod handshake_response;
|
||||
mod ssl_request;
|
||||
|
||||
pub(crate) use auth_switch::{AuthSwitchRequest, AuthSwitchResponse};
|
||||
pub(crate) use handshake::Handshake;
|
||||
pub(crate) use handshake_response::HandshakeResponse;
|
||||
pub(crate) use ssl_request::SslRequest;
|
||||
30
sqlx-core/src/mysql/protocol/connect/ssl_request.rs
Normal file
30
sqlx-core/src/mysql/protocol/connect/ssl_request.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SslRequest {
|
||||
pub max_packet_size: u32,
|
||||
pub char_set: u8,
|
||||
}
|
||||
|
||||
impl Encode<'_, Capabilities> for SslRequest {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
|
||||
buf.extend(&(capabilities.bits() as u32).to_le_bytes());
|
||||
buf.extend(&self.max_packet_size.to_le_bytes());
|
||||
buf.push(self.char_set);
|
||||
|
||||
// reserved: string<19>
|
||||
buf.extend(&[0_u8; 19]);
|
||||
|
||||
if capabilities.contains(Capabilities::MYSQL) {
|
||||
// reserved: string<4>
|
||||
buf.extend(&[0_u8; 4]);
|
||||
} else {
|
||||
// extended client capabilities (MariaDB-specified): int<4>
|
||||
buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::protocol::Status;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_eof_packet.html
|
||||
// https://mariadb.com/kb/en/eof_packet/
|
||||
#[derive(Debug)]
|
||||
pub struct EofPacket {
|
||||
pub warnings: u16,
|
||||
pub status: Status,
|
||||
}
|
||||
|
||||
impl EofPacket {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0xFE {
|
||||
return Err(protocol_err!(
|
||||
"expected EOF (0xFE); received 0x{:X}",
|
||||
header
|
||||
))?;
|
||||
}
|
||||
|
||||
let warnings = buf.get_u16::<LittleEndian>()?;
|
||||
let status = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
Ok(Self {
|
||||
warnings,
|
||||
status: Status::from_bits_truncate(status),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
|
||||
// https://mariadb.com/kb/en/err_packet/
|
||||
#[derive(Debug)]
|
||||
pub struct ErrPacket {
|
||||
pub error_code: u16,
|
||||
pub sql_state: Option<Box<str>>,
|
||||
pub error_message: Box<str>,
|
||||
}
|
||||
|
||||
impl ErrPacket {
|
||||
pub(crate) fn read(mut buf: &[u8], capabilities: Capabilities) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0xFF {
|
||||
return Err(protocol_err!(
|
||||
"expected 0xFF for ERR_PACKET; received 0x{:X}",
|
||||
header
|
||||
))?;
|
||||
}
|
||||
|
||||
let error_code = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
let mut sql_state = None;
|
||||
|
||||
if capabilities.contains(Capabilities::PROTOCOL_41) {
|
||||
// If the next byte is '#' then we have a SQL STATE
|
||||
if buf.get(0) == Some(&0x23) {
|
||||
buf.advance(1);
|
||||
sql_state = Some(buf.get_str(5)?.into())
|
||||
}
|
||||
}
|
||||
|
||||
let error_message = buf.get_str(buf.len())?.into();
|
||||
|
||||
Ok(Self {
|
||||
error_code,
|
||||
sql_state,
|
||||
error_message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{Capabilities, ErrPacket};
|
||||
|
||||
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
|
||||
|
||||
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
|
||||
|
||||
#[test]
|
||||
fn it_decodes_packets_out_of_order() {
|
||||
let p = ErrPacket::read(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap();
|
||||
|
||||
assert_eq!(&*p.error_message, "Got packets out of order");
|
||||
assert_eq!(p.error_code, 1156);
|
||||
assert_eq!(p.sql_state, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_decodes_ok_handshake() {
|
||||
let p = ErrPacket::read(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap();
|
||||
|
||||
assert_eq!(p.error_code, 1049);
|
||||
assert_eq!(p.sql_state.as_deref(), Some("42000"));
|
||||
assert_eq!(&*p.error_message, "Unknown database \'unknown\'");
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// https://mariadb.com/kb/en/library/resultset/#field-detail-flag
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
|
||||
bitflags::bitflags! {
|
||||
pub struct FieldFlags: u16 {
|
||||
/// Field cannot be NULL
|
||||
const NOT_NULL = 1;
|
||||
|
||||
/// Field is **part of** a primary key
|
||||
const PRIMARY_KEY = 2;
|
||||
|
||||
/// Field is **part of** a unique key/constraint
|
||||
const UNIQUE_KEY = 4;
|
||||
|
||||
/// Field is **part of** a unique or primary key
|
||||
const MULTIPLE_KEY = 8;
|
||||
|
||||
/// Field is a blob.
|
||||
const BLOB = 16;
|
||||
|
||||
/// Field is unsigned
|
||||
const UNSIGNED = 32;
|
||||
|
||||
/// Field is zero filled.
|
||||
const ZEROFILL = 64;
|
||||
|
||||
/// Field is binary (set for strings)
|
||||
const BINARY = 128;
|
||||
|
||||
/// Field is an enumeration
|
||||
const ENUM = 256;
|
||||
|
||||
/// Field is an auto-increment field
|
||||
const AUTO_INCREMENT = 512;
|
||||
|
||||
/// Field is a timestamp
|
||||
const TIMESTAMP = 1024;
|
||||
|
||||
/// Field is a set
|
||||
const SET = 2048;
|
||||
|
||||
/// Field does not have a default value
|
||||
const NO_DEFAULT_VALUE = 4096;
|
||||
|
||||
/// Field is set to NOW on UPDATE
|
||||
const ON_UPDATE_NOW = 8192;
|
||||
|
||||
/// Field is a number
|
||||
const NUM = 32768;
|
||||
}
|
||||
}
|
||||
@@ -1,208 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::protocol::{AuthPlugin, Capabilities, Status};
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_v10.html
|
||||
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Handshake {
|
||||
pub(crate) protocol_version: u8,
|
||||
pub(crate) server_version: Box<str>,
|
||||
pub(crate) connection_id: u32,
|
||||
pub(crate) server_capabilities: Capabilities,
|
||||
pub(crate) server_default_collation: u8,
|
||||
pub(crate) status: Status,
|
||||
pub(crate) auth_plugin: AuthPlugin,
|
||||
pub(crate) auth_plugin_data: Box<[u8]>,
|
||||
}
|
||||
|
||||
impl Handshake {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let protocol_version = buf.get_u8()?;
|
||||
let server_version = buf.get_str_nul()?.into();
|
||||
let connection_id = buf.get_u32::<LittleEndian>()?;
|
||||
|
||||
let mut scramble = Vec::with_capacity(8);
|
||||
|
||||
// scramble first part : string<8>
|
||||
scramble.extend_from_slice(&buf[..8]);
|
||||
buf.advance(8);
|
||||
|
||||
// reserved : string<1>
|
||||
buf.advance(1);
|
||||
|
||||
// capability_flags_1 : int<2>
|
||||
let capabilities_1 = buf.get_u16::<LittleEndian>()?;
|
||||
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
|
||||
|
||||
// character_set : int<1>
|
||||
let char_set = buf.get_u8()?;
|
||||
|
||||
// status_flags : int<2>
|
||||
let status = buf.get_u16::<LittleEndian>()?;
|
||||
let status = Status::from_bits_truncate(status);
|
||||
|
||||
// capability_flags_2 : int<2>
|
||||
let capabilities_2 = buf.get_u16::<LittleEndian>()?;
|
||||
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
|
||||
|
||||
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
// plugin data length : int<1>
|
||||
buf.get_u8()?
|
||||
} else {
|
||||
// 0x00 : int<1>
|
||||
buf.advance(1);
|
||||
0
|
||||
};
|
||||
|
||||
// reserved: string<6>
|
||||
buf.advance(6);
|
||||
|
||||
if capabilities.contains(Capabilities::MYSQL) {
|
||||
// reserved: string<4>
|
||||
buf.advance(4);
|
||||
} else {
|
||||
// capability_flags_3 : int<4>
|
||||
let capabilities_3 = buf.get_u32::<LittleEndian>()?;
|
||||
capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
|
||||
}
|
||||
|
||||
if capabilities.contains(Capabilities::SECURE_CONNECTION) {
|
||||
// scramble 2nd part : string<n> ( Length = max(12, plugin data length - 9) )
|
||||
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize;
|
||||
scramble.extend_from_slice(&buf[..len]);
|
||||
buf.advance(len);
|
||||
|
||||
// reserved : string<1>
|
||||
buf.advance(1);
|
||||
}
|
||||
|
||||
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?
|
||||
} else {
|
||||
AuthPlugin::from_opt_str(None)?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
protocol_version,
|
||||
server_capabilities: capabilities,
|
||||
server_version,
|
||||
server_default_collation: char_set,
|
||||
connection_id,
|
||||
auth_plugin_data: scramble.into_boxed_slice(),
|
||||
auth_plugin,
|
||||
status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{AuthPlugin, Capabilities, Handshake, Status};
|
||||
|
||||
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
|
||||
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
|
||||
|
||||
#[test]
|
||||
fn it_reads_handshake_mysql_8_0_18() {
|
||||
let mut p = Handshake::read(HANDSHAKE_MYSQL_8_0_18).unwrap();
|
||||
|
||||
assert_eq!(p.protocol_version, 10);
|
||||
|
||||
p.server_capabilities.toggle(
|
||||
Capabilities::MYSQL
|
||||
| Capabilities::FOUND_ROWS
|
||||
| Capabilities::LONG_FLAG
|
||||
| Capabilities::CONNECT_WITH_DB
|
||||
| Capabilities::NO_SCHEMA
|
||||
| Capabilities::COMPRESS
|
||||
| Capabilities::ODBC
|
||||
| Capabilities::LOCAL_FILES
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::PROTOCOL_41
|
||||
| Capabilities::INTERACTIVE
|
||||
| Capabilities::SSL
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::CONNECT_ATTRS
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
|
||||
| Capabilities::SESSION_TRACK
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
|
||||
| Capabilities::SSL_VERIFY_SERVER_CERT
|
||||
| Capabilities::OPTIONAL_RESULTSET_METADATA
|
||||
| Capabilities::REMEMBER_OPTIONS,
|
||||
);
|
||||
|
||||
assert!(p.server_capabilities.is_empty());
|
||||
|
||||
assert_eq!(p.server_default_collation, 255);
|
||||
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
|
||||
assert!(matches!(p.auth_plugin, AuthPlugin::CachingSha2Password));
|
||||
|
||||
assert_eq!(
|
||||
&*p.auth_plugin_data,
|
||||
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_reads_handshake_mariadb_10_4_7() {
|
||||
let mut p = Handshake::read(HANDSHAKE_MARIA_DB_10_4_7).unwrap();
|
||||
|
||||
assert_eq!(p.protocol_version, 10);
|
||||
|
||||
assert_eq!(
|
||||
&*p.server_version,
|
||||
"5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
|
||||
);
|
||||
|
||||
p.server_capabilities.toggle(
|
||||
Capabilities::FOUND_ROWS
|
||||
| Capabilities::LONG_FLAG
|
||||
| Capabilities::CONNECT_WITH_DB
|
||||
| Capabilities::NO_SCHEMA
|
||||
| Capabilities::COMPRESS
|
||||
| Capabilities::ODBC
|
||||
| Capabilities::LOCAL_FILES
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::PROTOCOL_41
|
||||
| Capabilities::INTERACTIVE
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::CONNECT_ATTRS
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
|
||||
| Capabilities::SESSION_TRACK
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::REMEMBER_OPTIONS,
|
||||
);
|
||||
|
||||
assert!(p.server_capabilities.is_empty());
|
||||
|
||||
assert_eq!(p.server_default_collation, 8);
|
||||
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
|
||||
assert!(matches!(p.auth_plugin, AuthPlugin::MySqlNativePassword));
|
||||
|
||||
assert_eq!(
|
||||
&*p.auth_plugin_data,
|
||||
&[
|
||||
116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53,
|
||||
110,
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::io::BufMutExt;
|
||||
use crate::mysql::protocol::{AuthPlugin, Capabilities, Encode};
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
|
||||
// https://mariadb.com/kb/en/connection/#handshake-response-packet
|
||||
#[derive(Debug)]
|
||||
pub struct HandshakeResponse<'a> {
|
||||
pub max_packet_size: u32,
|
||||
pub client_collation: u8,
|
||||
pub username: &'a str,
|
||||
pub database: Option<&'a str>,
|
||||
pub auth_plugin: &'a AuthPlugin,
|
||||
pub auth_response: &'a [u8],
|
||||
}
|
||||
|
||||
impl Encode for HandshakeResponse<'_> {
|
||||
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
|
||||
// client capabilities : int<4>
|
||||
buf.put_u32::<LittleEndian>(capabilities.bits() as u32);
|
||||
|
||||
// max packet size : int<4>
|
||||
buf.put_u32::<LittleEndian>(self.max_packet_size);
|
||||
|
||||
// client character collation : int<1>
|
||||
buf.put_u8(self.client_collation);
|
||||
|
||||
// reserved : string<19>
|
||||
buf.advance(19);
|
||||
|
||||
if capabilities.contains(Capabilities::MYSQL) {
|
||||
// reserved : string<4>
|
||||
buf.advance(4);
|
||||
} else {
|
||||
// extended client capabilities : int<4>
|
||||
buf.put_u32::<LittleEndian>((capabilities.bits() >> 32) as u32);
|
||||
}
|
||||
|
||||
// username : string<NUL>
|
||||
buf.put_str_nul(self.username);
|
||||
|
||||
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
|
||||
// auth_response : string<lenenc>
|
||||
buf.put_bytes_lenenc::<LittleEndian>(self.auth_response);
|
||||
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
|
||||
let auth_response = self.auth_response;
|
||||
|
||||
// auth_response_length : int<1>
|
||||
buf.put_u8(auth_response.len() as u8);
|
||||
|
||||
// auth_response : string<{auth_response_length}>
|
||||
buf.put_bytes(auth_response);
|
||||
} else {
|
||||
// no auth : int<1>
|
||||
buf.put_u8(0);
|
||||
}
|
||||
|
||||
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
|
||||
if let Some(database) = self.database {
|
||||
// database : string<NUL>
|
||||
buf.put_str_nul(database);
|
||||
}
|
||||
}
|
||||
|
||||
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
// client_plugin_name : string<NUL>
|
||||
buf.put_str_nul(self.auth_plugin.as_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,59 +1,13 @@
|
||||
mod auth_plugin;
|
||||
pub(crate) mod auth;
|
||||
mod capabilities;
|
||||
mod field;
|
||||
mod status;
|
||||
mod r#type;
|
||||
|
||||
pub(crate) use auth_plugin::AuthPlugin;
|
||||
pub(crate) use capabilities::Capabilities;
|
||||
pub(crate) use field::FieldFlags;
|
||||
pub(crate) use r#type::TypeId;
|
||||
pub(crate) use status::Status;
|
||||
|
||||
mod com_ping;
|
||||
mod com_query;
|
||||
mod com_stmt_execute;
|
||||
mod com_stmt_prepare;
|
||||
mod handshake;
|
||||
|
||||
pub(crate) use com_ping::ComPing;
|
||||
pub(crate) use com_query::ComQuery;
|
||||
pub(crate) use com_stmt_execute::{ComStmtExecute, Cursor};
|
||||
pub(crate) use com_stmt_prepare::ComStmtPrepare;
|
||||
pub(crate) use handshake::Handshake;
|
||||
|
||||
mod auth_switch;
|
||||
mod column_count;
|
||||
mod column_def;
|
||||
mod com_stmt_prepare_ok;
|
||||
mod eof;
|
||||
mod err;
|
||||
mod handshake_response;
|
||||
mod ok;
|
||||
pub(crate) mod connect;
|
||||
mod packet;
|
||||
pub(crate) mod response;
|
||||
mod row;
|
||||
#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))]
|
||||
mod ssl_request;
|
||||
pub(crate) mod rsa;
|
||||
pub(crate) mod statement;
|
||||
pub(crate) mod text;
|
||||
|
||||
pub(crate) use auth_switch::AuthSwitch;
|
||||
pub(crate) use column_count::ColumnCount;
|
||||
pub(crate) use column_def::ColumnDefinition;
|
||||
pub(crate) use com_stmt_prepare_ok::ComStmtPrepareOk;
|
||||
pub(crate) use eof::EofPacket;
|
||||
pub(crate) use err::ErrPacket;
|
||||
pub(crate) use handshake_response::HandshakeResponse;
|
||||
pub(crate) use ok::OkPacket;
|
||||
pub(crate) use capabilities::Capabilities;
|
||||
pub(crate) use packet::Packet;
|
||||
pub(crate) use row::Row;
|
||||
#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))]
|
||||
pub(crate) use ssl_request::SslRequest;
|
||||
|
||||
pub(crate) trait Encode {
|
||||
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities);
|
||||
}
|
||||
|
||||
impl Encode for &'_ [u8] {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
use crate::io::BufMut;
|
||||
|
||||
buf.put_bytes(self);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::io::BufExt;
|
||||
use crate::mysql::protocol::Status;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_ok_packet.html
|
||||
// https://mariadb.com/kb/en/ok_packet/
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct OkPacket {
|
||||
pub(crate) affected_rows: u64,
|
||||
pub(crate) last_insert_id: u64,
|
||||
pub(crate) status: Status,
|
||||
pub(crate) warnings: u16,
|
||||
pub(crate) info: Box<str>,
|
||||
}
|
||||
|
||||
impl OkPacket {
|
||||
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0 && header != 0xFE {
|
||||
return Err(protocol_err!(
|
||||
"expected 0x00 or 0xFE; received 0x{:X}",
|
||||
header
|
||||
))?;
|
||||
}
|
||||
|
||||
let affected_rows = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0); // 0
|
||||
let last_insert_id = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0); // 2
|
||||
let status = Status::from_bits_truncate(buf.get_u16::<LittleEndian>()?); //
|
||||
let warnings = buf.get_u16::<LittleEndian>()?;
|
||||
let info = buf.get_str(buf.len())?.into();
|
||||
|
||||
Ok(Self {
|
||||
affected_rows,
|
||||
last_insert_id,
|
||||
status,
|
||||
warnings,
|
||||
info,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{OkPacket, Status};
|
||||
|
||||
const OK_HANDSHAKE: &[u8] = b"\x00\x00\x00\x02@\x00\x00";
|
||||
|
||||
#[test]
|
||||
fn it_decodes_ok_handshake() {
|
||||
let p = OkPacket::read(OK_HANDSHAKE).unwrap();
|
||||
|
||||
assert_eq!(p.affected_rows, 0);
|
||||
assert_eq!(p.last_insert_id, 0);
|
||||
assert_eq!(p.warnings, 0);
|
||||
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
|
||||
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
|
||||
assert!(p.info.is_empty());
|
||||
}
|
||||
}
|
||||
89
sqlx-core/src/mysql/protocol/packet.rs
Normal file
89
sqlx-core/src/mysql/protocol/packet.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{Decode, Encode};
|
||||
use crate::mysql::protocol::response::{EofPacket, OkPacket};
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Packet<T>(pub(crate) T);
|
||||
|
||||
impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
|
||||
where
|
||||
T: Encode<'en, Capabilities>,
|
||||
{
|
||||
fn encode_with(
|
||||
&self,
|
||||
buf: &mut Vec<u8>,
|
||||
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
|
||||
) {
|
||||
// reserve space to write the prefixed length
|
||||
let offset = buf.len();
|
||||
buf.extend(&[0_u8; 4]);
|
||||
|
||||
// encode the payload
|
||||
self.0.encode_with(buf, capabilities);
|
||||
|
||||
// determine the length of the encoded payload
|
||||
// and write to our reserved space
|
||||
let len = buf.len() - offset - 4;
|
||||
let header = &mut buf[offset..];
|
||||
|
||||
// FIXME: Support larger packets
|
||||
assert!(len < 0xFF_FF_FF);
|
||||
|
||||
header[..4].copy_from_slice(&(len as u32).to_le_bytes());
|
||||
header[3] = *sequence_id;
|
||||
|
||||
*sequence_id = sequence_id.wrapping_add(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet<Bytes> {
|
||||
pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
|
||||
where
|
||||
T: Decode<'de, ()>,
|
||||
{
|
||||
self.decode_with(())
|
||||
}
|
||||
|
||||
pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
|
||||
where
|
||||
T: Decode<'de, C>,
|
||||
{
|
||||
T::decode_with(self.0, context)
|
||||
}
|
||||
|
||||
pub(crate) fn ok(self) -> Result<OkPacket, Error> {
|
||||
self.decode()
|
||||
}
|
||||
|
||||
pub(crate) fn eof(self, capabilities: Capabilities) -> Result<EofPacket, Error> {
|
||||
if capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
let ok = self.ok()?;
|
||||
|
||||
Ok(EofPacket {
|
||||
warnings: ok.warnings,
|
||||
status: ok.status,
|
||||
})
|
||||
} else {
|
||||
self.decode_with(capabilities)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Packet<Bytes> {
|
||||
type Target = Bytes;
|
||||
|
||||
fn deref(&self) -> &Bytes {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for Packet<Bytes> {
|
||||
fn deref_mut(&mut self) -> &mut Bytes {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
35
sqlx-core/src/mysql/protocol/response/eof.rs
Normal file
35
sqlx-core/src/mysql/protocol/response/eof.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::mysql::protocol::response::Status;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
/// Marks the end of a result set, returning status and warnings.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL
|
||||
/// prior MySQL versions.
|
||||
#[derive(Debug)]
|
||||
pub struct EofPacket {
|
||||
pub warnings: u16,
|
||||
pub status: Status,
|
||||
}
|
||||
|
||||
impl Decode<'_, Capabilities> for EofPacket {
|
||||
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
|
||||
let header = buf.get_u8();
|
||||
if header != 0xfe {
|
||||
return Err(err_protocol!(
|
||||
"expected 0xfe (EOF_Packet) but found 0x{:x}",
|
||||
header
|
||||
));
|
||||
}
|
||||
|
||||
let warnings = buf.get_u16_le();
|
||||
let status = Status::from_bits_truncate(buf.get_u16_le());
|
||||
|
||||
Ok(Self { status, warnings })
|
||||
}
|
||||
}
|
||||
71
sqlx-core/src/mysql/protocol/response/err.rs
Normal file
71
sqlx-core/src/mysql/protocol/response/err.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
|
||||
// https://mariadb.com/kb/en/err_packet/
|
||||
|
||||
/// Indicates that an error occurred.
|
||||
#[derive(Debug)]
|
||||
pub struct ErrPacket {
|
||||
pub error_code: u16,
|
||||
pub sql_state: Option<String>,
|
||||
pub error_message: String,
|
||||
}
|
||||
|
||||
impl Decode<'_, Capabilities> for ErrPacket {
|
||||
fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self, Error> {
|
||||
let header = buf.get_u8();
|
||||
if header != 0xff {
|
||||
return Err(err_protocol!(
|
||||
"expected 0xff (ERR_Packet) but found 0x{:x}",
|
||||
header
|
||||
));
|
||||
}
|
||||
|
||||
let error_code = buf.get_u16_le();
|
||||
let mut sql_state = None;
|
||||
|
||||
if capabilities.contains(Capabilities::PROTOCOL_41) {
|
||||
// If the next byte is '#' then we have a SQL STATE
|
||||
if buf.get(0) == Some(&0x23) {
|
||||
buf.advance(1);
|
||||
sql_state = Some(buf.get_str(5)?.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let error_message = buf.get_str(buf.len())?.to_owned();
|
||||
|
||||
Ok(Self {
|
||||
error_code,
|
||||
sql_state,
|
||||
error_message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_err_packet_out_of_order() {
|
||||
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
|
||||
|
||||
let p =
|
||||
ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap();
|
||||
|
||||
assert_eq!(&p.error_message, "Got packets out of order");
|
||||
assert_eq!(p.error_code, 1156);
|
||||
assert_eq!(p.sql_state, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_err_packet_unknown_database() {
|
||||
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
|
||||
|
||||
let p =
|
||||
ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap();
|
||||
|
||||
assert_eq!(p.error_code, 1049);
|
||||
assert_eq!(p.sql_state.as_deref(), Some("42000"));
|
||||
assert_eq!(&p.error_message, "Unknown database \'unknown\'");
|
||||
}
|
||||
14
sqlx-core/src/mysql/protocol/response/mod.rs
Normal file
14
sqlx-core/src/mysql/protocol/response/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! Generic Response Packets
|
||||
//!
|
||||
//! <https://dev.mysql.com/doc/internals/en/generic-response-packets.html>
|
||||
//! <https://mariadb.com/kb/en/4-server-response-packets/>
|
||||
|
||||
mod eof;
|
||||
mod err;
|
||||
mod ok;
|
||||
mod status;
|
||||
|
||||
pub use eof::EofPacket;
|
||||
pub use err::ErrPacket;
|
||||
pub use ok::OkPacket;
|
||||
pub use status::Status;
|
||||
52
sqlx-core/src/mysql/protocol/response/ok.rs
Normal file
52
sqlx-core/src/mysql/protocol/response/ok.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::mysql::io::MySqlBufExt;
|
||||
use crate::mysql::protocol::response::Status;
|
||||
|
||||
/// Indicates successful completion of a previous command sent by the client.
|
||||
#[derive(Debug)]
|
||||
pub struct OkPacket {
|
||||
pub affected_rows: u64,
|
||||
pub last_insert_id: u64,
|
||||
pub status: Status,
|
||||
pub warnings: u16,
|
||||
}
|
||||
|
||||
impl Decode<'_> for OkPacket {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
let header = buf.get_u8();
|
||||
if header != 0 && header != 0xfe {
|
||||
return Err(err_protocol!(
|
||||
"expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}",
|
||||
header
|
||||
));
|
||||
}
|
||||
|
||||
let affected_rows = buf.get_uint_lenenc();
|
||||
let last_insert_id = buf.get_uint_lenenc();
|
||||
let status = Status::from_bits_truncate(buf.get_u16_le());
|
||||
let warnings = buf.get_u16_le();
|
||||
|
||||
Ok(Self {
|
||||
affected_rows,
|
||||
last_insert_id,
|
||||
status,
|
||||
warnings,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_ok_packet() {
|
||||
const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00";
|
||||
|
||||
let p = OkPacket::decode(DATA.into()).unwrap();
|
||||
|
||||
assert_eq!(p.affected_rows, 0);
|
||||
assert_eq!(p.last_insert_id, 0);
|
||||
assert_eq!(p.warnings, 0);
|
||||
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
|
||||
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
|
||||
}
|
||||
@@ -1,323 +1,21 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::MySqlTypeInfo;
|
||||
|
||||
pub(crate) struct Row<'c> {
|
||||
buffer: &'c [u8],
|
||||
values: &'c [Option<Range<usize>>],
|
||||
pub(crate) columns: &'c [MySqlTypeInfo],
|
||||
pub(crate) binary: bool,
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Row {
|
||||
pub(crate) storage: Bytes,
|
||||
pub(crate) values: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl<'c> Row<'c> {
|
||||
impl Row {
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, index: usize) -> Option<&'c [u8]> {
|
||||
let range = self.values[index].as_ref()?;
|
||||
|
||||
Some(&self.buffer[(range.start as usize)..(range.end as usize)])
|
||||
pub(crate) fn get(&self, index: usize) -> Option<&[u8]> {
|
||||
self.values[index]
|
||||
.as_ref()
|
||||
.map(|col| &self.storage[(col.start as usize)..(col.end as usize)])
|
||||
}
|
||||
}
|
||||
|
||||
fn get_lenenc(buf: &[u8]) -> (usize, Option<usize>) {
|
||||
match buf[0] {
|
||||
0xFB => (1, None),
|
||||
|
||||
0xFC => {
|
||||
let len_size = 1 + 2;
|
||||
let len = LittleEndian::read_u16(&buf[1..]);
|
||||
|
||||
(len_size, Some(len as usize))
|
||||
}
|
||||
|
||||
0xFD => {
|
||||
let len_size = 1 + 3;
|
||||
let len = LittleEndian::read_u24(&buf[1..]);
|
||||
|
||||
(len_size, Some(len as usize))
|
||||
}
|
||||
|
||||
0xFE => {
|
||||
let len_size = 1 + 8;
|
||||
let len = LittleEndian::read_u64(&buf[1..]);
|
||||
|
||||
(len_size, Some(len as usize))
|
||||
}
|
||||
|
||||
len => (1, Some(len as usize)),
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> Row<'c> {
|
||||
pub(crate) fn read(
|
||||
mut buf: &'c [u8],
|
||||
columns: &'c [MySqlTypeInfo],
|
||||
values: &'c mut Vec<Option<Range<usize>>>,
|
||||
binary: bool,
|
||||
) -> crate::Result<Self> {
|
||||
let buffer = &*buf;
|
||||
|
||||
values.clear();
|
||||
values.reserve(columns.len());
|
||||
|
||||
if !binary {
|
||||
let mut index = 0;
|
||||
|
||||
for _ in 0..columns.len() {
|
||||
let (len_size, size) = get_lenenc(&buf[index..]);
|
||||
|
||||
if let Some(size) = size {
|
||||
values.push(Some((index + len_size)..(index + len_size + size)));
|
||||
} else {
|
||||
values.push(None);
|
||||
}
|
||||
|
||||
index += len_size + size.unwrap_or_default();
|
||||
}
|
||||
|
||||
return Ok(Self {
|
||||
buffer,
|
||||
columns,
|
||||
values: &*values,
|
||||
binary: false,
|
||||
});
|
||||
}
|
||||
|
||||
// 0x00 header : byte<1>
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0 {
|
||||
return Err(protocol_err!("expected ROW (0x00), got: {:#04X}", header).into());
|
||||
}
|
||||
|
||||
// NULL-Bitmap : byte<(number_of_columns + 9) / 8>
|
||||
let null_len = (columns.len() + 9) / 8;
|
||||
let null_bitmap = &buf[..];
|
||||
buf.advance(null_len);
|
||||
|
||||
let buffer: Box<[u8]> = buf.into();
|
||||
let mut index = 0;
|
||||
|
||||
for column_idx in 0..columns.len() {
|
||||
// the null index for a column starts at the 3rd bit in the null bitmap
|
||||
// for no reason at all besides mysql probably
|
||||
let column_null_idx = column_idx + 2;
|
||||
let is_null =
|
||||
null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0;
|
||||
|
||||
if is_null {
|
||||
values.push(None);
|
||||
} else {
|
||||
let (offset, size) = match columns[column_idx].id {
|
||||
TypeId::TINY_INT => (0, 1),
|
||||
TypeId::SMALL_INT => (0, 2),
|
||||
TypeId::INT | TypeId::FLOAT => (0, 4),
|
||||
TypeId::BIG_INT | TypeId::DOUBLE => (0, 8),
|
||||
|
||||
TypeId::DATE => (0, 5),
|
||||
TypeId::TIME => (0, 1 + buffer[index] as usize),
|
||||
|
||||
TypeId::TIMESTAMP | TypeId::DATETIME => (0, 1 + buffer[index] as usize),
|
||||
|
||||
TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::ENUM
|
||||
| TypeId::VAR_CHAR => {
|
||||
let (len_size, len) = get_lenenc(&buffer[index..]);
|
||||
|
||||
(len_size, len.unwrap_or_default())
|
||||
}
|
||||
|
||||
TypeId::NEWDECIMAL => (0, 1 + buffer[index] as usize),
|
||||
|
||||
id => {
|
||||
unimplemented!("encountered unknown field type id: {:?}", id);
|
||||
}
|
||||
};
|
||||
|
||||
values.push(Some((index + offset)..(index + offset + size)));
|
||||
index += size + offset;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
buffer: buf,
|
||||
values: &*values,
|
||||
columns,
|
||||
binary,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod test {
|
||||
// use super::super::column_count::ColumnCount;
|
||||
// use super::super::column_def::ColumnDefinition;
|
||||
// use super::super::eof::EofPacket;
|
||||
// use super::*;
|
||||
//
|
||||
// #[test]
|
||||
// fn null_bitmap_test() -> crate::Result<()> {
|
||||
// let column_len = ColumnCount::decode(&[26])?;
|
||||
// assert_eq!(column_len.columns, 26);
|
||||
//
|
||||
// let types: Vec<TypeId> = vec![
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0,
|
||||
// 0, 3, 11, 66, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105,
|
||||
// 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105,
|
||||
// 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105,
|
||||
// 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105,
|
||||
// 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105,
|
||||
// 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105,
|
||||
// 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105,
|
||||
// 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105,
|
||||
// 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102,
|
||||
// 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ColumnDefinition::decode(&[
|
||||
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
|
||||
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102,
|
||||
// 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
|
||||
// ])?,
|
||||
// ]
|
||||
// .into_iter()
|
||||
// .map(|def| def.type_id)
|
||||
// .collect();
|
||||
//
|
||||
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
|
||||
//
|
||||
// Row::read(
|
||||
// &[
|
||||
// 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8,
|
||||
// 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// ],
|
||||
// &types,
|
||||
// true,
|
||||
// )?;
|
||||
//
|
||||
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
|
||||
@@ -2,14 +2,16 @@ use digest::Digest;
|
||||
use num_bigint::BigUint;
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
// This is mostly taken from https://github.com/RustCrypto/RSA/pull/18
|
||||
// For the love of crypto, please delete as much of this as possible and use the RSA crate
|
||||
// directly when that PR is merged
|
||||
|
||||
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Vec<u8>> {
|
||||
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> Result<Vec<u8>, Error> {
|
||||
let key = std::str::from_utf8(key).map_err(|_err| {
|
||||
// TODO(@abonander): protocol_err doesn't like referring to [err]
|
||||
protocol_err!("unexpected error decoding what should be UTF-8")
|
||||
// TODO(@abonander): err_protocol doesn't like referring to [err]
|
||||
err_protocol!("unexpected error decoding what should be UTF-8")
|
||||
})?;
|
||||
|
||||
let key = parse(key)?;
|
||||
@@ -96,7 +98,7 @@ fn oaep_encrypt<R: Rng, D: Digest>(
|
||||
rng: &mut R,
|
||||
pub_key: &PublicKey,
|
||||
msg: &[u8],
|
||||
) -> crate::Result<Vec<u8>> {
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
// size of [n] in bytes
|
||||
let k = (pub_key.n.bits() + 7) / 8;
|
||||
|
||||
@@ -104,7 +106,7 @@ fn oaep_encrypt<R: Rng, D: Digest>(
|
||||
let h_size = D::output_size();
|
||||
|
||||
if msg.len() > k - 2 * h_size - 2 {
|
||||
return Err(protocol_err!("mysql: password too long").into());
|
||||
return Err(err_protocol!("mysql: password too long"));
|
||||
}
|
||||
|
||||
let mut em = vec![0u8; k];
|
||||
@@ -140,13 +142,13 @@ struct PublicKey {
|
||||
e: BigUint,
|
||||
}
|
||||
|
||||
fn parse(key: &str) -> crate::Result<PublicKey> {
|
||||
fn parse(key: &str) -> Result<PublicKey, Error> {
|
||||
// This takes advantage of the knowledge that we know
|
||||
// we are receiving a PKCS#8 RSA Public Key at all
|
||||
// times from MySQL
|
||||
|
||||
if !key.starts_with("-----BEGIN PUBLIC KEY-----\n") {
|
||||
return Err(protocol_err!(
|
||||
return Err(err_protocol!(
|
||||
"unexpected format for RSA Public Key from MySQL (expected PKCS#8); first line: {:?}",
|
||||
key.splitn(1, '\n').next()
|
||||
)
|
||||
@@ -158,8 +160,8 @@ fn parse(key: &str) -> crate::Result<PublicKey> {
|
||||
let inner_key = key_with_trailer[..trailer_pos].replace('\n', "");
|
||||
|
||||
let inner = base64::decode(&inner_key).map_err(|_err| {
|
||||
// TODO(@abonander): protocol_err doesn't like referring to [err]
|
||||
protocol_err!("unexpected error decoding what should be base64-encoded data")
|
||||
// TODO(@abonander): err_protocol doesn't like referring to [err]
|
||||
err_protocol!("unexpected error decoding what should be base64-encoded data")
|
||||
})?;
|
||||
|
||||
let len = inner.len();
|
||||
@@ -1,34 +0,0 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::io::BufMut;
|
||||
use crate::mysql::protocol::{Capabilities, Encode};
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||||
#[derive(Debug)]
|
||||
pub struct SslRequest {
|
||||
pub max_packet_size: u32,
|
||||
pub client_collation: u8,
|
||||
}
|
||||
|
||||
impl Encode for SslRequest {
|
||||
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
|
||||
// SSL must be set or else it makes no sense to ask for an upgrade
|
||||
assert!(
|
||||
capabilities.contains(Capabilities::SSL),
|
||||
"SSL bit must be set for Capabilities"
|
||||
);
|
||||
|
||||
// client capabilities : int<4>
|
||||
buf.put_u32::<LittleEndian>(capabilities.bits() as u32);
|
||||
|
||||
// max packet size : int<4>
|
||||
buf.put_u32::<LittleEndian>(self.max_packet_size);
|
||||
|
||||
// client character collation : int<1>
|
||||
buf.put_u8(self.client_collation);
|
||||
|
||||
// reserved : string<23>
|
||||
buf.advance(23);
|
||||
}
|
||||
}
|
||||
38
sqlx-core/src/mysql/protocol/statement/execute.rs
Normal file
38
sqlx-core/src/mysql/protocol/statement/execute.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mysql::protocol::text::ColumnFlags;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
use crate::mysql::MySqlArguments;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Execute<'q> {
|
||||
pub statement: u32,
|
||||
pub arguments: &'q MySqlArguments,
|
||||
}
|
||||
|
||||
impl<'q> Encode<'_, Capabilities> for Execute<'q> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.push(0x17); // COM_STMT_EXECUTE
|
||||
buf.extend(&self.statement.to_le_bytes());
|
||||
buf.push(0); // NO_CURSOR
|
||||
buf.extend(&0_u32.to_le_bytes()); // iterations (always 1): int<4>
|
||||
|
||||
if !self.arguments.is_empty() {
|
||||
buf.extend(&*self.arguments.null_bitmap);
|
||||
buf.push(1); // send type to server
|
||||
|
||||
for ty in &self.arguments.types {
|
||||
buf.push(ty.r#type as u8);
|
||||
|
||||
buf.push(if ty.flags.contains(ColumnFlags::UNSIGNED) {
|
||||
0x80
|
||||
} else {
|
||||
0
|
||||
});
|
||||
}
|
||||
|
||||
buf.extend(&*self.arguments.values);
|
||||
}
|
||||
}
|
||||
}
|
||||
9
sqlx-core/src/mysql/protocol/statement/mod.rs
Normal file
9
sqlx-core/src/mysql/protocol/statement/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod execute;
|
||||
mod prepare;
|
||||
mod prepare_ok;
|
||||
mod row;
|
||||
|
||||
pub(crate) use execute::Execute;
|
||||
pub(crate) use prepare::Prepare;
|
||||
pub(crate) use prepare_ok::PrepareOk;
|
||||
pub(crate) use row::BinaryRow;
|
||||
15
sqlx-core/src/mysql/protocol/statement/prepare.rs
Normal file
15
sqlx-core/src/mysql/protocol/statement/prepare.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE
|
||||
|
||||
pub struct Prepare<'a> {
|
||||
pub query: &'a str,
|
||||
}
|
||||
|
||||
impl Encode<'_, Capabilities> for Prepare<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.push(0x16); // COM_STMT_PREPARE
|
||||
buf.extend(self.query.as_bytes());
|
||||
}
|
||||
}
|
||||
42
sqlx-core/src/mysql/protocol/statement/prepare_ok.rs
Normal file
42
sqlx-core/src/mysql/protocol/statement/prepare_ok.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PrepareOk {
|
||||
pub(crate) statement_id: u32,
|
||||
pub(crate) columns: u16,
|
||||
pub(crate) params: u16,
|
||||
pub(crate) warnings: u16,
|
||||
}
|
||||
|
||||
impl Decode<'_, Capabilities> for PrepareOk {
|
||||
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
|
||||
let status = buf.get_u8();
|
||||
if status != 0x00 {
|
||||
return Err(err_protocol!(
|
||||
"expected 0x00 (COM_STMT_PREPARE_OK) but found 0x{:02x}",
|
||||
status
|
||||
));
|
||||
}
|
||||
|
||||
let statement_id = buf.get_u32_le();
|
||||
let columns = buf.get_u16_le();
|
||||
let params = buf.get_u16_le();
|
||||
|
||||
buf.advance(1); // reserved: string<1>
|
||||
|
||||
let warnings = buf.get_u16_le();
|
||||
|
||||
Ok(Self {
|
||||
statement_id,
|
||||
columns,
|
||||
params,
|
||||
warnings,
|
||||
})
|
||||
}
|
||||
}
|
||||
92
sqlx-core/src/mysql/protocol/statement/row.rs
Normal file
92
sqlx-core/src/mysql/protocol/statement/row.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::mysql::io::MySqlBufExt;
|
||||
use crate::mysql::protocol::text::ColumnType;
|
||||
use crate::mysql::protocol::Row;
|
||||
use crate::mysql::row::MySqlColumn;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html#packet-ProtocolBinary::ResultsetRow
|
||||
// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BinaryRow(pub(crate) Row);
|
||||
|
||||
impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow {
|
||||
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
|
||||
let header = buf.get_u8();
|
||||
if header != 0 {
|
||||
return Err(err_protocol!(
|
||||
"exepcted 0x00 (ROW) but found 0x{:02x}",
|
||||
header
|
||||
));
|
||||
}
|
||||
|
||||
let storage = buf.clone();
|
||||
let offset = buf.len();
|
||||
|
||||
let null_bitmap_len = (columns.len() + 9) / 8;
|
||||
let null_bitmap = buf.get_bytes(null_bitmap_len);
|
||||
|
||||
let mut values = Vec::with_capacity(columns.len());
|
||||
|
||||
for (column_idx, column) in columns.iter().enumerate() {
|
||||
// NOTE: the column index starts at the 3rd bit
|
||||
let column_null_idx = column_idx + 2;
|
||||
let is_null =
|
||||
null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0;
|
||||
|
||||
if is_null {
|
||||
values.push(None);
|
||||
continue;
|
||||
}
|
||||
|
||||
// NOTE: MySQL will never generate NULL types for non-NULL values
|
||||
let type_info = column.type_info.as_ref().unwrap();
|
||||
|
||||
let size: usize = match type_info.r#type {
|
||||
ColumnType::String
|
||||
| ColumnType::VarChar
|
||||
| ColumnType::VarString
|
||||
| ColumnType::Enum
|
||||
| ColumnType::Set
|
||||
| ColumnType::LongBlob
|
||||
| ColumnType::MediumBlob
|
||||
| ColumnType::Blob
|
||||
| ColumnType::TinyBlob
|
||||
| ColumnType::Geometry
|
||||
| ColumnType::Bit
|
||||
| ColumnType::Decimal
|
||||
| ColumnType::Json
|
||||
| ColumnType::NewDecimal => buf.get_uint_lenenc() as usize,
|
||||
|
||||
ColumnType::LongLong => 8,
|
||||
ColumnType::Long | ColumnType::Int24 => 4,
|
||||
ColumnType::Short | ColumnType::Year => 2,
|
||||
ColumnType::Tiny => 1,
|
||||
ColumnType::Float => 4,
|
||||
ColumnType::Double => 8,
|
||||
|
||||
ColumnType::Time
|
||||
| ColumnType::Timestamp
|
||||
| ColumnType::Date
|
||||
| ColumnType::Datetime => {
|
||||
// The size of this type is important for decoding
|
||||
buf[0] as usize + 1
|
||||
}
|
||||
|
||||
// NOTE: MySQL will never generate NULL types for non-NULL values
|
||||
ColumnType::Null => unreachable!(),
|
||||
};
|
||||
|
||||
let offset = offset - buf.len();
|
||||
|
||||
values.push(Some(offset..(offset + size)));
|
||||
|
||||
buf.advance(size);
|
||||
}
|
||||
|
||||
Ok(BinaryRow(Row { values, storage }))
|
||||
}
|
||||
}
|
||||
244
sqlx-core/src/mysql/protocol/text/column.rs
Normal file
244
sqlx-core/src/mysql/protocol/text/column.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
use std::str::from_utf8;
|
||||
|
||||
use bitflags::bitflags;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::mysql::io::MySqlBufExt;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
|
||||
|
||||
bitflags! {
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub(crate) struct ColumnFlags: u16 {
|
||||
/// Field can't be `NULL`.
|
||||
const NOT_NULL = 1;
|
||||
|
||||
/// Field is part of a primary key.
|
||||
const PRIMARY_KEY = 2;
|
||||
|
||||
/// Field is part of a unique key.
|
||||
const UNIQUE_KEY = 4;
|
||||
|
||||
/// Field is part of a multi-part unique or primary key.
|
||||
const MULTIPLE_KEY = 8;
|
||||
|
||||
/// Field is a blob.
|
||||
const BLOB = 16;
|
||||
|
||||
/// Field is unsigned.
|
||||
const UNSIGNED = 32;
|
||||
|
||||
/// Field is zero filled.
|
||||
const ZEROFILL = 64;
|
||||
|
||||
/// Field is binary.
|
||||
const BINARY = 128;
|
||||
|
||||
/// Field is an enumeration.
|
||||
const ENUM = 256;
|
||||
|
||||
/// Field is an auto-incement field.
|
||||
const AUTO_INCREMENT = 512;
|
||||
|
||||
/// Field is a timestamp.
|
||||
const TIMESTAMP = 1024;
|
||||
|
||||
/// Field is a set.
|
||||
const SET = 2048;
|
||||
|
||||
/// Field does not have a default value.
|
||||
const NO_DEFAULT_VALUE = 4096;
|
||||
|
||||
/// Field is set to NOW on UPDATE.
|
||||
const ON_UPDATE_NOW = 8192;
|
||||
|
||||
/// Field is a number.
|
||||
const NUM = 32768;
|
||||
}
|
||||
}
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[repr(u8)]
|
||||
pub enum ColumnType {
|
||||
Decimal = 0x00,
|
||||
Tiny = 0x01,
|
||||
Short = 0x02,
|
||||
Long = 0x03,
|
||||
Float = 0x04,
|
||||
Double = 0x05,
|
||||
Null = 0x06,
|
||||
Timestamp = 0x07,
|
||||
LongLong = 0x08,
|
||||
Int24 = 0x09,
|
||||
Date = 0x0a,
|
||||
Time = 0x0b,
|
||||
Datetime = 0x0c,
|
||||
Year = 0x0d,
|
||||
VarChar = 0x0f,
|
||||
Bit = 0x10,
|
||||
Json = 0xf5,
|
||||
NewDecimal = 0xf6,
|
||||
Enum = 0xf7,
|
||||
Set = 0xf8,
|
||||
TinyBlob = 0xf9,
|
||||
MediumBlob = 0xfa,
|
||||
LongBlob = 0xfb,
|
||||
Blob = 0xfc,
|
||||
VarString = 0xfd,
|
||||
String = 0xfe,
|
||||
Geometry = 0xff,
|
||||
}
|
||||
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html
|
||||
// https://mariadb.com/kb/en/resultset/#column-definition-packet
|
||||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ColumnDefinition {
|
||||
catalog: Bytes,
|
||||
schema: Bytes,
|
||||
table_alias: Bytes,
|
||||
table: Bytes,
|
||||
alias: Bytes,
|
||||
name: Bytes,
|
||||
pub(crate) char_set: u16,
|
||||
max_size: u32,
|
||||
pub(crate) r#type: ColumnType,
|
||||
pub(crate) flags: ColumnFlags,
|
||||
decimals: u8,
|
||||
}
|
||||
|
||||
impl ColumnDefinition {
|
||||
// NOTE: strings in-protocol are transmitted according to the client character set
|
||||
// as this is UTF-8, all these strings should be UTF-8
|
||||
|
||||
pub(crate) fn name(&self) -> Result<&str, Error> {
|
||||
from_utf8(&self.name).map_err(Error::protocol)
|
||||
}
|
||||
|
||||
pub(crate) fn alias(&self) -> Result<&str, Error> {
|
||||
from_utf8(&self.alias).map_err(Error::protocol)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, Capabilities> for ColumnDefinition {
|
||||
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
|
||||
let catalog = buf.get_bytes_lenenc();
|
||||
let schema = buf.get_bytes_lenenc();
|
||||
let table_alias = buf.get_bytes_lenenc();
|
||||
let table = buf.get_bytes_lenenc();
|
||||
let alias = buf.get_bytes_lenenc();
|
||||
let name = buf.get_bytes_lenenc();
|
||||
let _next_len = buf.get_uint_lenenc(); // always 0x0c
|
||||
let char_set = buf.get_u16_le();
|
||||
let max_size = buf.get_u32_le();
|
||||
let type_id = buf.get_u8();
|
||||
let flags = buf.get_u16_le();
|
||||
let decimals = buf.get_u8();
|
||||
|
||||
Ok(Self {
|
||||
catalog,
|
||||
schema,
|
||||
table_alias,
|
||||
table,
|
||||
alias,
|
||||
name,
|
||||
char_set,
|
||||
max_size,
|
||||
r#type: ColumnType::try_from_u16(type_id)?,
|
||||
flags: ColumnFlags::from_bits_truncate(flags),
|
||||
decimals,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ColumnType {
|
||||
pub(crate) fn name(self, char_set: u16) -> &'static str {
|
||||
let is_binary = char_set == 63;
|
||||
match self {
|
||||
ColumnType::Tiny => "TINYINT",
|
||||
ColumnType::Short => "SMALLINT",
|
||||
ColumnType::Long => "INT",
|
||||
ColumnType::Float => "FLOAT",
|
||||
ColumnType::Double => "DOUBLE",
|
||||
ColumnType::Null => "NULL",
|
||||
ColumnType::Timestamp => "TIMESTAMP",
|
||||
ColumnType::LongLong => "BIGINT",
|
||||
ColumnType::Int24 => "MEDIUMINT",
|
||||
ColumnType::Date => "DATE",
|
||||
ColumnType::Time => "TIME",
|
||||
ColumnType::Datetime => "DATETIME",
|
||||
ColumnType::Year => "YEAR",
|
||||
ColumnType::Bit => "BIT",
|
||||
ColumnType::Enum => "ENUM",
|
||||
ColumnType::Set => "SET",
|
||||
ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL",
|
||||
ColumnType::Geometry => "GEOMETRY",
|
||||
ColumnType::Json => "JSON",
|
||||
|
||||
ColumnType::String if is_binary => "BINARY",
|
||||
ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY",
|
||||
|
||||
ColumnType::String => "CHAR",
|
||||
ColumnType::VarChar | ColumnType::VarString => "VARCHAR",
|
||||
|
||||
ColumnType::TinyBlob if is_binary => "TINYBLOB",
|
||||
ColumnType::TinyBlob => "TINYTEXT",
|
||||
|
||||
ColumnType::Blob if is_binary => "BLOB",
|
||||
ColumnType::Blob => "TEXT",
|
||||
|
||||
ColumnType::MediumBlob if is_binary => "MEDIUMBLOB",
|
||||
ColumnType::MediumBlob => "MEDIUMTEXT",
|
||||
|
||||
ColumnType::LongBlob if is_binary => "LONGBLOB",
|
||||
ColumnType::LongBlob => "LONGTEXT",
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_from_u16(id: u8) -> Result<Self, Error> {
|
||||
Ok(match id {
|
||||
0x00 => ColumnType::Decimal,
|
||||
0x01 => ColumnType::Tiny,
|
||||
0x02 => ColumnType::Short,
|
||||
0x03 => ColumnType::Long,
|
||||
0x04 => ColumnType::Float,
|
||||
0x05 => ColumnType::Double,
|
||||
0x06 => ColumnType::Null,
|
||||
0x07 => ColumnType::Timestamp,
|
||||
0x08 => ColumnType::LongLong,
|
||||
0x09 => ColumnType::Int24,
|
||||
0x0a => ColumnType::Date,
|
||||
0x0b => ColumnType::Time,
|
||||
0x0c => ColumnType::Datetime,
|
||||
0x0d => ColumnType::Year,
|
||||
// [internal] 0x0e => ColumnType::NewDate,
|
||||
0x0f => ColumnType::VarChar,
|
||||
0x10 => ColumnType::Bit,
|
||||
// [internal] 0x11 => ColumnType::Timestamp2,
|
||||
// [internal] 0x12 => ColumnType::Datetime2,
|
||||
// [internal] 0x13 => ColumnType::Time2,
|
||||
0xf5 => ColumnType::Json,
|
||||
0xf6 => ColumnType::NewDecimal,
|
||||
0xf7 => ColumnType::Enum,
|
||||
0xf8 => ColumnType::Set,
|
||||
0xf9 => ColumnType::TinyBlob,
|
||||
0xfa => ColumnType::MediumBlob,
|
||||
0xfb => ColumnType::LongBlob,
|
||||
0xfc => ColumnType::Blob,
|
||||
0xfd => ColumnType::VarString,
|
||||
0xfe => ColumnType::String,
|
||||
0xff => ColumnType::Geometry,
|
||||
|
||||
_ => {
|
||||
return Err(err_protocol!("unknown column type 0x{:02x}", id));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
11
sqlx-core/src/mysql/protocol/text/mod.rs
Normal file
11
sqlx-core/src/mysql/protocol/text/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
mod column;
|
||||
mod ping;
|
||||
mod query;
|
||||
mod quit;
|
||||
mod row;
|
||||
|
||||
pub(crate) use column::{ColumnDefinition, ColumnFlags, ColumnType};
|
||||
pub(crate) use ping::Ping;
|
||||
pub(crate) use query::Query;
|
||||
pub(crate) use quit::Quit;
|
||||
pub(crate) use row::TextRow;
|
||||
13
sqlx-core/src/mysql/protocol/text/ping.rs
Normal file
13
sqlx-core/src/mysql/protocol/text/ping.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-ping.html
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Ping;
|
||||
|
||||
impl Encode<'_, Capabilities> for Ping {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.push(0x0e); // COM_PING
|
||||
}
|
||||
}
|
||||
14
sqlx-core/src/mysql/protocol/text/query.rs
Normal file
14
sqlx-core/src/mysql/protocol/text/query.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-query.html
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Query<'q>(pub(crate) &'q str);
|
||||
|
||||
impl Encode<'_, Capabilities> for Query<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.push(0x03); // COM_QUERY
|
||||
buf.extend(self.0.as_bytes())
|
||||
}
|
||||
}
|
||||
13
sqlx-core/src/mysql/protocol/text/quit.rs
Normal file
13
sqlx-core/src/mysql/protocol/text/quit.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-quit.html
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Quit;
|
||||
|
||||
impl Encode<'_, Capabilities> for Quit {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.push(0x01); // COM_QUIT
|
||||
}
|
||||
}
|
||||
36
sqlx-core/src/mysql/protocol/text/row.rs
Normal file
36
sqlx-core/src/mysql/protocol/text/row.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::mysql::io::MySqlBufExt;
|
||||
use crate::mysql::protocol::Row;
|
||||
use crate::mysql::row::MySqlColumn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TextRow(pub(crate) Row);
|
||||
|
||||
impl<'de> Decode<'de, &'de [MySqlColumn]> for TextRow {
|
||||
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
|
||||
let storage = buf.clone();
|
||||
let offset = buf.len();
|
||||
|
||||
let mut values = Vec::with_capacity(columns.len());
|
||||
|
||||
for _ in columns {
|
||||
if buf[0] == 0xfb {
|
||||
// NULL is sent as 0xfb
|
||||
values.push(None);
|
||||
buf.advance(1);
|
||||
} else {
|
||||
let size = buf.get_uint_lenenc() as usize;
|
||||
let offset = offset - buf.len();
|
||||
|
||||
values.push(Some(offset..(offset + size)));
|
||||
|
||||
buf.advance(size);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TextRow(Row { values, storage }))
|
||||
}
|
||||
}
|
||||
@@ -1,48 +1 @@
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html
|
||||
// https://mariadb.com/kb/en/library/resultset/#field-types
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TypeId(pub u8);
|
||||
|
||||
// https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429
|
||||
impl TypeId {
|
||||
pub const NULL: TypeId = TypeId(6);
|
||||
|
||||
// String: CHAR, VARCHAR, TEXT
|
||||
// Bytes: BINARY, VARBINARY, BLOB
|
||||
pub const CHAR: TypeId = TypeId(254); // or BINARY
|
||||
pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY
|
||||
pub const TEXT: TypeId = TypeId(252); // or BLOB
|
||||
|
||||
// Enum
|
||||
pub const ENUM: TypeId = TypeId(247);
|
||||
|
||||
// More Bytes
|
||||
pub const TINY_BLOB: TypeId = TypeId(249);
|
||||
pub const MEDIUM_BLOB: TypeId = TypeId(250);
|
||||
pub const LONG_BLOB: TypeId = TypeId(251);
|
||||
|
||||
// Numeric: TINYINT, SMALLINT, INT, BIGINT
|
||||
pub const TINY_INT: TypeId = TypeId(1);
|
||||
pub const SMALL_INT: TypeId = TypeId(2);
|
||||
pub const INT: TypeId = TypeId(3);
|
||||
pub const BIG_INT: TypeId = TypeId(8);
|
||||
// pub const MEDIUM_INT: TypeId = TypeId(9);
|
||||
|
||||
// Numeric: FLOAT, DOUBLE
|
||||
pub const FLOAT: TypeId = TypeId(4);
|
||||
pub const DOUBLE: TypeId = TypeId(5);
|
||||
pub const NEWDECIMAL: TypeId = TypeId(246);
|
||||
|
||||
// Date/Time: DATE, TIME, DATETIME, TIMESTAMP
|
||||
pub const DATE: TypeId = TypeId(10);
|
||||
pub const TIME: TypeId = TypeId(11);
|
||||
pub const DATETIME: TypeId = TypeId(12);
|
||||
pub const TIMESTAMP: TypeId = TypeId(7);
|
||||
}
|
||||
|
||||
impl Default for TypeId {
|
||||
fn default() -> TypeId {
|
||||
TypeId::NULL
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +1,60 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::mysql::protocol;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::mysql::{protocol, MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
|
||||
use crate::row::{ColumnIndex, Row};
|
||||
|
||||
pub struct MySqlRow<'c> {
|
||||
pub(super) row: protocol::Row<'c>,
|
||||
pub(super) names: Arc<HashMap<Box<str>, u16>>,
|
||||
// TODO: Merge with the other XXColumn types
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct MySqlColumn {
|
||||
pub(crate) name: Option<UStr>,
|
||||
pub(crate) type_info: Option<MySqlTypeInfo>,
|
||||
}
|
||||
|
||||
impl crate::row::private_row::Sealed for MySqlRow<'_> {}
|
||||
/// Implementation of [`Row`] for MySQL.
|
||||
#[derive(Debug)]
|
||||
pub struct MySqlRow {
|
||||
pub(crate) row: protocol::Row,
|
||||
pub(crate) columns: Arc<Vec<MySqlColumn>>,
|
||||
pub(crate) column_names: Arc<HashMap<UStr, usize>>,
|
||||
pub(crate) format: MySqlValueFormat,
|
||||
}
|
||||
|
||||
impl<'c> Row<'c> for MySqlRow<'c> {
|
||||
impl crate::row::private_row::Sealed for MySqlRow {}
|
||||
|
||||
impl Row for MySqlRow {
|
||||
type Database = MySql;
|
||||
|
||||
#[inline]
|
||||
fn len(&self) -> usize {
|
||||
self.row.len()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn try_get_raw<I>(&self, index: I) -> crate::Result<MySqlValue<'c>>
|
||||
fn try_get_raw<I>(&self, index: I) -> Result<MySqlValueRef, Error>
|
||||
where
|
||||
I: ColumnIndex<'c, Self>,
|
||||
I: ColumnIndex<Self>,
|
||||
{
|
||||
let index = index.index(self)?;
|
||||
let column_ty = self.row.columns[index].clone();
|
||||
let buffer = self.row.get(index);
|
||||
let value = match (self.row.binary, buffer) {
|
||||
(_, None) => MySqlValue::null(),
|
||||
(true, Some(buf)) => MySqlValue::binary(column_ty, buf),
|
||||
(false, Some(buf)) => MySqlValue::text(column_ty, buf),
|
||||
};
|
||||
let column = &self.columns[index];
|
||||
let value = self.row.get(index);
|
||||
|
||||
Ok(value)
|
||||
Ok(MySqlValueRef {
|
||||
format: self.format,
|
||||
row: Some(&self.row.storage),
|
||||
type_info: column.type_info.clone(),
|
||||
value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> ColumnIndex<'c, MySqlRow<'c>> for usize {
|
||||
fn index(&self, row: &MySqlRow<'c>) -> crate::Result<usize> {
|
||||
let len = Row::len(row);
|
||||
|
||||
if *self >= len {
|
||||
return Err(crate::Error::ColumnIndexOutOfBounds { len, index: *self });
|
||||
}
|
||||
|
||||
Ok(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> ColumnIndex<'c, MySqlRow<'c>> for str {
|
||||
fn index(&self, row: &MySqlRow<'c>) -> crate::Result<usize> {
|
||||
row.names
|
||||
.get(self)
|
||||
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))
|
||||
.map(|&index| index as usize)
|
||||
impl ColumnIndex<MySqlRow> for &'_ str {
|
||||
fn index(&self, row: &MySqlRow) -> Result<usize, Error> {
|
||||
row.column_names
|
||||
.get(*self)
|
||||
.ok_or_else(|| Error::ColumnNotFound((*self).into()))
|
||||
.map(|v| *v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
use std::net::Shutdown;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
|
||||
use crate::mysql::protocol::{Capabilities, Encode, EofPacket, ErrPacket, OkPacket};
|
||||
|
||||
use crate::mysql::MySqlError;
|
||||
use crate::url::Url;
|
||||
|
||||
// Size before a packet is split
|
||||
const MAX_PACKET_SIZE: u32 = 1024;
|
||||
|
||||
pub(crate) struct MySqlStream {
|
||||
pub(super) stream: BufStream<MaybeTlsStream>,
|
||||
|
||||
// Is the stream ready to send commands
|
||||
// Put another way, are we still expecting an EOF or OK packet to terminate
|
||||
pub(super) is_ready: bool,
|
||||
|
||||
// Active capabilities
|
||||
pub(super) capabilities: Capabilities,
|
||||
|
||||
// Packets in a command sequence have an incrementing sequence number
|
||||
// This number must be 0 at the start of each command
|
||||
pub(super) seq_no: u8,
|
||||
|
||||
// Packets are buffered into a second buffer from the stream
|
||||
// as we may have compressed or split packets to figure out before
|
||||
// decoding
|
||||
packet_buf: Vec<u8>,
|
||||
packet_len: usize,
|
||||
}
|
||||
|
||||
impl MySqlStream {
|
||||
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
|
||||
let host = url.host().unwrap_or("localhost");
|
||||
let port = url.port(3306);
|
||||
let stream = MaybeTlsStream::connect(host, port).await?;
|
||||
|
||||
let mut capabilities = Capabilities::PROTOCOL_41
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::DEPRECATE_EOF
|
||||
| Capabilities::FOUND_ROWS
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH;
|
||||
|
||||
if url.database().is_some() {
|
||||
capabilities |= Capabilities::CONNECT_WITH_DB;
|
||||
}
|
||||
|
||||
if cfg!(feature = "tls") {
|
||||
capabilities |= Capabilities::SSL;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
capabilities,
|
||||
stream: BufStream::new(stream),
|
||||
packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize),
|
||||
packet_len: 0,
|
||||
seq_no: 0,
|
||||
is_ready: true,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn is_tls(&self) -> bool {
|
||||
self.stream.is_tls()
|
||||
}
|
||||
|
||||
pub(super) fn shutdown(&self) -> crate::Result<()> {
|
||||
Ok(self.stream.shutdown(Shutdown::Both)?)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) async fn send<T>(&mut self, packet: T, initial: bool) -> crate::Result<()>
|
||||
where
|
||||
T: Encode + std::fmt::Debug,
|
||||
{
|
||||
if initial {
|
||||
self.seq_no = 0;
|
||||
}
|
||||
|
||||
self.write(packet);
|
||||
self.flush().await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) async fn flush(&mut self) -> crate::Result<()> {
|
||||
Ok(self.stream.flush().await?)
|
||||
}
|
||||
|
||||
/// Write the packet to the buffered stream ( do not send to the server )
|
||||
pub(super) fn write<T>(&mut self, packet: T)
|
||||
where
|
||||
T: Encode,
|
||||
{
|
||||
let buf = self.stream.buffer_mut();
|
||||
|
||||
// Allocate room for the header that we write after the packet;
|
||||
// so, we can get an accurate and cheap measure of packet length
|
||||
|
||||
let header_offset = buf.len();
|
||||
buf.advance(4);
|
||||
|
||||
packet.encode(buf, self.capabilities);
|
||||
|
||||
// Determine length of encoded packet
|
||||
// and write to allocated header
|
||||
|
||||
let len = buf.len() - header_offset - 4;
|
||||
let mut header = &mut buf[header_offset..];
|
||||
|
||||
LittleEndian::write_u32(&mut header, len as u32);
|
||||
|
||||
// Take the last sequence number received, if any, and increment by 1
|
||||
// If there was no sequence number, we only increment if we split packets
|
||||
header[3] = self.seq_no;
|
||||
self.seq_no = self.seq_no.wrapping_add(1);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
|
||||
self.read().await?;
|
||||
|
||||
Ok(self.packet())
|
||||
}
|
||||
|
||||
pub(super) async fn read(&mut self) -> crate::Result<()> {
|
||||
self.packet_buf.clear();
|
||||
self.packet_len = 0;
|
||||
|
||||
// Read the packet header which contains the length and the sequence number
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
|
||||
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
|
||||
let mut header = self.stream.peek(4_usize).await?;
|
||||
|
||||
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
|
||||
self.seq_no = header.get_u8()?.wrapping_add(1);
|
||||
|
||||
self.stream.consume(4);
|
||||
|
||||
// Read the packet body and copy it into our internal buf
|
||||
// We must have a separate buffer around the stream as we can't operate directly
|
||||
// on bytes returned from the stream. We have various kinds of payload manipulation
|
||||
// that must be handled before decoding.
|
||||
let payload = self.stream.peek(self.packet_len).await?;
|
||||
|
||||
self.packet_buf.reserve(payload.len());
|
||||
self.packet_buf.extend_from_slice(payload);
|
||||
|
||||
self.stream.consume(self.packet_len);
|
||||
|
||||
// TODO: Implement packet compression
|
||||
// TODO: Implement packet joining
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns a reference to the most recently received packet data.
|
||||
/// A call to `read` invalidates this buffer.
|
||||
#[inline]
|
||||
pub(super) fn packet(&self) -> &[u8] {
|
||||
&self.packet_buf[..self.packet_len]
|
||||
}
|
||||
}
|
||||
|
||||
impl MySqlStream {
|
||||
pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> {
|
||||
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
|
||||
let _eof = EofPacket::read(self.receive().await?)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result<Option<EofPacket>> {
|
||||
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) && self.packet()[0] == 0xFE {
|
||||
Ok(Some(EofPacket::read(self.packet())?))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_unexpected<T>(&mut self) -> crate::Result<T> {
|
||||
Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into())
|
||||
}
|
||||
|
||||
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
|
||||
self.is_ready = true;
|
||||
Err(MySqlError(ErrPacket::read(self.packet(), self.capabilities)?).into())
|
||||
}
|
||||
|
||||
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
|
||||
self.is_ready = true;
|
||||
OkPacket::read(self.packet())
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_until_ready(&mut self) -> crate::Result<()> {
|
||||
if !self.is_ready {
|
||||
loop {
|
||||
let packet_id = self.receive().await?[0];
|
||||
match packet_id {
|
||||
0xFE if self.packet().len() < 0xFF_FF_FF => {
|
||||
// OK or EOF packet
|
||||
self.is_ready = true;
|
||||
break;
|
||||
}
|
||||
|
||||
0xFF => {
|
||||
// ERR packet
|
||||
self.is_ready = true;
|
||||
return self.handle_err();
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Something else; skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
use crate::mysql::stream::MySqlStream;
|
||||
use crate::url::Url;
|
||||
|
||||
#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
|
||||
pub(super) async fn upgrade_if_needed(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
|
||||
#[cfg_attr(not(feature = "tls"), allow(unused_imports))]
|
||||
use crate::mysql::protocol::Capabilities;
|
||||
|
||||
let ca_file = url.param("ssl-ca");
|
||||
let ssl_mode = url.param("ssl-mode");
|
||||
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode
|
||||
match ssl_mode.as_deref() {
|
||||
Some("DISABLED") => {}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some("PREFERRED") | None if !stream.capabilities.contains(Capabilities::SSL) => {}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some("PREFERRED") => {
|
||||
if let Err(_error) = try_upgrade(stream, &url, None, true).await {
|
||||
// TLS upgrade failed; fall back to a normal connection
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
None => {
|
||||
if let Err(_error) = try_upgrade(stream, &url, ca_file.as_deref(), true).await {
|
||||
// TLS upgrade failed; fall back to a normal connection
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some("REQUIRED") | Some("VERIFY_CA") | Some("VERIFY_IDENTITY")
|
||||
if !stream.capabilities.contains(Capabilities::SSL) =>
|
||||
{
|
||||
return Err(tls_err!("server does not support TLS").into());
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some("VERIFY_CA") | Some("VERIFY_IDENTITY") if ca_file.is_none() => {
|
||||
return Err(
|
||||
tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") => {
|
||||
try_upgrade(
|
||||
stream,
|
||||
url,
|
||||
// false for both verify-ca and verify-full
|
||||
ca_file.as_deref(),
|
||||
// false for only verify-full
|
||||
mode != "VERIFY_IDENTITY",
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
None => {
|
||||
// The user neither explicitly enabled TLS in the connection string
|
||||
// nor did they turn the `tls` feature on
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Some(mode @ "PREFERRED")
|
||||
| Some(mode @ "REQUIRED")
|
||||
| Some(mode @ "VERIFY_CA")
|
||||
| Some(mode @ "VERIFY_IDENTITY") => {
|
||||
return Err(tls_err!(
|
||||
"ssl-mode {:?} unsupported; SQLx was compiled without `tls` feature",
|
||||
mode
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
Some(mode) => {
|
||||
return Err(tls_err!("unknown `ssl-mode` value: {:?}", mode).into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
async fn try_upgrade(
|
||||
stream: &mut MySqlStream,
|
||||
url: &Url,
|
||||
ca_file: Option<&str>,
|
||||
accept_invalid_hostnames: bool,
|
||||
) -> crate::Result<()> {
|
||||
use crate::mysql::protocol::SslRequest;
|
||||
use crate::runtime::fs;
|
||||
|
||||
use async_native_tls::{Certificate, TlsConnector};
|
||||
|
||||
let mut connector = TlsConnector::new()
|
||||
.danger_accept_invalid_certs(ca_file.is_none())
|
||||
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
|
||||
|
||||
if let Some(ca_file) = ca_file {
|
||||
let root_cert = fs::read(ca_file).await?;
|
||||
|
||||
connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?);
|
||||
}
|
||||
|
||||
// send upgrade request and then immediately try TLS handshake
|
||||
stream
|
||||
.send(
|
||||
SslRequest {
|
||||
client_collation: super::connection::COLLATE_UTF8MB4_UNICODE_CI,
|
||||
max_packet_size: super::connection::MAX_PACKET_SIZE,
|
||||
},
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
|
||||
stream
|
||||
.stream
|
||||
.upgrade(url.host().unwrap_or("localhost"), connector)
|
||||
.await
|
||||
}
|
||||
@@ -1,144 +1,67 @@
|
||||
use std::fmt::{self, Display};
|
||||
use std::fmt::{self, Display, Formatter};
|
||||
|
||||
use crate::mysql::protocol::{ColumnDefinition, FieldFlags, TypeId};
|
||||
use crate::types::TypeInfo;
|
||||
use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, ColumnType};
|
||||
use crate::type_info::TypeInfo;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
/// Type information for a MySql type.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct MySqlTypeInfo {
|
||||
pub(crate) id: TypeId,
|
||||
pub(crate) is_unsigned: bool,
|
||||
pub(crate) is_binary: bool,
|
||||
pub(crate) r#type: ColumnType,
|
||||
pub(crate) flags: ColumnFlags,
|
||||
pub(crate) char_set: u16,
|
||||
}
|
||||
|
||||
impl MySqlTypeInfo {
|
||||
pub(crate) const fn new(id: TypeId) -> Self {
|
||||
pub(crate) const fn binary(ty: ColumnType) -> Self {
|
||||
Self {
|
||||
id,
|
||||
is_unsigned: false,
|
||||
is_binary: true,
|
||||
char_set: 0,
|
||||
r#type: ty,
|
||||
flags: ColumnFlags::BINARY,
|
||||
char_set: 63,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const fn unsigned(id: TypeId) -> Self {
|
||||
Self {
|
||||
id,
|
||||
is_unsigned: true,
|
||||
is_binary: false,
|
||||
char_set: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const fn r#enum() -> Self {
|
||||
Self {
|
||||
id: TypeId::ENUM,
|
||||
is_unsigned: false,
|
||||
is_binary: false,
|
||||
char_set: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_nullable_column_def(def: &ColumnDefinition) -> Self {
|
||||
Self {
|
||||
id: def.type_id,
|
||||
is_unsigned: def.flags.contains(FieldFlags::UNSIGNED),
|
||||
is_binary: def.flags.contains(FieldFlags::BINARY),
|
||||
char_set: def.char_set,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_column_def(def: &ColumnDefinition) -> Option<Self> {
|
||||
if def.type_id == TypeId::NULL {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self::from_nullable_column_def(def))
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn type_feature_gate(&self) -> Option<&'static str> {
|
||||
match self.id {
|
||||
TypeId::DATE | TypeId::TIME | TypeId::DATETIME | TypeId::TIMESTAMP => Some("chrono"),
|
||||
_ => None,
|
||||
pub(crate) fn from_column(column: &ColumnDefinition) -> Option<Self> {
|
||||
if column.r#type == ColumnType::Null {
|
||||
None
|
||||
} else {
|
||||
Some(Self {
|
||||
r#type: column.r#type,
|
||||
flags: column.flags,
|
||||
char_set: column.char_set,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MySqlTypeInfo {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.id {
|
||||
TypeId::NULL => f.write_str("NULL"),
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.r#type.name(self.char_set))?;
|
||||
|
||||
TypeId::TINY_INT if self.is_unsigned => f.write_str("TINYINT UNSIGNED"),
|
||||
TypeId::SMALL_INT if self.is_unsigned => f.write_str("SMALLINT UNSIGNED"),
|
||||
TypeId::INT if self.is_unsigned => f.write_str("INT UNSIGNED"),
|
||||
TypeId::BIG_INT if self.is_unsigned => f.write_str("BIGINT UNSIGNED"),
|
||||
|
||||
TypeId::TINY_INT => f.write_str("TINYINT"),
|
||||
TypeId::SMALL_INT => f.write_str("SMALLINT"),
|
||||
TypeId::INT => f.write_str("INT"),
|
||||
TypeId::BIG_INT => f.write_str("BIGINT"),
|
||||
|
||||
TypeId::FLOAT => f.write_str("FLOAT"),
|
||||
TypeId::DOUBLE => f.write_str("DOUBLE"),
|
||||
|
||||
TypeId::CHAR if self.is_binary => f.write_str("BINARY"),
|
||||
TypeId::VAR_CHAR if self.is_binary => f.write_str("VARBINARY"),
|
||||
TypeId::TEXT if self.is_binary => f.write_str("BLOB"),
|
||||
|
||||
TypeId::CHAR => f.write_str("CHAR"),
|
||||
TypeId::VAR_CHAR => f.write_str("VARCHAR"),
|
||||
TypeId::TEXT => f.write_str("TEXT"),
|
||||
|
||||
TypeId::DATE => f.write_str("DATE"),
|
||||
TypeId::TIME => f.write_str("TIME"),
|
||||
TypeId::DATETIME => f.write_str("DATETIME"),
|
||||
TypeId::TIMESTAMP => f.write_str("TIMESTAMP"),
|
||||
|
||||
id => write!(f, "<{:#x}>", id.0),
|
||||
if self.flags.contains(ColumnFlags::UNSIGNED) {
|
||||
f.write_str(" UNSIGNED")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeInfo for MySqlTypeInfo {}
|
||||
|
||||
impl PartialEq<MySqlTypeInfo> for MySqlTypeInfo {
|
||||
fn eq(&self, other: &MySqlTypeInfo) -> bool {
|
||||
match self.id {
|
||||
TypeId::VAR_CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::ENUM
|
||||
if (self.is_binary == other.is_binary)
|
||||
&& match other.id {
|
||||
TypeId::VAR_CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::ENUM => true,
|
||||
|
||||
_ => false,
|
||||
} =>
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if self.id.0 != other.id.0 {
|
||||
if self.r#type != other.r#type {
|
||||
return false;
|
||||
}
|
||||
|
||||
match self.id {
|
||||
TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT => {
|
||||
return self.is_unsigned == other.is_unsigned;
|
||||
match self.r#type {
|
||||
ColumnType::Tiny
|
||||
| ColumnType::Short
|
||||
| ColumnType::Long
|
||||
| ColumnType::Int24
|
||||
| ColumnType::LongLong => {
|
||||
return self.flags.contains(ColumnFlags::UNSIGNED)
|
||||
== other.flags.contains(ColumnFlags::UNSIGNED);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
@@ -148,103 +71,4 @@ impl PartialEq<MySqlTypeInfo> for MySqlTypeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeInfo for MySqlTypeInfo {
|
||||
fn compatible(&self, other: &Self) -> bool {
|
||||
// NOTE: MySQL is weakly typed so much of this may be surprising to a Rust developer.
|
||||
|
||||
if self.id == TypeId::NULL || other.id == TypeId::NULL {
|
||||
// NULL is the "bottom" type
|
||||
// If the user is trying to select into a non-Option, we catch this soon with an
|
||||
// UnexpectedNull error message
|
||||
return true;
|
||||
}
|
||||
|
||||
match self.id {
|
||||
// All integer types should be considered compatible
|
||||
TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT
|
||||
if (self.is_unsigned == other.is_unsigned)
|
||||
&& match other.id {
|
||||
TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT => {
|
||||
true
|
||||
}
|
||||
|
||||
_ => false,
|
||||
} =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
// All textual types should be considered compatible
|
||||
TypeId::VAR_CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
if match other.id {
|
||||
TypeId::VAR_CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB => true,
|
||||
|
||||
_ => false,
|
||||
} =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
// Enums are considered compatible with other text/binary types
|
||||
TypeId::ENUM
|
||||
if match other.id {
|
||||
TypeId::VAR_CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::ENUM => true,
|
||||
|
||||
_ => false,
|
||||
} =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
TypeId::VAR_CHAR
|
||||
| TypeId::TEXT
|
||||
| TypeId::CHAR
|
||||
| TypeId::TINY_BLOB
|
||||
| TypeId::MEDIUM_BLOB
|
||||
| TypeId::LONG_BLOB
|
||||
| TypeId::ENUM
|
||||
if other.id == TypeId::ENUM =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
// FLOAT is compatible with DOUBLE
|
||||
TypeId::FLOAT | TypeId::DOUBLE
|
||||
if match other.id {
|
||||
TypeId::FLOAT | TypeId::DOUBLE => true,
|
||||
_ => false,
|
||||
} =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
// DATETIME is compatible with TIMESTAMP
|
||||
TypeId::DATETIME | TypeId::TIMESTAMP
|
||||
if match other.id {
|
||||
TypeId::DATETIME | TypeId::TIMESTAMP => true,
|
||||
_ => false,
|
||||
} =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
_ => self.eq(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Eq for MySqlTypeInfo {}
|
||||
|
||||
@@ -1,92 +1,30 @@
|
||||
use bigdecimal::BigDecimal;
|
||||
|
||||
use crate::database::{Database, HasArguments};
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::io::Buf;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlTypeInfo, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::io::MySqlBufMutExt;
|
||||
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
use std::str::FromStr;
|
||||
|
||||
impl Type<MySql> for BigDecimal {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::NEWDECIMAL)
|
||||
MySqlTypeInfo::binary(ColumnType::NewDecimal)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for BigDecimal {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let size = Encode::<MySql>::size_hint(self) - 1;
|
||||
assert!(size <= std::u8::MAX as usize, "Too large size");
|
||||
buf.push(size as u8);
|
||||
let s = self.to_string();
|
||||
buf.extend_from_slice(s.as_bytes());
|
||||
}
|
||||
impl Encode<'_, MySql> for BigDecimal {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.put_str_lenenc(&self.to_string());
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
let s = self.to_string();
|
||||
s.as_bytes().len() + 1
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for BigDecimal {
|
||||
fn decode(value: MySqlValue) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut binary) => {
|
||||
let _len = binary.get_u8()?;
|
||||
let s = std::str::from_utf8(binary).map_err(Error::decode)?;
|
||||
Ok(BigDecimal::from_str(s).map_err(Error::decode)?)
|
||||
}
|
||||
MySqlData::Text(s) => {
|
||||
let s = std::str::from_utf8(s).map_err(Error::decode)?;
|
||||
Ok(BigDecimal::from_str(s).map_err(Error::decode)?)
|
||||
}
|
||||
}
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(value.as_str()?.parse()?)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_decimal() {
|
||||
let v: BigDecimal = BigDecimal::from_str("-1.05").unwrap();
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
<BigDecimal as Encode<MySql>>::encode(&v, &mut buf);
|
||||
assert_eq!(buf, vec![0x05, b'-', b'1', b'.', b'0', b'5']);
|
||||
|
||||
let v: BigDecimal = BigDecimal::from_str("-105000").unwrap();
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
<BigDecimal as Encode<MySql>>::encode(&v, &mut buf);
|
||||
assert_eq!(buf, vec![0x07, b'-', b'1', b'0', b'5', b'0', b'0', b'0']);
|
||||
|
||||
let v: BigDecimal = BigDecimal::from_str("0.00105").unwrap();
|
||||
let mut buf: Vec<u8> = vec![];
|
||||
<BigDecimal as Encode<MySql>>::encode(&v, &mut buf);
|
||||
assert_eq!(buf, vec![0x07, b'0', b'.', b'0', b'0', b'1', b'0', b'5']);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_decimal() {
|
||||
let buf: Vec<u8> = vec![0x05, b'-', b'1', b'.', b'0', b'5'];
|
||||
let v = <BigDecimal as Decode<'_, MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::new(TypeId::NEWDECIMAL),
|
||||
buf.as_slice(),
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(v.to_string(), "-1.05");
|
||||
|
||||
let buf: Vec<u8> = vec![0x04, b'0', b'.', b'0', b'5'];
|
||||
let v = <BigDecimal as Decode<'_, MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::new(TypeId::NEWDECIMAL),
|
||||
buf.as_slice(),
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(v.to_string(), "0.05");
|
||||
|
||||
let buf: Vec<u8> = vec![0x06, b'-', b'9', b'0', b'0', b'0', b'0'];
|
||||
let v = <BigDecimal as Decode<'_, MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::new(TypeId::NEWDECIMAL),
|
||||
buf.as_slice(),
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(v.to_string(), "-90000");
|
||||
}
|
||||
|
||||
@@ -1,34 +1,28 @@
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
|
||||
impl Type<MySql> for bool {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TINY_INT)
|
||||
// MySQL has no actual `BOOLEAN` type, the type is an alias of `TINYINT(1)`
|
||||
<i8 as Type<MySql>>::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for bool {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(*self as u8);
|
||||
impl Encode<'_, MySql> for bool {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
<i8 as Encode<MySql>>::encode(*self as i8, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for bool {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()),
|
||||
impl Decode<'_, MySql> for bool {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
<i8 as Decode<MySql>>::accepts(ty)
|
||||
}
|
||||
|
||||
MySqlData::Text(b"0") => Ok(false),
|
||||
|
||||
MySqlData::Text(b"1") => Ok(true),
|
||||
|
||||
MySqlData::Text(s) => Err(crate::Error::Decode(
|
||||
format!("unexpected value {:?} for boolean", s).into(),
|
||||
)),
|
||||
}
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(<i8 as Decode<MySql>>::decode(value)? != 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,42 @@
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::io::BufMutExt;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::io::MySqlBufMutExt;
|
||||
use crate::mysql::protocol::text::ColumnType;
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
|
||||
impl Type<MySql> for [u8] {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo {
|
||||
id: TypeId::TEXT,
|
||||
is_binary: true,
|
||||
is_unsigned: false,
|
||||
char_set: 63, // binary
|
||||
}
|
||||
MySqlTypeInfo::binary(ColumnType::Blob)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MySql> for &'_ [u8] {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.put_bytes_lenenc(self);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> Decode<'r, MySql> for &'r [u8] {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(
|
||||
ty.r#type,
|
||||
ColumnType::VarChar
|
||||
| ColumnType::Blob
|
||||
| ColumnType::TinyBlob
|
||||
| ColumnType::MediumBlob
|
||||
| ColumnType::LongBlob
|
||||
| ColumnType::String
|
||||
| ColumnType::VarString
|
||||
| ColumnType::Enum
|
||||
)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
value.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,30 +46,18 @@ impl Type<MySql> for Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for [u8] {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_bytes_lenenc::<LittleEndian>(self);
|
||||
impl Encode<'_, MySql> for Vec<u8> {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
<&[u8] as Encode<MySql>>::encode(&**self, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for Vec<u8> {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
<[u8] as Encode<MySql>>::encode(self, buf);
|
||||
impl Decode<'_, MySql> for Vec<u8> {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
<&[u8] as Decode<MySql>>::accepts(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for Vec<u8> {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(buf) | MySqlData::Text(buf) => Ok(buf.to_vec()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for &'de [u8] {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(buf) | MySqlData::Text(buf) => Ok(buf),
|
||||
}
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +1,35 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use bytes::Buf;
|
||||
use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::io::{Buf, BufMut};
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::{BoxDynError, Error};
|
||||
use crate::mysql::protocol::text::ColumnType;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::mysql::{MySql, MySqlValue, MySqlValueFormat, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
use std::str::from_utf8;
|
||||
|
||||
impl Type<MySql> for DateTime<Utc> {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TIMESTAMP)
|
||||
MySqlTypeInfo::binary(ColumnType::Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for DateTime<Utc> {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
Encode::<MySql>::encode(&self.naive_utc(), buf);
|
||||
impl Encode<'_, MySql> for DateTime<Utc> {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
Encode::<MySql>::encode(&self.naive_utc(), buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for DateTime<Utc> {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
impl<'r> Decode<'r, MySql> for DateTime<Utc> {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
let naive: NaiveDateTime = Decode::<MySql>::decode(value)?;
|
||||
|
||||
Ok(DateTime::from_utc(naive, Utc))
|
||||
@@ -35,12 +38,12 @@ impl<'de> Decode<'de, MySql> for DateTime<Utc> {
|
||||
|
||||
impl Type<MySql> for NaiveTime {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TIME)
|
||||
MySqlTypeInfo::binary(ColumnType::Time)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for NaiveTime {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for NaiveTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
|
||||
@@ -49,9 +52,11 @@ impl Encode<MySql> for NaiveTime {
|
||||
|
||||
// "date on 4 bytes little-endian format" (?)
|
||||
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
|
||||
buf.advance(4);
|
||||
buf.extend_from_slice(&[0_u8; 4]);
|
||||
|
||||
encode_time(self, len > 9, buf);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
@@ -65,27 +70,29 @@ impl Encode<MySql> for NaiveTime {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for NaiveTime {
|
||||
fn decode(buf: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match buf.try_get()? {
|
||||
MySqlData::Binary(mut buf) => {
|
||||
impl<'r> Decode<'r, MySql> for NaiveTime {
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format() {
|
||||
MySqlValueFormat::Binary => {
|
||||
let mut buf = value.as_bytes()?;
|
||||
|
||||
// data length, expecting 8 or 12 (fractional seconds)
|
||||
let len = buf.get_u8()?;
|
||||
let len = buf.get_u8();
|
||||
|
||||
// is negative : int<1>
|
||||
let is_negative = buf.get_u8()?;
|
||||
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
|
||||
let is_negative = buf.get_u8();
|
||||
debug_assert_eq!(is_negative, 0, "Negative dates/times are not supported");
|
||||
|
||||
// "date on 4 bytes little-endian format" (?)
|
||||
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
|
||||
buf.advance(4);
|
||||
|
||||
decode_time(len - 5, buf)
|
||||
Ok(decode_time(len - 5, buf))
|
||||
}
|
||||
|
||||
MySqlData::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(Error::decode)?;
|
||||
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode)
|
||||
MySqlValueFormat::Text => {
|
||||
let s = value.as_str()?;
|
||||
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,15 +100,17 @@ impl<'de> Decode<'de, MySql> for NaiveTime {
|
||||
|
||||
impl Type<MySql> for NaiveDate {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DATE)
|
||||
MySqlTypeInfo::binary(ColumnType::Date)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for NaiveDate {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for NaiveDate {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.push(4);
|
||||
|
||||
encode_date(self, buf);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
@@ -109,14 +118,14 @@ impl Encode<MySql> for NaiveDate {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for NaiveDate {
|
||||
fn decode(buf: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match buf.try_get()? {
|
||||
MySqlData::Binary(buf) => Ok(decode_date(&buf[1..])),
|
||||
impl<'r> Decode<'r, MySql> for NaiveDate {
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format() {
|
||||
MySqlValueFormat::Binary => Ok(decode_date(&value.as_bytes()?[1..])),
|
||||
|
||||
MySqlData::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(Error::decode)?;
|
||||
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode)
|
||||
MySqlValueFormat::Text => {
|
||||
let s = value.as_str()?;
|
||||
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -124,12 +133,12 @@ impl<'de> Decode<'de, MySql> for NaiveDate {
|
||||
|
||||
impl Type<MySql> for NaiveDateTime {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DATETIME)
|
||||
MySqlTypeInfo::binary(ColumnType::Datetime)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for NaiveDateTime {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for NaiveDateTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
|
||||
@@ -138,6 +147,8 @@ impl Encode<MySql> for NaiveDateTime {
|
||||
if len > 4 {
|
||||
encode_time(&self.time(), len > 8, buf);
|
||||
}
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
@@ -162,15 +173,21 @@ impl Encode<MySql> for NaiveDateTime {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for NaiveDateTime {
|
||||
fn decode(buf: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match buf.try_get()? {
|
||||
MySqlData::Binary(buf) => {
|
||||
impl<'r> Decode<'r, MySql> for NaiveDateTime {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format() {
|
||||
MySqlValueFormat::Binary => {
|
||||
let mut buf = value.as_bytes()?;
|
||||
|
||||
let len = buf[0];
|
||||
let date = decode_date(&buf[1..]);
|
||||
|
||||
let dt = if len > 4 {
|
||||
date.and_time(decode_time(len - 4, &buf[5..])?)
|
||||
date.and_time(decode_time(len - 4, &buf[5..]))
|
||||
} else {
|
||||
date.and_hms(0, 0, 0)
|
||||
};
|
||||
@@ -178,9 +195,9 @@ impl<'de> Decode<'de, MySql> for NaiveDateTime {
|
||||
Ok(dt)
|
||||
}
|
||||
|
||||
MySqlData::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(Error::decode)?;
|
||||
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Error::decode)
|
||||
MySqlValueFormat::Text => {
|
||||
let s = value.as_str()?;
|
||||
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,12 +213,9 @@ fn encode_date(date: &NaiveDate, buf: &mut Vec<u8>) {
|
||||
buf.push(date.day() as u8);
|
||||
}
|
||||
|
||||
fn decode_date(buf: &[u8]) -> NaiveDate {
|
||||
NaiveDate::from_ymd(
|
||||
LittleEndian::read_u16(buf) as i32,
|
||||
buf[2] as u32,
|
||||
buf[3] as u32,
|
||||
)
|
||||
fn decode_date(mut buf: &[u8]) -> NaiveDate {
|
||||
let year = buf.get_u16_le();
|
||||
NaiveDate::from_ymd(year as i32, buf[0] as u32, buf[1] as u32)
|
||||
}
|
||||
|
||||
fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
|
||||
@@ -210,93 +224,21 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
|
||||
buf.push(time.second() as u8);
|
||||
|
||||
if include_micros {
|
||||
buf.put_u32::<LittleEndian>((time.nanosecond() / 1000) as u32);
|
||||
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result<NaiveTime> {
|
||||
let hour = buf.get_u8()?;
|
||||
let minute = buf.get_u8()?;
|
||||
let seconds = buf.get_u8()?;
|
||||
fn decode_time(len: u8, mut buf: &[u8]) -> NaiveTime {
|
||||
let hour = buf.get_u8();
|
||||
let minute = buf.get_u8();
|
||||
let seconds = buf.get_u8();
|
||||
|
||||
let micros = if len > 3 {
|
||||
// microseconds : int<EOF>
|
||||
buf.get_uint::<LittleEndian>(buf.len())?
|
||||
buf.get_uint_le(buf.len())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Ok(NaiveTime::from_hms_micro(
|
||||
hour as u32,
|
||||
minute as u32,
|
||||
seconds as u32,
|
||||
micros as u32,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_date_time() {
|
||||
let mut buf = Vec::new();
|
||||
|
||||
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
|
||||
let date1: NaiveDateTime = "2010-10-17T19:27:30.000001".parse().unwrap();
|
||||
Encode::<MySql>::encode(&date1, &mut buf);
|
||||
assert_eq!(*buf, [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]);
|
||||
|
||||
buf.clear();
|
||||
|
||||
let date2: NaiveDateTime = "2010-10-17T19:27:30".parse().unwrap();
|
||||
Encode::<MySql>::encode(&date2, &mut buf);
|
||||
assert_eq!(*buf, [7, 218, 7, 10, 17, 19, 27, 30]);
|
||||
|
||||
buf.clear();
|
||||
|
||||
let date3: NaiveDateTime = "2010-10-17T00:00:00".parse().unwrap();
|
||||
Encode::<MySql>::encode(&date3, &mut buf);
|
||||
assert_eq!(*buf, [4, 218, 7, 10, 17]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_date_time() {
|
||||
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
|
||||
let buf = [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0];
|
||||
let date1 = <NaiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::default(),
|
||||
&buf,
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(date1.to_string(), "2010-10-17 19:27:30.000001");
|
||||
|
||||
let buf = [7, 218, 7, 10, 17, 19, 27, 30];
|
||||
let date2 = <NaiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::default(),
|
||||
&buf,
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(date2.to_string(), "2010-10-17 19:27:30");
|
||||
|
||||
let buf = [4, 218, 7, 10, 17];
|
||||
let date3 = <NaiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::default(),
|
||||
&buf,
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(date3.to_string(), "2010-10-17 00:00:00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_date() {
|
||||
let mut buf = Vec::new();
|
||||
let date: NaiveDate = "2010-10-17".parse().unwrap();
|
||||
Encode::<MySql>::encode(&date, &mut buf);
|
||||
assert_eq!(*buf, [4, 218, 7, 10, 17]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_date() {
|
||||
let buf = [4, 218, 7, 10, 17];
|
||||
let date =
|
||||
<NaiveDate as Decode<MySql>>::decode(MySqlValue::binary(MySqlTypeInfo::default(), &buf))
|
||||
.unwrap();
|
||||
assert_eq!(date.to_string(), "2010-10-17");
|
||||
NaiveTime::from_hms_micro(hour as u32, minute as u32, seconds as u32, micros as u32)
|
||||
}
|
||||
|
||||
@@ -1,83 +1,77 @@
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::protocol::text::ColumnType;
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
use std::str::from_utf8;
|
||||
|
||||
/// The equivalent MySQL type for `f32` is `FLOAT`.
|
||||
///
|
||||
/// ### Note
|
||||
/// While we added support for `f32` as `FLOAT` for completeness, we don't recommend using
|
||||
/// it for any real-life applications as it cannot precisely represent some fractional values,
|
||||
/// and may be implicitly widened to `DOUBLE` in some cases, resulting in a slightly different
|
||||
/// value:
|
||||
///
|
||||
/// ```rust
|
||||
/// // Widening changes the equivalent decimal value, these two expressions are not equal
|
||||
/// // (This is expected behavior for floating points and happens both in Rust and in MySQL)
|
||||
/// assert_ne!(10.2f32 as f64, 10.2f64);
|
||||
/// ```
|
||||
fn real_accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(ty.r#type, ColumnType::Float | ColumnType::Double)
|
||||
}
|
||||
|
||||
impl Type<MySql> for f32 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::FLOAT)
|
||||
MySqlTypeInfo::binary(ColumnType::Float)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for f32 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
<i32 as Encode<MySql>>::encode(&(self.to_bits() as i32), buf);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for f32 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf
|
||||
.read_i32::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)
|
||||
.map(|value| f32::from_bits(value as u32)),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The equivalent MySQL type for `f64` is `DOUBLE`.
|
||||
///
|
||||
/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values
|
||||
/// exactly.
|
||||
impl Type<MySql> for f64 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DOUBLE)
|
||||
MySqlTypeInfo::binary(ColumnType::Double)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for f64 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
<i64 as Encode<MySql>>::encode(&(self.to_bits() as i64), buf);
|
||||
impl Encode<'_, MySql> for f32 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for f64 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf
|
||||
.read_i64::<LittleEndian>()
|
||||
.map_err(crate::Error::decode)
|
||||
.map(|value| f64::from_bits(value as u64)),
|
||||
impl Encode<'_, MySql> for f64 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for f32 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
real_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => {
|
||||
let buf = value.as_bytes()?;
|
||||
|
||||
if buf.len() == 8 {
|
||||
// MySQL can return 8-byte DOUBLE values for a FLOAT
|
||||
// We take and truncate to f32 as that's the same behavior as *in* MySQL
|
||||
LittleEndian::read_f64(buf) as f32
|
||||
} else {
|
||||
LittleEndian::read_f32(buf)
|
||||
}
|
||||
}
|
||||
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for f64 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
real_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_f64(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,111 +1,127 @@
|
||||
use std::str::from_utf8;
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
|
||||
fn int_accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(
|
||||
ty.r#type,
|
||||
ColumnType::Tiny
|
||||
| ColumnType::Short
|
||||
| ColumnType::Long
|
||||
| ColumnType::Int24
|
||||
| ColumnType::LongLong
|
||||
) && !ty.flags.contains(ColumnFlags::UNSIGNED)
|
||||
}
|
||||
|
||||
impl Type<MySql> for i8 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TINY_INT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for i8 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_i8(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for i8 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_i8().map_err(Into::into),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
MySqlTypeInfo::binary(ColumnType::Tiny)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<MySql> for i16 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::SMALL_INT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for i16 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_i16::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for i16 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_i16::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
MySqlTypeInfo::binary(ColumnType::Short)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<MySql> for i32 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::INT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for i32 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_i32::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for i32 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_i32::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
MySqlTypeInfo::binary(ColumnType::Long)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<MySql> for i64 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::BIG_INT)
|
||||
MySqlTypeInfo::binary(ColumnType::LongLong)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for i64 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_i64::<LittleEndian>(*self);
|
||||
impl Encode<'_, MySql> for i8 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for i64 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_i64::<LittleEndian>().map_err(Into::into),
|
||||
impl Encode<'_, MySql> for i16 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MySql> for i32 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MySql> for i64 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for i8 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
int_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => value.as_bytes()?[0] as i8,
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for i16 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
int_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_i16(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for i32 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
int_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_i32(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for i64 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
int_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_i64(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,46 +1,47 @@
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::database::MySql;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::types::*;
|
||||
use crate::mysql::{MySqlTypeInfo, MySqlValue};
|
||||
use crate::types::{Json, Type};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
impl Type<MySql> for JsonValue {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
<Json<Self> as Type<MySql>>::type_info()
|
||||
}
|
||||
}
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::protocol::text::ColumnType;
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
|
||||
use crate::types::{Json, Type};
|
||||
|
||||
impl<T> Type<MySql> for Json<T> {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
// MySql uses the CHAR type to pass JSON data from and to the client
|
||||
MySqlTypeInfo::new(TypeId::CHAR)
|
||||
// MySql uses the `CHAR` type to pass JSON data from and to the client
|
||||
// NOTE: This is forwards-compatible with MySQL v8+ as CHAR is a common transmission format
|
||||
// and has nothing to do with the native storage ability of MySQL v8+
|
||||
MySqlTypeInfo::binary(ColumnType::String)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Encode<MySql> for Json<T>
|
||||
impl<T> Encode<'_, MySql> for Json<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
let json_string_value =
|
||||
serde_json::to_string(&self.0).expect("serde_json failed to convert to string");
|
||||
<str as Encode<MySql>>::encode(json_string_value.as_str(), buf);
|
||||
|
||||
<&str as Encode<MySql>>::encode(json_string_value.as_str(), buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Decode<'de, MySql> for Json<T>
|
||||
impl<'r, T> Decode<'r, MySql> for Json<T>
|
||||
where
|
||||
T: 'de,
|
||||
T: for<'de1> Deserialize<'de1>,
|
||||
T: 'r + DeserializeOwned,
|
||||
{
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
let string_value = <&'de str as Decode<MySql>>::decode(value).unwrap();
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
ty.r#type == ColumnType::Json || <&str as Decode<MySql>>::accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
let string_value = <&str as Decode<MySql>>::decode(value)?;
|
||||
|
||||
serde_json::from_str(&string_value)
|
||||
.map(Json)
|
||||
.map_err(crate::Error::decode)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//!
|
||||
//! | Rust type | MySQL type(s) |
|
||||
//! |---------------------------------------|------------------------------------------------------|
|
||||
//! | `bool` | TINYINT(1) |
|
||||
//! | `bool` | TINYINT(1), BOOLEAN |
|
||||
//! | `i8` | TINYINT |
|
||||
//! | `i16` | SMALLINT |
|
||||
//! | `i32` | INT |
|
||||
@@ -47,6 +47,7 @@
|
||||
//! | Rust type | MySQL type(s) |
|
||||
//! |---------------------------------------|------------------------------------------------------|
|
||||
//! | `bigdecimal::BigDecimal` | DECIMAL |
|
||||
//!
|
||||
//! ### [`json`](https://crates.io/crates/json)
|
||||
//!
|
||||
//! Requires the `json` Cargo feature flag.
|
||||
@@ -79,19 +80,3 @@ mod time;
|
||||
|
||||
#[cfg(feature = "json")]
|
||||
mod json;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::mysql::{MySql, MySqlValue};
|
||||
|
||||
impl<'de, T> Decode<'de, MySql> for Option<T>
|
||||
where
|
||||
T: Decode<'de, MySql>,
|
||||
{
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
Ok(if value.get().is_some() {
|
||||
Some(<T as Decode<MySql>>::decode(value)?)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,30 +1,46 @@
|
||||
use std::str;
|
||||
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::io::BufMutExt;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::io::MySqlBufMutExt;
|
||||
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
use std::str::from_utf8;
|
||||
|
||||
impl Type<MySql> for str {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo {
|
||||
id: TypeId::TEXT,
|
||||
is_binary: false,
|
||||
is_unsigned: false,
|
||||
char_set: 224, // utf8mb4_unicode_ci
|
||||
r#type: ColumnType::Blob, // TEXT
|
||||
char_set: 224, // utf8mb4_unicode_ci
|
||||
flags: ColumnFlags::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for str {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
buf.put_str_lenenc::<LittleEndian>(self);
|
||||
impl Encode<'_, MySql> for &'_ str {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.put_str_lenenc(self);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> Decode<'r, MySql> for &'r str {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(
|
||||
ty.r#type,
|
||||
ColumnType::VarChar
|
||||
| ColumnType::Blob
|
||||
| ColumnType::TinyBlob
|
||||
| ColumnType::MediumBlob
|
||||
| ColumnType::LongBlob
|
||||
| ColumnType::String
|
||||
| ColumnType::VarString
|
||||
| ColumnType::Enum
|
||||
)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
value.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,24 +50,18 @@ impl Type<MySql> for String {
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for String {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
<str as Encode<MySql>>::encode(self.as_str(), buf)
|
||||
impl Encode<'_, MySql> for String {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
<&str as Encode<MySql>>::encode(&**self, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for &'de str {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(buf) | MySqlData::Text(buf) => {
|
||||
from_utf8(buf).map_err(crate::Error::decode)
|
||||
}
|
||||
}
|
||||
impl Decode<'_, MySql> for String {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
<&str as Decode<MySql>>::accepts(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for String {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
<&'de str as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
<&str as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,33 +2,38 @@ use std::borrow::Cow;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
use bytes::Buf;
|
||||
use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::io::{Buf, BufMut};
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::protocol::text::ColumnType;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::mysql::{MySql, MySqlValueFormat, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
|
||||
impl Type<MySql> for OffsetDateTime {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TIMESTAMP)
|
||||
MySqlTypeInfo::binary(ColumnType::Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for OffsetDateTime {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for OffsetDateTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
let utc_dt = self.to_offset(UtcOffset::UTC);
|
||||
let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time());
|
||||
|
||||
Encode::<MySql>::encode(&primitive_dt, buf);
|
||||
Encode::<MySql>::encode(&primitive_dt, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for OffsetDateTime {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
impl<'r> Decode<'r, MySql> for OffsetDateTime {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
let primitive: PrimitiveDateTime = Decode::<MySql>::decode(value)?;
|
||||
|
||||
Ok(primitive.assume_utc())
|
||||
@@ -37,12 +42,12 @@ impl<'de> Decode<'de, MySql> for OffsetDateTime {
|
||||
|
||||
impl Type<MySql> for Time {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::TIME)
|
||||
MySqlTypeInfo::binary(ColumnType::Time)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for Time {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for Time {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
|
||||
@@ -51,9 +56,11 @@ impl Encode<MySql> for Time {
|
||||
|
||||
// "date on 4 bytes little-endian format" (?)
|
||||
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
|
||||
buf.advance(4);
|
||||
buf.extend_from_slice(&[0_u8; 4]);
|
||||
|
||||
encode_time(self, len > 9, buf);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
@@ -67,15 +74,17 @@ impl Encode<MySql> for Time {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for Time {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => {
|
||||
impl<'r> Decode<'r, MySql> for Time {
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format() {
|
||||
MySqlValueFormat::Binary => {
|
||||
let mut buf = value.as_bytes()?;
|
||||
|
||||
// data length, expecting 8 or 12 (fractional seconds)
|
||||
let len = buf.get_u8()?;
|
||||
let len = buf.get_u8();
|
||||
|
||||
// is negative : int<1>
|
||||
let is_negative = buf.get_u8()?;
|
||||
let is_negative = buf.get_u8();
|
||||
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
|
||||
|
||||
// "date on 4 bytes little-endian format" (?)
|
||||
@@ -85,8 +94,8 @@ impl<'de> Decode<'de, MySql> for Time {
|
||||
decode_time(len - 5, buf)
|
||||
}
|
||||
|
||||
MySqlData::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(crate::Error::decode)?;
|
||||
MySqlValueFormat::Text => {
|
||||
let s = value.as_str()?;
|
||||
|
||||
// If there are less than 9 digits after the decimal point
|
||||
// We need to zero-pad
|
||||
@@ -98,7 +107,7 @@ impl<'de> Decode<'de, MySql> for Time {
|
||||
Cow::Borrowed(s)
|
||||
};
|
||||
|
||||
Time::parse(&*s, "%H:%M:%S.%N").map_err(crate::Error::decode)
|
||||
Time::parse(&*s, "%H:%M:%S.%N").map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,15 +115,17 @@ impl<'de> Decode<'de, MySql> for Time {
|
||||
|
||||
impl Type<MySql> for Date {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DATE)
|
||||
MySqlTypeInfo::binary(ColumnType::Date)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for Date {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for Date {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.push(4);
|
||||
|
||||
encode_date(self, buf);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
@@ -122,13 +133,13 @@ impl Encode<MySql> for Date {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for Date {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(buf) => decode_date(&buf[1..]),
|
||||
MySqlData::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(crate::Error::decode)?;
|
||||
Date::parse(s, "%Y-%m-%d").map_err(crate::Error::decode)
|
||||
impl<'r> Decode<'r, MySql> for Date {
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format() {
|
||||
MySqlValueFormat::Binary => decode_date(&value.as_bytes()?[1..]),
|
||||
MySqlValueFormat::Text => {
|
||||
let s = value.as_str()?;
|
||||
Date::parse(s, "%Y-%m-%d").map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -136,12 +147,12 @@ impl<'de> Decode<'de, MySql> for Date {
|
||||
|
||||
impl Type<MySql> for PrimitiveDateTime {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::new(TypeId::DATETIME)
|
||||
MySqlTypeInfo::binary(ColumnType::Datetime)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for PrimitiveDateTime {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
impl Encode<'_, MySql> for PrimitiveDateTime {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
let len = Encode::<MySql>::size_hint(self) - 1;
|
||||
buf.push(len as u8);
|
||||
|
||||
@@ -150,6 +161,8 @@ impl Encode<MySql> for PrimitiveDateTime {
|
||||
if len > 4 {
|
||||
encode_time(&self.time(), len > 8, buf);
|
||||
}
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
@@ -169,10 +182,15 @@ impl Encode<MySql> for PrimitiveDateTime {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for PrimitiveDateTime {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(buf) => {
|
||||
impl<'r> Decode<'r, MySql> for PrimitiveDateTime {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format() {
|
||||
MySqlValueFormat::Binary => {
|
||||
let mut buf = value.as_bytes()?;
|
||||
let len = buf[0];
|
||||
let date = decode_date(&buf[1..])?;
|
||||
|
||||
@@ -185,8 +203,8 @@ impl<'de> Decode<'de, MySql> for PrimitiveDateTime {
|
||||
Ok(dt)
|
||||
}
|
||||
|
||||
MySqlData::Text(buf) => {
|
||||
let s = from_utf8(buf).map_err(crate::Error::decode)?;
|
||||
MySqlValueFormat::Text => {
|
||||
let s = value.as_str()?;
|
||||
|
||||
// If there are less than 9 digits after the decimal point
|
||||
// We need to zero-pad
|
||||
@@ -202,7 +220,7 @@ impl<'de> Decode<'de, MySql> for PrimitiveDateTime {
|
||||
Cow::Borrowed(s)
|
||||
};
|
||||
|
||||
PrimitiveDateTime::parse(&*s, "%Y-%m-%d %H:%M:%S.%N").map_err(crate::Error::decode)
|
||||
PrimitiveDateTime::parse(&*s, "%Y-%m-%d %H:%M:%S.%N").map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -218,13 +236,13 @@ fn encode_date(date: &Date, buf: &mut Vec<u8>) {
|
||||
buf.push(date.day());
|
||||
}
|
||||
|
||||
fn decode_date(buf: &[u8]) -> crate::Result<Date> {
|
||||
fn decode_date(buf: &[u8]) -> Result<Date, BoxDynError> {
|
||||
Date::try_from_ymd(
|
||||
LittleEndian::read_u16(buf) as i32,
|
||||
buf[2] as u8,
|
||||
buf[3] as u8,
|
||||
)
|
||||
.map_err(|e| decode_err!("Error while decoding Date: {}", e))
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec<u8>) {
|
||||
@@ -233,92 +251,22 @@ fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec<u8>) {
|
||||
buf.push(time.second());
|
||||
|
||||
if include_micros {
|
||||
buf.put_u32::<LittleEndian>((time.nanosecond() / 1000) as u32);
|
||||
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result<Time> {
|
||||
let hour = buf.get_u8()?;
|
||||
let minute = buf.get_u8()?;
|
||||
let seconds = buf.get_u8()?;
|
||||
fn decode_time(len: u8, mut buf: &[u8]) -> Result<Time, BoxDynError> {
|
||||
let hour = buf.get_u8();
|
||||
let minute = buf.get_u8();
|
||||
let seconds = buf.get_u8();
|
||||
|
||||
let micros = if len > 3 {
|
||||
// microseconds : int<EOF>
|
||||
buf.get_uint::<LittleEndian>(buf.len())?
|
||||
buf.get_uint_le(buf.len())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
Time::try_from_hms_micro(hour, minute, seconds, micros as u32)
|
||||
.map_err(|e| decode_err!("Time out of range for MySQL: {}", e))
|
||||
}
|
||||
|
||||
use std::str::from_utf8;
|
||||
#[cfg(test)]
|
||||
use time::{date, time};
|
||||
|
||||
#[test]
|
||||
fn test_encode_date_time() {
|
||||
let mut buf = Vec::new();
|
||||
|
||||
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
|
||||
let date = PrimitiveDateTime::new(date!(2010 - 10 - 17), time!(19:27:30.000001));
|
||||
Encode::<MySql>::encode(&date, &mut buf);
|
||||
assert_eq!(*buf, [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]);
|
||||
|
||||
buf.clear();
|
||||
|
||||
let date = PrimitiveDateTime::new(date!(2010 - 10 - 17), time!(19:27:30));
|
||||
Encode::<MySql>::encode(&date, &mut buf);
|
||||
assert_eq!(*buf, [7, 218, 7, 10, 17, 19, 27, 30]);
|
||||
|
||||
buf.clear();
|
||||
|
||||
let date = PrimitiveDateTime::new(date!(2010 - 10 - 17), time!(00:00:00));
|
||||
Encode::<MySql>::encode(&date, &mut buf);
|
||||
assert_eq!(*buf, [4, 218, 7, 10, 17]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_date_time() {
|
||||
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
|
||||
let buf = [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0];
|
||||
let date1 = <PrimitiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::default(),
|
||||
&buf,
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(date1.to_string(), "2010-10-17 19:27:30.000001");
|
||||
|
||||
let buf = [7, 218, 7, 10, 17, 19, 27, 30];
|
||||
let date2 = <PrimitiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::default(),
|
||||
&buf,
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(date2.to_string(), "2010-10-17 19:27:30");
|
||||
|
||||
let buf = [4, 218, 7, 10, 17];
|
||||
let date3 = <PrimitiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
|
||||
MySqlTypeInfo::default(),
|
||||
&buf,
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(date3.to_string(), "2010-10-17 0:00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_date() {
|
||||
let mut buf = Vec::new();
|
||||
let date: Date = date!(2010 - 10 - 17);
|
||||
Encode::<MySql>::encode(&date, &mut buf);
|
||||
assert_eq!(*buf, [4, 218, 7, 10, 17]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_date() {
|
||||
let buf = [4, 218, 7, 10, 17];
|
||||
let date = <Date as Decode<MySql>>::decode(MySqlValue::binary(MySqlTypeInfo::default(), &buf))
|
||||
.unwrap();
|
||||
assert_eq!(date, date!(2010 - 10 - 17));
|
||||
.map_err(|e| format!("Time out of range for MySQL: {}", e).into())
|
||||
}
|
||||
|
||||
@@ -1,111 +1,135 @@
|
||||
use std::str::from_utf8;
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::Encode;
|
||||
use crate::mysql::protocol::TypeId;
|
||||
use crate::mysql::type_info::MySqlTypeInfo;
|
||||
use crate::mysql::{MySql, MySqlData, MySqlValue};
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
|
||||
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
|
||||
use crate::types::Type;
|
||||
use crate::Error;
|
||||
|
||||
fn uint_type_info(ty: ColumnType) -> MySqlTypeInfo {
|
||||
MySqlTypeInfo {
|
||||
r#type: ty,
|
||||
flags: ColumnFlags::BINARY | ColumnFlags::UNSIGNED,
|
||||
char_set: 63,
|
||||
}
|
||||
}
|
||||
|
||||
fn uint_accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
matches!(
|
||||
ty.r#type,
|
||||
ColumnType::Tiny
|
||||
| ColumnType::Short
|
||||
| ColumnType::Long
|
||||
| ColumnType::Int24
|
||||
| ColumnType::LongLong
|
||||
) && ty.flags.contains(ColumnFlags::UNSIGNED)
|
||||
}
|
||||
|
||||
impl Type<MySql> for u8 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::TINY_INT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for u8 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_u8(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for u8 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_u8().map_err(Into::into),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
uint_type_info(ColumnType::Tiny)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<MySql> for u16 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::SMALL_INT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for u16 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_u16::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for u16 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_u16::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
uint_type_info(ColumnType::Short)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<MySql> for u32 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::INT)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for u32 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_u32::<LittleEndian>(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for u32 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_u32::<LittleEndian>().map_err(Into::into),
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
uint_type_info(ColumnType::Long)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<MySql> for u64 {
|
||||
fn type_info() -> MySqlTypeInfo {
|
||||
MySqlTypeInfo::unsigned(TypeId::BIG_INT)
|
||||
uint_type_info(ColumnType::LongLong)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<MySql> for u64 {
|
||||
fn encode(&self, buf: &mut Vec<u8>) {
|
||||
let _ = buf.write_u64::<LittleEndian>(*self);
|
||||
impl Encode<'_, MySql> for u8 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de, MySql> for u64 {
|
||||
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
|
||||
match value.try_get()? {
|
||||
MySqlData::Binary(mut buf) => buf.read_u64::<LittleEndian>().map_err(Into::into),
|
||||
impl Encode<'_, MySql> for u16 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
MySqlData::Text(s) => from_utf8(s)
|
||||
.map_err(Error::decode)?
|
||||
.parse()
|
||||
.map_err(Error::decode),
|
||||
}
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MySql> for u32 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MySql> for u64 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for u8 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
uint_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => value.as_bytes()?[0] as u8,
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for u16 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
uint_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_u16(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for u32 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
uint_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_u32(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MySql> for u64 {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
uint_accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(match value.format() {
|
||||
MySqlValueFormat::Binary => LittleEndian::read_u64(value.as_bytes()?),
|
||||
MySqlValueFormat::Text => value.as_str()?.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
// XOR(x, y)
|
||||
// If len(y) < len(x), wrap around inside y
|
||||
pub fn xor_eq(x: &mut [u8], y: &[u8]) {
|
||||
let y_len = y.len();
|
||||
|
||||
for i in 0..x.len() {
|
||||
x[i] ^= y[i % y_len];
|
||||
}
|
||||
}
|
||||
@@ -1,62 +1,98 @@
|
||||
use crate::error::UnexpectedNullError;
|
||||
use std::borrow::Cow;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::error::{BoxDynError, UnexpectedNullError};
|
||||
use crate::mysql::{MySql, MySqlTypeInfo};
|
||||
use crate::value::RawValue;
|
||||
use crate::value::{Value, ValueRef};
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub enum MySqlData<'c> {
|
||||
Binary(&'c [u8]),
|
||||
Text(&'c [u8]),
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[repr(u8)]
|
||||
pub enum MySqlValueFormat {
|
||||
Text,
|
||||
Binary,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MySqlValue<'c> {
|
||||
/// Implementation of [`Value`] for MySQL.
|
||||
#[derive(Clone)]
|
||||
pub struct MySqlValue {
|
||||
value: Option<Bytes>,
|
||||
type_info: Option<MySqlTypeInfo>,
|
||||
data: Option<MySqlData<'c>>,
|
||||
format: MySqlValueFormat,
|
||||
}
|
||||
|
||||
impl<'c> MySqlValue<'c> {
|
||||
/// Gets the binary or text data for this value; or, `UnexpectedNullError` if this
|
||||
/// is a `NULL` value.
|
||||
pub(crate) fn try_get(&self) -> crate::Result<MySqlData<'c>> {
|
||||
match self.data {
|
||||
Some(data) => Ok(data),
|
||||
None => Err(crate::Error::decode(UnexpectedNullError)),
|
||||
/// Implementation of [`ValueRef`] for MySQL.
|
||||
#[derive(Clone)]
|
||||
pub struct MySqlValueRef<'r> {
|
||||
pub(crate) value: Option<&'r [u8]>,
|
||||
pub(crate) row: Option<&'r Bytes>,
|
||||
pub(crate) type_info: Option<MySqlTypeInfo>,
|
||||
pub(crate) format: MySqlValueFormat,
|
||||
}
|
||||
|
||||
impl<'r> MySqlValueRef<'r> {
|
||||
pub(crate) fn format(&self) -> MySqlValueFormat {
|
||||
self.format
|
||||
}
|
||||
|
||||
pub(crate) fn as_bytes(&self) -> Result<&'r [u8], BoxDynError> {
|
||||
match &self.value {
|
||||
Some(v) => Ok(v),
|
||||
None => Err(UnexpectedNullError.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the binary or text data for this value; or, `None` if this
|
||||
/// is a `NULL` value.
|
||||
#[inline]
|
||||
pub fn get(&self) -> Option<MySqlData<'c>> {
|
||||
self.data
|
||||
}
|
||||
|
||||
pub(crate) fn null() -> Self {
|
||||
Self {
|
||||
type_info: None,
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary(type_info: MySqlTypeInfo, buf: &'c [u8]) -> Self {
|
||||
Self {
|
||||
type_info: Some(type_info),
|
||||
data: Some(MySqlData::Binary(buf)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn text(type_info: MySqlTypeInfo, buf: &'c [u8]) -> Self {
|
||||
Self {
|
||||
type_info: Some(type_info),
|
||||
data: Some(MySqlData::Text(buf)),
|
||||
}
|
||||
pub(crate) fn as_str(&self) -> Result<&'r str, BoxDynError> {
|
||||
Ok(from_utf8(self.as_bytes()?)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c> RawValue<'c> for MySqlValue<'c> {
|
||||
impl Value for MySqlValue {
|
||||
type Database = MySql;
|
||||
|
||||
fn type_info(&self) -> Option<MySqlTypeInfo> {
|
||||
self.type_info.clone()
|
||||
fn as_ref(&self) -> MySqlValueRef<'_> {
|
||||
MySqlValueRef {
|
||||
value: self.value.as_deref(),
|
||||
row: None,
|
||||
type_info: self.type_info.clone(),
|
||||
format: self.format,
|
||||
}
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Option<Cow<'_, MySqlTypeInfo>> {
|
||||
self.type_info.as_ref().map(Cow::Borrowed)
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
self.value.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> ValueRef<'r> for MySqlValueRef<'r> {
|
||||
type Database = MySql;
|
||||
|
||||
fn to_owned(&self) -> MySqlValue {
|
||||
let value = match (self.row, self.value) {
|
||||
(Some(row), Some(value)) => Some(row.slice_ref(value)),
|
||||
|
||||
(None, Some(value)) => Some(Bytes::copy_from_slice(value)),
|
||||
|
||||
_ => None,
|
||||
};
|
||||
|
||||
MySqlValue {
|
||||
value,
|
||||
format: self.format,
|
||||
type_info: self.type_info.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_info(&self) -> Option<Cow<'_, MySqlTypeInfo>> {
|
||||
self.type_info.as_ref().map(Cow::Borrowed)
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
self.value.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user