refactor(mysql): adapt to the 0.4.x core refactor

This commit is contained in:
Ryan Leckey
2020-05-26 04:31:38 -07:00
parent 37a69e0ac3
commit 2966b655fc
80 changed files with 3479 additions and 3929 deletions

View File

@@ -1,43 +1,51 @@
use std::ops::{Deref, DerefMut};
use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull};
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::MySql;
use crate::types::Type;
use crate::mysql::{MySql, MySqlTypeInfo};
#[derive(Default)]
/// Implementation of [`Arguments`] for MySQL.
#[derive(Debug, Default)]
pub struct MySqlArguments {
pub(crate) param_types: Vec<MySqlTypeInfo>,
pub(crate) params: Vec<u8>,
pub(crate) values: Vec<u8>,
pub(crate) types: Vec<MySqlTypeInfo>,
pub(crate) null_bitmap: Vec<u8>,
}
impl Arguments for MySqlArguments {
impl<'q> Arguments<'q> for MySqlArguments {
type Database = MySql;
fn reserve(&mut self, len: usize, size: usize) {
self.param_types.reserve(len);
self.params.reserve(size);
// ensure we have enough size in the bitmap to hold at least `len` extra bits
// the second `& 7` gives us 0 spare bits when param_types.len() is a multiple of 8
let spare_bits = (8 - (self.param_types.len()) & 7) & 7;
// ensure that if there are no spare bits left, `len = 1` reserves another byte
self.null_bitmap.reserve((len + 7 - spare_bits) / 8);
self.types.reserve(len);
self.values.reserve(size);
}
fn add<T>(&mut self, value: T)
where
T: Type<Self::Database>,
T: Encode<Self::Database>,
T: Encode<'q, Self::Database>,
{
let type_id = <T as Type<MySql>>::type_info();
let index = self.param_types.len();
let ty = value.produces();
let index = self.types.len();
self.param_types.push(type_id);
self.types.push(ty);
self.null_bitmap.resize((index / 8) + 1, 0);
if let IsNull::Yes = value.encode_nullable(&mut self.params) {
if let IsNull::Yes = value.encode(self) {
self.null_bitmap[index / 8] |= (1 << index % 8) as u8;
}
}
}
impl Deref for MySqlArguments {
type Target = Vec<u8>;
fn deref(&self) -> &Self::Target {
&self.values
}
}
impl DerefMut for MySqlArguments {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.values
}
}

View File

@@ -1,345 +0,0 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Range;
use futures_core::future::BoxFuture;
use sha1::Sha1;
use crate::connection::{Connect, Connection};
use crate::executor::Executor;
use crate::mysql::protocol::{
AuthPlugin, AuthSwitch, Capabilities, ComPing, Handshake, HandshakeResponse,
};
use crate::mysql::stream::MySqlStream;
use crate::mysql::util::xor_eq;
use crate::mysql::{rsa, tls};
use crate::url::Url;
// Size before a packet is split
pub(super) const MAX_PACKET_SIZE: u32 = 1024;
pub(super) const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
/// An asynchronous connection to a [`MySql`] database.
///
/// The connection string expected by `MySqlConnection` should be a MySQL connection
/// string, as documented at
/// <https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html#connecting-using-uri>
///
/// ### TLS Support (requires `tls` feature)
/// This connection type supports some of the same flags as the `mysql` CLI application for SSL
/// connections, but they must be specified via the query segment of the connection string
/// rather than as program arguments.
///
/// The same options for `--ssl-mode` are supported as the `ssl-mode` query parameter:
/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
///
/// ```text
/// mysql://<user>[:<password>]@<host>[:<port>]/<database>[?ssl-mode=<ssl-mode>[&ssl-ca=<path>]]
/// ```
/// where
/// ```text
/// ssl-mode = DISABLED | PREFERRED | REQUIRED | VERIFY_CA | VERIFY_IDENTITY
/// path = percent (URL) encoded path on the local machine
/// ```
///
/// If the `tls` feature is not enabled, `ssl-mode=DISABLED` and `ssl-mode=PREFERRED` are no-ops and
/// `ssl-mode=REQUIRED`, `ssl-mode=VERIFY_CA` and `ssl-mode=VERIFY_IDENTITY` are forbidden
/// (attempting to connect with these will return an error).
///
/// If the `tls` feature is enabled, an upgrade to TLS is attempted on every connection by default
/// (equivalent to `ssl-mode=PREFERRED`). If the server does not support TLS (because `--ssl=0` was
/// passed to the server or an invalid certificate or key was used:
/// <https://dev.mysql.com/doc/refman/8.0/en/using-encrypted-connections.html>)
/// then it falls back to an unsecured connection and logs a warning.
///
/// Add `ssl-mode=REQUIRED` to your connection string to emit an error if the TLS upgrade fails.
///
/// However, like with `mysql` the server certificate is **not** checked for validity by default.
///
/// Specifying `ssl-mode=VERIFY_CA` will cause the TLS upgrade to verify the server's SSL
/// certificate against a local CA root certificate; this is not the system root certificate
/// but is instead expected to be specified as a local path with the `ssl-ca` query parameter
/// (percent-encoded so the URL remains valid).
///
/// If you're running MySQL locally it might look something like this (for `VERIFY_CA`):
/// ```text
/// mysql://root:password@localhost/my_database?ssl-mode=VERIFY_CA&ssl-ca=%2Fvar%2Flib%2Fmysql%2Fca.pem
/// ```
///
/// `%2F` is the percent-encoding for forward slash (`/`). In the example we give `/var/lib/mysql/ca.pem`
/// as the CA certificate path, which is generated by the MySQL server automatically if
/// no certificate is manually specified. Note that the path may vary based on the default `my.cnf`
/// packaged with MySQL for your Linux distribution. Also note that unlike MySQL, MariaDB does *not*
/// generate certificates automatically and they must always be passed in to enable TLS.
///
/// If `ssl-ca` is not specified or the file cannot be read, then an error is returned.
/// `ssl-ca` implies `ssl-mode=VERIFY_CA` so you only actually need to specify the former
/// but you may prefer having both to be more explicit.
///
/// If `ssl-mode=VERIFY_IDENTITY` is specified, in addition to checking the certificate as with
/// `ssl-mode=VERIFY_CA`, the hostname in the connection string will be verified
/// against the hostname in the server certificate, so they must be the same for the TLS
/// upgrade to succeed. `ssl-ca` must still be specified.
pub struct MySqlConnection {
pub(super) stream: MySqlStream,
pub(super) is_ready: bool,
pub(super) cache_statement: HashMap<Box<str>, u32>,
// Work buffer for the value ranges of the current row
// This is used as the backing memory for each Row's value indexes
pub(super) current_row_values: Vec<Option<Range<usize>>>,
}
fn to_asciz(s: &str) -> Vec<u8> {
let mut z = String::with_capacity(s.len() + 1);
z.push_str(s);
z.push('\0');
z.into_bytes()
}
async fn rsa_encrypt_with_nonce(
stream: &mut MySqlStream,
public_key_request_id: u8,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
if stream.is_tls() {
// If in a TLS stream, send the password directly in clear text
return Ok(to_asciz(password));
}
// client sends a public key request
stream.send(&[public_key_request_id][..], false).await?;
// server sends a public key response
let packet = stream.receive().await?;
let rsa_pub_key = &packet[1..];
// xor the password with the given nonce
let mut pass = to_asciz(password);
xor_eq(&mut pass, nonce);
// client sends an RSA encrypted password
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
}
async fn make_auth_response(
stream: &mut MySqlStream,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
if password.is_empty() {
// Empty password should not be sent
return Ok(vec![]);
}
match plugin {
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
Ok(plugin.scramble(password, nonce))
}
AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await,
}
}
async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
// https://mariadb.com/kb/en/connection/
// Read a [Handshake] packet. When connecting to the database server, this is immediately
// received from the database server.
let handshake = Handshake::read(stream.receive().await?)?;
let mut auth_plugin = handshake.auth_plugin;
let mut auth_plugin_data = handshake.auth_plugin_data;
stream.capabilities &= handshake.server_capabilities;
stream.capabilities |= Capabilities::PROTOCOL_41;
log::trace!("using capability flags: {:?}", stream.capabilities);
// Depending on the ssl-mode and capabilities we should upgrade
// our connection to TLS
tls::upgrade_if_needed(stream, url).await?;
// Send a [HandshakeResponse] packet. This is returned in response to the [Handshake] packet
// that is immediately received.
let password = &*url.password().unwrap_or_default();
let auth_response =
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
stream
.send(
HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: &url.username().unwrap_or(Cow::Borrowed("root")),
database: url.database(),
auth_plugin: &auth_plugin,
auth_response: &auth_response,
},
false,
)
.await?;
loop {
// After sending the handshake response with our assumed auth method the server
// will send OK, fail, or tell us to change auth methods
let packet = stream.receive().await?;
match packet[0] {
// OK
0x00 => {
break;
}
// ERROR
0xFF => {
return stream.handle_err();
}
// AUTH_SWITCH
0xFE => {
let auth = AuthSwitch::read(packet)?;
auth_plugin = auth.auth_plugin;
auth_plugin_data = auth.auth_plugin_data;
let auth_response =
make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?;
stream.send(&*auth_response, false).await?;
}
0x01 if auth_plugin == AuthPlugin::CachingSha2Password => {
match packet[1] {
// AUTH_OK
0x03 => {}
// AUTH_CONTINUE
0x04 => {
// The specific password is _not_ cached on the server
// We need to send a normal RSA-encrypted password for this
let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data)
.await?;
stream.send(&*enc, false).await?;
}
unk => {
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into());
}
}
}
_ => {
return stream.handle_unexpected();
}
}
}
Ok(())
}
async fn close(mut stream: MySqlStream) -> crate::Result<()> {
// TODO: Actually tell MySQL that we're closing
stream.flush().await?;
stream.shutdown()?;
Ok(())
}
async fn ping(stream: &mut MySqlStream) -> crate::Result<()> {
stream.wait_until_ready().await?;
stream.is_ready = false;
stream.send(ComPing, true).await?;
match stream.receive().await?[0] {
0x00 | 0xFE => stream.handle_ok().map(drop),
0xFF => stream.handle_err(),
_ => stream.handle_unexpected(),
}
}
impl MySqlConnection {
pub(super) async fn new(url: std::result::Result<Url, url::ParseError>) -> crate::Result<Self> {
let url = url?;
let mut stream = MySqlStream::new(&url).await?;
establish(&mut stream, &url).await?;
let mut self_ = Self {
stream,
current_row_values: Vec::with_capacity(10),
is_ready: true,
cache_statement: HashMap::new(),
};
// After the connection is established, we initialize by configuring a few
// connection parameters
// https://mariadb.com/kb/en/sql-mode/
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
// not available, a warning is given and the default storage
// engine is used instead.
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
// --
// Setting the time zone allows us to assume that the output
// from a TIMESTAMP field is UTC
// --
// https://mathiasbynens.be/notes/mysql-utf8mb4
self_.execute(r#"
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
SET time_zone = '+00:00';
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
"#).await?;
Ok(self_)
}
}
impl Connect for MySqlConnection {
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<MySqlConnection>>
where
T: TryInto<Url, Error = url::ParseError>,
Self: Sized,
{
Box::pin(MySqlConnection::new(url.try_into()))
}
}
impl Connection for MySqlConnection {
#[inline]
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(close(self.stream))
}
#[inline]
fn ping(&mut self) -> BoxFuture<crate::Result<()>> {
Box::pin(ping(&mut self.stream))
}
}

View File

@@ -0,0 +1,176 @@
use bytes::buf::ext::Chain;
use bytes::Bytes;
use digest::{Digest, FixedOutput};
use generic_array::GenericArray;
use sha1::Sha1;
use sha2::Sha256;
use crate::error::Error;
use crate::mysql::connection::stream::MySqlStream;
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::rsa;
use crate::mysql::protocol::Packet;
impl AuthPlugin {
pub(super) async fn scramble(
self,
stream: &mut MySqlStream,
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> Result<Vec<u8>, Error> {
match self {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()),
AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()),
// https://mariadb.com/kb/en/sha256_password-plugin/
AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await,
}
}
pub(super) async fn handle(
self,
stream: &mut MySqlStream,
packet: Packet<Bytes>,
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> Result<bool, Error> {
match self {
AuthPlugin::CachingSha2Password if packet[0] == 0x01 => {
match packet[1] {
// AUTH_OK
0x03 => Ok(true),
// AUTH_CONTINUE
0x04 => {
let payload = encrypt_rsa(stream, 0x02, password, nonce).await?;
stream.write_packet(&*payload);
stream.flush().await?;
Ok(false)
}
v => {
Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (AUTH_OK) or 0x04 (AUTH_CONTINUE)", v))
}
}
}
_ => Err(err_protocol!(
"unexpected packet 0x{:02x} for auth plugin '{}' during authentication",
packet[0],
self.name()
)),
}
}
}
fn scramble_sha1(
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> GenericArray<u8, <Sha1 as FixedOutput>::OutputSize> {
// SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
let mut ctx = Sha1::new();
ctx.input(password);
let mut pw_hash = ctx.result_reset();
ctx.input(&pw_hash);
let pw_hash_hash = ctx.result_reset();
ctx.input(nonce.first_ref());
ctx.input(nonce.last_ref());
ctx.input(pw_hash_hash);
let pw_seed_hash_hash = ctx.result();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash
}
fn scramble_sha256(
password: &str,
nonce: &Chain<Bytes, Bytes>,
) -> GenericArray<u8, <Sha256 as FixedOutput>::OutputSize> {
// XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password))))
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password
let mut ctx = Sha256::new();
ctx.input(password);
let mut pw_hash = ctx.result_reset();
ctx.input(&pw_hash);
let pw_hash_hash = ctx.result_reset();
ctx.input(nonce.first_ref());
ctx.input(nonce.last_ref());
ctx.input(pw_hash_hash);
let pw_seed_hash_hash = ctx.result();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash
}
async fn encrypt_rsa<'s>(
stream: &'s mut MySqlStream,
public_key_request_id: u8,
password: &'s str,
nonce: &'s Chain<Bytes, Bytes>,
) -> Result<Vec<u8>, Error> {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
if stream.is_tls() {
// If in a TLS stream, send the password directly in clear text
return Ok(to_asciz(password));
}
// client sends a public key request
stream.write_packet(&[public_key_request_id][..]);
stream.flush().await?;
// server sends a public key response
let packet = stream.recv_packet().await?;
let rsa_pub_key = &packet[1..];
// xor the password with the given nonce
let mut pass = to_asciz(password);
let (a, b) = (nonce.first_ref(), nonce.last_ref());
let mut nonce = Vec::with_capacity(a.len() + b.len());
nonce.extend_from_slice(&*a);
nonce.extend_from_slice(&*b);
xor_eq(&mut pass, &*nonce);
// client sends an RSA encrypted password
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
}
// XOR(x, y)
// If len(y) < len(x), wrap around inside y
fn xor_eq(x: &mut [u8], y: &[u8]) {
let y_len = y.len();
for i in 0..x.len() {
x[i] ^= y[i % y_len];
}
}
fn to_asciz(s: &str) -> Vec<u8> {
let mut z = String::with_capacity(s.len() + 1);
z.push_str(s);
z.push('\0');
z.into_bytes()
}

View File

@@ -0,0 +1,101 @@
use bytes::Bytes;
use hashbrown::HashMap;
use crate::error::Error;
use crate::mysql::connection::{tls, MySqlStream, COLLATE_UTF8MB4_UNICODE_CI, MAX_PACKET_SIZE};
use crate::mysql::protocol::connect::{
AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse,
};
use crate::mysql::protocol::Capabilities;
use crate::mysql::{MySqlConnectOptions, MySqlConnection};
use bytes::buf::BufExt;
impl MySqlConnection {
pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result<Self, Error> {
let mut stream: MySqlStream = MySqlStream::connect(options).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
// https://mariadb.com/kb/en/connection/
let handshake: Handshake = stream.recv_packet().await?.decode()?;
let mut plugin = handshake.auth_plugin;
let mut nonce = handshake.auth_plugin_data;
stream.capabilities &= handshake.server_capabilities;
stream.capabilities |= Capabilities::PROTOCOL_41;
// Upgrade to TLS if we were asked to and the server supports it
tls::maybe_upgrade(&mut stream, options).await?;
let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) {
Some(plugin.scramble(&mut stream, password, &nonce).await?)
} else {
None
};
stream.write_packet(HandshakeResponse {
char_set: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: &options.username,
database: options.database.as_deref(),
auth_plugin: plugin,
auth_response: auth_response.as_deref(),
});
stream.flush().await?;
loop {
let packet = stream.recv_packet().await?;
match packet[0] {
0x00 => {
let _ok = packet.ok()?;
break;
}
0xfe => {
let switch: AuthSwitchRequest = packet.decode()?;
plugin = Some(switch.plugin);
nonce = switch.data.chain(Bytes::new());
let response = switch
.plugin
.scramble(
&mut stream,
options.password.as_deref().unwrap_or_default(),
&nonce,
)
.await?;
stream.write_packet(AuthSwitchResponse(response));
stream.flush().await?;
}
id => {
if let (Some(plugin), Some(password)) = (plugin, &options.password) {
if plugin.handle(&mut stream, packet, password, &nonce).await? {
// plugin signaled authentication is ok
break;
}
// plugin signaled to continue authentication
} else {
return Err(err_protocol!(
"unexpected packet 0x{:02x} during authentication",
id
));
}
}
}
}
Ok(Self {
stream,
cache_statement: HashMap::new(),
scratch_row_columns: Default::default(),
scratch_row_column_names: Default::default(),
})
}
}

View File

@@ -0,0 +1,285 @@
use std::sync::Arc;
use async_stream::try_stream;
use bytes::Bytes;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt};
use crate::describe::{Column, Describe};
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::ext::ustr::UStr;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::statement::{
BinaryRow, Execute as StatementExecute, Prepare, PrepareOk,
};
use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow};
use crate::mysql::protocol::Packet;
use crate::mysql::row::MySqlColumn;
use crate::mysql::{
MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo, MySqlValueFormat,
};
impl MySqlConnection {
async fn prepare(&mut self, query: &str) -> Result<u32, Error> {
if let Some(&statement) = self.cache_statement.get(query) {
return Ok(statement);
}
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
self.stream.send_packet(Prepare { query }).await?;
let ok: PrepareOk = self.stream.recv().await?;
// the parameter definitions are very unreliable so we skip over them
// as we have little use
if ok.params > 0 {
for _ in 0..ok.params {
let _def: ColumnDefinition = self.stream.recv().await?;
}
self.stream.maybe_recv_eof().await?;
}
// the column definitions are berefit the type information from the
// to-be-bound parameters; we will receive the output column definitions
// once more on execute so we wait for that
if ok.columns > 0 {
for _ in 0..(ok.columns as usize) {
let _def: ColumnDefinition = self.stream.recv().await?;
}
self.stream.maybe_recv_eof().await?;
}
self.cache_statement
.insert(query.to_owned(), ok.statement_id);
Ok(ok.statement_id)
}
async fn recv_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
let num_columns: u64 = packet.get_uint_lenenc(); // column count
// the result-set metadata is primarily a listing of each output
// column in the result-set
let column_names = Arc::make_mut(&mut self.scratch_row_column_names);
let columns = Arc::make_mut(&mut self.scratch_row_columns);
columns.clear();
column_names.clear();
for i in 0..num_columns {
let def: ColumnDefinition = self.stream.recv().await?;
let name = (match (def.name()?, def.alias()?) {
(_, alias) if !alias.is_empty() => Some(alias),
(name, _) if !name.is_empty() => Some(name),
_ => None,
})
.map(UStr::new);
if let Some(name) = &name {
column_names.insert(name.clone(), i as usize);
}
let type_info = MySqlTypeInfo::from_column(&def);
columns.push(MySqlColumn { name, type_info });
}
self.stream.maybe_recv_eof().await?;
Ok(())
}
async fn run<'c>(
&'c mut self,
query: &str,
arguments: Option<MySqlArguments>,
) -> Result<impl Stream<Item = Result<Either<u64, MySqlRow>, Error>> + 'c, Error> {
self.stream.wait_until_ready().await?;
self.stream.busy = true;
let format = if let Some(arguments) = arguments {
let statement = self.prepare(query).await?;
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
.send_packet(StatementExecute {
statement,
arguments: &arguments,
})
.await?;
MySqlValueFormat::Binary
} else {
// https://dev.mysql.com/doc/internals/en/com-query.html
self.stream.send_packet(Query(query)).await?;
MySqlValueFormat::Text
};
Ok(try_stream! {
loop {
// query response is a meta-packet which may be one of:
// Ok, Err, ResultSet, or (unhandled) LocalInfileRequest
let mut packet = self.stream.recv_packet().await?;
if packet[0] == 0x00 || packet[0] == 0xff {
// first packet in a query response is OK or ERR
// this indicates either a successful query with no rows at all or a failed query
let ok = packet.ok()?;
let v = Either::Left(ok.affected_rows);
yield v;
if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// more result sets exist, continue to the next one
continue;
}
self.stream.busy = false;
return;
}
// otherwise, this first packet is the start of the result-set metadata,
self.recv_result_metadata(packet).await?;
// finally, there will be none or many result-rows
loop {
let packet = self.stream.recv_packet().await?;
if packet[0] == 0xfe && packet.len() < 9 {
let eof = packet.eof(self.stream.capabilities)?;
let v = Either::Left(0);
yield v;
if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// more result sets exist, continue to the next one
break;
}
self.stream.busy = false;
return;
}
let row = match format {
MySqlValueFormat::Binary => packet.decode_with::<BinaryRow, _>(&self.scratch_row_columns)?.0,
MySqlValueFormat::Text => packet.decode_with::<TextRow, _>(&self.scratch_row_columns)?.0,
};
let v = Either::Right(MySqlRow {
row,
format,
columns: Arc::clone(&self.scratch_row_columns),
column_names: Arc::clone(&self.scratch_row_column_names),
});
yield v;
}
}
})
}
}
impl<'c> Executor<'c> for &'c mut MySqlConnection {
type Database = MySql;
fn fetch_many<'q: 'c, E>(
self,
mut query: E,
) -> BoxStream<'c, Result<Either<u64, MySqlRow>, Error>>
where
E: Execute<'q, Self::Database>,
{
let s = query.query();
let arguments = query.take_arguments();
Box::pin(try_stream! {
let s = self.run(s, arguments).await?;
pin_mut!(s);
while let Some(v) = s.try_next().await? {
yield v;
}
})
}
fn fetch_optional<'q: 'c, E>(self, query: E) -> BoxFuture<'c, Result<Option<MySqlRow>, Error>>
where
E: Execute<'q, Self::Database>,
{
let mut s = self.fetch_many(query);
Box::pin(async move {
while let Some(v) = s.try_next().await? {
if let Either::Right(r) = v {
return Ok(Some(r));
}
}
Ok(None)
})
}
#[doc(hidden)]
fn describe<'q: 'c, E>(self, query: E) -> BoxFuture<'c, Result<Describe<MySql>, Error>>
where
E: Execute<'q, Self::Database>,
{
let query = query.query();
Box::pin(async move {
self.stream.send_packet(Prepare { query }).await?;
let ok: PrepareOk = self.stream.recv().await?;
let mut params = Vec::with_capacity(ok.params as usize);
let mut columns = Vec::with_capacity(ok.columns as usize);
if ok.params > 0 {
for _ in 0..ok.params {
let def: ColumnDefinition = self.stream.recv().await?;
params.push(MySqlTypeInfo::from_column(&def));
}
self.stream.maybe_recv_eof().await?;
}
// the column definitions are berefit the type information from the
// to-be-bound parameters; we will receive the output column definitions
// once more on execute so we wait for that
if ok.columns > 0 {
for _ in 0..(ok.columns as usize) {
let def: ColumnDefinition = self.stream.recv().await?;
let ty = MySqlTypeInfo::from_column(&def);
columns.push(Column {
name: def.name()?.to_owned(),
type_info: ty,
not_null: Some(def.flags.contains(ColumnFlags::NOT_NULL)),
})
}
self.stream.maybe_recv_eof().await?;
}
Ok(Describe { params, columns })
})
}
}

