refactor(mysql): split off MySqlStream from MySqlConnection

This commit is contained in:
Ryan Leckey 2021-01-26 01:16:13 -08:00
parent d279c6b978
commit ce2fba7b8d
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
8 changed files with 337 additions and 282 deletions

View File

@ -1,16 +1,18 @@
use std::collections::VecDeque;
use std::fmt::{self, Debug, Formatter};
use sqlx_core::io::BufStream;
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Close, Connect, Connection, Runtime};
use crate::protocol::Capabilities;
use crate::stream::MySqlStream;
use crate::{MySql, MySqlConnectOptions};
mod close;
mod command;
mod connect;
mod executor;
mod ping;
mod stream;
/// A single connection (also known as a session) to a MySQL database server.
#[allow(clippy::module_name_repetitions)]
@ -18,16 +20,17 @@ pub struct MySqlConnection<Rt>
where
Rt: Runtime,
{
stream: BufStream<Rt, NetStream<Rt>>,
stream: MySqlStream<Rt>,
connection_id: u32,
// the capability flags are used by the client and server to indicate which
// features they support and want to use.
capabilities: Capabilities,
// the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is
// reset to 0 when a new command begins in the Command Phase.
sequence_id: u8,
// queue of commands that are being processed
// this is what we expect to receive from the server
// in the case of a future or stream being dropped
commands: VecDeque<command::Command>,
}
impl<Rt> MySqlConnection<Rt>
@ -36,21 +39,22 @@ where
{
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
Self {
stream: BufStream::with_capacity(stream, 4096, 1024),
stream: MySqlStream::new(stream),
connection_id: 0,
sequence_id: 0,
capabilities: Capabilities::PROTOCOL_41 | Capabilities::LONG_PASSWORD
commands: VecDeque::with_capacity(2),
capabilities: Capabilities::PROTOCOL_41
| Capabilities::LONG_PASSWORD
| Capabilities::LONG_FLAG
| Capabilities::IGNORE_SPACE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
// | Capabilities::MULTI_STATEMENTS
// | Capabilities::MULTI_RESULTS
// | Capabilities::PS_MULTI_RESULTS
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::PLUGIN_AUTH_LENENC_DATA
// | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
// | Capabilities::SESSION_TRACK
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF,
}
}
@ -115,7 +119,7 @@ mod blocking {
impl<Rt: Runtime> Connection<Rt> for MySqlConnection<Rt> {
#[inline]
fn ping(&mut self) -> sqlx_core::Result<()> {
self.ping()
self.ping_blocking()
}
}
@ -125,14 +129,14 @@ mod blocking {
where
Self: Sized,
{
Self::connect(&url.parse::<MySqlConnectOptions<Rt>>()?)
Self::connect_blocking(&url.parse::<MySqlConnectOptions<Rt>>()?)
}
}
impl<Rt: Runtime> Close<Rt> for MySqlConnection<Rt> {
#[inline]
fn close(self) -> sqlx_core::Result<()> {
self.close()
self.close_blocking()
}
}
}

View File

