fix(mysql): fallout from ec5326e5

This commit is contained in:
Austin Bonander 2024-08-17 04:54:59 -07:00
parent 53766e4659
commit 1f3db8201d
20 changed files with 91 additions and 66 deletions

View File

@ -15,6 +15,13 @@ any = ["sqlx-core/any"]
offline = ["sqlx-core/offline", "serde/derive"]
migrate = ["sqlx-core/migrate"]
# Type Integration features
bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"]
chrono = ["dep:chrono", "sqlx-core/chrono"]
rust_decimal = ["dep:rust_decimal", "rust_decimal/maths", "sqlx-core/rust_decimal"]
time = ["dep:time", "sqlx-core/time"]
uuid = ["dep:uuid", "sqlx-core/uuid"]
[dependencies]
sqlx-core = { workspace = true }

View File

@ -6,7 +6,7 @@ use bytes::{Buf, Bytes, BytesMut};
use crate::collation::{CharSet, Collation};
use crate::error::Error;
use crate::io::MySqlBufExt;
use crate::io::{Decode, Encode};
use crate::io::{ProtocolDecode, ProtocolEncode};
use crate::net::{BufferedSocket, Socket};
use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::protocol::{Capabilities, Packet};
@ -110,7 +110,7 @@ impl<S: Socket> MySqlStream<S> {
pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
where
T: Encode<'en, Capabilities>,
T: ProtocolEncode<'en, Capabilities>,
{
self.sequence_id = 0;
self.write_packet(payload);
@ -120,7 +120,7 @@ impl<S: Socket> MySqlStream<S> {
pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
where
T: Encode<'en, Capabilities>,
T: ProtocolEncode<'en, Capabilities>,
{
self.socket
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
@ -184,7 +184,7 @@ impl<S: Socket> MySqlStream<S> {
pub(crate) async fn recv<'de, T>(&mut self) -> Result<T, Error>
where
T: Decode<'de, Capabilities>,
T: ProtocolDecode<'de, Capabilities>,
{
self.recv_packet().await?.decode_with(self.capabilities)
}

View File

@ -1,8 +1,8 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Encode;
use crate::io::{BufExt, Decode};
use crate::io::ProtocolEncode;
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::auth::AuthPlugin;
use crate::protocol::Capabilities;
@ -14,7 +14,7 @@ pub struct AuthSwitchRequest {
pub data: Bytes,
}
impl Decode<'_, bool> for AuthSwitchRequest {
impl ProtocolDecode<'_, bool> for AuthSwitchRequest {
fn decode_with(mut buf: Bytes, enable_cleartext_plugin: bool) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {
@ -58,9 +58,10 @@ impl Decode<'_, bool> for AuthSwitchRequest {
#[derive(Debug)]
pub struct AuthSwitchResponse(pub Vec<u8>);
impl Encode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for AuthSwitchResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), Error> {
buf.extend_from_slice(&self.0);
Ok(())
}
}

View File

@ -3,7 +3,7 @@ use bytes::{Buf, Bytes};
use std::cmp;
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::auth::AuthPlugin;
use crate::protocol::response::Status;
use crate::protocol::Capabilities;
@ -27,7 +27,7 @@ pub(crate) struct Handshake {
pub(crate) auth_plugin_data: Chain<Bytes, Bytes>,
}
impl Decode<'_> for Handshake {
impl ProtocolDecode<'_> for Handshake {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let protocol_version = buf.get_u8(); // int<1>
let server_version = buf.get_str_nul()?; // string<NUL>

View File

@ -1,5 +1,5 @@
use crate::io::MySqlBufMutExt;
use crate::io::{BufMutExt, Encode};
use crate::io::{BufMutExt, ProtocolEncode};
use crate::protocol::auth::AuthPlugin;
use crate::protocol::connect::ssl_request::SslRequest;
use crate::protocol::Capabilities;
@ -27,11 +27,15 @@ pub struct HandshakeResponse<'a> {
pub auth_response: Option<&'a [u8]>,
}
impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, mut capabilities: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> {
fn encode_with(
&self,
buf: &mut Vec<u8>,
mut context: Capabilities,
) -> Result<(), crate::Error> {
if self.auth_plugin.is_none() {
// ensure PLUGIN_AUTH is set *only* if we have a defined plugin
capabilities.remove(Capabilities::PLUGIN_AUTH);
context.remove(Capabilities::PLUGIN_AUTH);
}
// NOTE: Half of this packet is identical to the SSL Request packet
@ -39,13 +43,13 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
max_packet_size: self.max_packet_size,
collation: self.collation,
}
.encode_with(buf, capabilities);
.encode_with(buf, context)?;
buf.put_str_nul(self.username);
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
if context.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
buf.put_bytes_lenenc(self.auth_response.unwrap_or_default());
} else if capabilities.contains(Capabilities::SECURE_CONNECTION) {
} else if context.contains(Capabilities::SECURE_CONNECTION) {
let response = self.auth_response.unwrap_or_default();
buf.push(response.len() as u8);
@ -54,7 +58,7 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
buf.push(0);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
if context.contains(Capabilities::CONNECT_WITH_DB) {
if let Some(database) = &self.database {
buf.put_str_nul(database);
} else {
@ -62,12 +66,14 @@ impl Encode<'_, Capabilities> for HandshakeResponse<'_> {
}
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
if context.contains(Capabilities::PLUGIN_AUTH) {
if let Some(plugin) = &self.auth_plugin {
buf.put_str_nul(plugin.name());
} else {
buf.push(0);
}
}
Ok(())
}
}

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
@ -10,21 +10,23 @@ pub struct SslRequest {
pub collation: u8,
}
impl Encode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
buf.extend(&(capabilities.bits() as u32).to_le_bytes());
impl ProtocolEncode<'_, Capabilities> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, context: Capabilities) -> Result<(), crate::Error> {
buf.extend(&(context.bits() as u32).to_le_bytes());
buf.extend(&self.max_packet_size.to_le_bytes());
buf.push(self.collation);
// reserved: string<19>
buf.extend(&[0_u8; 19]);
if capabilities.contains(Capabilities::MYSQL) {
if context.contains(Capabilities::MYSQL) {
// reserved: string<4>
buf.extend(&[0_u8; 4]);
} else {
// extended client capabilities (MariaDB-specified): int<4>
buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes());
buf.extend(&((context.bits() >> 32) as u32).to_le_bytes());
}
Ok(())
}
}

View File

@ -4,22 +4,22 @@ use std::ops::{Deref, DerefMut};
use bytes::Bytes;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::io::{ProtocolDecode, ProtocolEncode};
use crate::protocol::response::{EofPacket, OkPacket};
use crate::protocol::Capabilities;
#[derive(Debug)]
pub struct Packet<T>(pub(crate) T);
impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
impl<'en, 'stream, T> ProtocolEncode<'stream, (Capabilities, &'stream mut u8)> for Packet<T>
where
T: Encode<'en, Capabilities>,
T: ProtocolEncode<'en, Capabilities>,
{
fn encode_with(
&self,
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
) -> Result<(), Error> {
let mut next_header = |len: u32| {
let mut buf = len.to_le_bytes();
buf[3] = *sequence_id;
@ -33,7 +33,7 @@ where
buf.extend(&[0_u8; 4]);
// encode the payload
self.0.encode_with(buf, capabilities);
self.0.encode_with(buf, capabilities)?;
// determine the length of the encoded payload
// and write to our reserved space
@ -59,20 +59,22 @@ where
buf.extend(&next_header(remainder.len() as u32));
buf.extend(remainder);
}
Ok(())
}
}
impl Packet<Bytes> {
pub(crate) fn decode<'de, T>(self) -> Result<T, Error>
where
T: Decode<'de, ()>,
T: ProtocolDecode<'de, ()>,
{
self.decode_with(())
}
pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result<T, Error>
where
T: Decode<'de, C>,
T: ProtocolDecode<'de, C>,
{
T::decode_with(self.0, context)
}

View File

@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::io::ProtocolDecode;
use crate::protocol::response::Status;
use crate::protocol::Capabilities;
@ -18,7 +18,7 @@ pub struct EofPacket {
pub status: Status,
}
impl Decode<'_, Capabilities> for EofPacket {
impl ProtocolDecode<'_, Capabilities> for EofPacket {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfe {

View File

@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html
@ -15,7 +15,7 @@ pub struct ErrPacket {
pub error_message: String,
}
impl Decode<'_, Capabilities> for ErrPacket {
impl ProtocolDecode<'_, Capabilities> for ErrPacket {
fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xff {

View File

@ -1,8 +1,8 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::io::MySqlBufExt;
use crate::io::ProtocolDecode;
use crate::protocol::response::Status;
/// Indicates successful completion of a previous command sent by the client.
@ -14,7 +14,7 @@ pub struct OkPacket {
pub warnings: u16,
}
impl Decode<'_> for OkPacket {
impl ProtocolDecode<'_> for OkPacket {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 && header != 0xfe {

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::text::ColumnFlags;
use crate::protocol::Capabilities;
use crate::MySqlArguments;
@ -11,8 +11,8 @@ pub struct Execute<'q> {
pub arguments: &'q MySqlArguments,
}
impl<'q> Encode<'_, Capabilities> for Execute<'q> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl<'q> ProtocolEncode<'_, Capabilities> for Execute<'q> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x17); // COM_STMT_EXECUTE
buf.extend(&self.statement.to_le_bytes());
buf.push(0); // NO_CURSOR
@ -34,5 +34,7 @@ impl<'q> Encode<'_, Capabilities> for Execute<'q> {
buf.extend(&*self.arguments.values);
}
Ok(())
}
}

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE
@ -7,9 +7,10 @@ pub struct Prepare<'a> {
pub query: &'a str,
}
impl Encode<'_, Capabilities> for Prepare<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for Prepare<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x16); // COM_STMT_PREPARE
buf.extend(self.query.as_bytes());
Ok(())
}
}

