diff --git a/src/mariadb/protocol/client.rs b/src/mariadb/protocol/client.rs index 6a1812a0..adfcada3 100644 --- a/src/mariadb/protocol/client.rs +++ b/src/mariadb/protocol/client.rs @@ -7,8 +7,13 @@ // TODO: Handle different Capabilities for server and client // TODO: Handle when capability is set, but field is None -use super::packets::{com_set_option::SetOptionOptions, com_shutdown::ShutdownOptions}; +use super::packets::{SetOptionOptions, ShutdownOptions}; +// This is an enum of text protocol packet tags. +// Tags are the 5th byte of the packet (1st byte of packet body) +// and are used to determine which type of query was sent. +// The name of the enum variant represents the type of query, and +// the value is the byte value required by the server. pub enum TextProtocol { ComChangeUser = 0x11, ComDebug = 0x0D, @@ -24,20 +29,9 @@ pub enum TextProtocol { ComStatistics = 0x09, } +// Helper method to easily transform into u8 impl Into for TextProtocol { fn into(self) -> u8 { self as u8 } } - -impl Into for SetOptionOptions { - fn into(self) -> u16 { - self as u16 - } -} - -impl Into for ShutdownOptions { - fn into(self) -> u8 { - self as u8 - } -} diff --git a/src/mariadb/protocol/decode.rs b/src/mariadb/protocol/decode.rs index 9dfd2446..be8a630e 100644 --- a/src/mariadb/protocol/decode.rs +++ b/src/mariadb/protocol/decode.rs @@ -5,16 +5,21 @@ use failure::{err_msg, Error}; // Deserializing bytes and string do the same thing. Except that string also has a null terminated deserialzer use super::packets::packet_header::PacketHeader; +// This is a simple wrapper around Bytes to make decoding easier +// since the index is always tracked pub struct Decoder<'a> { pub buf: &'a Bytes, pub index: usize, } impl<'a> Decoder<'a> { + // Create a new Decoder from an existing Bytes pub fn new(buf: &'a Bytes) -> Self { Decoder { buf, index: 0 } } + // Decode length from a packet + // Length is the first 3 bytes of the packet in little endian format #[inline] pub fn decode_length(&mut self) -> Result { let length = self.decode_int_3(); @@ -26,6 +31,9 @@ impl<'a> Decoder<'a> { Ok(length) } + // Helper method to get the tag of the packet. The tag is the 5th byte in the packet. It's not guaranteed + // to exist or to be used for each packet. NOTE: Peeking at a tag DOES NOT increment index. This is used + // to determine which type of packet was received before attempting to decode. #[inline] pub fn peek_tag(&self) -> Option<&u8> { if self.buf.len() < self.index + 4 { @@ -35,6 +43,9 @@ impl<'a> Decoder<'a> { } } + // Helper method to get the packet header. The packet header consist of the length (3 bytes) and + // sequence number (1 byte). NOTE: Peeking a packet_header DOES NOT increment index. This is used + // to determine if the packet is read to decode without starting the decoding process. #[inline] pub fn peek_packet_header(&self) -> Result { let length: u32 = (self.buf[self.index] as u32) + ((self.buf[self.index + 1] as u32) << 8) + ((self.buf[self.index + 2] as u32) << 16); @@ -47,21 +58,22 @@ impl<'a> Decoder<'a> { Ok(PacketHeader { length, seq_no }) } + // Helper method to skip bytes via incrementing index. This is used because some packets have + // "unused" bytes. #[inline] pub fn skip_bytes(&mut self, amount: usize) { self.index += amount; } - #[inline] - pub fn eof(&self) -> bool { - self.buf.len() == self.index - } - - #[inline] - pub fn eof_byte(&self) -> bool { - self.buf[self.index] == 0xFE - } - + // Deocde an int which is a length encoded int. + // The first byte of the int determines the length of the int. + // If the first byte is + // 0xFB then the int is "NULL" or None in Rust terms. + // 0xFC then the following 2 bytes are the int value u16. + // 0xFD then the following 3 bytes are the int value u24. + // 0xFE then the following 8 bytes are teh int value u64. + // 0xFF then there was an error. + // If the first byte is not in the previous list then that byte is the int value. #[inline] pub fn decode_int_lenenc(&mut self) -> Option { match self.buf[self.index] { @@ -93,6 +105,7 @@ impl<'a> Decoder<'a> { } } + // Decode an int<8> which is a u64 #[inline] pub fn decode_int_8(&mut self) -> u64 { let value = LittleEndian::read_u64(&self.buf[self.index..]); @@ -100,6 +113,7 @@ impl<'a> Decoder<'a> { value } + // Decode an int<4> which is a u32 #[inline] pub fn decode_int_4(&mut self) -> u32 { let value = LittleEndian::read_u32(&self.buf[self.index..]); @@ -107,6 +121,7 @@ impl<'a> Decoder<'a> { value } + // Decode an int<3> which is a u24 #[inline] pub fn decode_int_3(&mut self) -> u32 { let value = LittleEndian::read_u24(&self.buf[self.index..]); @@ -114,6 +129,7 @@ impl<'a> Decoder<'a> { value } + // Decode an int<2> which is a u16 #[inline] pub fn decode_int_2(&mut self) -> u16 { let value = LittleEndian::read_u16(&self.buf[self.index..]); @@ -121,6 +137,7 @@ impl<'a> Decoder<'a> { value } + // Decode an int<1> which is a u8 #[inline] pub fn decode_int_1(&mut self) -> u8 { let value = self.buf[self.index]; @@ -128,6 +145,8 @@ impl<'a> Decoder<'a> { value } + // Decode a string which is a length encoded string. First decode an int to get + // the length of the string, and the the following n bytes are the contents. #[inline] pub fn decode_string_lenenc(&mut self) -> Bytes { let length = self.decode_int_lenenc().unwrap_or(0usize); @@ -136,6 +155,7 @@ impl<'a> Decoder<'a> { value } + // Decode a string which is a string of fixed length. #[inline] pub fn decode_string_fix(&mut self, length: u32) -> Bytes { let value = self.buf.slice(self.index, self.index + length as usize); @@ -143,6 +163,7 @@ impl<'a> Decoder<'a> { value } + // Decode a string which is a string which is terminated byte the end of the packet. #[inline] pub fn decode_string_eof(&mut self, length: Option) -> Bytes { let value = self.buf.slice(self.index, if let Some(len) = length { @@ -158,6 +179,7 @@ impl<'a> Decoder<'a> { value } + // Decode a string which is a null terminated string (C style string). #[inline] pub fn decode_string_null(&mut self) -> Result { if let Some(null_index) = memchr::memchr(0, &self.buf[self.index..]) { @@ -169,6 +191,7 @@ impl<'a> Decoder<'a> { } } + // Same as the string counter part, but copied to maintain consistency with the spec. #[inline] pub fn decode_byte_fix(&mut self, length: u32) -> Bytes { let value = self.buf.slice(self.index, self.index + length as usize); @@ -176,6 +199,7 @@ impl<'a> Decoder<'a> { value } + // Same as the string counter part, but copied to maintain consistency with the spec. #[inline] pub fn decode_byte_lenenc(&mut self) -> Bytes { let length = self.decode_int_1(); @@ -184,6 +208,7 @@ impl<'a> Decoder<'a> { value } + // Same as the string counter part, but copied to maintain consistency with the spec. #[inline] pub fn decode_byte_eof(&mut self, length: Option) -> Bytes { let value = self.buf.slice(self.index, if let Some(len) = length { diff --git a/src/mariadb/protocol/deserialize.rs b/src/mariadb/protocol/deserialize.rs index d72c77b4..22a48177 100644 --- a/src/mariadb/protocol/deserialize.rs +++ b/src/mariadb/protocol/deserialize.rs @@ -3,6 +3,10 @@ use crate::mariadb::connection::{ConnContext, Connection}; use bytes::Bytes; use failure::Error; +// A wrapper around a connection context to prevent +// deserializers from touching the stream, yet still have +// access to the connection context. +// Mainly used to simply to simplify number of parameters for deserializing functions pub struct DeContext<'a> { pub conn: &'a mut ConnContext, pub decoder: Decoder<'a>, diff --git a/src/mariadb/protocol/encode.rs b/src/mariadb/protocol/encode.rs index 0810c1b5..69ca9088 100644 --- a/src/mariadb/protocol/encode.rs +++ b/src/mariadb/protocol/encode.rs @@ -3,15 +3,18 @@ use bytes::{BufMut, Bytes, BytesMut}; const U24_MAX: usize = 0xFF_FF_FF; +// A simple wrapper around a BytesMut to easily encode values pub struct Encoder { pub buf: BytesMut, } impl Encoder { + // Create a new Encoder with a given capacity pub fn new(capacity: usize) -> Self { Encoder { buf: BytesMut::with_capacity(capacity) } } + // Clears the encoding buffer pub fn clear(&mut self) { self.buf.clear(); } @@ -22,11 +25,13 @@ impl Encoder { self.buf.extend_from_slice(&[0; 4]); } + // Encode the sequence number; the 4th byte of the packet #[inline] pub fn seq_no(&mut self, seq_no: u8) { self.buf[3] = seq_no; } + // Encode the sequence number; the first 3 bytes of the packet in little endian format #[inline] pub fn encode_length(&mut self) { let mut length = [0; 3]; @@ -45,46 +50,56 @@ impl Encoder { self.buf[2] = length[2]; } + // Encode a u64 as an int<8> #[inline] pub fn encode_int_8(&mut self, value: u64) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode a u32 as an int<4> #[inline] pub fn encode_int_4(&mut self, value: u32) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode a u32 (truncated to u24) as an int<3> #[inline] pub fn encode_int_3(&mut self, value: u32) { self.buf.extend_from_slice(&value.to_le_bytes()[0..3]); } + // Encode a u16 as an int<2> #[inline] pub fn encode_int_2(&mut self, value: u16) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode a u8 as an int<1> #[inline] pub fn encode_int_1(&mut self, value: u8) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode an int; length encoded int + // See Decoder::decode_int_lenenc for explanation of how int is encoded #[inline] pub fn encode_int_lenenc(&mut self, value: Option<&usize>) { if let Some(value) = value { if *value > U24_MAX && *value <= std::u64::MAX as usize { self.buf.put_u8(0xFE); self.encode_int_8(*value as u64); + } else if *value > std::u16::MAX as usize && *value <= U24_MAX { self.buf.put_u8(0xFD); self.encode_int_3(*value as u32); + } else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize { self.buf.put_u8(0xFC); self.encode_int_2(*value as u16); + } else if *value <= std::u8::MAX as usize { - self.buf.put_u8(0xFA); self.encode_int_1(*value as u8); + } else { panic!("Value is too long"); } @@ -93,24 +108,27 @@ impl Encoder { } } + // Encode a string; a length encoded string. #[inline] pub fn encode_string_lenenc(&mut self, string: &Bytes) { if string.len() > 0xFFF { panic!("String inside string lenenc serialization is too long"); } - self.encode_int_3(string.len() as u32); + self.encode_int_lenenc(Some(&string.len())); if string.len() > 0 { self.buf.extend_from_slice(string); } } + // Encode a string; a null termianted string (C style) #[inline] pub fn encode_string_null(&mut self, string: &Bytes) { self.buf.extend_from_slice(string); self.buf.put(0_u8); } + // Encode a string; a string of fixed length #[inline] pub fn encode_string_fix(&mut self, bytes: &Bytes, size: usize) { if size != bytes.len() { @@ -120,11 +138,13 @@ impl Encoder { self.buf.extend_from_slice(bytes); } + // Encode a string; a string that is terminated by the packet length #[inline] pub fn encode_string_eof(&mut self, bytes: &Bytes) { self.buf.extend_from_slice(bytes); } + // Same as the string counterpart copied to maintain consistency with the spec. #[inline] pub fn encode_byte_lenenc(&mut self, bytes: &Bytes) { if bytes.len() > 0xFFF { @@ -135,6 +155,7 @@ impl Encoder { self.buf.extend_from_slice(bytes); } + // Same as the string counterpart copied to maintain consistency with the spec. #[inline] pub fn encode_byte_fix(&mut self, bytes: &Bytes, size: usize) { if size != bytes.len() { @@ -144,6 +165,7 @@ impl Encoder { self.buf.extend_from_slice(bytes); } + // Same as the string counterpart copied to maintain consistency with the spec. #[inline] pub fn encode_byte_eof(&mut self, bytes: &Bytes) { self.buf.extend_from_slice(bytes); diff --git a/src/mariadb/protocol/packets/com_set_option.rs b/src/mariadb/protocol/packets/com_set_option.rs index 3eef465f..b4ada05a 100644 --- a/src/mariadb/protocol/packets/com_set_option.rs +++ b/src/mariadb/protocol/packets/com_set_option.rs @@ -20,3 +20,10 @@ impl Serialize for ComSetOption { Ok(()) } } + +// Helper method to easily transform into u16 +impl Into for SetOptionOptions { + fn into(self) -> u16 { + self as u16 + } +} diff --git a/src/mariadb/protocol/packets/com_shutdown.rs b/src/mariadb/protocol/packets/com_shutdown.rs index 4c5925e0..f52a3741 100644 --- a/src/mariadb/protocol/packets/com_shutdown.rs +++ b/src/mariadb/protocol/packets/com_shutdown.rs @@ -19,3 +19,10 @@ impl Serialize for ComShutdown { Ok(()) } } + +// Helper method to easily transform into u8 +impl Into for ShutdownOptions { + fn into(self) -> u8 { + self as u8 + } +} diff --git a/src/mariadb/protocol/packets/mod.rs b/src/mariadb/protocol/packets/mod.rs index addbe9b0..e1998923 100644 --- a/src/mariadb/protocol/packets/mod.rs +++ b/src/mariadb/protocol/packets/mod.rs @@ -21,3 +21,29 @@ pub mod packet_header; pub mod result_set; pub mod ssl_request; pub mod result_row; + +pub use auth_switch_request::AuthenticationSwitchRequestPacket; +pub use column::ColumnPacket; +pub use column_def::ColumnDefPacket; +pub use com_debug::ComDebug; +pub use com_init_db::ComInitDb; +pub use com_ping::ComPing; +pub use com_process_kill::ComProcessKill; +pub use com_query::ComQuery; +pub use com_quit::ComQuit; +pub use com_reset_conn::ComResetConnection; +pub use com_set_option::ComSetOption; +pub use com_set_option::SetOptionOptions; +pub use com_shutdown::ShutdownOptions; +pub use com_shutdown::ComShutdown; +pub use com_sleep::ComSleep; +pub use com_statistics::ComStatistics; +pub use eof::EofPacket; +pub use err::ErrPacket; +pub use handshake_response::HandshakeResponsePacket; +pub use initial::InitialHandshakePacket; +pub use ok::OkPacket; +pub use packet_header::PacketHeader; +pub use result_set::ResultSet; +pub use result_row::ResultRow; +pub use ssl_request::SSLRequestPacket;