View File

@@ -0,0 +1,116 @@
use std::fmt::{self, Debug, Formatter};
use std::net::Shutdown;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use hashbrown::HashMap;
use crate::connection::{Connect, Connection};
use crate::error::Error;
use crate::executor::Executor;
use crate::ext::ustr::UStr;
use crate::mysql::protocol::text::{Ping, Quit};
use crate::mysql::row::MySqlColumn;
use crate::mysql::{MySql, MySqlConnectOptions};
mod auth;
mod establish;
mod executor;
mod stream;
mod tls;
pub(crate) use stream::MySqlStream;
const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
const MAX_PACKET_SIZE: u32 = 1024;
/// A connection to a MySQL database.
pub struct MySqlConnection {
// underlying TCP stream,
// wrapped in a potentially TLS stream,
// wrapped in a buffered stream
stream: MySqlStream,
// cache by query string to the statement id
cache_statement: HashMap<String, u32>,
// working memory for the active row's column information
// this allows us to re-use these allocations unless the user is persisting the
// Row type past a stream iteration (clone-on-write)
scratch_row_columns: Arc<Vec<MySqlColumn>>,
scratch_row_column_names: Arc<HashMap<UStr, usize>>,
}
impl Debug for MySqlConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("MySqlConnection").finish()
}
}
impl Connection for MySqlConnection {
type Database = MySql;
fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move {
self.stream.send_packet(Quit).await?;
self.stream.shutdown(Shutdown::Both)?;
Ok(())
})
}
fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
self.stream.wait_until_ready().await?;
self.stream.send_packet(Ping).await?;
self.stream.recv_ok().await?;
Ok(())
})
}
}
impl Connect for MySqlConnection {
type Options = MySqlConnectOptions;
#[inline]
fn connect_with(options: &Self::Options) -> BoxFuture<'_, Result<Self, Error>> {
Box::pin(async move {
let mut conn = MySqlConnection::establish(options).await?;
// After the connection is established, we initialize by configuring a few
// connection parameters
// https://mariadb.com/kb/en/sql-mode/
// PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator.
// This means that "A" || "B" can be used in place of CONCAT("A", "B").
// NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is
// not available, a warning is given and the default storage
// engine is used instead.
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
// --
// Setting the time zone allows us to assume that the output
// from a TIMESTAMP field is UTC
// --
// https://mathiasbynens.be/notes/mysql-utf8mb4
conn.execute(r#"
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
SET time_zone = '+00:00';
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
"#).await?;
Ok(conn)
})
}
}

View File

@@ -0,0 +1,150 @@
use std::ops::{Deref, DerefMut};
use bytes::{Buf, Bytes};
use sqlx_rt::TcpStream;
use crate::error::Error;
use crate::io::{BufStream, Decode, Encode};
use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket};
use crate::mysql::protocol::{Capabilities, Packet};
use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
use crate::net::MaybeTlsStream;
pub struct MySqlStream {
stream: BufStream<MaybeTlsStream<TcpStream>>,
pub(super) capabilities: Capabilities,
pub(super) sequence_id: u8,
pub(crate) busy: bool,
}
impl MySqlStream {
pub(super) async fn connect(options: &MySqlConnectOptions) -> Result<Self, Error> {
let stream = TcpStream::connect((&*options.host, options.port)).await?;
let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::DEPRECATE_EOF
| Capabilities::FOUND_ROWS
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::PS_MULTI_RESULTS
| Capabilities::SSL;
if options.database.is_some() {
capabilities |= Capabilities::CONNECT_WITH_DB;
}
Ok(Self {
busy: false,
capabilities,
sequence_id: 0,
stream: BufStream::new(MaybeTlsStream::Raw(stream)),
})
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
if self.busy {
loop {
let packet = self.recv_packet().await?;
match packet[0] {
0xfe if packet.len() < 9 => {
// OK or EOF packet
self.busy = false;
break;
}
_ => {
// Something else; skip
}
}
}
}
Ok(())
}
pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
where
T: Encode<'en, Capabilities>,
{
self.sequence_id = 0;
self.write_packet(payload);
self.flush().await
}
pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
where
T: Encode<'en, Capabilities>,
{
self.stream
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
}
// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
let mut header: Bytes = self.stream.read(4).await?;
let packet_size = header.get_uint_le(3) as usize;
let sequence_id = header.get_u8();
self.sequence_id = sequence_id.wrapping_add(1);
let payload: Bytes = self.stream.read(packet_size).await?;
// TODO: packet compression
// TODO: packet joining
if payload[0] == 0xff {
self.busy = false;
// instead of letting this packet be looked at everywhere, we check here
// and emit a proper Error
return Err(
MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(),
);
}
Ok(Packet(payload))
}
pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
where
T: Decode<'de, Capabilities>,
{
self.recv_packet().await?.decode_with(self.capabilities)
}
pub(crate) async fn recv_ok(&mut self) -> Result<OkPacket, Error> {
self.recv_packet().await?.ok()
}
pub(crate) async fn maybe_recv_eof(&mut self) -> Result<Option<EofPacket>, Error> {
if self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
Ok(None)
} else {
self.recv().await.map(Some)
}
}
}
impl Deref for MySqlStream {
type Target = BufStream<MaybeTlsStream<TcpStream>>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl DerefMut for MySqlStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}

View File

@@ -0,0 +1,79 @@
use sqlx_rt::{
fs,
native_tls::{Certificate, TlsConnector},
};
use crate::error::Error;
use crate::mysql::connection::MySqlStream;
use crate::mysql::protocol::connect::SslRequest;
use crate::mysql::protocol::Capabilities;
use crate::mysql::{MySqlConnectOptions, MySqlSslMode};
pub(super) async fn maybe_upgrade(
stream: &mut MySqlStream,
options: &MySqlConnectOptions,
) -> Result<(), Error> {
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
match options.ssl_mode {
MySqlSslMode::Disabled => {}
MySqlSslMode::Preferred => {
// try upgrade, but its okay if we fail
upgrade(stream, options).await?;
}
MySqlSslMode::Required | MySqlSslMode::VerifyIdentity | MySqlSslMode::VerifyCa => {
if !upgrade(stream, options).await? {
// upgrade failed, die
return Err(Error::Tls("server does not support TLS".into()));
}
}
}
Ok(())
}
async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Result<bool, Error> {
if !stream.capabilities.contains(Capabilities::SSL) {
// server does not support TLS
return Ok(false);
}
stream.write_packet(SslRequest {
max_packet_size: super::MAX_PACKET_SIZE,
char_set: super::COLLATE_UTF8MB4_UNICODE_CI,
});
stream.flush().await?;
// FIXME: de-duplicate with postgres/connection/tls.rs
let accept_invalid_certs = !matches!(
options.ssl_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
);
let mut builder = TlsConnector::builder();
builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity));
if !accept_invalid_certs {
if let Some(ca) = &options.ssl_ca {
let data = fs::read(ca).await?;
let cert = Certificate::from_pem(&data).map_err(Error::tls)?;
builder.add_root_certificate(cert);
}
}
#[cfg(not(feature = "runtime-async-std"))]
let connector = builder.build().map_err(Error::tls)?;
#[cfg(feature = "runtime-async-std")]
let connector = builder;
stream.upgrade(&options.host, connector.into()).await?;
Ok(true)
}

View File

@@ -1,164 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use crate::connection::ConnectionSource;
use crate::cursor::Cursor;
use crate::executor::Execute;
use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Row, Status};
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};
use crate::pool::Pool;
pub struct MySqlCursor<'c, 'q> {
source: ConnectionSource<'c, MySqlConnection>,
query: Option<(&'q str, Option<MySqlArguments>)>,
column_names: Arc<HashMap<Box<str>, u16>>,
column_types: Vec<MySqlTypeInfo>,
binary: bool,
}
impl crate::cursor::private::Sealed for MySqlCursor<'_, '_> {}
impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> {
type Database = MySql;
#[doc(hidden)]
fn from_pool<E>(pool: &Pool<MySqlConnection>, query: E) -> Self
where
Self: Sized,
E: Execute<'q, MySql>,
{
Self {
source: ConnectionSource::Pool(pool.clone()),
column_names: Arc::default(),
column_types: Vec::new(),
binary: true,
query: Some(query.into_parts()),
}
}
#[doc(hidden)]
fn from_connection<E>(conn: &'c mut MySqlConnection, query: E) -> Self
where
Self: Sized,
E: Execute<'q, MySql>,
{
Self {
source: ConnectionSource::ConnectionRef(conn),
column_names: Arc::default(),
column_types: Vec::new(),
binary: true,
query: Some(query.into_parts()),
}
}
fn next(&mut self) -> BoxFuture<crate::Result<Option<MySqlRow<'_>>>> {
Box::pin(next(self))
}
}
async fn next<'a, 'c: 'a, 'q: 'a>(
cursor: &'a mut MySqlCursor<'c, 'q>,
) -> crate::Result<Option<MySqlRow<'a>>> {
let mut conn = cursor.source.resolve().await?;
// The first time [next] is called we need to actually execute our
// contained query. We guard against this happening on _all_ next calls
// by using [Option::take] which replaces the potential value in the Option with `None
let mut initial = if let Some((query, arguments)) = cursor.query.take() {
let statement = conn.run(query, arguments).await?;
// No statement ID = TEXT mode
cursor.binary = statement.is_some();
true
} else {
false
};
loop {
let packet_id = conn.stream.receive().await?[0];
match packet_id {
// OK or EOF packet
0x00 | 0xFE
if conn.stream.packet().len() < 0xFF_FF_FF && (packet_id != 0x00 || initial) =>
{
let status = if let Some(eof) = conn.stream.maybe_handle_eof()? {
eof.status
} else {
conn.stream.handle_ok()?.status
};
if status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// There is more to this query
initial = true;
} else {
conn.is_ready = true;
return Ok(None);
}
}
// ERR packet
0xFF => {
conn.is_ready = true;
return conn.stream.handle_err();
}
_ if initial => {
// At the start of the results we expect to see a
// COLUMN_COUNT followed by N COLUMN_DEF
let cc = ColumnCount::read(conn.stream.packet())?;
// We use these definitions to get the actual column types that is critical
// in parsing the rows coming back soon
cursor.column_types.clear();
cursor.column_types.reserve(cc.columns as usize);
let mut column_names = HashMap::with_capacity(cc.columns as usize);
for i in 0..cc.columns {
let column = ColumnDefinition::read(conn.stream.receive().await?)?;
cursor
.column_types
.push(MySqlTypeInfo::from_nullable_column_def(&column));
if let Some(name) = column.name() {
column_names.insert(name.to_owned().into_boxed_str(), i as u16);
}
}
if cc.columns > 0 {
conn.stream.maybe_receive_eof().await?;
}
cursor.column_names = Arc::new(column_names);
initial = false;
}
_ if !cursor.binary || packet_id == 0x00 => {
let row = Row::read(
conn.stream.packet(),
&cursor.column_types,
&mut conn.current_row_values,
cursor.binary,
)?;
let row = MySqlRow {
row,
names: Arc::clone(&cursor.column_names),
};
return Ok(Some(row));
}
_ => {
return conn.stream.handle_unexpected();
}
}
}
}

View File

@@ -1,41 +1,31 @@
use crate::cursor::HasCursor;
use crate::database::Database;
use crate::mysql::error::MySqlError;
use crate::row::HasRow;
use crate::value::HasRawValue;
use crate::database::{Database, HasArguments, HasValueRef};
use crate::mysql::value::{MySqlValue, MySqlValueRef};
use crate::mysql::{MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo};
/// **MySQL** database driver.
/// PostgreSQL database driver.
#[derive(Debug)]
pub struct MySql;
impl Database for MySql {
type Connection = super::MySqlConnection;
type Connection = MySqlConnection;
type Arguments = super::MySqlArguments;
type Row = MySqlRow;
type TypeInfo = super::MySqlTypeInfo;
type TypeInfo = MySqlTypeInfo;
type TableId = Box<str>;
type RawBuffer = Vec<u8>;
type Error = MySqlError;
type Value = MySqlValue;
}
impl<'c> HasRow<'c> for MySql {
impl<'r> HasValueRef<'r> for MySql {
type Database = MySql;
type Row = super::MySqlRow<'c>;
type ValueRef = MySqlValueRef<'r>;
}
impl<'c, 'q> HasCursor<'c, 'q> for MySql {
impl HasArguments<'_> for MySql {
type Database = MySql;
type Cursor = super::MySqlCursor<'c, 'q>;
}
type Arguments = MySqlArguments;
impl<'c> HasRawValue<'c> for MySql {
type Database = MySql;
type RawValue = super::MySqlValue<'c>;
type ArgumentBuffer = Vec<u8>;
}

View File

@@ -1,63 +1,64 @@
use std::error::Error as StdError;
use std::fmt::{self, Display};
use std::error::Error;
use std::fmt::{self, Debug, Display, Formatter};
use crate::error::DatabaseError;
use crate::mysql::protocol::ErrPacket;
use crate::mysql::protocol::response::ErrPacket;
use smallvec::alloc::borrow::Cow;
#[derive(Debug)]
pub struct MySqlError(pub(super) ErrPacket);
/// An error returned from the MySQL database.
pub struct MySqlDatabaseError(pub(super) ErrPacket);
impl Display for MySqlError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad(self.message())
}
}
impl DatabaseError for MySqlError {
fn message(&self) -> &str {
&*self.0.error_message
}
fn code(&self) -> Option<&str> {
impl MySqlDatabaseError {
/// The [SQLSTATE](https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html) code for this error.
pub fn code(&self) -> Option<&str> {
self.0.sql_state.as_deref()
}
fn as_ref_err(&self) -> &(dyn StdError + Send + Sync + 'static) {
self
/// The [number](https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html)
/// for this error.
///
/// MySQL tends to use SQLSTATE as a general error category, and the error number as a more
/// granular indication of the error.
pub fn number(&self) -> u16 {
self.0.error_code
}
fn as_mut_err(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) {
self
}
fn into_box_err(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static> {
self
/// The human-readable error message.
pub fn message(&self) -> &str {
&self.0.error_message
}
}
impl StdError for MySqlError {}
impl From<MySqlError> for crate::Error {
fn from(err: MySqlError) -> Self {
crate::Error::Database(Box::new(err))
impl Debug for MySqlDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("MySqlDatabaseError")
.field("code", &self.code())
.field("number", &self.number())
.field("message", &self.message())
.finish()
}
}
#[test]
fn test_error_downcasting() {
let error = MySqlError(ErrPacket {
error_code: 0xABCD,
sql_state: None,
error_message: "".into(),
});
let error = crate::Error::from(error);
let db_err = match error {
crate::Error::Database(db_err) => db_err,
e => panic!("expected crate::Error::Database, got {:?}", e),
};
assert_eq!(db_err.downcast_ref::<MySqlError>().0.error_code, 0xABCD);
assert_eq!(db_err.downcast::<MySqlError>().0.error_code, 0xABCD);
impl Display for MySqlDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if let Some(code) = &self.code() {
write!(f, "{} ({}): {}", self.number(), code, self.message())
} else {
write!(f, "{}: {}", self.number(), self.message())
}
}
}
impl Error for MySqlDatabaseError {}
impl DatabaseError for MySqlDatabaseError {
#[inline]
fn message(&self) -> &str {
self.message()
}
#[inline]
fn code(&self) -> Option<Cow<str>> {
self.code().map(Cow::Borrowed)
}
}

View File

@@ -1,223 +0,0 @@
use futures_core::future::BoxFuture;
use crate::cursor::Cursor;
use crate::describe::{Column, Describe};
use crate::executor::{Execute, Executor, RefExecutor};
use crate::mysql::protocol::{
self, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, FieldFlags,
Status,
};
use crate::mysql::{MySql, MySqlArguments, MySqlCursor, MySqlTypeInfo};
impl super::MySqlConnection {
// Creates a prepared statement for the passed query string
async fn prepare(&mut self, query: &str) -> crate::Result<ComStmtPrepareOk> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_prepare.html
self.stream.send(ComStmtPrepare { query }, true).await?;
// Should receive a COM_STMT_PREPARE_OK or ERR_PACKET
let packet = self.stream.receive().await?;
if packet[0] == 0xFF {
return self.stream.handle_err();
}
ComStmtPrepareOk::read(packet)
}
async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> {
for _ in 0..count {
let _column = ColumnDefinition::read(self.stream.receive().await?)?;
}
if count > 0 {
self.stream.maybe_receive_eof().await?;
}
Ok(())
}
// Gets a cached prepared statement ID _or_ prepares the statement if not in the cache
// At the end we should have [cache_statement] and [cache_statement_columns] filled
async fn get_or_prepare(&mut self, query: &str) -> crate::Result<u32> {
if let Some(&id) = self.cache_statement.get(query) {
Ok(id)
} else {
let stmt = self.prepare(query).await?;
self.cache_statement.insert(query.into(), stmt.statement_id);
// COM_STMT_PREPARE returns the input columns
// We make no use of that data, so cycle through and drop them
self.drop_column_defs(stmt.params as usize).await?;
// COM_STMT_PREPARE next returns the output columns
// We just drop these as we get these when we execute the query
self.drop_column_defs(stmt.columns as usize).await?;
Ok(stmt.statement_id)
}
}
pub(crate) async fn run(
&mut self,
query: &str,
arguments: Option<MySqlArguments>,
) -> crate::Result<Option<u32>> {
self.stream.wait_until_ready().await?;
self.stream.is_ready = false;
if let Some(arguments) = arguments {
let statement_id = self.get_or_prepare(query).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_execute.html
self.stream
.send(
ComStmtExecute {
cursor: protocol::Cursor::NO_CURSOR,
statement_id,
params: &arguments.params,
null_bitmap: &arguments.null_bitmap,
param_types: &arguments.param_types,
},
true,
)
.await?;
Ok(Some(statement_id))
} else {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_query.html
self.stream.send(ComQuery { query }, true).await?;
Ok(None)
}
}
async fn affected_rows(&mut self) -> crate::Result<u64> {
let mut rows = 0;
loop {
let id = self.stream.receive().await?[0];
match id {
0x00 | 0xFE if self.stream.packet().len() < 0xFF_FF_FF => {
// ResultSet row can begin with 0xfe byte (when using text protocol
// with a field length > 0xffffff)
let status = if let Some(eof) = self.stream.maybe_handle_eof()? {
eof.status
} else {
let ok = self.stream.handle_ok()?;
rows += ok.affected_rows;
ok.status
};
if !status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
self.is_ready = true;
break;
}
}
0xFF => {
return self.stream.handle_err();
}
_ => {}
}
}
Ok(rows)
}
// method is not named describe to work around an intellijrust bug
// otherwise it marks someone trying to describe the connection as "method is private"
async fn do_describe(&mut self, query: &str) -> crate::Result<Describe<MySql>> {
self.stream.wait_until_ready().await?;
let stmt = self.prepare(query).await?;
let mut param_types = Vec::with_capacity(stmt.params as usize);
let mut result_columns = Vec::with_capacity(stmt.columns as usize);
for _ in 0..stmt.params {
let param = ColumnDefinition::read(self.stream.receive().await?)?;
param_types.push(MySqlTypeInfo::from_column_def(&param));
}
if stmt.params > 0 {
self.stream.maybe_receive_eof().await?;
}
for _ in 0..stmt.columns {
let column = ColumnDefinition::read(self.stream.receive().await?)?;
result_columns.push(Column::<MySql> {
type_info: MySqlTypeInfo::from_column_def(&column),
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
// TODO(@abonander): Should this be None in some cases?
non_null: Some(column.flags.contains(FieldFlags::NOT_NULL)),
});
}
if stmt.columns > 0 {
self.stream.maybe_receive_eof().await?;
}
Ok(Describe {
param_types: param_types.into_boxed_slice(),
result_columns: result_columns.into_boxed_slice(),
})
}
}
impl Executor for super::MySqlConnection {
type Database = MySql;
fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>(
&'c mut self,
query: E,
) -> BoxFuture<'e, crate::Result<u64>>
where
E: Execute<'q, Self::Database>,
{
log_execution!(query, {
Box::pin(async move {
let (query, arguments) = query.into_parts();
self.run(query, arguments).await?;
self.affected_rows().await
})
})
}
fn fetch<'q, E>(&mut self, query: E) -> MySqlCursor<'_, 'q>
where
E: Execute<'q, Self::Database>,
{
log_execution!(query, { MySqlCursor::from_connection(self, query) })
}
#[doc(hidden)]
fn describe<'e, 'q, E: 'e>(
&'e mut self,
query: E,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>>
where
E: Execute<'q, Self::Database>,
{
Box::pin(async move { self.do_describe(query.into_parts().0).await })
}
}
impl<'c> RefExecutor<'c> for &'c mut super::MySqlConnection {
type Database = MySql;
fn fetch_by_ref<'q, E>(self, query: E) -> MySqlCursor<'c, 'q>
where
E: Execute<'q, Self::Database>,
{
log_execution!(query, { MySqlCursor::from_connection(self, query) })
}
}

View File

@@ -0,0 +1,40 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::BufExt;
pub trait MySqlBufExt: Buf {
// Read a length-encoded integer.
// NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL.
// NOTE: 0xff is only returned during a result set to indicate ERR.
// <https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger>
fn get_uint_lenenc(&mut self) -> u64;
// Read a length-encoded string.
fn get_str_lenenc(&mut self) -> Result<String, Error>;
// Read a length-encoded byte sequence.
fn get_bytes_lenenc(&mut self) -> Bytes;
}
impl MySqlBufExt for Bytes {
fn get_uint_lenenc(&mut self) -> u64 {
match self.get_u8() {
0xfc => u64::from(self.get_u16_le()),
0xfd => u64::from(self.get_uint_le(3)),
0xfe => u64::from(self.get_u64_le()),
v => u64::from(v),
}
}
fn get_str_lenenc(&mut self) -> Result<String, Error> {
let size = self.get_uint_lenenc();
self.get_str(size as usize)
}
fn get_bytes_lenenc(&mut self) -> Bytes {
let size = self.get_uint_lenenc();
self.split_to(size as usize)
}
}

View File

@@ -1,38 +0,0 @@
use std::io;
use byteorder::ByteOrder;
use crate::io::Buf;
pub trait BufExt {
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>>;
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&str>>;
fn get_bytes_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&[u8]>>;
}
impl BufExt for &'_ [u8] {
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>> {
Ok(match self.get_u8()? {
0xFB => None,
0xFC => Some(u64::from(self.get_u16::<T>()?)),
0xFD => Some(u64::from(self.get_u24::<T>()?)),
0xFE => Some(self.get_u64::<T>()?),
value => Some(u64::from(value)),
})
}
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&str>> {
self.get_uint_lenenc::<T>()?
.map(move |len| self.get_str(len as usize))
.transpose()
}
fn get_bytes_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&[u8]>> {
self.get_uint_lenenc::<T>()?
.map(move |len| self.get_bytes(len as usize))
.transpose()
}
}

