diff --git a/mason-mariadb/src/protocol/client.rs b/mason-mariadb/src/protocol/client.rs index 78691e80..ebbc8893 100644 --- a/mason-mariadb/src/protocol/client.rs +++ b/mason-mariadb/src/protocol/client.rs @@ -10,6 +10,7 @@ pub trait Serialize { #[derive(Default, Debug)] pub struct SSLRequestPacket { + pub sequence_number: u8, pub capabilities: Capabilities, pub max_packet_size: u32, pub collation: u8, @@ -18,13 +19,24 @@ pub struct SSLRequestPacket { impl Serialize for SSLRequestPacket { fn serialize(&self, buf: &mut Vec) { - // FIXME: Prepend length of packet in standard packet form // https://mariadb.com/kb/en/library/0-packet - // buf.push(32); + + // Temporary storage for length: 3 bytes + buf.push(0); + buf.push(0); + buf.push(0); + + // Sequence Numer + buf.push(0); + LittleEndian::write_u32(buf, self.capabilities.bits() as u32); + LittleEndian::write_u32(buf, self.max_packet_size); + buf.push(self.collation); + buf.extend_from_slice(&[0u8;19]); + if !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { LittleEndian::write_u32(buf, capabilities.bits() as u32); @@ -32,5 +44,11 @@ impl Serialize for SSLRequestPacket { } else { buf.extend_from_slice(&[0u8;4]); } + + // Get length in little endian bytes + // packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16) + buf[0] = buf.len().to_le_bytes()[0]; + buf[1] = buf.len().to_le_bytes()[1]; + buf[2] = buf.len().to_le_bytes()[2]; } } diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index ba6e7ca4..d4e075ee 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -1,8 +1,7 @@ // Reference: https://mariadb.com/kb/en/library/connection use byteorder::{ByteOrder, LittleEndian}; -use failure::Error; -use std::iter::FromIterator; +use failure::{Error, err_msg}; use bytes::Bytes; pub trait Deserialize: Sized { @@ -51,6 +50,8 @@ impl Default for Capabilities { #[derive(Default, Debug)] pub struct InitialHandshakePacket { + pub length: u32, + pub sequence_number: u8, pub protocol_version: u8, pub server_version: Bytes, pub connection_id: u32, @@ -66,7 +67,22 @@ pub struct InitialHandshakePacket { impl Deserialize for InitialHandshakePacket { fn deserialize(buf: &mut Vec) -> Result { let mut index = 0; - let protocol_version = buf[0] as u8; + + let length = (buf[0] + (buf[1]<<8) + (buf[2]<<16)) as u32; + index += 3; + + if buf.len() != length as usize { + return Err(err_msg("Lengths to do not match")); + } + + let sequence_number = buf[index]; + index += 1; + + if sequence_number != 0 { + return Err(err_msg("Squence Number of Initial Handshake Packet is not 0")); + } + + let protocol_version = buf[index] as u8; index += 1; let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap(); @@ -119,6 +135,8 @@ impl Deserialize for InitialHandshakePacket { } Ok(InitialHandshakePacket { + length, + sequence_number, protocol_version, server_version, connection_id,