mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-30 05:11:13 +00:00
refactor(mysql): split off MySqlStream from MySqlConnection
This commit is contained in:
parent
d279c6b978
commit
ce2fba7b8d
@ -1,16 +1,18 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
|
||||
use sqlx_core::io::BufStream;
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Close, Connect, Connection, Runtime};
|
||||
|
||||
use crate::protocol::Capabilities;
|
||||
use crate::stream::MySqlStream;
|
||||
use crate::{MySql, MySqlConnectOptions};
|
||||
|
||||
mod close;
|
||||
mod command;
|
||||
mod connect;
|
||||
mod executor;
|
||||
mod ping;
|
||||
mod stream;
|
||||
|
||||
/// A single connection (also known as a session) to a MySQL database server.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
@ -18,16 +20,17 @@ pub struct MySqlConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
stream: BufStream<Rt, NetStream<Rt>>,
|
||||
stream: MySqlStream<Rt>,
|
||||
connection_id: u32,
|
||||
|
||||
// the capability flags are used by the client and server to indicate which
|
||||
// features they support and want to use.
|
||||
capabilities: Capabilities,
|
||||
|
||||
// the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is
|
||||
// reset to 0 when a new command begins in the Command Phase.
|
||||
sequence_id: u8,
|
||||
// queue of commands that are being processed
|
||||
// this is what we expect to receive from the server
|
||||
// in the case of a future or stream being dropped
|
||||
commands: VecDeque<command::Command>,
|
||||
}
|
||||
|
||||
impl<Rt> MySqlConnection<Rt>
|
||||
@ -36,21 +39,22 @@ where
|
||||
{
|
||||
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
|
||||
Self {
|
||||
stream: BufStream::with_capacity(stream, 4096, 1024),
|
||||
stream: MySqlStream::new(stream),
|
||||
connection_id: 0,
|
||||
sequence_id: 0,
|
||||
capabilities: Capabilities::PROTOCOL_41 | Capabilities::LONG_PASSWORD
|
||||
commands: VecDeque::with_capacity(2),
|
||||
capabilities: Capabilities::PROTOCOL_41
|
||||
| Capabilities::LONG_PASSWORD
|
||||
| Capabilities::LONG_FLAG
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
// | Capabilities::MULTI_STATEMENTS
|
||||
// | Capabilities::MULTI_RESULTS
|
||||
// | Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_DATA
|
||||
// | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
|
||||
// | Capabilities::SESSION_TRACK
|
||||
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
|
||||
| Capabilities::SESSION_TRACK
|
||||
| Capabilities::DEPRECATE_EOF,
|
||||
}
|
||||
}
|
||||
@ -115,7 +119,7 @@ mod blocking {
|
||||
impl<Rt: Runtime> Connection<Rt> for MySqlConnection<Rt> {
|
||||
#[inline]
|
||||
fn ping(&mut self) -> sqlx_core::Result<()> {
|
||||
self.ping()
|
||||
self.ping_blocking()
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,14 +129,14 @@ mod blocking {
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Self::connect(&url.parse::<MySqlConnectOptions<Rt>>()?)
|
||||
Self::connect_blocking(&url.parse::<MySqlConnectOptions<Rt>>()?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Close<Rt> for MySqlConnection<Rt> {
|
||||
#[inline]
|
||||
fn close(self) -> sqlx_core::Result<()> {
|
||||
self.close()
|
||||
self.close_blocking()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,16 +2,13 @@ use sqlx_core::{io::Stream, Result, Runtime};
|
||||
|
||||
use crate::protocol::Quit;
|
||||
|
||||
impl<Rt> super::MySqlConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
impl<Rt: Runtime> super::MySqlConnection<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn close_async(mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
self.write_packet(&Quit)?;
|
||||
self.stream.write_packet(&Quit)?;
|
||||
self.stream.flush_async().await?;
|
||||
self.stream.shutdown_async().await?;
|
||||
|
||||
@ -19,11 +16,11 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn close(mut self) -> Result<()>
|
||||
pub(crate) fn close_blocking(mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
self.write_packet(&Quit)?;
|
||||
self.stream.write_packet(&Quit)?;
|
||||
self.stream.flush()?;
|
||||
self.stream.shutdown()?;
|
||||
|
||||
|
||||
@ -13,100 +13,115 @@
|
||||
//!
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::Result;
|
||||
use sqlx_core::Runtime;
|
||||
|
||||
use crate::protocol::{Auth, AuthResponse, Handshake, HandshakeResponse};
|
||||
use crate::protocol::{AuthResponse, Handshake, HandshakeResponse};
|
||||
use crate::{MySqlConnectOptions, MySqlConnection};
|
||||
|
||||
macro_rules! connect {
|
||||
(@blocking @tcp $options:ident) => {
|
||||
impl<Rt: Runtime> MySqlConnection<Rt> {
|
||||
fn recv_handshake(
|
||||
&mut self,
|
||||
options: &MySqlConnectOptions<Rt>,
|
||||
handshake: &Handshake,
|
||||
) -> Result<()> {
|
||||
// & the declared server capabilities with our capabilities to find
|
||||
// what rules the client should operate under
|
||||
self.capabilities &= handshake.capabilities;
|
||||
|
||||
// store the connection ID, mainly for debugging
|
||||
self.connection_id = handshake.connection_id;
|
||||
|
||||
// create the initial auth response
|
||||
// this may just be a request for an RSA public key
|
||||
let initial_auth_response = handshake
|
||||
.auth_plugin
|
||||
.invoke(&handshake.auth_plugin_data, options.get_password().unwrap_or_default());
|
||||
|
||||
// the <HandshakeResponse> contains an initial guess at the correct encoding of
|
||||
// the password and some other metadata like "which database", "which user", etc.
|
||||
self.stream.write_packet(&HandshakeResponse {
|
||||
capabilities: self.capabilities,
|
||||
auth_plugin_name: handshake.auth_plugin.name(),
|
||||
auth_response: initial_auth_response,
|
||||
charset: 45, // [utf8mb4]
|
||||
database: options.get_database(),
|
||||
max_packet_size: 1024,
|
||||
username: options.get_username(),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv_auth_response(
|
||||
&mut self,
|
||||
options: &MySqlConnectOptions<Rt>,
|
||||
handshake: &mut Handshake,
|
||||
response: AuthResponse,
|
||||
) -> Result<bool> {
|
||||
match response {
|
||||
AuthResponse::Ok(_) => {
|
||||
// successful, simple authentication; good to go
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
AuthResponse::MoreData(data) => {
|
||||
if let Some(data) = handshake.auth_plugin.handle(
|
||||
data,
|
||||
&handshake.auth_plugin_data,
|
||||
options.get_password().unwrap_or_default(),
|
||||
)? {
|
||||
// write the response from the plugin
|
||||
self.stream.write_packet(&&*data)?;
|
||||
}
|
||||
}
|
||||
|
||||
AuthResponse::Switch(sw) => {
|
||||
// switch to the new plugin
|
||||
handshake.auth_plugin = sw.plugin;
|
||||
handshake.auth_plugin_data = sw.plugin_data;
|
||||
|
||||
// generate an initial response from this plugin
|
||||
let data = handshake.auth_plugin.invoke(
|
||||
&handshake.auth_plugin_data,
|
||||
options.get_password().unwrap_or_default(),
|
||||
);
|
||||
|
||||
// write the response from the plugin
|
||||
self.stream.write_packet(&&*data)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_connect {
|
||||
(@blocking @new $options:ident) => {
|
||||
NetStream::connect($options.address.as_ref())?;
|
||||
};
|
||||
|
||||
(@tcp $options:ident) => {
|
||||
(@new $options:ident) => {
|
||||
NetStream::connect_async($options.address.as_ref()).await?;
|
||||
};
|
||||
|
||||
(@blocking @packet $self:ident) => {
|
||||
$self.read_packet()?;
|
||||
};
|
||||
|
||||
(@packet $self:ident) => {
|
||||
$self.read_packet_async().await?;
|
||||
};
|
||||
|
||||
($(@$blocking:ident)? $options:ident) => {{
|
||||
// open a network stream to the database server
|
||||
let stream = connect!($(@$blocking)? @tcp $options);
|
||||
let stream = impl_connect!($(@$blocking)? @new $options);
|
||||
|
||||
// construct a <MySqlConnection> around the network stream
|
||||
// wraps the stream in a <BufStream> to buffer read and write
|
||||
let mut self_ = Self::new(stream);
|
||||
|
||||
// immediately the server should emit a <Handshake> packet
|
||||
let handshake: Handshake = connect!($(@$blocking)? @packet self_);
|
||||
|
||||
// & the declared server capabilities with our capabilities to find
|
||||
// what rules the client should operate under
|
||||
self_.capabilities &= handshake.capabilities;
|
||||
|
||||
// store the connection ID, mainly for debugging
|
||||
self_.connection_id = handshake.connection_id;
|
||||
|
||||
// extract the auth plugin and data from the handshake
|
||||
// this can get overwritten by an auth switch
|
||||
let mut auth_plugin = handshake.auth_plugin;
|
||||
let mut auth_plugin_data = handshake.auth_plugin_data;
|
||||
let password = $options.get_password().unwrap_or_default();
|
||||
|
||||
// create the initial auth response
|
||||
// this may just be a request for an RSA public key
|
||||
let initial_auth_response = auth_plugin.invoke(&auth_plugin_data, password);
|
||||
|
||||
// the <HandshakeResponse> contains an initial guess at the correct encoding of
|
||||
// the password and some other metadata like "which database", "which user", etc.
|
||||
self_.write_packet(&HandshakeResponse {
|
||||
auth_plugin_name: auth_plugin.name(),
|
||||
auth_response: initial_auth_response,
|
||||
charset: 45, // [utf8mb4]
|
||||
database: $options.get_database(),
|
||||
max_packet_size: 1024,
|
||||
username: $options.get_username(),
|
||||
})?;
|
||||
// we need to handle that and reply with a <HandshakeResponse>
|
||||
let mut handshake = read_packet!($(@$blocking)? self_.stream).deserialize()?;
|
||||
self_.recv_handshake($options, &handshake)?;
|
||||
|
||||
loop {
|
||||
match connect!($(@$blocking)? @packet self_) {
|
||||
Auth::Ok(_) => {
|
||||
// successful, simple authentication; good to go
|
||||
break;
|
||||
}
|
||||
|
||||
Auth::MoreData(data) => {
|
||||
if let Some(data) = auth_plugin.handle(data, &auth_plugin_data, password)? {
|
||||
// write the response from the plugin
|
||||
self_.write_packet(&AuthResponse { data })?;
|
||||
|
||||
// let's try again
|
||||
continue;
|
||||
}
|
||||
|
||||
// all done, the plugin says we check out
|
||||
break;
|
||||
}
|
||||
|
||||
Auth::Switch(sw) => {
|
||||
// switch to the new plugin
|
||||
auth_plugin = sw.plugin;
|
||||
auth_plugin_data = sw.plugin_data;
|
||||
|
||||
// generate an initial response from this plugin
|
||||
let data = auth_plugin.invoke(&auth_plugin_data, password);
|
||||
|
||||
// write the response from the plugin
|
||||
self_.write_packet(&AuthResponse { data })?;
|
||||
|
||||
// let's try again
|
||||
continue;
|
||||
}
|
||||
let response = read_packet!($(@$blocking)? self_.stream).deserialize_with(self_.capabilities)?;
|
||||
if self_.recv_auth_response($options, &mut handshake, response)? {
|
||||
// complete, successful authentication
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,24 +129,21 @@ macro_rules! connect {
|
||||
}};
|
||||
}
|
||||
|
||||
impl<Rt> MySqlConnection<Rt>
|
||||
where
|
||||
Rt: sqlx_core::Runtime,
|
||||
{
|
||||
impl<Rt: Runtime> MySqlConnection<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn connect_async(options: &MySqlConnectOptions<Rt>) -> Result<Self>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
connect!(options)
|
||||
impl_connect!(options)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn connect(options: &MySqlConnectOptions<Rt>) -> Result<Self>
|
||||
pub(crate) fn connect_blocking(options: &MySqlConnectOptions<Rt>) -> Result<Self>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
connect!(@blocking options)
|
||||
impl_connect!(@blocking options)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -6,31 +6,37 @@ use crate::protocol::{OkPacket, Ping};
|
||||
// send the COM_PING packet
|
||||
// should receive an OK
|
||||
|
||||
impl<Rt> super::MySqlConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
macro_rules! impl_ping {
|
||||
($(@$blocking:ident)? $self:ident) => {{
|
||||
$self.stream.write_packet(&Ping)?;
|
||||
|
||||
// STATE: remember that we are expecting an OK packet
|
||||
$self.begin_simple_command();
|
||||
|
||||
let _ok: OkPacket = read_packet!($(@$blocking)? $self.stream)
|
||||
.deserialize_with($self.capabilities)?;
|
||||
|
||||
// STATE: received OK packet
|
||||
$self.end_command();
|
||||
|
||||
Ok(())
|
||||
}};
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> super::MySqlConnection<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn ping_async(&mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
self.write_packet(&Ping)?;
|
||||
|
||||
let _ok: OkPacket = self.read_packet_async().await?;
|
||||
|
||||
Ok(())
|
||||
impl_ping!(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn ping(&mut self) -> Result<()>
|
||||
pub(crate) fn ping_blocking(&mut self) -> Result<()>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
self.write_packet(&Ping)?;
|
||||
|
||||
let _ok: OkPacket = self.read_packet()?;
|
||||
|
||||
Ok(())
|
||||
impl_ping!(@blocking self)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
//! Reads and writes packets to and from the MySQL database server.
|
||||
//!
|
||||
//! The logic for serializing data structures into the packets is found
|
||||
//! mostly in `protocol/`.
|
||||
//!
|
||||
//! Packets in MySQL are prefixed by 4 bytes.
|
||||
//! 3 for length (in LE) and a sequence id.
|
||||
//!
|
||||
//! Packets may only be as large as the communicated size in the initial
|
||||
//! `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
|
||||
//! a larger payload is simply sending completely "full" packets, one after the
|
||||
//! other, with an increasing sequence id.
|
||||
//!
|
||||
//! In other words, when we sent data, we:
|
||||
//!
|
||||
//! - Split the data into "packets" of size `2 ** 24 - 1` bytes.
|
||||
//!
|
||||
//! - Prepend each packet with a **packet header**, consisting of the length of that packet,
|
||||
//! and the sequence number.
|
||||
//!
|
||||
//! https://dev.mysql.com/doc/internals/en/mysql-packet.html
|
||||
//!
|
||||
use std::fmt::Debug;
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use sqlx_core::io::{Deserialize, Serialize};
|
||||
use sqlx_core::{Error, Result, Runtime};
|
||||
|
||||
use crate::protocol::{Capabilities, ErrPacket, MaybeCommand};
|
||||
use crate::{MySqlConnection, MySqlDatabaseError};
|
||||
|
||||
impl<Rt> MySqlConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
pub(super) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
|
||||
where
|
||||
T: Serialize<'ser, Capabilities> + Debug + MaybeCommand,
|
||||
{
|
||||
log::trace!("write > {:?}", packet);
|
||||
|
||||
// the sequence-id is incremented with each packet and may
|
||||
// wrap around. it starts at 0 and is reset to 0 when a new command
|
||||
// begins in the Command Phase
|
||||
|
||||
self.sequence_id = if T::is_command() { 0 } else { self.sequence_id.wrapping_add(1) };
|
||||
|
||||
// optimize for <16M packet sizes, in the case of >= 16M we would
|
||||
// swap out the write buffer for a fresh buffer and then split it into
|
||||
// 16M chunks separated by packet headers
|
||||
|
||||
let buf = self.stream.buffer();
|
||||
let pos = buf.len();
|
||||
|
||||
// leave room for the length of the packet header at the start
|
||||
buf.reserve(4);
|
||||
buf.extend_from_slice(&[0_u8; 3]);
|
||||
buf.push(self.sequence_id);
|
||||
|
||||
// serialize the passed packet structure directly into the write buffer
|
||||
packet.serialize_with(buf, self.capabilities)?;
|
||||
|
||||
let payload_len = buf.len() - pos - 4;
|
||||
|
||||
// FIXME: handle split packets
|
||||
assert!(payload_len < 0xFF_FF_FF);
|
||||
|
||||
// write back the length of the packet
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
(&mut buf[pos..]).put_uint_le(payload_len as u64, 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv_packet<'de, T>(&'de mut self, len: usize) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de, Capabilities> + Debug,
|
||||
{
|
||||
// FIXME: handle split packets
|
||||
assert_ne!(len, 0xFF_FF_FF);
|
||||
|
||||
// We store the sequence id here. To respond to a packet, it should use a
|
||||
// sequence id of n+1. It only "resets" at the start of a new command.
|
||||
self.sequence_id = self.stream.get(3, 1).get_u8();
|
||||
|
||||
// tell the stream that we are done with the 4-byte header
|
||||
self.stream.consume(4);
|
||||
|
||||
// and remove the remainder of the packet from the stream, the payload
|
||||
let payload = self.stream.take(len);
|
||||
|
||||
if payload[0] == 0xff {
|
||||
// if the first byte of the payload is 0xFF and the payload is an ERR packet
|
||||
let err = ErrPacket::deserialize_with(payload, self.capabilities)?;
|
||||
log::trace!("read > {:?}", err);
|
||||
|
||||
return Err(Error::connect(MySqlDatabaseError(err)));
|
||||
}
|
||||
|
||||
let packet = T::deserialize_with(payload, self.capabilities)?;
|
||||
log::trace!("read > {:?}", packet);
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! read_packet {
|
||||
($(@$blocking:ident)? $self:ident) => {{
|
||||
// reads at least 4 bytes from the IO stream into the read buffer
|
||||
read_packet!($(@$blocking)? @stream $self, 0, 4);
|
||||
|
||||
// the first 3 bytes will be the payload length of the packet (in LE)
|
||||
// ALLOW: the max this len will be is 16M
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize;
|
||||
|
||||
// read <payload_len> bytes _after_ the 4 byte packet header
|
||||
// note that we have not yet told the stream we are done with any of
|
||||
// these bytes yet. if this next read invocation were to never return (eg., the
|
||||
// outer future was dropped), then the next time read_packet_async was called
|
||||
// it will re-read the parsed-above packet header. Note that we have NOT
|
||||
// mutated `self` _yet_. This is important.
|
||||
read_packet!($(@$blocking)? @stream $self, 4, payload_len);
|
||||
|
||||
$self.recv_packet(payload_len)
|
||||
}};
|
||||
|
||||
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read($offset, $n)?;
|
||||
};
|
||||
|
||||
(@stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read_async($offset, $n).await?;
|
||||
};
|
||||
}
|
||||
|
||||
impl<Rt> MySqlConnection<Rt>
|
||||
where
|
||||
Rt: Runtime,
|
||||
{
|
||||
#[cfg(feature = "async")]
|
||||
|
||||
pub(super) async fn read_packet_async<'de, T>(&'de mut self) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de, Capabilities> + Debug,
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
read_packet!(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
|
||||
pub(super) fn read_packet<'de, T>(&'de mut self) -> Result<T>
|
||||
where
|
||||
T: Deserialize<'de, Capabilities> + Debug,
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
read_packet!(@blocking self)
|
||||
}
|
||||
}
|
||||
@ -18,6 +18,9 @@
|
||||
#![warn(clippy::useless_let_if_seq)]
|
||||
#![allow(clippy::doc_markdown)]
|
||||
|
||||
#[macro_use]
|
||||
mod stream;
|
||||
|
||||
mod connection;
|
||||
mod database;
|
||||
mod error;
|
||||
|
||||
@ -97,7 +97,7 @@ mod blocking {
|
||||
where
|
||||
Self::Connection: Sized,
|
||||
{
|
||||
<MySqlConnection<Rt>>::connect(self)
|
||||
<MySqlConnection<Rt>>::connect_blocking(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
193
sqlx-mysql/src/stream.rs
Normal file
193
sqlx-mysql/src/stream.rs
Normal file
@ -0,0 +1,193 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use sqlx_core::io::{BufStream, Serialize};
|
||||
use sqlx_core::net::Stream as NetStream;
|
||||
use sqlx_core::{Error, Result, Runtime};
|
||||
|
||||
use crate::protocol::{MaybeCommand, Packet};
|
||||
use crate::MySqlDatabaseError;
|
||||
|
||||
/// Reads and writes packets to and from the MySQL database server.
|
||||
///
|
||||
/// The logic for serializing data structures into the packets is found
|
||||
/// mostly in `protocol/`.
|
||||
///
|
||||
/// Packets in MySQL are prefixed by 4 bytes.
|
||||
/// 3 for length (in LE) and a sequence id.
|
||||
///
|
||||
/// Packets may only be as large as the communicated size in the initial
|
||||
/// `HandshakeResponse`. By default, SQLx configures its chunk size to 16M. Sending
|
||||
/// a larger payload is simply sending completely "full" packets, one after the
|
||||
/// other, with an increasing sequence id.
|
||||
///
|
||||
/// In other words, when we sent data, we:
|
||||
///
|
||||
/// - Split the data into "packets" of size `2 ** 24 - 1` bytes.
|
||||
///
|
||||
/// - Prepend each packet with a **packet header**, consisting of the length of that packet,
|
||||
/// and the sequence number.
|
||||
///
|
||||
/// <https://dev.mysql.com/doc/internals/en/mysql-packet.html>
|
||||
///
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub(crate) struct MySqlStream<Rt: Runtime> {
|
||||
stream: BufStream<Rt, NetStream<Rt>>,
|
||||
|
||||
// the sequence-id is incremented with each packet and may wrap around. It starts at 0 and is
|
||||
// reset to 0 when a new command begins in the Command Phase.
|
||||
sequence_id: u8,
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> MySqlStream<Rt> {
|
||||
pub(crate) fn new(stream: NetStream<Rt>) -> Self {
|
||||
Self { stream: BufStream::with_capacity(stream, 4096, 1024), sequence_id: 0 }
|
||||
}
|
||||
|
||||
pub(crate) fn write_packet<'ser, T>(&'ser mut self, packet: &T) -> Result<()>
|
||||
where
|
||||
T: Serialize<'ser> + Debug + MaybeCommand,
|
||||
{
|
||||
log::trace!("write > {:?}", packet);
|
||||
|
||||
// the sequence-id is incremented with each packet and may
|
||||
// wrap around. it starts at 0 and is reset to 0 when a new command
|
||||
// begins in the Command Phase
|
||||
|
||||
self.sequence_id = if T::is_command() { 0 } else { self.sequence_id.wrapping_add(1) };
|
||||
|
||||
// optimize for <16M packet sizes, in the case of >= 16M we would
|
||||
// swap out the write buffer for a fresh buffer and then split it into
|
||||
// 16M chunks separated by packet headers
|
||||
|
||||
let buf = self.stream.buffer();
|
||||
let pos = buf.len();
|
||||
|
||||
// leave room for the length of the packet header at the start
|
||||
buf.reserve(4);
|
||||
buf.extend_from_slice(&[0_u8; 3]);
|
||||
buf.push(self.sequence_id);
|
||||
|
||||
// serialize the passed packet structure directly into the write buffer
|
||||
packet.serialize(buf)?;
|
||||
|
||||
let payload_len = buf.len() - pos - 4;
|
||||
|
||||
// FIXME: handle split packets
|
||||
assert!(payload_len < 0xFF_FF_FF);
|
||||
|
||||
// write back the length of the packet
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
(&mut buf[pos..]).put_uint_le(payload_len as u64, 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// read and consumes a packet from the stream _buffer_
|
||||
// assumes there is a packet on the stream
|
||||
// is called by [read_packet_blocking] or [read_packet_async]
|
||||
fn read_packet(&mut self, len: usize) -> Result<Packet> {
|
||||
// We store the sequence id here. To respond to a packet, it should use a
|
||||
// sequence id of n+1. It only "resets" at the start of a new command.
|
||||
self.sequence_id = self.stream.get(3, 1).get_u8();
|
||||
|
||||
// tell the stream that we are done with the 4-byte header
|
||||
self.stream.consume(4);
|
||||
|
||||
// and remove the remainder of the packet from the stream, the payload
|
||||
let packet = Packet { bytes: self.stream.take(len) };
|
||||
|
||||
if packet.bytes.len() != len {
|
||||
// BUG: something is very wrong somewhere if this branch is executed
|
||||
// either in the SQLx MySQL driver or in the MySQL server
|
||||
return Err(Error::connect(MySqlDatabaseError::malformed_packet(&format!(
|
||||
"Received {} bytes for packet but expecting {} bytes",
|
||||
packet.bytes.len(),
|
||||
len
|
||||
))));
|
||||
}
|
||||
|
||||
if packet.bytes[0] == 0xff {
|
||||
// if the first byte of the payload is 0xFF and the payload is an ERR packet
|
||||
return Err(Error::connect(MySqlDatabaseError(packet.deserialize()?)));
|
||||
}
|
||||
|
||||
Ok(packet)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_read_packet {
|
||||
($(@$blocking:ident)? $self:ident) => {{
|
||||
// reads at least 4 bytes from the IO stream into the read buffer
|
||||
impl_read_packet!($(@$blocking)? @stream $self, 0, 4);
|
||||
|
||||
// the first 3 bytes will be the payload length of the packet (in LE)
|
||||
// ALLOW: the max this len will be is 16M
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let payload_len: usize = $self.stream.get(0, 3).get_uint_le(3) as usize;
|
||||
|
||||
// read <payload_len> bytes _after_ the 4 byte packet header
|
||||
// note that we have not yet told the stream we are done with any of
|
||||
// these bytes yet. if this next read invocation were to never return (eg., the
|
||||
// outer future was dropped), then the next time read_packet was called
|
||||
// it will re-read the parsed-above packet header. Note that we have NOT
|
||||
// mutated `self` _yet_. This is important.
|
||||
impl_read_packet!($(@$blocking)? @stream $self, 4, payload_len);
|
||||
|
||||
// FIXME: handle split packets
|
||||
assert_ne!(payload_len, 0xFF_FF_FF);
|
||||
|
||||
$self.read_packet(payload_len)
|
||||
}};
|
||||
|
||||
(@blocking @stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read($offset, $n)?;
|
||||
};
|
||||
|
||||
(@stream $self:ident, $offset:expr, $n:expr) => {
|
||||
$self.stream.read_async($offset, $n).await?;
|
||||
};
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> MySqlStream<Rt> {
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) async fn read_packet_async(&mut self) -> Result<Packet>
|
||||
where
|
||||
Rt: sqlx_core::Async,
|
||||
{
|
||||
impl_read_packet!(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn read_packet_blocking(&mut self) -> Result<Packet>
|
||||
where
|
||||
Rt: sqlx_core::blocking::Runtime,
|
||||
{
|
||||
impl_read_packet!(@blocking self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> Deref for MySqlStream<Rt> {
|
||||
type Target = BufStream<Rt, NetStream<Rt>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<Rt: Runtime> DerefMut for MySqlStream<Rt> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.stream
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! read_packet {
|
||||
(@blocking $stream:expr) => {
|
||||
$stream.read_packet_blocking()?
|
||||
};
|
||||
|
||||
($stream:expr) => {
|
||||
$stream.read_packet_async().await?
|
||||
};
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user