View File

@@ -0,0 +1,134 @@
use bytes::BufMut;
pub trait MySqlBufMutExt: BufMut {
fn put_uint_lenenc(&mut self, v: u64);
fn put_str_lenenc(&mut self, v: &str);
fn put_bytes_lenenc(&mut self, v: &[u8]);
}
impl MySqlBufMutExt for Vec<u8> {
fn put_uint_lenenc(&mut self, v: u64) {
// https://dev.mysql.com/doc/internals/en/integer.html
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
if v < 251 {
self.push(v as u8);
} else if v < 0x1_00_00 {
self.push(0xfc);
self.extend(&(v as u16).to_le_bytes());
} else if v < 0x1_00_00_00 {
self.push(0xfd);
self.extend(&(v as u32).to_le_bytes()[..3]);
} else {
self.push(0xfe);
self.extend(&v.to_le_bytes());
}
}
fn put_str_lenenc(&mut self, v: &str) {
self.put_bytes_lenenc(v.as_bytes());
}
fn put_bytes_lenenc(&mut self, v: &[u8]) {
self.put_uint_lenenc(v.len() as u64);
self.extend(v);
}
}
#[test]
fn test_encodes_int_lenenc_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFA as u64);
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn test_encodes_int_lenenc_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(std::u16::MAX as u64);
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFF_FF_FF as u64);
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(std::u64::MAX);
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn test_encodes_int_lenenc_fb() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFB as u64);
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn test_encodes_int_lenenc_fc() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFC as u64);
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn test_encodes_int_lenenc_fd() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFD as u64);
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn test_encodes_int_lenenc_fe() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFE as u64);
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
#[test]
fn test_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc(0xFF as u64);
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn test_encodes_string_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_lenenc("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn test_encodes_string_null() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_nul("random_string");
assert_eq!(&buf[..], b"random_string\0");
}
#[test]
fn test_encodes_byte_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_bytes_lenenc(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}

View File

@@ -1,217 +0,0 @@
use std::{u16, u32, u64, u8};
use byteorder::ByteOrder;
use crate::io::BufMut;
pub trait BufMutExt {
fn put_uint_lenenc<T: ByteOrder, U: Into<Option<u64>>>(&mut self, val: U);
fn put_str_lenenc<T: ByteOrder>(&mut self, val: &str);
fn put_bytes_lenenc<T: ByteOrder>(&mut self, val: &[u8]);
}
impl BufMutExt for Vec<u8> {
fn put_uint_lenenc<T: ByteOrder, U: Into<Option<u64>>>(&mut self, value: U) {
if let Some(value) = value.into() {
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
if value > 0xFF_FF_FF {
// Integer value is encoded in the next 8 bytes (9 bytes total)
self.push(0xFE);
self.put_u64::<T>(value);
} else if value > u64::from(u16::MAX) {
// Integer value is encoded in the next 3 bytes (4 bytes total)
self.push(0xFD);
self.put_u24::<T>(value as u32);
} else if value > u64::from(u8::MAX) {
// Integer value is encoded in the next 2 bytes (3 bytes total)
self.push(0xFC);
self.put_u16::<T>(value as u16);
} else {
match value {
// If the value is of size u8 and one of the key bytes used in length encoding
// we must put that single byte as a u16
0xFB | 0xFC | 0xFD | 0xFE | 0xFF => {
self.push(0xFC);
self.put_u16::<T>(value as u16);
}
_ => {
self.push(value as u8);
}
}
}
} else {
self.push(0xFB);
}
}
fn put_str_lenenc<T: ByteOrder>(&mut self, val: &str) {
self.put_uint_lenenc::<T, _>(val.len() as u64);
self.extend_from_slice(val.as_bytes());
}
fn put_bytes_lenenc<T: ByteOrder>(&mut self, val: &[u8]) {
self.put_uint_lenenc::<T, _>(val.len() as u64);
self.extend_from_slice(val);
}
}
#[cfg(test)]
mod tests {
use super::{BufMut, BufMutExt};
use byteorder::LittleEndian;
#[test]
fn it_encodes_int_lenenc_none() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(None);
assert_eq!(&buf[..], b"\xFB");
}
#[test]
fn it_encodes_int_lenenc_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFA as u64);
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn it_encodes_int_lenenc_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(std::u16::MAX as u64);
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFF_FF_FF as u64);
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(std::u64::MAX);
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_fb() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFB as u64);
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn it_encodes_int_lenenc_fc() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFC as u64);
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn it_encodes_int_lenenc_fd() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFD as u64);
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn it_encodes_int_lenenc_fe() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFE as u64);
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
#[test]
fn it_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFF as u64);
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn it_encodes_int_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_u64::<LittleEndian>(std::u64::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u32() {
let mut buf = Vec::with_capacity(1024);
buf.put_u32::<LittleEndian>(std::u32::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_u24::<LittleEndian>(0xFF_FF_FF as u32);
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_u16::<LittleEndian>(std::u16::MAX);
assert_eq!(&buf[..], b"\xFF\xFF");
}
#[test]
fn it_encodes_int_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_u8(std::u8::MAX);
assert_eq!(&buf[..], b"\xFF");
}
#[test]
fn it_encodes_string_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_lenenc::<LittleEndian>("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn it_encodes_string_fix() {
let mut buf = Vec::with_capacity(1024);
buf.put_str("random_string");
assert_eq!(&buf[..], b"random_string");
}
#[test]
fn it_encodes_string_null() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_nul("random_string");
assert_eq!(&buf[..], b"random_string\0");
}
#[test]
fn it_encodes_byte_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_bytes_lenenc::<LittleEndian>(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
}

View File

@@ -1,5 +1,5 @@
mod buf_ext;
mod buf_mut_ext;
mod buf;
mod buf_mut;
pub use buf_ext::BufExt;
pub use buf_mut_ext::BufMutExt;
pub use buf::MySqlBufExt;
pub use buf_mut::MySqlBufMutExt;

View File

@@ -1,35 +1,25 @@
//! **MySQL** database and connection types.
pub use arguments::MySqlArguments;
pub use connection::MySqlConnection;
pub use cursor::MySqlCursor;
pub use database::MySql;
pub use error::MySqlError;
pub use row::MySqlRow;
pub use type_info::MySqlTypeInfo;
pub use value::{MySqlData, MySqlValue};
//! **MySQL** database driver.
mod arguments;
mod connection;
mod cursor;
mod database;
mod error;
mod executor;
mod io;
mod options;
mod protocol;
mod row;
mod rsa;
mod stream;
mod tls;
mod type_info;
pub mod types;
mod util;
mod value;
/// An alias for [`crate::pool::Pool`], specialized for **MySQL**.
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
pub type MySqlPool = crate::pool::Pool<MySqlConnection>;
pub use arguments::MySqlArguments;
pub use connection::MySqlConnection;
pub use database::MySql;
pub use error::MySqlDatabaseError;
pub use options::{MySqlConnectOptions, MySqlSslMode};
pub use row::MySqlRow;
pub use type_info::MySqlTypeInfo;
pub use value::{MySqlValue, MySqlValueFormat, MySqlValueRef};
make_query_as!(MySqlQueryAs, MySql, MySqlRow);
impl_map_row_for_row!(MySql, MySqlRow);
impl_from_row_for_tuples!(MySql, MySqlRow);
/// An alias for [`Pool`][crate::pool::Pool], specialized for MySQL.
pub type MySqlPool = crate::pool::Pool<MySqlConnection>;

View File

@@ -0,0 +1,232 @@
use std::path::{Path, PathBuf};
use std::str::FromStr;
use url::Url;
use crate::error::{BoxDynError, Error};
/// Options for controlling the desired security state of the connection to the MySQL server.
///
/// It is used by the [`ssl_mode`](MySqlConnectOptions::ssl_mode) method.
#[derive(Debug, Clone, Copy)]
pub enum MySqlSslMode {
/// Establish an unencrypted connection.
Disabled,
/// Establish an encrypted connection if the server supports encrypted connections, falling
/// back to an unencrypted connection if an encrypted connection cannot be established.
///
/// This is the default if `ssl_mode` is not specified.
Preferred,
/// Establish an encrypted connection if the server supports encrypted connections.
/// The connection attempt fails if an encrypted connection cannot be established.
Required,
/// Like `Required`, but additionally verify the server Certificate Authority (CA)
/// certificate against the configured CA certificates. The connection attempt fails
/// if no valid matching CA certificates are found.
VerifyCa,
/// Like `VerifyCa`, but additionally perform host name identity verification by
/// checking the host name the client uses for connecting to the server against the
/// identity in the certificate that the server sends to the client.
VerifyIdentity,
}
impl Default for MySqlSslMode {
fn default() -> Self {
MySqlSslMode::Preferred
}
}
impl FromStr for MySqlSslMode {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Error> {
Ok(match s {
"DISABLED" => MySqlSslMode::Disabled,
"PREFERRED" => MySqlSslMode::Preferred,
"REQUIRED" => MySqlSslMode::Required,
"VERIFY_CA" => MySqlSslMode::VerifyCa,
"VERIFY_IDENTITY" => MySqlSslMode::VerifyIdentity,
_ => {
return Err(err_protocol!("unknown SSL mode value: {:?}", s));
}
})
}
}
/// Options and flags which can be used to configure a MySQL connection.
///
/// A value of `PgConnectOptions` can be parsed from a connection URI,
/// as described by [MySQL](https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-reference-jdbc-url-format.html).
///
/// The generic format of the connection URL:
///
/// ```text
/// mysql://[host][/database][?properties]
/// ```
///
/// # Example
///
/// ```rust,no_run
/// # use sqlx_core::error::Error;
/// # use sqlx_core::connection::Connect;
/// # use sqlx_core::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
/// #
/// # #[sqlx_rt::main]
/// # async fn main() -> Result<(), Error> {
/// // URI connection string
/// let conn = MySqlConnection::connect("mysql://root:password@localhost/db").await?;
///
/// // Manually-constructed options
/// let conn = MySqlConnection::connect_with(&MySqlConnectOptions::new()
/// .host("localhost")
/// .username("root")
/// .password("password")
/// .database("db")
/// ).await?;
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone)]
pub struct MySqlConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) username: String,
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
pub(crate) ssl_mode: MySqlSslMode,
pub(crate) ssl_ca: Option<PathBuf>,
}
impl MySqlConnectOptions {
/// Creates a new, default set of options ready for configuration
pub fn new() -> Self {
Self {
port: 3306,
host: String::from("localhost"),
username: String::from("root"),
password: None,
database: None,
ssl_mode: MySqlSslMode::Preferred,
ssl_ca: None,
}
}
/// Sets the name of the host to connect to.
///
/// The default behavior when the host is not specified,
/// is to connect to localhost.
pub fn host(mut self, host: &str) -> Self {
self.host = host.to_owned();
self
}
/// Sets the port to connect to at the server host.
///
/// The default port for MySQL is `3306`.
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
/// Sets the username to connect as.
pub fn username(mut self, username: &str) -> Self {
self.username = username.to_owned();
self
}
/// Sets the password to connect with.
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.to_owned());
self
}
/// Sets the database name.
pub fn database(mut self, database: &str) -> Self {
self.database = Some(database.to_owned());
self
}
/// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated
/// with the server.
///
/// By default, the SSL mode is [`Preferred`](MySqlSslMode::Preferred), and the client will
/// first attempt an SSL connection but fallback to a non-SSL connection on failure.
///
/// # Example
///
/// ```rust
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
/// let options = MySqlConnectOptions::new()
/// .ssl_mode(MySqlSslMode::Required);
/// ```
pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self {
self.ssl_mode = mode;
self
}
/// Sets the name of a file containing a list of trusted SSL Certificate Authorities.
///
/// # Example
///
/// ```rust
/// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions};
/// let options = MySqlConnectOptions::new()
/// .ssl_mode(MySqlSslMode::VerifyCa)
/// .ssl_ca("path/to/ca.crt");
/// ```
pub fn ssl_ca(mut self, file_name: impl AsRef<Path>) -> Self {
self.ssl_ca = Some(file_name.as_ref().to_owned());
self
}
}
impl FromStr for MySqlConnectOptions {
type Err = BoxDynError;
fn from_str(s: &str) -> Result<Self, BoxDynError> {
let url: Url = s.parse()?;
let mut options = Self::new();
if let Some(host) = url.host_str() {
options = options.host(host);
}
if let Some(port) = url.port() {
options = options.port(port);
}
let username = url.username();
if !username.is_empty() {
options = options.username(username);
}
if let Some(password) = url.password() {
options = options.password(password);
}
let path = url.path().trim_start_matches('/');
if !path.is_empty() {
options = options.database(path);
}
for (key, value) in url.query_pairs().into_iter() {
match &*key {
"ssl-mode" => {
options = options.ssl_mode(value.parse()?);
}
"ssl-ca" => {
options = options.ssl_ca(&*value);
}
_ => {}
}
}
Ok(options)
}
}

View File

@@ -0,0 +1,34 @@
use std::str::FromStr;
use crate::error::Error;
#[derive(Debug, Copy, Clone)]
pub enum AuthPlugin {
MySqlNativePassword,
CachingSha2Password,
Sha256Password,
}
impl AuthPlugin {
pub(crate) fn name(self) -> &'static str {
match self {
AuthPlugin::MySqlNativePassword => "mysql_native_password",
AuthPlugin::CachingSha2Password => "caching_sha2_password",
AuthPlugin::Sha256Password => "sha256_password",
}
}
}
impl FromStr for AuthPlugin {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword),
"caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password),
"sha256_password" => Ok(AuthPlugin::Sha256Password),
_ => Err(err_protocol!("unknown authentication plugin: {}", s)),
}
}
}

View File

@@ -1,103 +0,0 @@
use digest::{Digest, FixedOutput};
use generic_array::GenericArray;
use memchr::memchr;
use sha1::Sha1;
use sha2::Sha256;
use crate::mysql::util::xor_eq;
#[derive(Debug, PartialEq)]
pub enum AuthPlugin {
MySqlNativePassword,
CachingSha2Password,
Sha256Password,
}
impl AuthPlugin {
pub(crate) fn from_opt_str(s: Option<&str>) -> crate::Result<AuthPlugin> {
match s {
Some("mysql_native_password") | None => Ok(AuthPlugin::MySqlNativePassword),
Some("caching_sha2_password") => Ok(AuthPlugin::CachingSha2Password),
Some("sha256_password") => Ok(AuthPlugin::Sha256Password),
Some(s) => {
Err(protocol_err!("requires unimplemented authentication plugin: {}", s).into())
}
}
}
pub(crate) fn as_str(&self) -> &'static str {
match self {
AuthPlugin::MySqlNativePassword => "mysql_native_password",
AuthPlugin::CachingSha2Password => "caching_sha2_password",
AuthPlugin::Sha256Password => "sha256_password",
}
}
pub(crate) fn scramble(&self, password: &str, nonce: &[u8]) -> Vec<u8> {
match self {
AuthPlugin::MySqlNativePassword => {
// The [nonce] for mysql_native_password is (optionally) nul terminated
let end = memchr(b'\0', nonce).unwrap_or(nonce.len());
scramble_sha1(password, &nonce[..end]).to_vec()
}
AuthPlugin::CachingSha2Password => scramble_sha256(password, nonce).to_vec(),
_ => unimplemented!(),
}
}
}
fn scramble_sha1(
password: &str,
seed: &[u8],
) -> GenericArray<u8, <Sha1 as FixedOutput>::OutputSize> {
// SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
let mut ctx = Sha1::new();
ctx.input(password);
let mut pw_hash = ctx.result_reset();
ctx.input(&pw_hash);
let pw_hash_hash = ctx.result_reset();
ctx.input(seed);
ctx.input(pw_hash_hash);
let pw_seed_hash_hash = ctx.result();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash
}
fn scramble_sha256(
password: &str,
seed: &[u8],
) -> GenericArray<u8, <Sha256 as FixedOutput>::OutputSize> {
// XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password))))
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password
let mut ctx = Sha256::new();
ctx.input(password);
let mut pw_hash = ctx.result_reset();
ctx.input(&pw_hash);
let pw_hash_hash = ctx.result_reset();
ctx.input(seed);
ctx.input(pw_hash_hash);
let pw_seed_hash_hash = ctx.result();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash
}

View File

@@ -1,32 +0,0 @@
use crate::io::Buf;
use crate::mysql::protocol::AuthPlugin;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
#[derive(Debug)]
pub(crate) struct AuthSwitch {
pub(crate) auth_plugin: AuthPlugin,
pub(crate) auth_plugin_data: Box<[u8]>,
}
impl AuthSwitch {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
where
Self: Sized,
{
let header = buf.get_u8()?;
if header != 0xFE {
return Err(protocol_err!(
"expected AUTH SWITCH (0xFE); received 0x{:X}",
header
))?;
}
let auth_plugin = AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?;
let auth_plugin_data = buf.get_bytes(buf.len())?.to_owned().into_boxed_slice();
Ok(Self {
auth_plugin_data,
auth_plugin,
})
}
}

View File

@@ -1,16 +0,0 @@
use byteorder::LittleEndian;
use crate::mysql::io::BufExt;
#[derive(Debug)]
pub struct ColumnCount {
pub columns: u64,
}
impl ColumnCount {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
let columns = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
Ok(Self { columns })
}
}

View File

@@ -1,83 +0,0 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::io::BufExt;
use crate::mysql::protocol::{FieldFlags, TypeId};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html
// https://mariadb.com/kb/en/resultset/#column-definition-packet
#[derive(Debug)]
pub struct ColumnDefinition {
pub schema: Option<Box<str>>,
pub table_alias: Option<Box<str>>,
pub table: Option<Box<str>>,
pub column_alias: Option<Box<str>>,
pub column: Option<Box<str>>,
pub char_set: u16,
pub max_size: u32,
pub type_id: TypeId,
pub flags: FieldFlags,
pub decimals: u8,
}
impl ColumnDefinition {
pub fn name(&self) -> Option<&str> {
self.column_alias.as_deref().or(self.column.as_deref())
}
}
impl ColumnDefinition {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
// catalog : string<lenenc>
let catalog = buf.get_str_lenenc::<LittleEndian>()?;
if catalog != Some("def") {
return Err(protocol_err!(
"expected ColumnDefinition (\"def\"); received {:?}",
catalog
))?;
}
let schema = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
let table_alias = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
let table = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
let column_alias = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
let column = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
let len_fixed_fields = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
if len_fixed_fields != 0x0c {
return Err(protocol_err!(
"expected ColumnDefinition (0x0c); received {:?}",
len_fixed_fields
))?;
}
let char_set = buf.get_u16::<LittleEndian>()?;
let max_size = buf.get_u32::<LittleEndian>()?;
let type_id = buf.get_u8()?;
let flags = buf.get_u16::<LittleEndian>()?;
let decimals = buf.get_u8()?;
Ok(Self {
schema,
table,
table_alias,
column,
column_alias,
char_set,
max_size,
type_id: TypeId(type_id),
flags: FieldFlags::from_bits_truncate(flags),
decimals,
})
}
}

View File

@@ -1,13 +0,0 @@
use crate::io::BufMut;
use crate::mysql::protocol::{Capabilities, Encode};
// https://dev.mysql.com/doc/internals/en/com-ping.html
#[derive(Debug)]
pub struct ComPing;
impl Encode for ComPing {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_PING : int<1>
buf.put_u8(0x0e);
}
}

View File

