diff --git a/mason-mariadb/src/protocol/client.rs b/mason-mariadb/src/protocol/client.rs index 985f58c2..e82dd149 100644 --- a/mason-mariadb/src/protocol/client.rs +++ b/mason-mariadb/src/protocol/client.rs @@ -8,6 +8,5 @@ pub struct StartupMessage<'a> { } impl<'a> Serialize for StartupMessage<'a> { - fn serialize(&self, buf: &mut Vec) { - } + fn serialize(&self, buf: &mut Vec) {} } diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index 9beac954..0a254236 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -1,7 +1,6 @@ +use byteorder::{ByteOrder, LittleEndian}; use failure::Error; use std::iter::FromIterator; -use byteorder::LittleEndian; -use byteorder::ByteOrder; pub trait Deserialize: Sized { fn deserialize(buf: &mut Vec) -> Result; @@ -14,44 +13,44 @@ pub enum Message { } pub enum Capabilities { - ClientMysql = 1, - FoundRows = 2, - ConnectWithDb = 8, - Compress = 32, - LocalFiles = 128, - IgnroeSpace = 256, - ClientProtocol41 = 1 << 9, - ClientInteractive = 1 << 10, - SSL = 1 << 11, - Transactions = 1 << 12, - SecureConnection = 1 << 13, - MultiStatements = 1 << 16, - MultiResults = 1 << 17, - PsMultiResults = 1 << 18, - PluginAuth = 1 << 19, - ConnectAttrs = 1 << 20, - PluginAuthLenencClientData = 1 << 21, - ClientSessionTrack = 1 << 23, - ClientDeprecateEof = 1 << 24, - MariaDbClientProgress = 1 << 32, - MariaDbClientComMulti = 1 << 33, + ClientMysql = 1, + FoundRows = 2, + ConnectWithDb = 8, + Compress = 32, + LocalFiles = 128, + IgnroeSpace = 256, + ClientProtocol41 = 1 << 9, + ClientInteractive = 1 << 10, + SSL = 1 << 11, + Transactions = 1 << 12, + SecureConnection = 1 << 13, + MultiStatements = 1 << 16, + MultiResults = 1 << 17, + PsMultiResults = 1 << 18, + PluginAuth = 1 << 19, + ConnectAttrs = 1 << 20, + PluginAuthLenencClientData = 1 << 21, + ClientSessionTrack = 1 << 23, + ClientDeprecateEof = 1 << 24, + MariaDbClientProgress = 1 << 32, + MariaDbClientComMulti = 1 << 33, MariaClientStmtBulkOperations = 1 << 34, } #[derive(Default, Debug)] pub struct InitialHandshakePacket { - pub protocol_version: u8, - pub server_version: String, - pub connection_id: u32, - pub auth_seed: String, - pub reserved: u8, - pub capabilities1: u16, - pub collation: u8, - pub status: u16, + pub protocol_version: u8, + pub server_version: String, + pub connection_id: u32, + pub auth_seed: String, + pub reserved: u8, + pub capabilities1: u16, + pub collation: u8, + pub status: u16, pub plugin_data_length: u8, - pub scramble2: Option, - pub reserved2: Option, - pub auth_plugin_name: Option, + pub scramble2: Option, + pub reserved2: Option, + pub auth_plugin_name: Option, } impl Deserialize for InitialHandshakePacket { @@ -69,12 +68,13 @@ impl Deserialize for InitialHandshakePacket { } null_index += 1; } - let server_version = String::from_iter(buf[index..null_index] - .iter() - .map(|b| char::from(b.clone())) - .collect::>() - .into_iter() - ); + let server_version = String::from_iter( + buf[index..null_index] + .iter() + .map(|b| char::from(b.clone())) + .collect::>() + .into_iter(), + ); // Script null character index = null_index + 1; @@ -83,12 +83,13 @@ impl Deserialize for InitialHandshakePacket { // Increment by index by 4 bytes since we read a u32 index += 4; - let auth_seed = String::from_iter(buf[index..index+8] - .iter() - .map(|b| char::from(b.clone())) - .collect::>() - .into_iter() - ); + let auth_seed = String::from_iter( + buf[index..index + 8] + .iter() + .map(|b| char::from(b.clone())) + .collect::>() + .into_iter(), + ); index += 8; // Skip reserved byte @@ -102,13 +103,13 @@ impl Deserialize for InitialHandshakePacket { let status = LittleEndian::read_u16(&buf[index..]); index += 2; - + capabilities |= LittleEndian::read_u16(&buf[index..]) as u32; index += 2; - let mut plugin_data_length = None; - if capabilities as u128 & Capabilities::PluginAuth as u128> 0 { - plugin_data_length = Some(buf[index] as u8); + let mut plugin_data_length = 0; + if capabilities as u128 & Capabilities::PluginAuth as u128 > 0 { + plugin_data_length = buf[index] as u8; } index += 1; @@ -124,8 +125,35 @@ impl Deserialize for InitialHandshakePacket { let mut auth_plugin_name: Option = None; if capabilities as u128 & Capabilities::SecureConnection as u128 > 0 { // TODO: scramble 2nd part. Length = max(12, plugin_data_length - 9) + let len = max(12, plugin_data_length - 9); + scramble2 = Some(String::from_iter( + buf[index..index + len] + .iter() + .map(|b| char::from(b.clone())) + .collect::>() + .into_iter(), + )); + // Skip length characters + the reserved byte + index += len + 1; } else { // TODO: auth_plugin_name null temrinated string + // Find index of null character + null_index = index; + loop { + if buf[null_index] == b'\0' { + break; + } + null_index += 1; + } + auth_plugin_name = Some(String::from_iter( + buf[index..null_index] + .iter() + .map(|b| char::from(b.clone())) + .collect::>() + .into_iter(), + )); + // Script null character + index = null_index + 1; } Ok(InitialHandshakePacket::default())