@ -2,16 +2,13 @@ use sqlx_core::{io::Stream, Result, Runtime};
use crate::protocol::Quit;
impl<Rt> super::MySqlConnection<Rt>
where
Rt: Runtime,
{
impl<Rt: Runtime> super::MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn close_async(mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
self.write_packet(&Quit)?;
self.stream.write_packet(&Quit)?;
self.stream.flush_async().await?;
self.stream.shutdown_async().await?;
@ -19,11 +16,11 @@ where
}
#[cfg(feature = "blocking")]
pub(crate) fn close(mut self) -> Result<()>
pub(crate) fn close_blocking(mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
self.write_packet(&Quit)?;
self.stream.write_packet(&Quit)?;
self.stream.flush()?;
self.stream.shutdown()?;

View File

@ -13,100 +13,115 @@
//!
use sqlx_core::net::Stream as NetStream;
use sqlx_core::Result;
use sqlx_core::Runtime;
use crate::protocol::{Auth, AuthResponse, Handshake, HandshakeResponse};
use crate::protocol::{AuthResponse, Handshake, HandshakeResponse};
use crate::{MySqlConnectOptions, MySqlConnection};
macro_rules! connect {
(@blocking @tcp $options:ident) => {
impl<Rt: Runtime> MySqlConnection<Rt> {
fn recv_handshake(
&mut self,
options: &MySqlConnectOptions<Rt>,
handshake: &Handshake,
) -> Result<()> {
// & the declared server capabilities with our capabilities to find
// what rules the client should operate under
self.capabilities &= handshake.capabilities;
// store the connection ID, mainly for debugging
self.connection_id = handshake.connection_id;
// create the initial auth response
// this may just be a request for an RSA public key
let initial_auth_response = handshake
.auth_plugin
.invoke(&handshake.auth_plugin_data, options.get_password().unwrap_or_default());
// the <HandshakeResponse> contains an initial guess at the correct encoding of
// the password and some other metadata like "which database", "which user", etc.
self.stream.write_packet(&HandshakeResponse {
capabilities: self.capabilities,
auth_plugin_name: handshake.auth_plugin.name(),
auth_response: initial_auth_response,
charset: 45, // [utf8mb4]
database: options.get_database(),
max_packet_size: 1024,
username: options.get_username(),
})?;
Ok(())
}
fn recv_auth_response(
&mut self,
options: &MySqlConnectOptions<Rt>,
handshake: &mut Handshake,
response: AuthResponse,
) -> Result<bool> {
match response {
AuthResponse::Ok(_) => {
// successful, simple authentication; good to go
return Ok(true);
}
AuthResponse::MoreData(data) => {
if let Some(data) = handshake.auth_plugin.handle(
data,
&handshake.auth_plugin_data,
options.get_password().unwrap_or_default(),
)? {
// write the response from the plugin
self.stream.write_packet(&&*data)?;
}
}
AuthResponse::Switch(sw) => {
// switch to the new plugin
handshake.auth_plugin = sw.plugin;
handshake.auth_plugin_data = sw.plugin_data;
// generate an initial response from this plugin
let data = handshake.auth_plugin.invoke(
&handshake.auth_plugin_data,
options.get_password().unwrap_or_default(),
);
// write the response from the plugin
self.stream.write_packet(&&*data)?;
}
}
Ok(false)
}
}
macro_rules! impl_connect {
(@blocking @new $options:ident) => {
NetStream::connect($options.address.as_ref())?;
};
(@tcp $options:ident) => {
(@new $options:ident) => {
NetStream::connect_async($options.address.as_ref()).await?;
};
(@blocking @packet $self:ident) => {
$self.read_packet()?;
};
(@packet $self:ident) => {
$self.read_packet_async().await?;
};
($(@$blocking:ident)? $options:ident) => {{
// open a network stream to the database server
let stream = connect!($(@$blocking)? @tcp $options);
let stream = impl_connect!($(@$blocking)? @new $options);
// construct a <MySqlConnection> around the network stream
// wraps the stream in a <BufStream> to buffer read and write
let mut self_ = Self::new(stream);
// immediately the server should emit a <Handshake> packet
let handshake: Handshake = connect!($(@$blocking)? @packet self_);
// & the declared server capabilities with our capabilities to find
// what rules the client should operate under
self_.capabilities &= handshake.capabilities;
// store the connection ID, mainly for debugging
self_.connection_id = handshake.connection_id;
// extract the auth plugin and data from the handshake
// this can get overwritten by an auth switch
let mut auth_plugin = handshake.auth_plugin;
let mut auth_plugin_data = handshake.auth_plugin_data;
let password = $options.get_password().unwrap_or_default();
// create the initial auth response
// this may just be a request for an RSA public key
let initial_auth_response = auth_plugin.invoke(&auth_plugin_data, password);
// the <HandshakeResponse> contains an initial guess at the correct encoding of
// the password and some other metadata like "which database", "which user", etc.
self_.write_packet(&HandshakeResponse {
auth_plugin_name: auth_plugin.name(),
auth_response: initial_auth_response,
charset: 45, // [utf8mb4]
database: $options.get_database(),
max_packet_size: 1024,
username: $options.get_username(),
})?;
// we need to handle that and reply with a <HandshakeResponse>
let mut handshake = read_packet!($(@$blocking)? self_.stream).deserialize()?;
self_.recv_handshake($options, &handshake)?;
loop {
match connect!($(@$blocking)? @packet self_) {
Auth::Ok(_) => {
// successful, simple authentication; good to go
break;
}
Auth::MoreData(data) => {
if let Some(data) = auth_plugin.handle(data, &auth_plugin_data, password)? {
// write the response from the plugin
self_.write_packet(&AuthResponse { data })?;
// let's try again
continue;
}
// all done, the plugin says we check out
break;
}
Auth::Switch(sw) => {
// switch to the new plugin
auth_plugin = sw.plugin;
auth_plugin_data = sw.plugin_data;
// generate an initial response from this plugin
let data = auth_plugin.invoke(&auth_plugin_data, password);
// write the response from the plugin
self_.write_packet(&AuthResponse { data })?;
// let's try again
continue;
}
let response = read_packet!($(@$blocking)? self_.stream).deserialize_with(self_.capabilities)?;
if self_.recv_auth_response($options, &mut handshake, response)? {
// complete, successful authentication
break;
}
}
@ -114,24 +129,21 @@ macro_rules! connect {
}};
}
impl<Rt> MySqlConnection<Rt>
where
Rt: sqlx_core::Runtime,
{
impl<Rt: Runtime> MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn connect_async(options: &MySqlConnectOptions<Rt>) -> Result<Self>
where
Rt: sqlx_core::Async,
{
connect!(options)
impl_connect!(options)
}
#[cfg(feature = "blocking")]
pub(crate) fn connect(options: &MySqlConnectOptions<Rt>) -> Result<Self>
pub(crate) fn connect_blocking(options: &MySqlConnectOptions<Rt>) -> Result<Self>
where
Rt: sqlx_core::blocking::Runtime,
{
connect!(@blocking options)
impl_connect!(@blocking options)
}
}

View File

@ -6,31 +6,37 @@ use crate::protocol::{OkPacket, Ping};
// send the COM_PING packet
// should receive an OK
impl<Rt> super::MySqlConnection<Rt>
where
Rt: Runtime,
{
macro_rules! impl_ping {
($(@$blocking:ident)? $self:ident) => {{
$self.stream.write_packet(&Ping)?;
// STATE: remember that we are expecting an OK packet
$self.begin_simple_command();
let _ok: OkPacket = read_packet!($(@$blocking)? $self.stream)
.deserialize_with($self.capabilities)?;
// STATE: received OK packet
$self.end_command();
Ok(())
}};
}
impl<Rt: Runtime> super::MySqlConnection<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn ping_async(&mut self) -> Result<()>
where
Rt: sqlx_core::Async,
{
self.write_packet(&Ping)?;
let _ok: OkPacket = self.read_packet_async().await?;
Ok(())
impl_ping!(self)
}
#[cfg(feature = "blocking")]
pub(crate) fn ping(&mut self) -> Result<()>
pub(crate) fn ping_blocking(&mut self) -> Result<()>
where
Rt: sqlx_core::blocking::Runtime,
{
self.write_packet(&Ping)?;
let _ok: OkPacket = self.read_packet()?;
Ok(())
impl_ping!(@blocking self)
}
}

View File

@ -1,160 +0,0 @@
//! Reads and writes packets to and from the MySQL database server.
//!
//! The logic for serializing data structures into the packets is found
//! mostly in `protocol/`.
//!
//! Packets in MySQL are prefixed by 4 bytes.
//! 3 for length (in LE) and a sequence id.
//!
//! Packets may only be as large as the communicated size in the initial
//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
//! a larger payload is simply sending completely "full" packets, one after the
//! other, with an increasing sequence id.
//!
//! In other words, when we sent data, we:
//!
//! - Split the data into "packets" of size `2 ** 24 - 1` bytes.
//!
//! - Prepend each packet with a **packet header**, consisting of the length of that packet,
//! and the sequence number.
//!
//! https://dev.mysql.com/doc/internals/en/mysql-packet.html
//!
use std::fmt::Debug;
use bytes::{Buf, BufMut};
use sqlx_core::io::{Deserialize, Serialize};
use sqlx_core::{Error, Result, Runtime};
use crate::protocol::{Capabilities, ErrPacket, MaybeCommand};
use crate::{MySqlConnection, MySqlDatabaseError};
impl<Rt> MySqlConnection<Rt>
where
Rt: Runtime,
{
pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
where
T: Serialize<'ser, Capabilities> + Debug + MaybeCommand,
{
log::trace!("write > {:?}", packet);
// the sequence-id is incremented with each packet and may
// wrap around. it starts at 0 and is reset to 0 when a new command
// begins in the Command Phase
self.sequence_id = if T::is_command() { 0 } else { self.sequence_id.wrapping_add(1) };
// optimize for <16M packet sizes, in the case of >= 16M we would
// swap out the write buffer for a fresh buffer and then split it into
// 16M chunks separated by packet headers
let buf = self.stream.buffer();
let pos = buf.len();
// leave room for the length of the packet header at the start
buf.reserve(4);
buf.extend_from_slice(&[0_u8; 3]);
buf.push(self.sequence_id);
// serialize the passed packet structure directly into the write buffer
packet.serialize_with(buf, self.capabilities)?;
let payload_len = buf.len() - pos - 4;
// FIXME: handle split packets
assert!(payload_len < 0xFF_FF_FF);
// write back the length of the packet
#[allow(clippy::cast_possible_truncation)]
(&mut buf[pos..]).put_uint_le(payload_len as u64, 3);
Ok(())
}
fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result<T>
where
T: Deserialize<'de, Capabilities> + Debug,
{
// FIXME: handle split packets
assert_ne!(len, 0xFF_FF_FF);
// We store the sequence id here. To respond to a packet, it should use a
// sequence id of n+1. It only "resets" at the start of a new command.
self.sequence_id = self.stream.get(3, 1).get_u8();
// tell the stream that we are done with the 4-byte header
self.stream.consume(4);
// and remove the remainder of the packet from the stream, the payload
let payload = self.stream.take(len);
if payload[0] == 0xff {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
let err = ErrPacket::deserialize_with(payload, self.capabilities)?;
log::trace!("read > {:?}", err);
return Err(Error::connect(MySqlDatabaseError(err)));
}
let packet = T::deserialize_with(payload, self.capabilities)?;
log::trace!("read > {:?}", packet);
Ok(packet)
}
}
macro_rules! read_packet {
($(@$blocking:ident)? $self:ident) => {{
// reads at least 4 bytes from the IO stream into the read buffer
read_packet!($(@$blocking)? @stream $self, 0, 4);
// the first 3 bytes will be the payload length of the packet (in LE)
// ALLOW: the max this len will be is 16M
#[allow(clippy::cast_possible_truncation)]
let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize;
// read <payload_len> bytes _after_ the 4 byte packet header
// note that we have not yet told the stream we are done with any of
// these bytes yet. if this next read invocation were to never return (eg., the
// outer future was dropped), then the next time read_packet_async was called
// it will re-read the parsed-above packet header. Note that we have NOT
// mutated `self` _yet_. This is important.
read_packet!($(@$blocking)? @stream $self, 4, payload_len);
$self.recv_packet(payload_len)
}};
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read($offset, $n)?;
};
(@stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read_async($offset, $n).await?;
};
}
impl<Rt> MySqlConnection<Rt>
where
Rt: Runtime,
{
#[cfg(feature = "async")]
pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, Capabilities> + Debug,
Rt: sqlx_core::Async,
{
read_packet!(self)
}
#[cfg(feature = "blocking")]
pub(super) fn read_packet<'de, T>(&'de mut self) -> Result<T>
where
T: Deserialize<'de, Capabilities> + Debug,
Rt: sqlx_core::blocking::Runtime,
{
read_packet!(@blocking self)
}
}

View File

@ -18,6 +18,9 @@
#![warn(clippy::useless_let_if_seq)]
#![allow(clippy::doc_markdown)]
#[macro_use]
mod stream;
mod connection;
mod database;
mod error;

View File

@ -97,7 +97,7 @@ mod blocking {
where
Self::Connection: Sized,
{
<MySqlConnection<Rt>>::connect(self)
<MySqlConnection<Rt>>::connect_blocking(self)
}
}
}

193
sqlx-mysql/src/stream.rs Normal file
View File

@ -0,0 +1,193 @@
use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use bytes::{Buf, BufMut};
use sqlx_core::io::{BufStream, Serialize};
use sqlx_core::net::Stream as NetStream;
use sqlx_core::{Error, Result, Runtime};
use crate::protocol::{MaybeCommand, Packet};
use crate::MySqlDatabaseError;
/// Reads and writes packets to and from the MySQL database server.
///
/// The logic for serializing data structures into the packets is found
/// mostly in `protocol/`.
///
/// Packets in MySQL are prefixed by 4 bytes.
/// 3 for length (in LE) and a sequence id.
///
/// Packets may only be as large as the communicated size in the initial
/// `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
/// a larger payload is simply sending completely "full" packets, one after the
/// other, with an increasing sequence id.
///
/// In other words, when we sent data, we:
///
/// - Split the data into "packets" of size `2 ** 24 - 1` bytes.
///
/// - Prepend each packet with a **packet header**, consisting of the length of that packet,
/// and the sequence number.
///
/// <https://dev.mysql.com/doc/internals/en/mysql-packet.html>
///
#[allow(clippy::module_name_repetitions)]
pub(crate) struct MySqlStream<Rt: Runtime> {
stream: BufStream<Rt, NetStream<Rt>>,
// the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is
// reset to 0 when a new command begins in the Command Phase.
sequence_id: u8,
}
impl<Rt: Runtime> MySqlStream<Rt> {
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
Self { stream: BufStream::with_capacity(stream, 4096, 1024), sequence_id: 0 }
}
pub(crate) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
where
T: Serialize<'ser> + Debug + MaybeCommand,
{
log::trace!("write > {:?}", packet);
// the sequence-id is incremented with each packet and may
// wrap around. it starts at 0 and is reset to 0 when a new command
// begins in the Command Phase
self.sequence_id = if T::is_command() { 0 } else { self.sequence_id.wrapping_add(1) };
// optimize for <16M packet sizes, in the case of >= 16M we would
// swap out the write buffer for a fresh buffer and then split it into
// 16M chunks separated by packet headers
let buf = self.stream.buffer();
let pos = buf.len();
// leave room for the length of the packet header at the start
buf.reserve(4);
buf.extend_from_slice(&[0_u8; 3]);
buf.push(self.sequence_id);
// serialize the passed packet structure directly into the write buffer
packet.serialize(buf)?;
let payload_len = buf.len() - pos - 4;
// FIXME: handle split packets
assert!(payload_len < 0xFF_FF_FF);
// write back the length of the packet
#[allow(clippy::cast_possible_truncation)]
(&mut buf[pos..]).put_uint_le(payload_len as u64, 3);
Ok(())
}
// read and consumes a packet from the stream _buffer_
// assumes there is a packet on the stream
// is called by [read_packet_blocking] or [read_packet_async]
fn read_packet(&mut self, len: usize) -> Result<Packet> {
// We store the sequence id here. To respond to a packet, it should use a
// sequence id of n+1. It only "resets" at the start of a new command.
self.sequence_id = self.stream.get(3, 1).get_u8();
// tell the stream that we are done with the 4-byte header
self.stream.consume(4);
// and remove the remainder of the packet from the stream, the payload
let packet = Packet { bytes: self.stream.take(len) };
if packet.bytes.len() != len {
// BUG: something is very wrong somewhere if this branch is executed
// either in the SQLx MySQL driver or in the MySQL server
return Err(Error::connect(MySqlDatabaseError::malformed_packet(&format!(
"Received {} bytes for packet but expecting {} bytes",
packet.bytes.len(),
len
))));
}
if packet.bytes[0] == 0xff {
// if the first byte of the payload is 0xFF and the payload is an ERR packet
return Err(Error::connect(MySqlDatabaseError(packet.deserialize()?)));
}
Ok(packet)
}
}
macro_rules! impl_read_packet {
($(@$blocking:ident)? $self:ident) => {{
// reads at least 4 bytes from the IO stream into the read buffer
impl_read_packet!($(@$blocking)? @stream $self, 0, 4);
// the first 3 bytes will be the payload length of the packet (in LE)
// ALLOW: the max this len will be is 16M
#[allow(clippy::cast_possible_truncation)]
let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize;
// read <payload_len> bytes _after_ the 4 byte packet header
// note that we have not yet told the stream we are done with any of
// these bytes yet. if this next read invocation were to never return (eg., the
// outer future was dropped), then the next time read_packet was called
// it will re-read the parsed-above packet header. Note that we have NOT
// mutated `self` _yet_. This is important.
impl_read_packet!($(@$blocking)? @stream $self, 4, payload_len);
// FIXME: handle split packets
assert_ne!(payload_len, 0xFF_FF_FF);
$self.read_packet(payload_len)
}};
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read($offset, $n)?;
};
(@stream $self:ident, $offset:expr, $n:expr) => {
$self.stream.read_async($offset, $n).await?;
};
}
impl<Rt: Runtime> MySqlStream<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn read_packet_async(&mut self) -> Result<Packet>
where
Rt: sqlx_core::Async,
{
impl_read_packet!(self)
}
#[cfg(feature = "blocking")]
pub(crate) fn read_packet_blocking(&mut self) -> Result<Packet>
where
Rt: sqlx_core::blocking::Runtime,
{
impl_read_packet!(@blocking self)
}
}
impl<Rt: Runtime> Deref for MySqlStream<Rt> {
type Target = BufStream<Rt, NetStream<Rt>>;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<Rt: Runtime> DerefMut for MySqlStream<Rt> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
macro_rules! read_packet {
(@blocking $stream:expr) => {
$stream.read_packet_blocking()?
};
($stream:expr) => {
$stream.read_packet_async().await?
};
}