@@ -1,18 +0,0 @@
use crate::io::BufMut;
use crate::mysql::protocol::{Capabilities, Encode};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query.html
#[derive(Debug)]
pub struct ComQuery<'a> {
pub query: &'a str,
}
impl Encode for ComQuery<'_> {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_QUERY : int<1>
buf.put_u8(0x03);
// query : string<EOF>
buf.put_str(self.query);
}
}

View File

@@ -1,61 +0,0 @@
use byteorder::LittleEndian;
use crate::io::BufMut;
use crate::mysql::protocol::{Capabilities, Encode};
use crate::mysql::type_info::MySqlTypeInfo;
bitflags::bitflags! {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a3e5e9e744ff6f7b989a604fd669977da
// https://mariadb.com/kb/en/library/com_stmt_execute/#flag
pub struct Cursor: u8 {
const NO_CURSOR = 0;
const READ_ONLY = 1;
const FOR_UPDATE = 2;
const SCROLLABLE = 4;
}
}
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html
#[derive(Debug)]
pub struct ComStmtExecute<'a> {
pub statement_id: u32,
pub cursor: Cursor,
pub params: &'a [u8],
pub null_bitmap: &'a [u8],
pub param_types: &'a [MySqlTypeInfo],
}
impl Encode for ComStmtExecute<'_> {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_STMT_EXECUTE : int<1>
buf.put_u8(0x17);
// statement_id : int<4>
buf.put_u32::<LittleEndian>(self.statement_id);
// cursor : int<1>
buf.put_u8(self.cursor.bits());
// iterations (always 1) : int<4>
buf.put_u32::<LittleEndian>(1);
if !self.param_types.is_empty() {
// null bitmap : byte<(param_count + 7)/8>
buf.put_bytes(self.null_bitmap);
// send type to server (0 / 1) : byte<1>
buf.put_u8(1);
for ty in self.param_types {
// field type : byte<1>
buf.put_u8(ty.id.0);
// parameter flag : byte<1>
buf.put_u8(if ty.is_unsigned { 0x80 } else { 0 });
}
// byte<n> binary parameter value
buf.put_bytes(self.params);
}
}
}

View File

@@ -1,18 +0,0 @@
use crate::io::BufMut;
use crate::mysql::protocol::{Capabilities, Encode};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html
#[derive(Debug)]
pub struct ComStmtPrepare<'a> {
pub query: &'a str,
}
impl Encode for ComStmtPrepare<'_> {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_STMT_PREPARE : int<1>
buf.put_u8(0x16);
// query : string<EOF>
buf.put_str(self.query);
}
}

View File

@@ -1,48 +0,0 @@
use byteorder::LittleEndian;
use crate::io::Buf;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response_ok
#[derive(Debug)]
pub(crate) struct ComStmtPrepareOk {
pub(crate) statement_id: u32,
/// Number of columns in the returned result set (or 0 if statement
/// does not return result set).
pub(crate) columns: u16,
/// Number of prepared statement parameters ('?' placeholders).
pub(crate) params: u16,
/// Number of warnings.
pub(crate) warnings: u16,
}
impl ComStmtPrepareOk {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self> {
let header = buf.get_u8()?;
if header != 0x00 {
return Err(protocol_err!(
"expected COM_STMT_PREPARE_OK (0x00); received 0x{:X}",
header
))?;
}
let statement_id = buf.get_u32::<LittleEndian>()?;
let columns = buf.get_u16::<LittleEndian>()?;
let params = buf.get_u16::<LittleEndian>()?;
// -not used- : string<1>
buf.advance(1);
let warnings = buf.get_u16::<LittleEndian>()?;
Ok(Self {
statement_id,
columns,
params,
warnings,
})
}
}

View File

@@ -0,0 +1,41 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Encode;
use crate::io::{BufExt, Decode};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
#[derive(Debug)]
pub struct AuthSwitchRequest {
pub plugin: AuthPlugin,
pub data: Bytes,
}
impl Decode<'_> for AuthSwitchRequest {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
return Err(err_protocol!(
"expected 0xfe (AUTH_SWITCH) but found 0x{:x}",
header
));
}
let plugin = buf.get_str_nul()?.parse()?;
let data = buf.get_bytes(buf.len());
Ok(Self { plugin, data })
}
}
#[derive(Debug)]
pub struct AuthSwitchResponse(pub Vec<u8>);
impl Encode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.extend_from_slice(&self.0);
}
}

View File

@@ -0,0 +1,194 @@
use bytes::buf::ext::Chain;
use bytes::buf::BufExt as _;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
#[derive(Debug)]
pub(crate) struct Handshake {
pub(crate) protocol_version: u8,
pub(crate) server_version: String,
pub(crate) connection_id: u32,
pub(crate) server_capabilities: Capabilities,
pub(crate) server_default_collation: u8,
pub(crate) status: Status,
pub(crate) auth_plugin: Option<AuthPlugin>,
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
}
impl Decode<'_> for Handshake {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let protocol_version = buf.get_u8(); // int<1>
let server_version = buf.get_str_nul()?; // string<NUL>
let connection_id = buf.get_u32_le(); // int<4>
let auth_plugin_data_1 = buf.get_bytes(8); // string<8>
buf.advance(1); // reserved: string<1>
let capabilities_1 = buf.get_u16_le(); // int<2>
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
let collation = buf.get_u8(); // int<1>
let status = Status::from_bits_truncate(buf.get_u16_le());
let capabilities_2 = buf.get_u16_le(); // int<2>
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
buf.get_u8()
} else {
buf.advance(1); // int<1>
0
};
buf.advance(6); // reserved: string<6>
if capabilities.contains(Capabilities::MYSQL) {
buf.advance(4); // reserved: string<4>
} else {
let capabilities_3 = buf.get_u32_le(); // int<4>
capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
}
let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize;
let v = buf.get_bytes(len);
buf.advance(1); // NUL-terminator
v
} else {
Bytes::new()
};
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
Some(buf.get_str_nul()?.parse()?)
} else {
None
};
Ok(Self {
protocol_version,
server_version,
connection_id,
server_default_collation: collation,
status,
server_capabilities: capabilities,
auth_plugin,
auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2),
})
}
}
#[test]
fn test_decode_handshake_mysql_8_0_18() {
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap();
assert_eq!(p.protocol_version, 10);
p.server_capabilities.toggle(
Capabilities::MYSQL
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::SSL
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
| Capabilities::SSL_VERIFY_SERVER_CERT
| Capabilities::OPTIONAL_RESULTSET_METADATA
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 255);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(
p.auth_plugin,
Some(AuthPlugin::CachingSha2Password)
));
assert_eq!(
&*p.auth_plugin_data.to_bytes(),
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
);
}
#[test]
fn test_decode_handshake_mariadb_10_4_7() {
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
let mut p = Handshake::decode(HANDSHAKE_MARIA_DB_10_4_7.into()).unwrap();
assert_eq!(p.protocol_version, 10);
assert_eq!(
&*p.server_version,
"5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
);
p.server_capabilities.toggle(
Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 8);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(
p.auth_plugin,
Some(AuthPlugin::MySqlNativePassword)
));
assert_eq!(
&*p.auth_plugin_data.to_bytes(),
&[116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53, 110,]
);
}

View File

@@ -0,0 +1,73 @@
use crate::io::{BufMutExt, Encode};
use crate::mysql::io::MySqlBufMutExt;
use crate::mysql::protocol::auth::AuthPlugin;
use crate::mysql::protocol::connect::ssl_request::SslRequest;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
// https://mariadb.com/kb/en/connection/#client-handshake-response
#[derive(Debug)]
pub struct HandshakeResponse<'a> {
pub database: Option<&'a str>,
/// Max size of a command packet that the client wants to send to the server
pub max_packet_size: u32,
/// Default character set for the connection
pub char_set: u8,
/// Name of the SQL account which client wants to log in
pub username: &'a str,
/// Authentication method used by the client
pub auth_plugin: Option<AuthPlugin>,
/// Opaque authentication response
pub auth_response: Option<&'a [u8]>,
}
impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, mut capabilities: Capabilities) {
if self.auth_plugin.is_none() {
// ensure PLUGIN_AUTH is set *only* if we have a defined plugin
capabilities.remove(Capabilities::PLUGIN_AUTH);
}
// NOTE: Half of this packet is identical to the SSL Request packet
SslRequest {
max_packet_size: self.max_packet_size,
char_set: self.char_set,
}
.encode_with(buf, capabilities);
buf.put_str_nul(self.username);
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
buf.put_bytes_lenenc(self.auth_response.unwrap_or_default());
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let response = self.auth_response.unwrap_or_default();
buf.push(response.len() as u8);
buf.extend(response);
} else {
buf.push(0);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if let Some(database) = &self.database {
buf.put_str_nul(database);
} else {
buf.push(0);
}
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
if let Some(plugin) = &self.auth_plugin {
buf.put_str_nul(plugin.name());
} else {
buf.push(0);
}
}
}
}

View File

@@ -0,0 +1,13 @@
//! Connection Phase
//!
//! <https://dev.mysql.com/doc/internals/en/connection-phase.html>
mod auth_switch;
mod handshake;
mod handshake_response;
mod ssl_request;
pub(crate) use auth_switch::{AuthSwitchRequest, AuthSwitchResponse};
pub(crate) use handshake::Handshake;
pub(crate) use handshake_response::HandshakeResponse;
pub(crate) use ssl_request::SslRequest;

View File

@@ -0,0 +1,30 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
#[derive(Debug)]
pub struct SslRequest {
pub max_packet_size: u32,
pub char_set: u8,
}
impl Encode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
buf.extend(&(capabilities.bits() as u32).to_le_bytes());
buf.extend(&self.max_packet_size.to_le_bytes());
buf.push(self.char_set);
// reserved: string<19>
buf.extend(&[0_u8; 19]);
if capabilities.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.extend(&[0_u8; 4]);
} else {
// extended client capabilities (MariaDB-specified): int<4>
buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes());
}
}
}

View File

@@ -1,35 +0,0 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::protocol::Status;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_eof_packet.html
// https://mariadb.com/kb/en/eof_packet/
#[derive(Debug)]
pub struct EofPacket {
pub warnings: u16,
pub status: Status,
}
impl EofPacket {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
where
Self: Sized,
{
let header = buf.get_u8()?;
if header != 0xFE {
return Err(protocol_err!(
"expected EOF (0xFE); received 0x{:X}",
header
))?;
}
let warnings = buf.get_u16::<LittleEndian>()?;
let status = buf.get_u16::<LittleEndian>()?;
Ok(Self {
warnings,
status: Status::from_bits_truncate(status),
})
}
}

View File

@@ -1,75 +0,0 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
// https://mariadb.com/kb/en/err_packet/
#[derive(Debug)]
pub struct ErrPacket {
pub error_code: u16,
pub sql_state: Option<Box<str>>,
pub error_message: Box<str>,
}
impl ErrPacket {
pub(crate) fn read(mut buf: &[u8], capabilities: Capabilities) -> crate::Result<Self>
where
Self: Sized,
{
let header = buf.get_u8()?;
if header != 0xFF {
return Err(protocol_err!(
"expected 0xFF for ERR_PACKET; received 0x{:X}",
header
))?;
}
let error_code = buf.get_u16::<LittleEndian>()?;
let mut sql_state = None;
if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) {
buf.advance(1);
sql_state = Some(buf.get_str(5)?.into())
}
}
let error_message = buf.get_str(buf.len())?.into();
Ok(Self {
error_code,
sql_state,
error_message,
})
}
}
#[cfg(test)]
mod tests {
use super::{Capabilities, ErrPacket};
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
#[test]
fn it_decodes_packets_out_of_order() {
let p = ErrPacket::read(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap();
assert_eq!(&*p.error_message, "Got packets out of order");
assert_eq!(p.error_code, 1156);
assert_eq!(p.sql_state, None);
}
#[test]
fn it_decodes_ok_handshake() {
let p = ErrPacket::read(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap();
assert_eq!(p.error_code, 1049);
assert_eq!(p.sql_state.as_deref(), Some("42000"));
assert_eq!(&*p.error_message, "Unknown database \'unknown\'");
}
}

View File

@@ -1,50 +0,0 @@
// https://mariadb.com/kb/en/library/resultset/#field-detail-flag
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
bitflags::bitflags! {
pub struct FieldFlags: u16 {
/// Field cannot be NULL
const NOT_NULL = 1;
/// Field is **part of** a primary key
const PRIMARY_KEY = 2;
/// Field is **part of** a unique key/constraint
const UNIQUE_KEY = 4;
/// Field is **part of** a unique or primary key
const MULTIPLE_KEY = 8;
/// Field is a blob.
const BLOB = 16;
/// Field is unsigned
const UNSIGNED = 32;
/// Field is zero filled.
const ZEROFILL = 64;
/// Field is binary (set for strings)
const BINARY = 128;
/// Field is an enumeration
const ENUM = 256;
/// Field is an auto-increment field
const AUTO_INCREMENT = 512;
/// Field is a timestamp
const TIMESTAMP = 1024;
/// Field is a set
const SET = 2048;
/// Field does not have a default value
const NO_DEFAULT_VALUE = 4096;
/// Field is set to NOW on UPDATE
const ON_UPDATE_NOW = 8192;
/// Field is a number
const NUM = 32768;
}
}

View File

@@ -1,208 +0,0 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::protocol::{AuthPlugin, Capabilities, Status};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_v10.html
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
#[derive(Debug)]
pub(crate) struct Handshake {
pub(crate) protocol_version: u8,
pub(crate) server_version: Box<str>,
pub(crate) connection_id: u32,
pub(crate) server_capabilities: Capabilities,
pub(crate) server_default_collation: u8,
pub(crate) status: Status,
pub(crate) auth_plugin: AuthPlugin,
pub(crate) auth_plugin_data: Box<[u8]>,
}
impl Handshake {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
where
Self: Sized,
{
let protocol_version = buf.get_u8()?;
let server_version = buf.get_str_nul()?.into();
let connection_id = buf.get_u32::<LittleEndian>()?;
let mut scramble = Vec::with_capacity(8);
// scramble first part : string<8>
scramble.extend_from_slice(&buf[..8]);
buf.advance(8);
// reserved : string<1>
buf.advance(1);
// capability_flags_1 : int<2>
let capabilities_1 = buf.get_u16::<LittleEndian>()?;
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
// character_set : int<1>
let char_set = buf.get_u8()?;
// status_flags : int<2>
let status = buf.get_u16::<LittleEndian>()?;
let status = Status::from_bits_truncate(status);
// capability_flags_2 : int<2>
let capabilities_2 = buf.get_u16::<LittleEndian>()?;
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// plugin data length : int<1>
buf.get_u8()?
} else {
// 0x00 : int<1>
buf.advance(1);
0
};
// reserved: string<6>
buf.advance(6);
if capabilities.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.advance(4);
} else {
// capability_flags_3 : int<4>
let capabilities_3 = buf.get_u32::<LittleEndian>()?;
capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32);
}
if capabilities.contains(Capabilities::SECURE_CONNECTION) {
// scramble 2nd part : string<n> ( Length = max(12, plugin data length - 9) )
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize;
scramble.extend_from_slice(&buf[..len]);
buf.advance(len);
// reserved : string<1>
buf.advance(1);
}
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?
} else {
AuthPlugin::from_opt_str(None)?
};
Ok(Self {
protocol_version,
server_capabilities: capabilities,
server_version,
server_default_collation: char_set,
connection_id,
auth_plugin_data: scramble.into_boxed_slice(),
auth_plugin,
status,
})
}
}
#[cfg(test)]
mod tests {
use super::{AuthPlugin, Capabilities, Handshake, Status};
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
#[test]
fn it_reads_handshake_mysql_8_0_18() {
let mut p = Handshake::read(HANDSHAKE_MYSQL_8_0_18).unwrap();
assert_eq!(p.protocol_version, 10);
p.server_capabilities.toggle(
Capabilities::MYSQL
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::SSL
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
| Capabilities::SSL_VERIFY_SERVER_CERT
| Capabilities::OPTIONAL_RESULTSET_METADATA
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 255);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(p.auth_plugin, AuthPlugin::CachingSha2Password));
assert_eq!(
&*p.auth_plugin_data,
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
);
}
#[test]
fn it_reads_handshake_mariadb_10_4_7() {
let mut p = Handshake::read(HANDSHAKE_MARIA_DB_10_4_7).unwrap();
assert_eq!(p.protocol_version, 10);
assert_eq!(
&*p.server_version,
"5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic"
);
p.server_capabilities.toggle(
Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 8);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(matches!(p.auth_plugin, AuthPlugin::MySqlNativePassword));
assert_eq!(
&*p.auth_plugin_data,
&[
116, 54, 76, 92, 106, 34, 100, 83, 85, 49, 52, 79, 112, 104, 57, 34, 60, 72, 53,
110,
]
);
}
}

View File

@@ -1,72 +0,0 @@
use byteorder::LittleEndian;
use crate::io::BufMut;
use crate::mysql::io::BufMutExt;
use crate::mysql::protocol::{AuthPlugin, Capabilities, Encode};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
// https://mariadb.com/kb/en/connection/#handshake-response-packet
#[derive(Debug)]
pub struct HandshakeResponse<'a> {
pub max_packet_size: u32,
pub client_collation: u8,
pub username: &'a str,
pub database: Option<&'a str>,
pub auth_plugin: &'a AuthPlugin,
pub auth_response: &'a [u8],
}
impl Encode for HandshakeResponse<'_> {
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
// client capabilities : int<4>
buf.put_u32::<LittleEndian>(capabilities.bits() as u32);
// max packet size : int<4>
buf.put_u32::<LittleEndian>(self.max_packet_size);
// client character collation : int<1>
buf.put_u8(self.client_collation);
// reserved : string<19>
buf.advance(19);
if capabilities.contains(Capabilities::MYSQL) {
// reserved : string<4>
buf.advance(4);
} else {
// extended client capabilities : int<4>
buf.put_u32::<LittleEndian>((capabilities.bits() >> 32) as u32);
}
// username : string<NUL>
buf.put_str_nul(self.username);
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
// auth_response : string<lenenc>
buf.put_bytes_lenenc::<LittleEndian>(self.auth_response);
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let auth_response = self.auth_response;
// auth_response_length : int<1>
buf.put_u8(auth_response.len() as u8);
// auth_response : string<{auth_response_length}>
buf.put_bytes(auth_response);
} else {
// no auth : int<1>
buf.put_u8(0);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if let Some(database) = self.database {
// database : string<NUL>
buf.put_str_nul(database);
}
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// client_plugin_name : string<NUL>
buf.put_str_nul(self.auth_plugin.as_str());
}
}
}

View File

@@ -1,59 +1,13 @@
mod auth_plugin;
pub(crate) mod auth;
mod capabilities;
mod field;
mod status;
mod r#type;
pub(crate) use auth_plugin::AuthPlugin;
pub(crate) use capabilities::Capabilities;
pub(crate) use field::FieldFlags;
pub(crate) use r#type::TypeId;
pub(crate) use status::Status;
mod com_ping;
mod com_query;
mod com_stmt_execute;
mod com_stmt_prepare;
mod handshake;
pub(crate) use com_ping::ComPing;
pub(crate) use com_query::ComQuery;
pub(crate) use com_stmt_execute::{ComStmtExecute, Cursor};
pub(crate) use com_stmt_prepare::ComStmtPrepare;
pub(crate) use handshake::Handshake;
mod auth_switch;
mod column_count;
mod column_def;
mod com_stmt_prepare_ok;
mod eof;
mod err;
mod handshake_response;
mod ok;
pub(crate) mod connect;
mod packet;
pub(crate) mod response;
mod row;
#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))]
mod ssl_request;
pub(crate) mod rsa;
pub(crate) mod statement;
pub(crate) mod text;
pub(crate) use auth_switch::AuthSwitch;
pub(crate) use column_count::ColumnCount;
pub(crate) use column_def::ColumnDefinition;
pub(crate) use com_stmt_prepare_ok::ComStmtPrepareOk;
pub(crate) use eof::EofPacket;
pub(crate) use err::ErrPacket;
pub(crate) use handshake_response::HandshakeResponse;
pub(crate) use ok::OkPacket;
pub(crate) use capabilities::Capabilities;
pub(crate) use packet::Packet;
pub(crate) use row::Row;
#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))]
pub(crate) use ssl_request::SslRequest;
pub(crate) trait Encode {
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities);
}
impl Encode for &'_ [u8] {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
use crate::io::BufMut;
buf.put_bytes(self);
}
}

View File

@@ -1,64 +0,0 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::io::BufExt;
use crate::mysql::protocol::Status;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_ok_packet.html
// https://mariadb.com/kb/en/ok_packet/
#[derive(Debug)]
pub(crate) struct OkPacket {
pub(crate) affected_rows: u64,
pub(crate) last_insert_id: u64,
pub(crate) status: Status,
pub(crate) warnings: u16,
pub(crate) info: Box<str>,
}
impl OkPacket {
pub(crate) fn read(mut buf: &[u8]) -> crate::Result<Self>
where
Self: Sized,
{
let header = buf.get_u8()?;
if header != 0 && header != 0xFE {
return Err(protocol_err!(
"expected 0x00 or 0xFE; received 0x{:X}",
header
))?;
}
let affected_rows = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0); // 0
let last_insert_id = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0); // 2
let status = Status::from_bits_truncate(buf.get_u16::<LittleEndian>()?); //
let warnings = buf.get_u16::<LittleEndian>()?;
let info = buf.get_str(buf.len())?.into();
Ok(Self {
affected_rows,
last_insert_id,
status,
warnings,
info,
})
}
}
#[cfg(test)]
mod tests {
use super::{OkPacket, Status};
const OK_HANDSHAKE: &[u8] = b"\x00\x00\x00\x02@\x00\x00";
#[test]
fn it_decodes_ok_handshake() {
let p = OkPacket::read(OK_HANDSHAKE).unwrap();
assert_eq!(p.affected_rows, 0);
assert_eq!(p.last_insert_id, 0);
assert_eq!(p.warnings, 0);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
assert!(p.info.is_empty());
}
}

View File