View File

@ -1,7 +1,7 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::io::ProtocolDecode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
@ -15,7 +15,7 @@ pub(crate) struct PrepareOk {
pub(crate) warnings: u16,
}
impl Decode<'_, Capabilities> for PrepareOk {
impl ProtocolDecode<'_, Capabilities> for PrepareOk {
fn decode_with(buf: Bytes, _: Capabilities) -> Result<Self, Error> {
const SIZE: usize = 12;

View File

@ -2,7 +2,7 @@ use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::MySqlBufExt;
use crate::io::{BufExt, Decode};
use crate::io::{BufExt, ProtocolDecode};
use crate::protocol::text::ColumnType;
use crate::protocol::Row;
use crate::MySqlColumn;
@ -13,7 +13,7 @@ use crate::MySqlColumn;
#[derive(Debug)]
pub(crate) struct BinaryRow(pub(crate) Row);
impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow {
impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for BinaryRow {
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0 {

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-stmt-close.html
@ -8,9 +8,10 @@ pub struct StmtClose {
pub statement: u32,
}
impl Encode<'_, Capabilities> for StmtClose {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for StmtClose {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x19); // COM_STMT_CLOSE
buf.extend(&self.statement.to_le_bytes());
Ok(())
}
}

View File

@ -4,8 +4,8 @@ use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::io::MySqlBufExt;
use crate::io::ProtocolDecode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html
@ -134,7 +134,7 @@ impl ColumnDefinition {
}
}
impl Decode<'_, Capabilities> for ColumnDefinition {
impl ProtocolDecode<'_, Capabilities> for ColumnDefinition {
fn decode_with(mut buf: Bytes, _: Capabilities) -> Result<Self, Error> {
let catalog = buf.get_bytes_lenenc();
let schema = buf.get_bytes_lenenc();

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-ping.html
@ -6,8 +6,9 @@ use crate::protocol::Capabilities;
#[derive(Debug)]
pub(crate) struct Ping;
impl Encode<'_, Capabilities> for Ping {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for Ping {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x0e); // COM_PING
Ok(())
}
}

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-query.html
@ -6,9 +6,10 @@ use crate::protocol::Capabilities;
#[derive(Debug)]
pub(crate) struct Query<'q>(pub(crate) &'q str);
impl Encode<'_, Capabilities> for Query<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for Query<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x03); // COM_QUERY
buf.extend(self.0.as_bytes())
buf.extend(self.0.as_bytes());
Ok(())
}
}

View File

@ -1,4 +1,4 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
use crate::protocol::Capabilities;
// https://dev.mysql.com/doc/internals/en/com-quit.html
@ -6,8 +6,9 @@ use crate::protocol::Capabilities;
#[derive(Debug)]
pub(crate) struct Quit;
impl Encode<'_, Capabilities> for Quit {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) {
impl ProtocolEncode<'_, Capabilities> for Quit {
fn encode_with(&self, buf: &mut Vec<u8>, _: Capabilities) -> Result<(), crate::Error> {
buf.push(0x01); // COM_QUIT
Ok(())
}
}

View File

@ -2,14 +2,14 @@ use bytes::{Buf, Bytes};
use crate::column::MySqlColumn;
use crate::error::Error;
use crate::io::Decode;
use crate::io::MySqlBufExt;
use crate::io::ProtocolDecode;
use crate::protocol::Row;
#[derive(Debug)]
pub(crate) struct TextRow(pub(crate) Row);
impl<'de> Decode<'de, &'de [MySqlColumn]> for TextRow {
impl<'de> ProtocolDecode<'de, &'de [MySqlColumn]> for TextRow {
fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result<Self, Error> {
let storage = buf.clone();
let offset = buf.len();