From 653ead0322e51b015fa3cd34793934c579d61cff Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Tue, 25 Jun 2019 14:41:22 -0700 Subject: [PATCH] WIP: Cleanup serializations --- mason-mariadb/src/connection/establish.rs | 25 +-- mason-mariadb/src/connection/mod.rs | 27 ++- mason-mariadb/src/protocol/client.rs | 220 ++++++++++++++++++---- 3 files changed, 218 insertions(+), 54 deletions(-) diff --git a/mason-mariadb/src/connection/establish.rs b/mason-mariadb/src/connection/establish.rs index aebdb948..9357ce6e 100644 --- a/mason-mariadb/src/connection/establish.rs +++ b/mason-mariadb/src/connection/establish.rs @@ -17,20 +17,21 @@ pub async fn establish<'a, 'b: 'a>( conn: &'a mut Connection, options: ConnectOptions<'b>, ) -> Result<(), Error> { - let init_packet = if let Some(message) = conn.incoming.next().await { - match message { - ServerMessage::InitialHandshakePacket(message) => { - Ok(message) - }, - _ => Err(failure::err_msg("Incorrect First Packet")), - } - } else { - Err(failure::err_msg("Failed to connect")) - }?; + let init_packet = if let Some(message) = conn.incoming.next().await { + match message { + ServerMessage::InitialHandshakePacket(message) => { + Ok(message) + }, + _ => Err(failure::err_msg("Incorrect First Packet")), + } + } else { + Err(failure::err_msg("Failed to connect")) + }?; + + conn.server_capabilities = init_packet.capabilities; let handshake = HandshakeResponsePacket { - server_capabilities: init_packet.capabilities, - sequence_number: 1, + // Minimum client capabilities required to establish connection capabilities: Capabilities::CLIENT_PROTOCOL_41, max_packet_size: 1024, collation: 0, diff --git a/mason-mariadb/src/connection/mod.rs b/mason-mariadb/src/connection/mod.rs index f46e86db..4aeefab4 100644 --- a/mason-mariadb/src/connection/mod.rs +++ b/mason-mariadb/src/connection/mod.rs @@ -1,6 +1,7 @@ use crate::protocol::{ client::Serialize, server::Message as ServerMessage, + server::Capabilities, server::InitialHandshakePacket, server::Deserialize }; @@ -15,7 +16,8 @@ use runtime::{net::TcpStream, task::JoinHandle}; use std::io; use failure::Error; use failure::err_msg; -use byteorder::{ByteOrder, LittleEndian}; +use byteorder::{ByteOrder, LittleEndian, WriteBytesExt}; +use crate::protocol::serialize::serialize_length; mod establish; // mod query; @@ -32,6 +34,12 @@ pub struct Connection { // MariaDB Connection ID connection_id: i32, + + // Sequence Number + sequence_number: u8, + + // Server Capabilities + server_capabilities: Capabilities, } impl Connection { @@ -46,6 +54,8 @@ impl Connection { receiver, incoming: rx, connection_id: -1, + sequence_number: 1, + server_capabilities: Capabilities::default(), }; establish::establish(&mut conn, options).await?; @@ -59,7 +69,20 @@ impl Connection { { self.wbuf.clear(); - message.serialize(&mut self.wbuf); + /* + Reserve space for packet header; Packet Body Length (3 bytes) and sequence number (1 byte) + `self.wbuf.write_u32::(0_u32);` + causes compiler to panic + self.wbuf.write + rustc 1.37.0-nightly (7cdaffd79 2019-06-05) running on x86_64-unknown-linux-gnu + https://github.com/rust-lang/rust/issues/62126 + */ + self.wbuf.extend_from_slice(&[0, 0, 0, 0]); + + message.serialize(&mut self.wbuf, &self.server_capabilities)?; + serialize_length(&mut self.wbuf); + self.wbuf[3] = self.sequence_number; + self.sequence_number += 1; self.writer.write_all(&self.wbuf).await?; self.writer.flush().await?; diff --git a/mason-mariadb/src/protocol/client.rs b/mason-mariadb/src/protocol/client.rs index 6e353d07..f57290c3 100644 --- a/mason-mariadb/src/protocol/client.rs +++ b/mason-mariadb/src/protocol/client.rs @@ -11,15 +11,58 @@ use super::server::Capabilities; use byteorder::{ByteOrder, LittleEndian, WriteBytesExt}; use bytes::Bytes; use crate::protocol::serialize::*; +use failure::Error; pub trait Serialize { - fn serialize(&self, buf: &mut Vec); + fn serialize(&self, buf: &mut Vec, server_capabilities: &Capabilities) -> Result<(), Error>; +} + +pub enum TextProtocol { + ComChangeUser = 0x11, + ComDebug = 0x0D, + ComInitDb = 0x02, + ComPing = 0x0e, + ComProcessKill = 0xC, + ComQuery = 0x03, + ComQuit = 0x01, + ComResetConnection = 0x1F, + ComSetOption = 0x1B, + ComShutdown = 0x0A, + ComSleep = 0x00, + ComStatistics = 0x09, +} + +#[derive(Clone, Copy)] +pub enum SetOptionOptions { + MySqlOptionMultiStatementsOn = 0x00, + MySqlOptionMultiStatementsOff = 0x01, +} + +#[derive(Clone, Copy)] +pub enum ShutdownOptions { + ShutdownDefault = 0x00 +} + +impl Into for TextProtocol { + fn into(self) -> u8 { + self as u8 + } +} + +impl Into for SetOptionOptions { + fn into(self) -> u16 { + self as u16 + } +} + +impl Into for ShutdownOptions { + fn into(self) -> u8 { + self as u8 + } } #[derive(Default, Debug)] pub struct SSLRequestPacket { - pub server_capabilities: Capabilities, - pub sequence_number: u8, pub capabilities: Capabilities, pub max_packet_size: u32, pub collation: u8, @@ -28,8 +71,6 @@ pub struct SSLRequestPacket { #[derive(Default, Debug)] pub struct HandshakeResponsePacket { - pub server_capabilities: Capabilities, - pub sequence_number: u8, pub capabilities: Capabilities, pub max_packet_size: u32, pub collation: u8, @@ -44,21 +85,135 @@ pub struct HandshakeResponsePacket { pub conn_attr: Option>, } +pub struct ComQuit(); +pub struct ComDebug(); +pub struct ComPing(); +pub struct ComResetConnection(); +pub struct ComStatistics(); +pub struct ComSleep(); + +pub struct ComInitDb { + pub schema_name: Bytes +} + +pub struct ComProcessKill { + pub process_id: u32 +} + +pub struct ComQuery { + pub sql_statement: Bytes +} + +pub struct ComSetOption { + pub option: SetOptionOptions +} + +pub struct ComShutdown { + pub option: ShutdownOptions +} + #[derive(Default, Debug)] pub struct AuthenticationSwitchRequestPacket { - pub sequence_number: u8, pub auth_plugin_name: Bytes, pub auth_plugin_data: Bytes, } -impl Serialize for SSLRequestPacket { - fn serialize(&self, buf: &mut Vec) { - // Temporary storage for length: 3 bytes - buf.write_u24::(0); - // Sequence Number - serialize_int_1(buf, self.sequence_number); +impl Serialize for ComQuit { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComQuit.into()); - // Packet body + Ok(()) + } +} + +impl Serialize for ComInitDb { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComInitDb.into()); + serialize_string_null(buf, &self.schema_name); + + Ok(()) + } +} + +impl Serialize for ComDebug { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComDebug.into()); + + Ok(()) + } +} + +impl Serialize for ComPing { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComPing.into()); + + Ok(()) + } +} + +impl Serialize for ComProcessKill { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComProcessKill.into()); + serialize_int_4(buf, self.process_id); + + Ok(()) + } +} + +impl Serialize for ComQuery { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComQuery.into()); + serialize_string_eof(buf, &self.sql_statement); + + Ok(()) + } +} + +impl Serialize for ComResetConnection { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComResetConnection.into()); + + Ok(()) + } +} + +impl Serialize for ComSetOption { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComSetOption.into()); + serialize_int_2(buf, self.option.into()); + + Ok(()) + } +} + +impl Serialize for ComShutdown { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComShutdown.into()); + serialize_int_1(buf, self.option.into()); + + Ok(()) + } +} + +impl Serialize for ComSleep { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComSleep.into()); + + Ok(()) + } +} + +impl Serialize for ComStatistics { + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { + serialize_int_1(buf, TextProtocol::ComStatistics.into()); + + Ok(()) + } +} + + +impl Serialize for SSLRequestPacket { + fn serialize(&self, buf: &mut Vec, server_capabilities: &Capabilities) -> Result<(), Error> { serialize_int_4(buf, self.capabilities.bits() as u32); serialize_int_4(buf, self.max_packet_size); serialize_int_1(buf, self.collation); @@ -66,7 +221,7 @@ impl Serialize for SSLRequestPacket { // Filler serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19); - if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() && + if !(*server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { serialize_int_4(buf, capabilities.bits() as u32); @@ -75,19 +230,12 @@ impl Serialize for SSLRequestPacket { serialize_byte_fix(buf, &Bytes::from_static(&[0u8;4]), 4); } - // Set packet length - serialize_length(buf); + Ok(()) } } impl Serialize for HandshakeResponsePacket { - fn serialize(&self, buf: &mut Vec) { - // Temporary storage for length: 3 bytes - buf.write_u24::(0); - // Sequence Number - serialize_int_1(buf, self.sequence_number); - - // Packet body + fn serialize(&self, buf: &mut Vec, server_capabilities: &Capabilities) -> Result<(), Error> { serialize_int_4(buf, self.capabilities.bits() as u32); serialize_int_4(buf, self.max_packet_size); serialize_int_1(buf, self.collation); @@ -95,7 +243,7 @@ impl Serialize for HandshakeResponsePacket { // Filler serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19); - if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() && + if !(*server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { serialize_int_4(buf, capabilities.bits() as u32); @@ -106,11 +254,11 @@ impl Serialize for HandshakeResponsePacket { serialize_string_null(buf, &self.username); - if !(self.server_capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() { + if !(*server_capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() { if let Some(auth_data) = &self.auth_data { serialize_string_lenenc(buf, &auth_data); } - } else if !(self.server_capabilities & Capabilities::SECURE_CONNECTION).is_empty() { + } else if !(*server_capabilities & Capabilities::SECURE_CONNECTION).is_empty() { if let Some(auth_response) = &self.auth_response { serialize_int_1(buf, self.auth_response_len.unwrap()); serialize_string_fix(buf, &auth_response, self.auth_response_len.unwrap() as usize); @@ -119,21 +267,21 @@ impl Serialize for HandshakeResponsePacket { serialize_int_1(buf, 0); } - if !(self.server_capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { + if !(*server_capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { if let Some(database) = &self.database { // string serialize_string_null(buf, &database); } } - if !(self.server_capabilities & Capabilities::PLUGIN_AUTH).is_empty() { + if !(*server_capabilities & Capabilities::PLUGIN_AUTH).is_empty() { if let Some(auth_plugin_name) = &self.auth_plugin_name { // string serialize_string_null(buf, &auth_plugin_name); } } - if !(self.server_capabilities & Capabilities::CONNECT_ATTRS).is_empty() { + if !(*server_capabilities & Capabilities::CONNECT_ATTRS).is_empty() { if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) { // int serialize_int_lenenc(buf, Some(conn_attr_len)); @@ -146,24 +294,16 @@ impl Serialize for HandshakeResponsePacket { } } - // Set packet length - serialize_length(buf); + Ok(()) } } impl Serialize for AuthenticationSwitchRequestPacket { - fn serialize(&self, buf: &mut Vec) { - // Temporary storage for length: 3 bytes - buf.write_u24::(0); - // Sequence Number - serialize_int_1(buf, self.sequence_number); - - // Packet body + fn serialize(&self, buf: &mut Vec, _server_capabilities: &Capabilities) -> Result<(), Error> { serialize_int_1(buf, 0xFE); serialize_string_null(buf, &self.auth_plugin_name); serialize_byte_eof(buf, &self.auth_plugin_data); - // Set packet length - serialize_length(buf); + Ok(()) } }