@@ -0,0 +1,89 @@
use std::ops::{Deref, DerefMut};
use bytes::Bytes;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::mysql::protocol::response::{EofPacket, OkPacket};
use crate::mysql::protocol::Capabilities;
#[derive(Debug)]
pub struct Packet<T>(pub(crate) T);
impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
where
T: Encode<'en, Capabilities>,
{
fn encode_with(
&self,
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
// reserve space to write the prefixed length
let offset = buf.len();
buf.extend(&[0_u8; 4]);
// encode the payload
self.0.encode_with(buf, capabilities);
// determine the length of the encoded payload
// and write to our reserved space
let len = buf.len() - offset - 4;
let header = &mut buf[offset..];
// FIXME: Support larger packets
assert!(len < 0xFF_FF_FF);
header[..4].copy_from_slice(&(len as u32).to_le_bytes());
header[3] = *sequence_id;
*sequence_id = sequence_id.wrapping_add(1);
}
}
impl Packet<Bytes> {
pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
where
T: Decode<'de, ()>,
{
self.decode_with(())
}
pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
{
T::decode_with(self.0, context)
}
pub(crate) fn ok(self) -> Result<OkPacket, Error> {
self.decode()
}
pub(crate) fn eof(self, capabilities: Capabilities) -> Result<EofPacket, Error> {
if capabilities.contains(Capabilities::DEPRECATE_EOF) {
let ok = self.ok()?;
Ok(EofPacket {
warnings: ok.warnings,
status: ok.status,
})
} else {
self.decode_with(capabilities)
}
}
}
impl Deref for Packet<Bytes> {
type Target = Bytes;
fn deref(&self) -> &Bytes {
&self.0
}
}
impl DerefMut for Packet<Bytes> {
fn deref_mut(&mut self) -> &mut Bytes {
&mut self.0
}
}

View File

@@ -0,0 +1,35 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::Capabilities;
/// Marks the end of a result set, returning status and warnings.
///
/// # Note
///
/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL
/// prior MySQL versions.
#[derive(Debug)]
pub struct EofPacket {
pub warnings: u16,
pub status: Status,
}
impl Decode<'_, Capabilities> for EofPacket {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
return Err(err_protocol!(
"expected 0xfe (EOF_Packet) but found 0x{:x}",
header
));
}
let warnings = buf.get_u16_le();
let status = Status::from_bits_truncate(buf.get_u16_le());
Ok(Self { status, warnings })
}
}

View File

@@ -0,0 +1,71 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
// https://mariadb.com/kb/en/err_packet/
/// Indicates that an error occurred.
#[derive(Debug)]
pub struct ErrPacket {
pub error_code: u16,
pub sql_state: Option<String>,
pub error_message: String,
}
impl Decode<'_, Capabilities> for ErrPacket {
fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xff {
return Err(err_protocol!(
"expected 0xff (ERR_Packet) but found 0x{:x}",
header
));
}
let error_code = buf.get_u16_le();
let mut sql_state = None;
if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) {
buf.advance(1);
sql_state = Some(buf.get_str(5)?.to_owned());
}
}
let error_message = buf.get_str(buf.len())?.to_owned();
Ok(Self {
error_code,
sql_state,
error_message,
})
}
}
#[test]
fn test_decode_err_packet_out_of_order() {
const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order";
let p =
ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap();
assert_eq!(&p.error_message, "Got packets out of order");
assert_eq!(p.error_code, 1156);
assert_eq!(p.sql_state, None);
}
#[test]
fn test_decode_err_packet_unknown_database() {
const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
let p =
ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap();
assert_eq!(p.error_code, 1049);
assert_eq!(p.sql_state.as_deref(), Some("42000"));
assert_eq!(&p.error_message, "Unknown database \'unknown\'");
}

View File

@@ -0,0 +1,14 @@
//! Generic Response Packets
//!
//! <https://dev.mysql.com/doc/internals/en/generic-response-packets.html>
//! <https://mariadb.com/kb/en/4-server-response-packets/>
mod eof;
mod err;
mod ok;
mod status;
pub use eof::EofPacket;
pub use err::ErrPacket;
pub use ok::OkPacket;
pub use status::Status;

View File

@@ -0,0 +1,52 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::response::Status;
/// Indicates successful completion of a previous command sent by the client.
#[derive(Debug)]
pub struct OkPacket {
pub affected_rows: u64,
pub last_insert_id: u64,
pub status: Status,
pub warnings: u16,
}
impl Decode<'_> for OkPacket {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 && header != 0xfe {
return Err(err_protocol!(
"expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}",
header
));
}
let affected_rows = buf.get_uint_lenenc();
let last_insert_id = buf.get_uint_lenenc();
let status = Status::from_bits_truncate(buf.get_u16_le());
let warnings = buf.get_u16_le();
Ok(Self {
affected_rows,
last_insert_id,
status,
warnings,
})
}
}
#[test]
fn test_decode_ok_packet() {
const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00";
let p = OkPacket::decode(DATA.into()).unwrap();
assert_eq!(p.affected_rows, 0);
assert_eq!(p.last_insert_id, 0);
assert_eq!(p.warnings, 0);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
}

View File

@@ -1,323 +1,21 @@
use std::ops::Range;
use byteorder::{ByteOrder, LittleEndian};
use bytes::Bytes;
use crate::io::Buf;
use crate::mysql::protocol::TypeId;
use crate::mysql::MySqlTypeInfo;
pub(crate) struct Row<'c> {
buffer: &'c [u8],
values: &'c [Option<Range<usize>>],
pub(crate) columns: &'c [MySqlTypeInfo],
pub(crate) binary: bool,
#[derive(Debug)]
pub(crate) struct Row {
pub(crate) storage: Bytes,
pub(crate) values: Vec<Option<Range<usize>>>,
}
impl<'c> Row<'c> {
impl Row {
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub(crate) fn get(&self, index: usize) -> Option<&'c [u8]> {
let range = self.values[index].as_ref()?;
Some(&self.buffer[(range.start as usize)..(range.end as usize)])
pub(crate) fn get(&self, index: usize) -> Option<&[u8]> {
self.values[index]
.as_ref()
.map(|col| &self.storage[(col.start as usize)..(col.end as usize)])
}
}
fn get_lenenc(buf: &[u8]) -> (usize, Option<usize>) {
match buf[0] {
0xFB => (1, None),
0xFC => {
let len_size = 1 + 2;
let len = LittleEndian::read_u16(&buf[1..]);
(len_size, Some(len as usize))
}
0xFD => {
let len_size = 1 + 3;
let len = LittleEndian::read_u24(&buf[1..]);
(len_size, Some(len as usize))
}
0xFE => {
let len_size = 1 + 8;
let len = LittleEndian::read_u64(&buf[1..]);
(len_size, Some(len as usize))
}
len => (1, Some(len as usize)),
}
}
impl<'c> Row<'c> {
pub(crate) fn read(
mut buf: &'c [u8],
columns: &'c [MySqlTypeInfo],
values: &'c mut Vec<Option<Range<usize>>>,
binary: bool,
) -> crate::Result<Self> {
let buffer = &*buf;
values.clear();
values.reserve(columns.len());
if !binary {
let mut index = 0;
for _ in 0..columns.len() {
let (len_size, size) = get_lenenc(&buf[index..]);
if let Some(size) = size {
values.push(Some((index + len_size)..(index + len_size + size)));
} else {
values.push(None);
}
index += len_size + size.unwrap_or_default();
}
return Ok(Self {
buffer,
columns,
values: &*values,
binary: false,
});
}
// 0x00 header : byte<1>
let header = buf.get_u8()?;
if header != 0 {
return Err(protocol_err!("expected ROW (0x00), got: {:#04X}", header).into());
}
// NULL-Bitmap : byte<(number_of_columns + 9) / 8>
let null_len = (columns.len() + 9) / 8;
let null_bitmap = &buf[..];
buf.advance(null_len);
let buffer: Box<[u8]> = buf.into();
let mut index = 0;
for column_idx in 0..columns.len() {
// the null index for a column starts at the 3rd bit in the null bitmap
// for no reason at all besides mysql probably
let column_null_idx = column_idx + 2;
let is_null =
null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0;
if is_null {
values.push(None);
} else {
let (offset, size) = match columns[column_idx].id {
TypeId::TINY_INT => (0, 1),
TypeId::SMALL_INT => (0, 2),
TypeId::INT | TypeId::FLOAT => (0, 4),
TypeId::BIG_INT | TypeId::DOUBLE => (0, 8),
TypeId::DATE => (0, 5),
TypeId::TIME => (0, 1 + buffer[index] as usize),
TypeId::TIMESTAMP | TypeId::DATETIME => (0, 1 + buffer[index] as usize),
TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB
| TypeId::CHAR
| TypeId::TEXT
| TypeId::ENUM
| TypeId::VAR_CHAR => {
let (len_size, len) = get_lenenc(&buffer[index..]);
(len_size, len.unwrap_or_default())
}
TypeId::NEWDECIMAL => (0, 1 + buffer[index] as usize),
id => {
unimplemented!("encountered unknown field type id: {:?}", id);
}
};
values.push(Some((index + offset)..(index + offset + size)));
index += size + offset;
}
}
Ok(Self {
buffer: buf,
values: &*values,
columns,
binary,
})
}
}
// #[cfg(test)]
// mod test {
// use super::super::column_count::ColumnCount;
// use super::super::column_def::ColumnDefinition;
// use super::super::eof::EofPacket;
// use super::*;
//
// #[test]
// fn null_bitmap_test() -> crate::Result<()> {
// let column_len = ColumnCount::decode(&[26])?;
// assert_eq!(column_len.columns, 26);
//
// let types: Vec<TypeId> = vec![
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0,
// 0, 3, 11, 66, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105,
// 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105,
// 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105,
// 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105,
// 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105,
// 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105,
// 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105,
// 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105,
// 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102,
// 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102,
// 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102,
// 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102,
// 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102,
// 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102,
// 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102,
// 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102,
// 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102,
// 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102,
// 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102,
// 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102,
// 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102,
// 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102,
// 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102,
// 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102,
// 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0,
// ])?,
// ColumnDefinition::decode(&[
// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8,
// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102,
// 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0,
// ])?,
// ]
// .into_iter()
// .map(|def| def.type_id)
// .collect();
//
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
//
// Row::read(
// &[
// 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8,
// 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// ],
// &types,
// true,
// )?;
//
// EofPacket::decode(&[254, 0, 0, 34, 0])?;
// Ok(())
// }
// }

View File

@@ -2,14 +2,16 @@ use digest::Digest;
use num_bigint::BigUint;
use rand::{thread_rng, Rng};
use crate::error::Error;
// This is mostly taken from https://github.com/RustCrypto/RSA/pull/18
// For the love of crypto, please delete as much of this as possible and use the RSA crate
// directly when that PR is merged
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Vec<u8>> {
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> Result<Vec<u8>, Error> {
let key = std::str::from_utf8(key).map_err(|_err| {
// TODO(@abonander): protocol_err doesn't like referring to [err]
protocol_err!("unexpected error decoding what should be UTF-8")
// TODO(@abonander): err_protocol doesn't like referring to [err]
err_protocol!("unexpected error decoding what should be UTF-8")
})?;
let key = parse(key)?;
@@ -96,7 +98,7 @@ fn oaep_encrypt<R: Rng, D: Digest>(
rng: &mut R,
pub_key: &PublicKey,
msg: &[u8],
) -> crate::Result<Vec<u8>> {
) -> Result<Vec<u8>, Error> {
// size of [n] in bytes
let k = (pub_key.n.bits() + 7) / 8;
@@ -104,7 +106,7 @@ fn oaep_encrypt<R: Rng, D: Digest>(
let h_size = D::output_size();
if msg.len() > k - 2 * h_size - 2 {
return Err(protocol_err!("mysql: password too long").into());
return Err(err_protocol!("mysql: password too long"));
}
let mut em = vec![0u8; k];
@@ -140,13 +142,13 @@ struct PublicKey {
e: BigUint,
}
fn parse(key: &str) -> crate::Result<PublicKey> {
fn parse(key: &str) -> Result<PublicKey, Error> {
// This takes advantage of the knowledge that we know
// we are receiving a PKCS#8 RSA Public Key at all
// times from MySQL
if !key.starts_with("-----BEGIN PUBLIC KEY-----\n") {
return Err(protocol_err!(
return Err(err_protocol!(
"unexpected format for RSA Public Key from MySQL (expected PKCS#8); first line: {:?}",
key.splitn(1, '\n').next()
)
@@ -158,8 +160,8 @@ fn parse(key: &str) -> crate::Result<PublicKey> {
let inner_key = key_with_trailer[..trailer_pos].replace('\n', "");
let inner = base64::decode(&inner_key).map_err(|_err| {
// TODO(@abonander): protocol_err doesn't like referring to [err]
protocol_err!("unexpected error decoding what should be base64-encoded data")
// TODO(@abonander): err_protocol doesn't like referring to [err]
err_protocol!("unexpected error decoding what should be base64-encoded data")
})?;
let len = inner.len();

View File

@@ -1,34 +0,0 @@
use byteorder::LittleEndian;
use crate::io::BufMut;
use crate::mysql::protocol::{Capabilities, Encode};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
#[derive(Debug)]
pub struct SslRequest {
pub max_packet_size: u32,
pub client_collation: u8,
}
impl Encode for SslRequest {
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
// SSL must be set or else it makes no sense to ask for an upgrade
assert!(
capabilities.contains(Capabilities::SSL),
"SSL bit must be set for Capabilities"
);
// client capabilities : int<4>
buf.put_u32::<LittleEndian>(capabilities.bits() as u32);
// max packet size : int<4>
buf.put_u32::<LittleEndian>(self.max_packet_size);
// client character collation : int<1>
buf.put_u8(self.client_collation);
// reserved : string<23>
buf.advance(23);
}
}

View File

@@ -0,0 +1,38 @@
use crate::io::Encode;
use crate::mysql::protocol::text::ColumnFlags;
use crate::mysql::protocol::Capabilities;
use crate::mysql::MySqlArguments;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html
#[derive(Debug)]
pub struct Execute<'q> {
pub statement: u32,
pub arguments: &'q MySqlArguments,
}
impl<'q> Encode<'_, Capabilities> for Execute<'q> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x17); // COM_STMT_EXECUTE
buf.extend(&self.statement.to_le_bytes());
buf.push(0); // NO_CURSOR
buf.extend(&0_u32.to_le_bytes()); // iterations (always 1): int<4>
if !self.arguments.is_empty() {
buf.extend(&*self.arguments.null_bitmap);
buf.push(1); // send type to server
for ty in &self.arguments.types {
buf.push(ty.r#type as u8);
buf.push(if ty.flags.contains(ColumnFlags::UNSIGNED) {
0x80
} else {
0
});
}
buf.extend(&*self.arguments.values);
}
}
}

View File

@@ -0,0 +1,9 @@
mod execute;
mod prepare;
mod prepare_ok;
mod row;
pub(crate) use execute::Execute;
pub(crate) use prepare::Prepare;
pub(crate) use prepare_ok::PrepareOk;
pub(crate) use row::BinaryRow;

View File

@@ -0,0 +1,15 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE
pub struct Prepare<'a> {
pub query: &'a str,
}
impl Encode<'_, Capabilities> for Prepare<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x16); // COM_STMT_PREPARE
buf.extend(self.query.as_bytes());
}
}

View File

@@ -0,0 +1,42 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
#[derive(Debug)]
pub(crate) struct PrepareOk {
pub(crate) statement_id: u32,
pub(crate) columns: u16,
pub(crate) params: u16,
pub(crate) warnings: u16,
}
impl Decode<'_, Capabilities> for PrepareOk {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let status = buf.get_u8();
if status != 0x00 {
return Err(err_protocol!(
"expected 0x00 (COM_STMT_PREPARE_OK) but found 0x{:02x}",
status
));
}
let statement_id = buf.get_u32_le();
let columns = buf.get_u16_le();
let params = buf.get_u16_le();
buf.advance(1); // reserved: string<1>
let warnings = buf.get_u16_le();
Ok(Self {
statement_id,
columns,
params,
warnings,
})
}
}

View File

@@ -0,0 +1,92 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::text::ColumnType;
use crate::mysql::protocol::Row;
use crate::mysql::row::MySqlColumn;
// https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html#packet-ProtocolBinary::ResultsetRow
// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
#[derive(Debug)]
pub(crate) struct BinaryRow(pub(crate) Row);
impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow {
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 {
return Err(err_protocol!(
"exepcted 0x00 (ROW) but found 0x{:02x}",
header
));
}
let storage = buf.clone();
let offset = buf.len();
let null_bitmap_len = (columns.len() + 9) / 8;
let null_bitmap = buf.get_bytes(null_bitmap_len);
let mut values = Vec::with_capacity(columns.len());
for (column_idx, column) in columns.iter().enumerate() {
// NOTE: the column index starts at the 3rd bit
let column_null_idx = column_idx + 2;
let is_null =
null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0;
if is_null {
values.push(None);
continue;
}
// NOTE: MySQL will never generate NULL types for non-NULL values
let type_info = column.type_info.as_ref().unwrap();
let size: usize = match type_info.r#type {
ColumnType::String
| ColumnType::VarChar
| ColumnType::VarString
| ColumnType::Enum
| ColumnType::Set
| ColumnType::LongBlob
| ColumnType::MediumBlob
| ColumnType::Blob
| ColumnType::TinyBlob
| ColumnType::Geometry
| ColumnType::Bit
| ColumnType::Decimal
| ColumnType::Json
| ColumnType::NewDecimal => buf.get_uint_lenenc() as usize,
ColumnType::LongLong => 8,
ColumnType::Long | ColumnType::Int24 => 4,
ColumnType::Short | ColumnType::Year => 2,
ColumnType::Tiny => 1,
ColumnType::Float => 4,
ColumnType::Double => 8,
ColumnType::Time
| ColumnType::Timestamp
| ColumnType::Date
| ColumnType::Datetime => {
// The size of this type is important for decoding
buf[0] as usize + 1
}
// NOTE: MySQL will never generate NULL types for non-NULL values
ColumnType::Null => unreachable!(),
};
let offset = offset - buf.len();
values.push(Some(offset..(offset + size)));
buf.advance(size);
}
Ok(BinaryRow(Row { values, storage }))
}
}

View File

@@ -0,0 +1,244 @@
use std::str::from_utf8;
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
bitflags! {
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct ColumnFlags: u16 {
/// Field can't be `NULL`.
const NOT_NULL = 1;
/// Field is part of a primary key.
const PRIMARY_KEY = 2;
/// Field is part of a unique key.
const UNIQUE_KEY = 4;
/// Field is part of a multi-part unique or primary key.
const MULTIPLE_KEY = 8;
/// Field is a blob.
const BLOB = 16;
/// Field is unsigned.
const UNSIGNED = 32;
/// Field is zero filled.
const ZEROFILL = 64;
/// Field is binary.
const BINARY = 128;
/// Field is an enumeration.
const ENUM = 256;
/// Field is an auto-incement field.
const AUTO_INCREMENT = 512;
/// Field is a timestamp.
const TIMESTAMP = 1024;
/// Field is a set.
const SET = 2048;
/// Field does not have a default value.
const NO_DEFAULT_VALUE = 4096;
/// Field is set to NOW on UPDATE.
const ON_UPDATE_NOW = 8192;
/// Field is a number.
const NUM = 32768;
}
}
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[repr(u8)]
pub enum ColumnType {
Decimal = 0x00,
Tiny = 0x01,
Short = 0x02,
Long = 0x03,
Float = 0x04,
Double = 0x05,
Null = 0x06,
Timestamp = 0x07,
LongLong = 0x08,
Int24 = 0x09,
Date = 0x0a,
Time = 0x0b,
Datetime = 0x0c,
Year = 0x0d,
VarChar = 0x0f,
Bit = 0x10,
Json = 0xf5,
NewDecimal = 0xf6,
Enum = 0xf7,
Set = 0xf8,
TinyBlob = 0xf9,
MediumBlob = 0xfa,
LongBlob = 0xfb,
Blob = 0xfc,
VarString = 0xfd,
String = 0xfe,
Geometry = 0xff,
}
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html
// https://mariadb.com/kb/en/resultset/#column-definition-packet
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
#[derive(Debug)]
pub(crate) struct ColumnDefinition {
catalog: Bytes,
schema: Bytes,
table_alias: Bytes,
table: Bytes,
alias: Bytes,
name: Bytes,
pub(crate) char_set: u16,
max_size: u32,
pub(crate) r#type: ColumnType,
pub(crate) flags: ColumnFlags,
decimals: u8,
}
impl ColumnDefinition {
// NOTE: strings in-protocol are transmitted according to the client character set
// as this is UTF-8, all these strings should be UTF-8
pub(crate) fn name(&self) -> Result<&str, Error> {
from_utf8(&self.name).map_err(Error::protocol)
}
pub(crate) fn alias(&self) -> Result<&str, Error> {
from_utf8(&self.alias).map_err(Error::protocol)
}
}
impl Decode<'_, Capabilities> for ColumnDefinition {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let catalog = buf.get_bytes_lenenc();
let schema = buf.get_bytes_lenenc();
let table_alias = buf.get_bytes_lenenc();
let table = buf.get_bytes_lenenc();
let alias = buf.get_bytes_lenenc();
let name = buf.get_bytes_lenenc();
let _next_len = buf.get_uint_lenenc(); // always 0x0c
let char_set = buf.get_u16_le();
let max_size = buf.get_u32_le();
let type_id = buf.get_u8();
let flags = buf.get_u16_le();
let decimals = buf.get_u8();
Ok(Self {
catalog,
schema,
table_alias,
table,
alias,
name,
char_set,
max_size,
r#type: ColumnType::try_from_u16(type_id)?,
flags: ColumnFlags::from_bits_truncate(flags),
decimals,
})
}
}
impl ColumnType {
pub(crate) fn name(self, char_set: u16) -> &'static str {
let is_binary = char_set == 63;
match self {
ColumnType::Tiny => "TINYINT",
ColumnType::Short => "SMALLINT",
ColumnType::Long => "INT",
ColumnType::Float => "FLOAT",
ColumnType::Double => "DOUBLE",
ColumnType::Null => "NULL",
ColumnType::Timestamp => "TIMESTAMP",
ColumnType::LongLong => "BIGINT",
ColumnType::Int24 => "MEDIUMINT",
ColumnType::Date => "DATE",
ColumnType::Time => "TIME",
ColumnType::Datetime => "DATETIME",
ColumnType::Year => "YEAR",
ColumnType::Bit => "BIT",
ColumnType::Enum => "ENUM",
ColumnType::Set => "SET",
ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL",
ColumnType::Geometry => "GEOMETRY",
ColumnType::Json => "JSON",
ColumnType::String if is_binary => "BINARY",
ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY",
ColumnType::String => "CHAR",
ColumnType::VarChar | ColumnType::VarString => "VARCHAR",
ColumnType::TinyBlob if is_binary => "TINYBLOB",
ColumnType::TinyBlob => "TINYTEXT",
ColumnType::Blob if is_binary => "BLOB",
ColumnType::Blob => "TEXT",
ColumnType::MediumBlob if is_binary => "MEDIUMBLOB",
ColumnType::MediumBlob => "MEDIUMTEXT",
ColumnType::LongBlob if is_binary => "LONGBLOB",
ColumnType::LongBlob => "LONGTEXT",
}
}
pub(crate) fn try_from_u16(id: u8) -> Result<Self, Error> {
Ok(match id {
0x00 => ColumnType::Decimal,
0x01 => ColumnType::Tiny,
0x02 => ColumnType::Short,
0x03 => ColumnType::Long,
0x04 => ColumnType::Float,
0x05 => ColumnType::Double,
0x06 => ColumnType::Null,
0x07 => ColumnType::Timestamp,
0x08 => ColumnType::LongLong,
0x09 => ColumnType::Int24,
0x0a => ColumnType::Date,
0x0b => ColumnType::Time,
0x0c => ColumnType::Datetime,
0x0d => ColumnType::Year,
// [internal] 0x0e => ColumnType::NewDate,
0x0f => ColumnType::VarChar,
0x10 => ColumnType::Bit,
// [internal] 0x11 => ColumnType::Timestamp2,
// [internal] 0x12 => ColumnType::Datetime2,
// [internal] 0x13 => ColumnType::Time2,
0xf5 => ColumnType::Json,
0xf6 => ColumnType::NewDecimal,
0xf7 => ColumnType::Enum,
0xf8 => ColumnType::Set,
0xf9 => ColumnType::TinyBlob,
0xfa => ColumnType::MediumBlob,
0xfb => ColumnType::LongBlob,
0xfc => ColumnType::Blob,
0xfd => ColumnType::VarString,
0xfe => ColumnType::String,
0xff => ColumnType::Geometry,
_ => {
return Err(err_protocol!("unknown column type 0x{:02x}", id));
}
})
}
}

View File

@@ -0,0 +1,11 @@
mod column;
mod ping;
mod query;
mod quit;
mod row;
pub(crate) use column::{ColumnDefinition, ColumnFlags, ColumnType};
pub(crate) use ping::Ping;
pub(crate) use query::Query;
pub(crate) use quit::Quit;
pub(crate) use row::TextRow;

View File

@@ -0,0 +1,13 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-ping.html
#[derive(Debug)]
pub(crate) struct Ping;
impl Encode<'_, Capabilities> for Ping {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x0e); // COM_PING
}
}

View File

@@ -0,0 +1,14 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-query.html
#[derive(Debug)]
pub(crate) struct Query<'q>(pub(crate) &'q str);
impl Encode<'_, Capabilities> for Query<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x03); // COM_QUERY
buf.extend(self.0.as_bytes())
}
}

View File

@@ -0,0 +1,13 @@
use crate::io::Encode;
use crate::mysql::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-quit.html
#[derive(Debug)]
pub(crate) struct Quit;
impl Encode<'_, Capabilities> for Quit {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.push(0x01); // COM_QUIT
}
}

View File

@@ -0,0 +1,36 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::Row;
use crate::mysql::row::MySqlColumn;
#[derive(Debug)]
pub(crate) struct TextRow(pub(crate) Row);
impl<'de> Decode<'de, &'de [MySqlColumn]> for TextRow {
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
let storage = buf.clone();
let offset = buf.len();
let mut values = Vec::with_capacity(columns.len());
for _ in columns {
if buf[0] == 0xfb {
// NULL is sent as 0xfb
values.push(None);
buf.advance(1);
} else {
let size = buf.get_uint_lenenc() as usize;
let offset = offset - buf.len();
values.push(Some(offset..(offset + size)));
buf.advance(size);
}
}
Ok(TextRow(Row { values, storage }))
}
}

View File

@@ -1,48 +1 @@
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html
// https://mariadb.com/kb/en/library/resultset/#field-types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct TypeId(pub u8);
// https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429
impl TypeId {
pub const NULL: TypeId = TypeId(6);
// String: CHAR, VARCHAR, TEXT
// Bytes: BINARY, VARBINARY, BLOB
pub const CHAR: TypeId = TypeId(254); // or BINARY
pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY
pub const TEXT: TypeId = TypeId(252); // or BLOB
// Enum
pub const ENUM: TypeId = TypeId(247);
// More Bytes
pub const TINY_BLOB: TypeId = TypeId(249);
pub const MEDIUM_BLOB: TypeId = TypeId(250);
pub const LONG_BLOB: TypeId = TypeId(251);
// Numeric: TINYINT, SMALLINT, INT, BIGINT
pub const TINY_INT: TypeId = TypeId(1);
pub const SMALL_INT: TypeId = TypeId(2);
pub const INT: TypeId = TypeId(3);
pub const BIG_INT: TypeId = TypeId(8);
// pub const MEDIUM_INT: TypeId = TypeId(9);
// Numeric: FLOAT, DOUBLE
pub const FLOAT: TypeId = TypeId(4);
pub const DOUBLE: TypeId = TypeId(5);
pub const NEWDECIMAL: TypeId = TypeId(246);
// Date/Time: DATE, TIME, DATETIME, TIMESTAMP
pub const DATE: TypeId = TypeId(10);
pub const TIME: TypeId = TypeId(11);
pub const DATETIME: TypeId = TypeId(12);
pub const TIMESTAMP: TypeId = TypeId(7);
}
impl Default for TypeId {
fn default() -> TypeId {
TypeId::NULL
}
}

View File

@@ -1,59 +1,60 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::mysql::protocol;
use crate::mysql::{MySql, MySqlValue};
use hashbrown::HashMap;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::mysql::{protocol, MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
use crate::row::{ColumnIndex, Row};
pub struct MySqlRow<'c> {
pub(super) row: protocol::Row<'c>,
pub(super) names: Arc<HashMap<Box<str>, u16>>,
// TODO: Merge with the other XXColumn types
#[derive(Debug, Clone)]
pub(crate) struct MySqlColumn {
pub(crate) name: Option<UStr>,
pub(crate) type_info: Option<MySqlTypeInfo>,
}
impl crate::row::private_row::Sealed for MySqlRow<'_> {}
/// Implementation of [`Row`] for MySQL.
#[derive(Debug)]
pub struct MySqlRow {
pub(crate) row: protocol::Row,
pub(crate) columns: Arc<Vec<MySqlColumn>>,
pub(crate) column_names: Arc<HashMap<UStr, usize>>,
pub(crate) format: MySqlValueFormat,
}
impl<'c> Row<'c> for MySqlRow<'c> {
impl crate::row::private_row::Sealed for MySqlRow {}
impl Row for MySqlRow {
type Database = MySql;
#[inline]
fn len(&self) -> usize {
self.row.len()
}
#[doc(hidden)]
fn try_get_raw<I>(&self, index: I) -> crate::Result<MySqlValue<'c>>
fn try_get_raw<I>(&self, index: I) -> Result<MySqlValueRef, Error>
where
I: ColumnIndex<'c, Self>,
I: ColumnIndex<Self>,
{
let index = index.index(self)?;
let column_ty = self.row.columns[index].clone();
let buffer = self.row.get(index);
let value = match (self.row.binary, buffer) {
(_, None) => MySqlValue::null(),
(true, Some(buf)) => MySqlValue::binary(column_ty, buf),
(false, Some(buf)) => MySqlValue::text(column_ty, buf),
};
let column = &self.columns[index];
let value = self.row.get(index);
Ok(value)
Ok(MySqlValueRef {
format: self.format,
row: Some(&self.row.storage),
type_info: column.type_info.clone(),
value,
})
}
}
impl<'c> ColumnIndex<'c, MySqlRow<'c>> for usize {
fn index(&self, row: &MySqlRow<'c>) -> crate::Result<usize> {
let len = Row::len(row);
if *self >= len {
return Err(crate::Error::ColumnIndexOutOfBounds { len, index: *self });
}
Ok(*self)
}
}
impl<'c> ColumnIndex<'c, MySqlRow<'c>> for str {
fn index(&self, row: &MySqlRow<'c>) -> crate::Result<usize> {
row.names
.get(self)
.ok_or_else(|| crate::Error::ColumnNotFound((*self).into()))
.map(|&index| index as usize)
impl ColumnIndex<MySqlRow> for &'_ str {
fn index(&self, row: &MySqlRow) -> Result<usize, Error> {
row.column_names
.get(*self)
.ok_or_else(|| Error::ColumnNotFound((*self).into()))
.map(|v| *v)
}
}

View File

@@ -1,228 +0,0 @@
use std::net::Shutdown;
use byteorder::{ByteOrder, LittleEndian};
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
use crate::mysql::protocol::{Capabilities, Encode, EofPacket, ErrPacket, OkPacket};
use crate::mysql::MySqlError;
use crate::url::Url;
// Size before a packet is split
const MAX_PACKET_SIZE: u32 = 1024;
pub(crate) struct MySqlStream {
pub(super) stream: BufStream<MaybeTlsStream>,
// Is the stream ready to send commands
// Put another way, are we still expecting an EOF or OK packet to terminate
pub(super) is_ready: bool,
// Active capabilities
pub(super) capabilities: Capabilities,
// Packets in a command sequence have an incrementing sequence number
// This number must be 0 at the start of each command
pub(super) seq_no: u8,
// Packets are buffered into a second buffer from the stream
// as we may have compressed or split packets to figure out before
// decoding
packet_buf: Vec<u8>,
packet_len: usize,
}
impl MySqlStream {
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
let host = url.host().unwrap_or("localhost");
let port = url.port(3306);
let stream = MaybeTlsStream::connect(host, port).await?;
let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::DEPRECATE_EOF
| Capabilities::FOUND_ROWS
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PLUGIN_AUTH;
if url.database().is_some() {
capabilities |= Capabilities::CONNECT_WITH_DB;
}
if cfg!(feature = "tls") {
capabilities |= Capabilities::SSL;
}
Ok(Self {
capabilities,
stream: BufStream::new(stream),
packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize),
packet_len: 0,
seq_no: 0,
is_ready: true,
})
}
pub(super) fn is_tls(&self) -> bool {
self.stream.is_tls()
}
pub(super) fn shutdown(&self) -> crate::Result<()> {
Ok(self.stream.shutdown(Shutdown::Both)?)
}
#[inline]
pub(super) async fn send<T>(&mut self, packet: T, initial: bool) -> crate::Result<()>
where
T: Encode + std::fmt::Debug,
{
if initial {
self.seq_no = 0;
}
self.write(packet);
self.flush().await
}
#[inline]
pub(super) async fn flush(&mut self) -> crate::Result<()> {
Ok(self.stream.flush().await?)
}
/// Write the packet to the buffered stream ( do not send to the server )
pub(super) fn write<T>(&mut self, packet: T)
where
T: Encode,
{
let buf = self.stream.buffer_mut();
// Allocate room for the header that we write after the packet;
// so, we can get an accurate and cheap measure of packet length
let header_offset = buf.len();
buf.advance(4);
packet.encode(buf, self.capabilities);
// Determine length of encoded packet
// and write to allocated header
let len = buf.len() - header_offset - 4;
let mut header = &mut buf[header_offset..];
LittleEndian::write_u32(&mut header, len as u32);
// Take the last sequence number received, if any, and increment by 1
// If there was no sequence number, we only increment if we split packets
header[3] = self.seq_no;
self.seq_no = self.seq_no.wrapping_add(1);
}
#[inline]
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
self.read().await?;
Ok(self.packet())
}
pub(super) async fn read(&mut self) -> crate::Result<()> {
self.packet_buf.clear();
self.packet_len = 0;
// Read the packet header which contains the length and the sequence number
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
let mut header = self.stream.peek(4_usize).await?;
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
self.seq_no = header.get_u8()?.wrapping_add(1);
self.stream.consume(4);
// Read the packet body and copy it into our internal buf
// We must have a separate buffer around the stream as we can't operate directly
// on bytes returned from the stream. We have various kinds of payload manipulation
// that must be handled before decoding.
let payload = self.stream.peek(self.packet_len).await?;
self.packet_buf.reserve(payload.len());
self.packet_buf.extend_from_slice(payload);
self.stream.consume(self.packet_len);
// TODO: Implement packet compression
// TODO: Implement packet joining
Ok(())
}
/// Returns a reference to the most recently received packet data.
/// A call to `read` invalidates this buffer.
#[inline]
pub(super) fn packet(&self) -> &[u8] {
&self.packet_buf[..self.packet_len]
}
}
impl MySqlStream {
pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> {
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::read(self.receive().await?)?;
}
Ok(())
}
pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result<Option<EofPacket>> {
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) && self.packet()[0] == 0xFE {
Ok(Some(EofPacket::read(self.packet())?))
} else {
Ok(None)
}
}
pub(crate) fn handle_unexpected<T>(&mut self) -> crate::Result<T> {
Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into())
}
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
self.is_ready = true;
Err(MySqlError(ErrPacket::read(self.packet(), self.capabilities)?).into())
}
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
self.is_ready = true;
OkPacket::read(self.packet())
}
pub(crate) async fn wait_until_ready(&mut self) -> crate::Result<()> {
if !self.is_ready {
loop {
let packet_id = self.receive().await?[0];
match packet_id {
0xFE if self.packet().len() < 0xFF_FF_FF => {
// OK or EOF packet
self.is_ready = true;
break;
}
0xFF => {
// ERR packet
self.is_ready = true;
return self.handle_err();
}
_ => {
// Something else; skip
}
}
}
}
Ok(())
}
}

View File

@@ -1,123 +0,0 @@
use crate::mysql::stream::MySqlStream;
use crate::url::Url;
#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
pub(super) async fn upgrade_if_needed(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> {
#[cfg_attr(not(feature = "tls"), allow(unused_imports))]
use crate::mysql::protocol::Capabilities;
let ca_file = url.param("ssl-ca");
let ssl_mode = url.param("ssl-mode");
// https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode
match ssl_mode.as_deref() {
Some("DISABLED") => {}
#[cfg(feature = "tls")]
Some("PREFERRED") | None if !stream.capabilities.contains(Capabilities::SSL) => {}
#[cfg(feature = "tls")]
Some("PREFERRED") => {
if let Err(_error) = try_upgrade(stream, &url, None, true).await {
// TLS upgrade failed; fall back to a normal connection
}
}
#[cfg(feature = "tls")]
None => {
if let Err(_error) = try_upgrade(stream, &url, ca_file.as_deref(), true).await {
// TLS upgrade failed; fall back to a normal connection
}
}
#[cfg(feature = "tls")]
Some("REQUIRED") | Some("VERIFY_CA") | Some("VERIFY_IDENTITY")
if !stream.capabilities.contains(Capabilities::SSL) =>
{
return Err(tls_err!("server does not support TLS").into());
}
#[cfg(feature = "tls")]
Some("VERIFY_CA") | Some("VERIFY_IDENTITY") if ca_file.is_none() => {
return Err(
tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(),
);
}
#[cfg(feature = "tls")]
Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") => {
try_upgrade(
stream,
url,
// false for both verify-ca and verify-full
ca_file.as_deref(),
// false for only verify-full
mode != "VERIFY_IDENTITY",
)
.await?;
}
#[cfg(not(feature = "tls"))]
None => {
// The user neither explicitly enabled TLS in the connection string
// nor did they turn the `tls` feature on
}
#[cfg(not(feature = "tls"))]
Some(mode @ "PREFERRED")
| Some(mode @ "REQUIRED")
| Some(mode @ "VERIFY_CA")
| Some(mode @ "VERIFY_IDENTITY") => {
return Err(tls_err!(
"ssl-mode {:?} unsupported; SQLx was compiled without `tls` feature",
mode
)
.into());
}
Some(mode) => {
return Err(tls_err!("unknown `ssl-mode` value: {:?}", mode).into());
}
}
Ok(())
}
#[cfg(feature = "tls")]
async fn try_upgrade(
stream: &mut MySqlStream,
url: &Url,
ca_file: Option<&str>,
accept_invalid_hostnames: bool,
) -> crate::Result<()> {
use crate::mysql::protocol::SslRequest;
use crate::runtime::fs;
use async_native_tls::{Certificate, TlsConnector};
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(ca_file.is_none())
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
if let Some(ca_file) = ca_file {
let root_cert = fs::read(ca_file).await?;
connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?);
}
// send upgrade request and then immediately try TLS handshake
stream
.send(
SslRequest {
client_collation: super::connection::COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: super::connection::MAX_PACKET_SIZE,
},
false,
)
.await?;
stream
.stream
.upgrade(url.host().unwrap_or("localhost"), connector)
.await
}

View File

@@ -1,144 +1,67 @@
use std::fmt::{self, Display};
use std::fmt::{self, Display, Formatter};
use crate::mysql::protocol::{ColumnDefinition, FieldFlags, TypeId};
use crate::types::TypeInfo;
use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, ColumnType};
use crate::type_info::TypeInfo;
#[derive(Clone, Debug, Default)]
/// Type information for a MySql type.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct MySqlTypeInfo {
pub(crate) id: TypeId,
pub(crate) is_unsigned: bool,
pub(crate) is_binary: bool,
pub(crate) r#type: ColumnType,
pub(crate) flags: ColumnFlags,
pub(crate) char_set: u16,
}
impl MySqlTypeInfo {
pub(crate) const fn new(id: TypeId) -> Self {
pub(crate) const fn binary(ty: ColumnType) -> Self {
Self {
id,
is_unsigned: false,
is_binary: true,
char_set: 0,
r#type: ty,
flags: ColumnFlags::BINARY,
char_set: 63,
}
}
pub(crate) const fn unsigned(id: TypeId) -> Self {
Self {
id,
is_unsigned: true,
is_binary: false,
char_set: 0,
}
}
#[doc(hidden)]
pub const fn r#enum() -> Self {
Self {
id: TypeId::ENUM,
is_unsigned: false,
is_binary: false,
char_set: 0,
}
}
pub(crate) fn from_nullable_column_def(def: &ColumnDefinition) -> Self {
Self {
id: def.type_id,
is_unsigned: def.flags.contains(FieldFlags::UNSIGNED),
is_binary: def.flags.contains(FieldFlags::BINARY),
char_set: def.char_set,
}
}
pub(crate) fn from_column_def(def: &ColumnDefinition) -> Option<Self> {
if def.type_id == TypeId::NULL {
return None;
}
Some(Self::from_nullable_column_def(def))
}
#[doc(hidden)]
pub fn type_feature_gate(&self) -> Option<&'static str> {
match self.id {
TypeId::DATE | TypeId::TIME | TypeId::DATETIME | TypeId::TIMESTAMP => Some("chrono"),
_ => None,
pub(crate) fn from_column(column: &ColumnDefinition) -> Option<Self> {
if column.r#type == ColumnType::Null {
None
} else {
Some(Self {
r#type: column.r#type,
flags: column.flags,
char_set: column.char_set,
})
}
}
}
impl Display for MySqlTypeInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.id {
TypeId::NULL => f.write_str("NULL"),
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(self.r#type.name(self.char_set))?;
TypeId::TINY_INT if self.is_unsigned => f.write_str("TINYINT UNSIGNED"),
TypeId::SMALL_INT if self.is_unsigned => f.write_str("SMALLINT UNSIGNED"),
TypeId::INT if self.is_unsigned => f.write_str("INT UNSIGNED"),
TypeId::BIG_INT if self.is_unsigned => f.write_str("BIGINT UNSIGNED"),
TypeId::TINY_INT => f.write_str("TINYINT"),
TypeId::SMALL_INT => f.write_str("SMALLINT"),
TypeId::INT => f.write_str("INT"),
TypeId::BIG_INT => f.write_str("BIGINT"),
TypeId::FLOAT => f.write_str("FLOAT"),
TypeId::DOUBLE => f.write_str("DOUBLE"),
TypeId::CHAR if self.is_binary => f.write_str("BINARY"),
TypeId::VAR_CHAR if self.is_binary => f.write_str("VARBINARY"),
TypeId::TEXT if self.is_binary => f.write_str("BLOB"),
TypeId::CHAR => f.write_str("CHAR"),
TypeId::VAR_CHAR => f.write_str("VARCHAR"),
TypeId::TEXT => f.write_str("TEXT"),
TypeId::DATE => f.write_str("DATE"),
TypeId::TIME => f.write_str("TIME"),
TypeId::DATETIME => f.write_str("DATETIME"),
TypeId::TIMESTAMP => f.write_str("TIMESTAMP"),
id => write!(f, "<{:#x}>", id.0),
if self.flags.contains(ColumnFlags::UNSIGNED) {
f.write_str(" UNSIGNED")?;
}
Ok(())
}
}
impl TypeInfo for MySqlTypeInfo {}
impl PartialEq<MySqlTypeInfo> for MySqlTypeInfo {
fn eq(&self, other: &MySqlTypeInfo) -> bool {
match self.id {
TypeId::VAR_CHAR
| TypeId::TEXT
| TypeId::CHAR
| TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB
| TypeId::ENUM
if (self.is_binary == other.is_binary)
&& match other.id {
TypeId::VAR_CHAR
| TypeId::TEXT
| TypeId::CHAR
| TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB
| TypeId::ENUM => true,
_ => false,
} =>
{
return true;
}
_ => {}
}
if self.id.0 != other.id.0 {
if self.r#type != other.r#type {
return false;
}
match self.id {
TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT => {
return self.is_unsigned == other.is_unsigned;
match self.r#type {
ColumnType::Tiny
| ColumnType::Short
| ColumnType::Long
| ColumnType::Int24
| ColumnType::LongLong => {
return self.flags.contains(ColumnFlags::UNSIGNED)
== other.flags.contains(ColumnFlags::UNSIGNED);
}
_ => {}
@@ -148,103 +71,4 @@ impl PartialEq<MySqlTypeInfo> for MySqlTypeInfo {
}
}
impl TypeInfo for MySqlTypeInfo {
fn compatible(&self, other: &Self) -> bool {
// NOTE: MySQL is weakly typed so much of this may be surprising to a Rust developer.
if self.id == TypeId::NULL || other.id == TypeId::NULL {
// NULL is the "bottom" type
// If the user is trying to select into a non-Option, we catch this soon with an
// UnexpectedNull error message
return true;
}
match self.id {
// All integer types should be considered compatible
TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT
if (self.is_unsigned == other.is_unsigned)
&& match other.id {
TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT => {
true
}
_ => false,
} =>
{
true
}
// All textual types should be considered compatible
TypeId::VAR_CHAR
| TypeId::TEXT
| TypeId::CHAR
| TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB
if match other.id {
TypeId::VAR_CHAR
| TypeId::TEXT
| TypeId::CHAR
| TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB => true,
_ => false,
} =>
{
true
}
// Enums are considered compatible with other text/binary types
TypeId::ENUM
if match other.id {
TypeId::VAR_CHAR
| TypeId::TEXT
| TypeId::CHAR
| TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB
| TypeId::ENUM => true,
_ => false,
} =>
{
true
}
TypeId::VAR_CHAR
| TypeId::TEXT
| TypeId::CHAR
| TypeId::TINY_BLOB
| TypeId::MEDIUM_BLOB
| TypeId::LONG_BLOB
| TypeId::ENUM
if other.id == TypeId::ENUM =>
{
true
}
// FLOAT is compatible with DOUBLE
TypeId::FLOAT | TypeId::DOUBLE
if match other.id {
TypeId::FLOAT | TypeId::DOUBLE => true,
_ => false,
} =>
{
true
}
// DATETIME is compatible with TIMESTAMP
TypeId::DATETIME | TypeId::TIMESTAMP
if match other.id {
TypeId::DATETIME | TypeId::TIMESTAMP => true,
_ => false,
} =>
{
true
}
_ => self.eq(other),
}
}
}
impl Eq for MySqlTypeInfo {}

View File

@@ -1,92 +1,30 @@
use bigdecimal::BigDecimal;
use crate::database::{Database, HasArguments};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::Buf;
use crate::mysql::protocol::TypeId;
use crate::mysql::{MySql, MySqlData, MySqlTypeInfo, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::io::MySqlBufMutExt;
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
use crate::types::Type;
use crate::Error;
use std::str::FromStr;
impl Type<MySql> for BigDecimal {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::NEWDECIMAL)
MySqlTypeInfo::binary(ColumnType::NewDecimal)
}
}
impl Encode<MySql> for BigDecimal {
fn encode(&self, buf: &mut Vec<u8>) {
let size = Encode::<MySql>::size_hint(self) - 1;
assert!(size <= std::u8::MAX as usize, "Too large size");
buf.push(size as u8);
let s = self.to_string();
buf.extend_from_slice(s.as_bytes());
}
impl Encode<'_, MySql> for BigDecimal {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.put_str_lenenc(&self.to_string());
fn size_hint(&self) -> usize {
let s = self.to_string();
s.as_bytes().len() + 1
IsNull::No
}
}
impl Decode<'_, MySql> for BigDecimal {
fn decode(value: MySqlValue) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut binary) => {
let _len = binary.get_u8()?;
let s = std::str::from_utf8(binary).map_err(Error::decode)?;
Ok(BigDecimal::from_str(s).map_err(Error::decode)?)
}
MySqlData::Text(s) => {
let s = std::str::from_utf8(s).map_err(Error::decode)?;
Ok(BigDecimal::from_str(s).map_err(Error::decode)?)
}
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?.parse()?)
}
}
#[test]
fn test_encode_decimal() {
let v: BigDecimal = BigDecimal::from_str("-1.05").unwrap();
let mut buf: Vec<u8> = vec![];
<BigDecimal as Encode<MySql>>::encode(&v, &mut buf);
assert_eq!(buf, vec![0x05, b'-', b'1', b'.', b'0', b'5']);
let v: BigDecimal = BigDecimal::from_str("-105000").unwrap();
let mut buf: Vec<u8> = vec![];
<BigDecimal as Encode<MySql>>::encode(&v, &mut buf);
assert_eq!(buf, vec![0x07, b'-', b'1', b'0', b'5', b'0', b'0', b'0']);
let v: BigDecimal = BigDecimal::from_str("0.00105").unwrap();
let mut buf: Vec<u8> = vec![];
<BigDecimal as Encode<MySql>>::encode(&v, &mut buf);
assert_eq!(buf, vec![0x07, b'0', b'.', b'0', b'0', b'1', b'0', b'5']);
}
#[test]
fn test_decode_decimal() {
let buf: Vec<u8> = vec![0x05, b'-', b'1', b'.', b'0', b'5'];
let v = <BigDecimal as Decode<'_, MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::new(TypeId::NEWDECIMAL),
buf.as_slice(),
))
.unwrap();
assert_eq!(v.to_string(), "-1.05");
let buf: Vec<u8> = vec![0x04, b'0', b'.', b'0', b'5'];
let v = <BigDecimal as Decode<'_, MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::new(TypeId::NEWDECIMAL),
buf.as_slice(),
))
.unwrap();
assert_eq!(v.to_string(), "0.05");
let buf: Vec<u8> = vec![0x06, b'-', b'9', b'0', b'0', b'0', b'0'];
let v = <BigDecimal as Decode<'_, MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::new(TypeId::NEWDECIMAL),
buf.as_slice(),
))
.unwrap();
assert_eq!(v.to_string(), "-90000");
}

View File

@@ -1,34 +1,28 @@
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::protocol::TypeId;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
use crate::types::Type;
impl Type<MySql> for bool {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT)
// MySQL has no actual `BOOLEAN` type, the type is an alias of `TINYINT(1)`
<i8 as Type<MySql>>::type_info()
}
}
impl Encode<MySql> for bool {
fn encode(&self, buf: &mut Vec<u8>) {
buf.push(*self as u8);
impl Encode<'_, MySql> for bool {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<i8 as Encode<MySql>>::encode(*self as i8, buf)
}
}
impl<'de> Decode<'de, MySql> for bool {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()),
impl Decode<'_, MySql> for bool {
fn accepts(ty: &MySqlTypeInfo) -> bool {
<i8 as Decode<MySql>>::accepts(ty)
}
MySqlData::Text(b"0") => Ok(false),
MySqlData::Text(b"1") => Ok(true),
MySqlData::Text(s) => Err(crate::Error::Decode(
format!("unexpected value {:?} for boolean", s).into(),
)),
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(<i8 as Decode<MySql>>::decode(value)? != 0)
}
}

View File

@@ -1,21 +1,42 @@
use byteorder::LittleEndian;
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::io::BufMutExt;
use crate::mysql::protocol::TypeId;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::io::MySqlBufMutExt;
use crate::mysql::protocol::text::ColumnType;
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
use crate::types::Type;
impl Type<MySql> for [u8] {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo {
id: TypeId::TEXT,
is_binary: true,
is_unsigned: false,
char_set: 63, // binary
}
MySqlTypeInfo::binary(ColumnType::Blob)
}
}
impl Encode<'_, MySql> for &'_ [u8] {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.put_bytes_lenenc(self);
IsNull::No
}
}
impl<'r> Decode<'r, MySql> for &'r [u8] {
fn accepts(ty: &MySqlTypeInfo) -> bool {
matches!(
ty.r#type,
ColumnType::VarChar
| ColumnType::Blob
| ColumnType::TinyBlob
| ColumnType::MediumBlob
| ColumnType::LongBlob
| ColumnType::String
| ColumnType::VarString
| ColumnType::Enum
)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_bytes()
}
}
@@ -25,30 +46,18 @@ impl Type<MySql> for Vec<u8> {
}
}
impl Encode<MySql> for [u8] {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_bytes_lenenc::<LittleEndian>(self);
impl Encode<'_, MySql> for Vec<u8> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<&[u8] as Encode<MySql>>::encode(&**self, buf)
}
}
impl Encode<MySql> for Vec<u8> {
fn encode(&self, buf: &mut Vec<u8>) {
<[u8] as Encode<MySql>>::encode(self, buf);
impl Decode<'_, MySql> for Vec<u8> {
fn accepts(ty: &MySqlTypeInfo) -> bool {
<&[u8] as Decode<MySql>>::accepts(ty)
}
}
impl<'de> Decode<'de, MySql> for Vec<u8> {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(buf) | MySqlData::Text(buf) => Ok(buf.to_vec()),
}
}
}
impl<'de> Decode<'de, MySql> for &'de [u8] {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(buf) | MySqlData::Text(buf) => Ok(buf),
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

View File

@@ -1,32 +1,35 @@
use std::convert::TryFrom;
use std::str::from_utf8;
use byteorder::{ByteOrder, LittleEndian};
use bytes::Buf;
use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId;
use crate::encode::{Encode, IsNull};
use crate::error::{BoxDynError, Error};
use crate::mysql::protocol::text::ColumnType;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::mysql::{MySql, MySqlValue, MySqlValueFormat, MySqlValueRef};
use crate::types::Type;
use crate::Error;
use std::str::from_utf8;
impl Type<MySql> for DateTime<Utc> {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIMESTAMP)
MySqlTypeInfo::binary(ColumnType::Timestamp)
}
}
impl Encode<MySql> for DateTime<Utc> {
fn encode(&self, buf: &mut Vec<u8>) {
Encode::<MySql>::encode(&self.naive_utc(), buf);
impl Encode<'_, MySql> for DateTime<Utc> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
Encode::<MySql>::encode(&self.naive_utc(), buf)
}
}
impl<'de> Decode<'de, MySql> for DateTime<Utc> {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
impl<'r> Decode<'r, MySql> for DateTime<Utc> {
fn accepts(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let naive: NaiveDateTime = Decode::<MySql>::decode(value)?;
Ok(DateTime::from_utc(naive, Utc))
@@ -35,12 +38,12 @@ impl<'de> Decode<'de, MySql> for DateTime<Utc> {
impl Type<MySql> for NaiveTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIME)
MySqlTypeInfo::binary(ColumnType::Time)
}
}
impl Encode<MySql> for NaiveTime {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for NaiveTime {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let len = Encode::<MySql>::size_hint(self) - 1;
buf.push(len as u8);
@@ -49,9 +52,11 @@ impl Encode<MySql> for NaiveTime {
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
buf.advance(4);
buf.extend_from_slice(&[0_u8; 4]);
encode_time(self, len > 9, buf);
IsNull::No
}
fn size_hint(&self) -> usize {
@@ -65,27 +70,29 @@ impl Encode<MySql> for NaiveTime {
}
}
impl<'de> Decode<'de, MySql> for NaiveTime {
fn decode(buf: MySqlValue<'de>) -> crate::Result<Self> {
match buf.try_get()? {
MySqlData::Binary(mut buf) => {
impl<'r> Decode<'r, MySql> for NaiveTime {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
// data length, expecting 8 or 12 (fractional seconds)
let len = buf.get_u8()?;
let len = buf.get_u8();
// is negative : int<1>
let is_negative = buf.get_u8()?;
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
let is_negative = buf.get_u8();
debug_assert_eq!(is_negative, 0, "Negative dates/times are not supported");
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
buf.advance(4);
decode_time(len - 5, buf)
Ok(decode_time(len - 5, buf))
}
MySqlData::Text(buf) => {
let s = from_utf8(buf).map_err(Error::decode)?;
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode)
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into)
}
}
}
@@ -93,15 +100,17 @@ impl<'de> Decode<'de, MySql> for NaiveTime {
impl Type<MySql> for NaiveDate {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATE)
MySqlTypeInfo::binary(ColumnType::Date)
}
}
impl Encode<MySql> for NaiveDate {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for NaiveDate {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.push(4);
encode_date(self, buf);
IsNull::No
}
fn size_hint(&self) -> usize {
@@ -109,14 +118,14 @@ impl Encode<MySql> for NaiveDate {
}
}
impl<'de> Decode<'de, MySql> for NaiveDate {
fn decode(buf: MySqlValue<'de>) -> crate::Result<Self> {
match buf.try_get()? {
MySqlData::Binary(buf) => Ok(decode_date(&buf[1..])),
impl<'r> Decode<'r, MySql> for NaiveDate {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => Ok(decode_date(&value.as_bytes()?[1..])),
MySqlData::Text(buf) => {
let s = from_utf8(buf).map_err(Error::decode)?;
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode)
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Into::into)
}
}
}
@@ -124,12 +133,12 @@ impl<'de> Decode<'de, MySql> for NaiveDate {
impl Type<MySql> for NaiveDateTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATETIME)
MySqlTypeInfo::binary(ColumnType::Datetime)
}
}
impl Encode<MySql> for NaiveDateTime {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for NaiveDateTime {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let len = Encode::<MySql>::size_hint(self) - 1;
buf.push(len as u8);
@@ -138,6 +147,8 @@ impl Encode<MySql> for NaiveDateTime {
if len > 4 {
encode_time(&self.time(), len > 8, buf);
}
IsNull::No
}
fn size_hint(&self) -> usize {
@@ -162,15 +173,21 @@ impl Encode<MySql> for NaiveDateTime {
}
}
impl<'de> Decode<'de, MySql> for NaiveDateTime {
fn decode(buf: MySqlValue<'de>) -> crate::Result<Self> {
match buf.try_get()? {
MySqlData::Binary(buf) => {
impl<'r> Decode<'r, MySql> for NaiveDateTime {
fn accepts(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
let len = buf[0];
let date = decode_date(&buf[1..]);
let dt = if len > 4 {
date.and_time(decode_time(len - 4, &buf[5..])?)
date.and_time(decode_time(len - 4, &buf[5..]))
} else {
date.and_hms(0, 0, 0)
};
@@ -178,9 +195,9 @@ impl<'de> Decode<'de, MySql> for NaiveDateTime {
Ok(dt)
}
MySqlData::Text(buf) => {
let s = from_utf8(buf).map_err(Error::decode)?;
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Error::decode)
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Into::into)
}
}
}
@@ -196,12 +213,9 @@ fn encode_date(date: &NaiveDate, buf: &mut Vec<u8>) {
buf.push(date.day() as u8);
}
fn decode_date(buf: &[u8]) -> NaiveDate {
NaiveDate::from_ymd(
LittleEndian::read_u16(buf) as i32,
buf[2] as u32,
buf[3] as u32,
)
fn decode_date(mut buf: &[u8]) -> NaiveDate {
let year = buf.get_u16_le();
NaiveDate::from_ymd(year as i32, buf[0] as u32, buf[1] as u32)
}
fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
@@ -210,93 +224,21 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
buf.push(time.second() as u8);
if include_micros {
buf.put_u32::<LittleEndian>((time.nanosecond() / 1000) as u32);
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes());
}
}
fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result<NaiveTime> {
let hour = buf.get_u8()?;
let minute = buf.get_u8()?;
let seconds = buf.get_u8()?;
fn decode_time(len: u8, mut buf: &[u8]) -> NaiveTime {
let hour = buf.get_u8();
let minute = buf.get_u8();
let seconds = buf.get_u8();
let micros = if len > 3 {
// microseconds : int<EOF>
buf.get_uint::<LittleEndian>(buf.len())?
buf.get_uint_le(buf.len())
} else {
0
};
Ok(NaiveTime::from_hms_micro(
hour as u32,
minute as u32,
seconds as u32,
micros as u32,
))
}
#[test]
fn test_encode_date_time() {
let mut buf = Vec::new();
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
let date1: NaiveDateTime = "2010-10-17T19:27:30.000001".parse().unwrap();
Encode::<MySql>::encode(&date1, &mut buf);
assert_eq!(*buf, [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]);
buf.clear();
let date2: NaiveDateTime = "2010-10-17T19:27:30".parse().unwrap();
Encode::<MySql>::encode(&date2, &mut buf);
assert_eq!(*buf, [7, 218, 7, 10, 17, 19, 27, 30]);
buf.clear();
let date3: NaiveDateTime = "2010-10-17T00:00:00".parse().unwrap();
Encode::<MySql>::encode(&date3, &mut buf);
assert_eq!(*buf, [4, 218, 7, 10, 17]);
}
#[test]
fn test_decode_date_time() {
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
let buf = [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0];
let date1 = <NaiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::default(),
&buf,
))
.unwrap();
assert_eq!(date1.to_string(), "2010-10-17 19:27:30.000001");
let buf = [7, 218, 7, 10, 17, 19, 27, 30];
let date2 = <NaiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::default(),
&buf,
))
.unwrap();
assert_eq!(date2.to_string(), "2010-10-17 19:27:30");
let buf = [4, 218, 7, 10, 17];
let date3 = <NaiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::default(),
&buf,
))
.unwrap();
assert_eq!(date3.to_string(), "2010-10-17 00:00:00");
}
#[test]
fn test_encode_date() {
let mut buf = Vec::new();
let date: NaiveDate = "2010-10-17".parse().unwrap();
Encode::<MySql>::encode(&date, &mut buf);
assert_eq!(*buf, [4, 218, 7, 10, 17]);
}
#[test]
fn test_decode_date() {
let buf = [4, 218, 7, 10, 17];
let date =
<NaiveDate as Decode<MySql>>::decode(MySqlValue::binary(MySqlTypeInfo::default(), &buf))
.unwrap();
assert_eq!(date.to_string(), "2010-10-17");
NaiveTime::from_hms_micro(hour as u32, minute as u32, seconds as u32, micros as u32)
}

View File

@@ -1,83 +1,77 @@
use byteorder::{LittleEndian, ReadBytesExt};
use byteorder::{ByteOrder, LittleEndian};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::protocol::TypeId;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::protocol::text::ColumnType;
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
use crate::types::Type;
use crate::Error;
use std::str::from_utf8;
/// The equivalent MySQL type for `f32` is `FLOAT`.
///
/// ### Note
/// While we added support for `f32` as `FLOAT` for completeness, we don't recommend using
/// it for any real-life applications as it cannot precisely represent some fractional values,
/// and may be implicitly widened to `DOUBLE` in some cases, resulting in a slightly different
/// value:
///
/// ```rust
/// // Widening changes the equivalent decimal value, these two expressions are not equal
/// // (This is expected behavior for floating points and happens both in Rust and in MySQL)
/// assert_ne!(10.2f32 as f64, 10.2f64);
/// ```
fn real_accepts(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Float | ColumnType::Double)
}
impl Type<MySql> for f32 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::FLOAT)
MySqlTypeInfo::binary(ColumnType::Float)
}
}
impl Encode<MySql> for f32 {
fn encode(&self, buf: &mut Vec<u8>) {
<i32 as Encode<MySql>>::encode(&(self.to_bits() as i32), buf);
}
}
impl<'de> Decode<'de, MySql> for f32 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf
.read_i32::<LittleEndian>()
.map_err(crate::Error::decode)
.map(|value| f32::from_bits(value as u32)),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
}
}
/// The equivalent MySQL type for `f64` is `DOUBLE`.
///
/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values
/// exactly.
impl Type<MySql> for f64 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DOUBLE)
MySqlTypeInfo::binary(ColumnType::Double)
}
}
impl Encode<MySql> for f64 {
fn encode(&self, buf: &mut Vec<u8>) {
<i64 as Encode<MySql>>::encode(&(self.to_bits() as i64), buf);
impl Encode<'_, MySql> for f32 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl<'de> Decode<'de, MySql> for f64 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf
.read_i64::<LittleEndian>()
.map_err(crate::Error::decode)
.map(|value| f64::from_bits(value as u64)),
impl Encode<'_, MySql> for f64 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
IsNull::No
}
}
impl Decode<'_, MySql> for f32 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
real_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
if buf.len() == 8 {
// MySQL can return 8-byte DOUBLE values for a FLOAT
// We take and truncate to f32 as that's the same behavior as *in* MySQL
LittleEndian::read_f64(buf) as f32
} else {
LittleEndian::read_f32(buf)
}
}
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for f64 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
real_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_f64(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}

View File

@@ -1,111 +1,127 @@
use std::str::from_utf8;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use byteorder::{ByteOrder, LittleEndian};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::protocol::TypeId;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
use crate::types::Type;
use crate::Error;
fn int_accepts(ty: &MySqlTypeInfo) -> bool {
matches!(
ty.r#type,
ColumnType::Tiny
| ColumnType::Short
| ColumnType::Long
| ColumnType::Int24
| ColumnType::LongLong
) && !ty.flags.contains(ColumnFlags::UNSIGNED)
}
impl Type<MySql> for i8 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TINY_INT)
}
}
impl Encode<MySql> for i8 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_i8(*self);
}
}
impl<'de> Decode<'de, MySql> for i8 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_i8().map_err(Into::into),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
MySqlTypeInfo::binary(ColumnType::Tiny)
}
}
impl Type<MySql> for i16 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::SMALL_INT)
}
}
impl Encode<MySql> for i16 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_i16::<LittleEndian>(*self);
}
}
impl<'de> Decode<'de, MySql> for i16 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_i16::<LittleEndian>().map_err(Into::into),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
MySqlTypeInfo::binary(ColumnType::Short)
}
}
impl Type<MySql> for i32 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::INT)
}
}
impl Encode<MySql> for i32 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_i32::<LittleEndian>(*self);
}
}
impl<'de> Decode<'de, MySql> for i32 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_i32::<LittleEndian>().map_err(Into::into),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
MySqlTypeInfo::binary(ColumnType::Long)
}
}
impl Type<MySql> for i64 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::BIG_INT)
MySqlTypeInfo::binary(ColumnType::LongLong)
}
}
impl Encode<MySql> for i64 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_i64::<LittleEndian>(*self);
impl Encode<'_, MySql> for i8 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl<'de> Decode<'de, MySql> for i64 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_i64::<LittleEndian>().map_err(Into::into),
impl Encode<'_, MySql> for i16 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
IsNull::No
}
}
impl Encode<'_, MySql> for i32 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl Encode<'_, MySql> for i64 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl Decode<'_, MySql> for i8 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
int_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => value.as_bytes()?[0] as i8,
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for i16 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
int_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_i16(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for i32 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
int_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_i32(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for i64 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
int_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_i64(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}

View File

@@ -1,46 +1,47 @@
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::database::MySql;
use crate::mysql::protocol::TypeId;
use crate::mysql::types::*;
use crate::mysql::{MySqlTypeInfo, MySqlValue};
use crate::types::{Json, Type};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
impl Type<MySql> for JsonValue {
fn type_info() -> MySqlTypeInfo {
<Json<Self> as Type<MySql>>::type_info()
}
}
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::protocol::text::ColumnType;
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
use crate::types::{Json, Type};
impl<T> Type<MySql> for Json<T> {
fn type_info() -> MySqlTypeInfo {
// MySql uses the CHAR type to pass JSON data from and to the client
MySqlTypeInfo::new(TypeId::CHAR)
// MySql uses the `CHAR` type to pass JSON data from and to the client
// NOTE: This is forwards-compatible with MySQL v8+ as CHAR is a common transmission format
// and has nothing to do with the native storage ability of MySQL v8+
MySqlTypeInfo::binary(ColumnType::String)
}
}
impl<T> Encode<MySql> for Json<T>
impl<T> Encode<'_, MySql> for Json<T>
where
T: Serialize,
{
fn encode(&self, buf: &mut Vec<u8>) {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let json_string_value =
serde_json::to_string(&self.0).expect("serde_json failed to convert to string");
<str as Encode<MySql>>::encode(json_string_value.as_str(), buf);
<&str as Encode<MySql>>::encode(json_string_value.as_str(), buf)
}
}
impl<'de, T> Decode<'de, MySql> for Json<T>
impl<'r, T> Decode<'r, MySql> for Json<T>
where
T: 'de,
T: for<'de1> Deserialize<'de1>,
T: 'r + DeserializeOwned,
{
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
let string_value = <&'de str as Decode<MySql>>::decode(value).unwrap();
fn accepts(ty: &MySqlTypeInfo) -> bool {
ty.r#type == ColumnType::Json || <&str as Decode<MySql>>::accepts(ty)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let string_value = <&str as Decode<MySql>>::decode(value)?;
serde_json::from_str(&string_value)
.map(Json)
.map_err(crate::Error::decode)
.map_err(Into::into)
}
}

View File

@@ -4,7 +4,7 @@
//!
//! | Rust type | MySQL type(s) |
//! |---------------------------------------|------------------------------------------------------|
//! | `bool` | TINYINT(1) |
//! | `bool` | TINYINT(1), BOOLEAN |
//! | `i8` | TINYINT |
//! | `i16` | SMALLINT |
//! | `i32` | INT |
@@ -47,6 +47,7 @@
//! | Rust type | MySQL type(s) |
//! |---------------------------------------|------------------------------------------------------|
//! | `bigdecimal::BigDecimal` | DECIMAL |
//!
//! ### [`json`](https://crates.io/crates/json)
//!
//! Requires the `json` Cargo feature flag.
@@ -79,19 +80,3 @@ mod time;
#[cfg(feature = "json")]
mod json;
use crate::decode::Decode;
use crate::mysql::{MySql, MySqlValue};
impl<'de, T> Decode<'de, MySql> for Option<T>
where
T: Decode<'de, MySql>,
{
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
Ok(if value.get().is_some() {
Some(<T as Decode<MySql>>::decode(value)?)
} else {
None
})
}
}

View File

@@ -1,30 +1,46 @@
use std::str;
use byteorder::LittleEndian;
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::io::BufMutExt;
use crate::mysql::protocol::TypeId;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::io::MySqlBufMutExt;
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef};
use crate::types::Type;
use std::str::from_utf8;
impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo {
id: TypeId::TEXT,
is_binary: false,
is_unsigned: false,
char_set: 224, // utf8mb4_unicode_ci
r#type: ColumnType::Blob, // TEXT
char_set: 224, // utf8mb4_unicode_ci
flags: ColumnFlags::empty(),
}
}
}
impl Encode<MySql> for str {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_str_lenenc::<LittleEndian>(self);
impl Encode<'_, MySql> for &'_ str {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.put_str_lenenc(self);
IsNull::No
}
}
impl<'r> Decode<'r, MySql> for &'r str {
fn accepts(ty: &MySqlTypeInfo) -> bool {
matches!(
ty.r#type,
ColumnType::VarChar
| ColumnType::Blob
| ColumnType::TinyBlob
| ColumnType::MediumBlob
| ColumnType::LongBlob
| ColumnType::String
| ColumnType::VarString
| ColumnType::Enum
)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_str()
}
}
@@ -34,24 +50,18 @@ impl Type<MySql> for String {
}
}
impl Encode<MySql> for String {
fn encode(&self, buf: &mut Vec<u8>) {
<str as Encode<MySql>>::encode(self.as_str(), buf)
impl Encode<'_, MySql> for String {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<&str as Encode<MySql>>::encode(&**self, buf)
}
}
impl<'de> Decode<'de, MySql> for &'de str {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(buf) | MySqlData::Text(buf) => {
from_utf8(buf).map_err(crate::Error::decode)
}
}
impl Decode<'_, MySql> for String {
fn accepts(ty: &MySqlTypeInfo) -> bool {
<&str as Decode<MySql>>::accepts(ty)
}
}
impl<'de> Decode<'de, MySql> for String {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
<&'de str as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

View File

@@ -2,33 +2,38 @@ use std::borrow::Cow;
use std::convert::TryFrom;
use byteorder::{ByteOrder, LittleEndian};
use bytes::Buf;
use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::mysql::protocol::TypeId;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::protocol::text::ColumnType;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::mysql::{MySql, MySqlValueFormat, MySqlValueRef};
use crate::types::Type;
impl Type<MySql> for OffsetDateTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIMESTAMP)
MySqlTypeInfo::binary(ColumnType::Timestamp)
}
}
impl Encode<MySql> for OffsetDateTime {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for OffsetDateTime {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let utc_dt = self.to_offset(UtcOffset::UTC);
let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time());
Encode::<MySql>::encode(&primitive_dt, buf);
Encode::<MySql>::encode(&primitive_dt, buf)
}
}
impl<'de> Decode<'de, MySql> for OffsetDateTime {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
impl<'r> Decode<'r, MySql> for OffsetDateTime {
fn accepts(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let primitive: PrimitiveDateTime = Decode::<MySql>::decode(value)?;
Ok(primitive.assume_utc())
@@ -37,12 +42,12 @@ impl<'de> Decode<'de, MySql> for OffsetDateTime {
impl Type<MySql> for Time {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::TIME)
MySqlTypeInfo::binary(ColumnType::Time)
}
}
impl Encode<MySql> for Time {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for Time {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let len = Encode::<MySql>::size_hint(self) - 1;
buf.push(len as u8);
@@ -51,9 +56,11 @@ impl Encode<MySql> for Time {
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
buf.advance(4);
buf.extend_from_slice(&[0_u8; 4]);
encode_time(self, len > 9, buf);
IsNull::No
}
fn size_hint(&self) -> usize {
@@ -67,15 +74,17 @@ impl Encode<MySql> for Time {
}
}
impl<'de> Decode<'de, MySql> for Time {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => {
impl<'r> Decode<'r, MySql> for Time {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
// data length, expecting 8 or 12 (fractional seconds)
let len = buf.get_u8()?;
let len = buf.get_u8();
// is negative : int<1>
let is_negative = buf.get_u8()?;
let is_negative = buf.get_u8();
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
// "date on 4 bytes little-endian format" (?)
@@ -85,8 +94,8 @@ impl<'de> Decode<'de, MySql> for Time {
decode_time(len - 5, buf)
}
MySqlData::Text(buf) => {
let s = from_utf8(buf).map_err(crate::Error::decode)?;
MySqlValueFormat::Text => {
let s = value.as_str()?;
// If there are less than 9 digits after the decimal point
// We need to zero-pad
@@ -98,7 +107,7 @@ impl<'de> Decode<'de, MySql> for Time {
Cow::Borrowed(s)
};
Time::parse(&*s, "%H:%M:%S.%N").map_err(crate::Error::decode)
Time::parse(&*s, "%H:%M:%S.%N").map_err(Into::into)
}
}
}
@@ -106,15 +115,17 @@ impl<'de> Decode<'de, MySql> for Time {
impl Type<MySql> for Date {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATE)
MySqlTypeInfo::binary(ColumnType::Date)
}
}
impl Encode<MySql> for Date {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for Date {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.push(4);
encode_date(self, buf);
IsNull::No
}
fn size_hint(&self) -> usize {
@@ -122,13 +133,13 @@ impl Encode<MySql> for Date {
}
}
impl<'de> Decode<'de, MySql> for Date {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(buf) => decode_date(&buf[1..]),
MySqlData::Text(buf) => {
let s = from_utf8(buf).map_err(crate::Error::decode)?;
Date::parse(s, "%Y-%m-%d").map_err(crate::Error::decode)
impl<'r> Decode<'r, MySql> for Date {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => decode_date(&value.as_bytes()?[1..]),
MySqlValueFormat::Text => {
let s = value.as_str()?;
Date::parse(s, "%Y-%m-%d").map_err(Into::into)
}
}
}
@@ -136,12 +147,12 @@ impl<'de> Decode<'de, MySql> for Date {
impl Type<MySql> for PrimitiveDateTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::new(TypeId::DATETIME)
MySqlTypeInfo::binary(ColumnType::Datetime)
}
}
impl Encode<MySql> for PrimitiveDateTime {
fn encode(&self, buf: &mut Vec<u8>) {
impl Encode<'_, MySql> for PrimitiveDateTime {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
let len = Encode::<MySql>::size_hint(self) - 1;
buf.push(len as u8);
@@ -150,6 +161,8 @@ impl Encode<MySql> for PrimitiveDateTime {
if len > 4 {
encode_time(&self.time(), len > 8, buf);
}
IsNull::No
}
fn size_hint(&self) -> usize {
@@ -169,10 +182,15 @@ impl Encode<MySql> for PrimitiveDateTime {
}
}
impl<'de> Decode<'de, MySql> for PrimitiveDateTime {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(buf) => {
impl<'r> Decode<'r, MySql> for PrimitiveDateTime {
fn accepts(ty: &MySqlTypeInfo) -> bool {
matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp)
}
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
let len = buf[0];
let date = decode_date(&buf[1..])?;
@@ -185,8 +203,8 @@ impl<'de> Decode<'de, MySql> for PrimitiveDateTime {
Ok(dt)
}
MySqlData::Text(buf) => {
let s = from_utf8(buf).map_err(crate::Error::decode)?;
MySqlValueFormat::Text => {
let s = value.as_str()?;
// If there are less than 9 digits after the decimal point
// We need to zero-pad
@@ -202,7 +220,7 @@ impl<'de> Decode<'de, MySql> for PrimitiveDateTime {
Cow::Borrowed(s)
};
PrimitiveDateTime::parse(&*s, "%Y-%m-%d %H:%M:%S.%N").map_err(crate::Error::decode)
PrimitiveDateTime::parse(&*s, "%Y-%m-%d %H:%M:%S.%N").map_err(Into::into)
}
}
}
@@ -218,13 +236,13 @@ fn encode_date(date: &Date, buf: &mut Vec<u8>) {
buf.push(date.day());
}
fn decode_date(buf: &[u8]) -> crate::Result<Date> {
fn decode_date(buf: &[u8]) -> Result<Date, BoxDynError> {
Date::try_from_ymd(
LittleEndian::read_u16(buf) as i32,
buf[2] as u8,
buf[3] as u8,
)
.map_err(|e| decode_err!("Error while decoding Date: {}", e))
.map_err(Into::into)
}
fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec<u8>) {
@@ -233,92 +251,22 @@ fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec<u8>) {
buf.push(time.second());
if include_micros {
buf.put_u32::<LittleEndian>((time.nanosecond() / 1000) as u32);
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes());
}
}
fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result<Time> {
let hour = buf.get_u8()?;
let minute = buf.get_u8()?;
let seconds = buf.get_u8()?;
fn decode_time(len: u8, mut buf: &[u8]) -> Result<Time, BoxDynError> {
let hour = buf.get_u8();
let minute = buf.get_u8();
let seconds = buf.get_u8();
let micros = if len > 3 {
// microseconds : int<EOF>
buf.get_uint::<LittleEndian>(buf.len())?
buf.get_uint_le(buf.len())
} else {
0
};
Time::try_from_hms_micro(hour, minute, seconds, micros as u32)
.map_err(|e| decode_err!("Time out of range for MySQL: {}", e))
}
use std::str::from_utf8;
#[cfg(test)]
use time::{date, time};
#[test]
fn test_encode_date_time() {
let mut buf = Vec::new();
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
let date = PrimitiveDateTime::new(date!(2010 - 10 - 17), time!(19:27:30.000001));
Encode::<MySql>::encode(&date, &mut buf);
assert_eq!(*buf, [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]);
buf.clear();
let date = PrimitiveDateTime::new(date!(2010 - 10 - 17), time!(19:27:30));
Encode::<MySql>::encode(&date, &mut buf);
assert_eq!(*buf, [7, 218, 7, 10, 17, 19, 27, 30]);
buf.clear();
let date = PrimitiveDateTime::new(date!(2010 - 10 - 17), time!(00:00:00));
Encode::<MySql>::encode(&date, &mut buf);
assert_eq!(*buf, [4, 218, 7, 10, 17]);
}
#[test]
fn test_decode_date_time() {
// test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
let buf = [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0];
let date1 = <PrimitiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::default(),
&buf,
))
.unwrap();
assert_eq!(date1.to_string(), "2010-10-17 19:27:30.000001");
let buf = [7, 218, 7, 10, 17, 19, 27, 30];
let date2 = <PrimitiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::default(),
&buf,
))
.unwrap();
assert_eq!(date2.to_string(), "2010-10-17 19:27:30");
let buf = [4, 218, 7, 10, 17];
let date3 = <PrimitiveDateTime as Decode<MySql>>::decode(MySqlValue::binary(
MySqlTypeInfo::default(),
&buf,
))
.unwrap();
assert_eq!(date3.to_string(), "2010-10-17 0:00");
}
#[test]
fn test_encode_date() {
let mut buf = Vec::new();
let date: Date = date!(2010 - 10 - 17);
Encode::<MySql>::encode(&date, &mut buf);
assert_eq!(*buf, [4, 218, 7, 10, 17]);
}
#[test]
fn test_decode_date() {
let buf = [4, 218, 7, 10, 17];
let date = <Date as Decode<MySql>>::decode(MySqlValue::binary(MySqlTypeInfo::default(), &buf))
.unwrap();
assert_eq!(date, date!(2010 - 10 - 17));
.map_err(|e| format!("Time out of range for MySQL: {}", e).into())
}

View File

@@ -1,111 +1,135 @@
use std::str::from_utf8;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use byteorder::{ByteOrder, LittleEndian};
use crate::decode::Decode;
use crate::encode::Encode;
use crate::mysql::protocol::TypeId;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlData, MySqlValue};
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mysql::protocol::text::{ColumnFlags, ColumnType};
use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
use crate::types::Type;
use crate::Error;
fn uint_type_info(ty: ColumnType) -> MySqlTypeInfo {
MySqlTypeInfo {
r#type: ty,
flags: ColumnFlags::BINARY | ColumnFlags::UNSIGNED,
char_set: 63,
}
}
fn uint_accepts(ty: &MySqlTypeInfo) -> bool {
matches!(
ty.r#type,
ColumnType::Tiny
| ColumnType::Short
| ColumnType::Long
| ColumnType::Int24
| ColumnType::LongLong
) && ty.flags.contains(ColumnFlags::UNSIGNED)
}
impl Type<MySql> for u8 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::TINY_INT)
}
}
impl Encode<MySql> for u8 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_u8(*self);
}
}
impl<'de> Decode<'de, MySql> for u8 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_u8().map_err(Into::into),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
uint_type_info(ColumnType::Tiny)
}
}
impl Type<MySql> for u16 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::SMALL_INT)
}
}
impl Encode<MySql> for u16 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_u16::<LittleEndian>(*self);
}
}
impl<'de> Decode<'de, MySql> for u16 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_u16::<LittleEndian>().map_err(Into::into),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
uint_type_info(ColumnType::Short)
}
}
impl Type<MySql> for u32 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::INT)
}
}
impl Encode<MySql> for u32 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_u32::<LittleEndian>(*self);
}
}
impl<'de> Decode<'de, MySql> for u32 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_u32::<LittleEndian>().map_err(Into::into),
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
uint_type_info(ColumnType::Long)
}
}
impl Type<MySql> for u64 {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::unsigned(TypeId::BIG_INT)
uint_type_info(ColumnType::LongLong)
}
}
impl Encode<MySql> for u64 {
fn encode(&self, buf: &mut Vec<u8>) {
let _ = buf.write_u64::<LittleEndian>(*self);
impl Encode<'_, MySql> for u8 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl<'de> Decode<'de, MySql> for u64 {
fn decode(value: MySqlValue<'de>) -> crate::Result<Self> {
match value.try_get()? {
MySqlData::Binary(mut buf) => buf.read_u64::<LittleEndian>().map_err(Into::into),
impl Encode<'_, MySql> for u16 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
MySqlData::Text(s) => from_utf8(s)
.map_err(Error::decode)?
.parse()
.map_err(Error::decode),
}
IsNull::No
}
}
impl Encode<'_, MySql> for u32 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl Encode<'_, MySql> for u64 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl Decode<'_, MySql> for u8 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
uint_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => value.as_bytes()?[0] as u8,
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for u16 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
uint_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_u16(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for u32 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
uint_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_u32(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}
impl Decode<'_, MySql> for u64 {
fn accepts(ty: &MySqlTypeInfo) -> bool {
uint_accepts(ty)
}
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_u64(value.as_bytes()?),
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}
}

View File

@@ -1,9 +0,0 @@
// XOR(x, y)
// If len(y) < len(x), wrap around inside y
pub fn xor_eq(x: &mut [u8], y: &[u8]) {
let y_len = y.len();
for i in 0..x.len() {
x[i] ^= y[i % y_len];
}
}

View File

@@ -1,62 +1,98 @@
use crate::error::UnexpectedNullError;
use std::borrow::Cow;
use std::str::from_utf8;
use bytes::Bytes;
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::mysql::{MySql, MySqlTypeInfo};
use crate::value::RawValue;
use crate::value::{Value, ValueRef};
#[derive(Debug, Copy, Clone)]
pub enum MySqlData<'c> {
Binary(&'c [u8]),
Text(&'c [u8]),
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum MySqlValueFormat {
Text,
Binary,
}
#[derive(Debug)]
pub struct MySqlValue<'c> {
/// Implementation of [`Value`] for MySQL.
#[derive(Clone)]
pub struct MySqlValue {
value: Option<Bytes>,
type_info: Option<MySqlTypeInfo>,
data: Option<MySqlData<'c>>,
format: MySqlValueFormat,
}
impl<'c> MySqlValue<'c> {
/// Gets the binary or text data for this value; or, `UnexpectedNullError` if this
/// is a `NULL` value.
pub(crate) fn try_get(&self) -> crate::Result<MySqlData<'c>> {
match self.data {
Some(data) => Ok(data),
None => Err(crate::Error::decode(UnexpectedNullError)),
/// Implementation of [`ValueRef`] for MySQL.
#[derive(Clone)]
pub struct MySqlValueRef<'r> {
pub(crate) value: Option<&'r [u8]>,
pub(crate) row: Option<&'r Bytes>,
pub(crate) type_info: Option<MySqlTypeInfo>,
pub(crate) format: MySqlValueFormat,
}
impl<'r> MySqlValueRef<'r> {
pub(crate) fn format(&self) -> MySqlValueFormat {
self.format
}
pub(crate) fn as_bytes(&self) -> Result<&'r [u8], BoxDynError> {
match &self.value {
Some(v) => Ok(v),
None => Err(UnexpectedNullError.into()),
}
}
/// Gets the binary or text data for this value; or, `None` if this
/// is a `NULL` value.
#[inline]
pub fn get(&self) -> Option<MySqlData<'c>> {
self.data
}
pub(crate) fn null() -> Self {
Self {
type_info: None,
data: None,
}
}
pub(crate) fn binary(type_info: MySqlTypeInfo, buf: &'c [u8]) -> Self {
Self {
type_info: Some(type_info),
data: Some(MySqlData::Binary(buf)),
}
}
pub(crate) fn text(type_info: MySqlTypeInfo, buf: &'c [u8]) -> Self {
Self {
type_info: Some(type_info),
data: Some(MySqlData::Text(buf)),
}
pub(crate) fn as_str(&self) -> Result<&'r str, BoxDynError> {
Ok(from_utf8(self.as_bytes()?)?)
}
}
impl<'c> RawValue<'c> for MySqlValue<'c> {
impl Value for MySqlValue {
type Database = MySql;
fn type_info(&self) -> Option<MySqlTypeInfo> {
self.type_info.clone()
fn as_ref(&self) -> MySqlValueRef<'_> {
MySqlValueRef {
value: self.value.as_deref(),
row: None,
type_info: self.type_info.clone(),
format: self.format,
}
}
fn type_info(&self) -> Option<Cow<'_, MySqlTypeInfo>> {
self.type_info.as_ref().map(Cow::Borrowed)
}
fn is_null(&self) -> bool {
self.value.is_none()
}
}
impl<'r> ValueRef<'r> for MySqlValueRef<'r> {
type Database = MySql;
fn to_owned(&self) -> MySqlValue {
let value = match (self.row, self.value) {
(Some(row), Some(value)) => Some(row.slice_ref(value)),
(None, Some(value)) => Some(Bytes::copy_from_slice(value)),
_ => None,
};
MySqlValue {
value,
format: self.format,
type_info: self.type_info.clone(),
}
}
fn type_info(&self) -> Option<Cow<'_, MySqlTypeInfo>> {
self.type_info.as_ref().map(Cow::Borrowed)
}
fn is_null(&self) -> bool {
self.value.is_none()
}
}