diff --git a/mason-mariadb/src/protocol/client.rs b/mason-mariadb/src/protocol/client.rs index 40e17241..6e353d07 100644 --- a/mason-mariadb/src/protocol/client.rs +++ b/mason-mariadb/src/protocol/client.rs @@ -55,31 +55,27 @@ impl Serialize for SSLRequestPacket { fn serialize(&self, buf: &mut Vec) { // Temporary storage for length: 3 bytes buf.write_u24::(0); + // Sequence Number + serialize_int_1(buf, self.sequence_number); - // Sequence Numer - buf.push(self.sequence_number); + // Packet body + serialize_int_4(buf, self.capabilities.bits() as u32); + serialize_int_4(buf, self.max_packet_size); + serialize_int_1(buf, self.collation); - LittleEndian::write_u32(buf, self.capabilities.bits() as u32); + // Filler + serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19); - LittleEndian::write_u32(buf, self.max_packet_size); - - buf.push(self.collation); - - buf.extend_from_slice(&[0u8;19]); - - if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() { + if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() && + !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { - LittleEndian::write_u32(buf, capabilities.bits() as u32); + serialize_int_4(buf, capabilities.bits() as u32); } } else { - buf.extend_from_slice(&[0u8;4]); + serialize_byte_fix(buf, &Bytes::from_static(&[0u8;4]), 4); } - - // Get length in little endian bytes - // packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16) - buf[0] = buf.len().to_le_bytes()[0]; - buf[1] = buf.len().to_le_bytes()[1]; - buf[2] = buf.len().to_le_bytes()[2]; + + // Set packet length serialize_length(buf); } } @@ -87,121 +83,87 @@ impl Serialize for SSLRequestPacket { impl Serialize for HandshakeResponsePacket { fn serialize(&self, buf: &mut Vec) { // Temporary storage for length: 3 bytes - buf.push(0); - buf.push(0); - buf.push(0); + buf.write_u24::(0); + // Sequence Number + serialize_int_1(buf, self.sequence_number); - // Sequence Numer - buf.push(self.sequence_number); + // Packet body + serialize_int_4(buf, self.capabilities.bits() as u32); + serialize_int_4(buf, self.max_packet_size); + serialize_int_1(buf, self.collation); - LittleEndian::write_u32(buf, self.capabilities.bits() as u32); + // Filler + serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19); - LittleEndian::write_u32(buf, self.max_packet_size); - - buf.push(self.collation); - - buf.extend_from_slice(&[0u8;19]); - - if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() { + if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() && + !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { - LittleEndian::write_u32(buf, capabilities.bits() as u32); + serialize_int_4(buf, capabilities.bits() as u32); } } else { - buf.extend_from_slice(&[0u8;4]); + serialize_byte_fix(buf, &Bytes::from_static(&[0u8;4]), 4); } - // Username: string - buf.extend_from_slice(&self.username); - buf.push(0); + serialize_string_null(buf, &self.username); if !(self.server_capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() { if let Some(auth_data) = &self.auth_data { - // string - buf.push(auth_data.len().to_le_bytes()[0]); - buf.push(auth_data.len().to_le_bytes()[1]); - buf.push(auth_data.len().to_le_bytes()[2]); - buf.extend_from_slice(&auth_data); + serialize_string_lenenc(buf, &auth_data); } } else if !(self.server_capabilities & Capabilities::SECURE_CONNECTION).is_empty() { if let Some(auth_response) = &self.auth_response { - buf.push(self.auth_response_len.unwrap()); - buf.extend_from_slice(&auth_response); + serialize_int_1(buf, self.auth_response_len.unwrap()); + serialize_string_fix(buf, &auth_response, self.auth_response_len.unwrap() as usize); } } else { - buf.push(0); + serialize_int_1(buf, 0); } if !(self.server_capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { if let Some(database) = &self.database { // string - buf.extend_from_slice(&database); - buf.push(0); + serialize_string_null(buf, &database); } } if !(self.server_capabilities & Capabilities::PLUGIN_AUTH).is_empty() { if let Some(auth_plugin_name) = &self.auth_plugin_name { // string - buf.extend_from_slice(&auth_plugin_name); - buf.push(0); + serialize_string_null(buf, &auth_plugin_name); } } if !(self.server_capabilities & Capabilities::CONNECT_ATTRS).is_empty() { if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) { // int - buf.push(conn_attr_len.to_le_bytes().len().to_le_bytes()[0]); - buf.extend_from_slice(&conn_attr_len.to_le_bytes()); + serialize_int_lenenc(buf, Some(conn_attr_len)); // Loop for (key, value) in conn_attr { - // string - buf.push(key.len().to_le_bytes()[0]); - buf.push(key.len().to_le_bytes()[1]); - buf.push(key.len().to_le_bytes()[2]); - buf.extend_from_slice(&key); - - // string - buf.push(value.len().to_le_bytes()[0]); - buf.push(value.len().to_le_bytes()[1]); - buf.push(value.len().to_le_bytes()[2]); - buf.extend_from_slice(&value); + serialize_string_lenenc(buf, &key); + serialize_string_lenenc(buf, &value); } } } - // Get length in little endian bytes - // packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16) - buf[0] = buf.len().to_le_bytes()[0]; - buf[1] = buf.len().to_le_bytes()[1]; - buf[2] = buf.len().to_le_bytes()[2]; + // Set packet length + serialize_length(buf); } } impl Serialize for AuthenticationSwitchRequestPacket { fn serialize(&self, buf: &mut Vec) { // Temporary storage for length: 3 bytes - buf.push(0); - buf.push(0); - buf.push(0); + buf.write_u24::(0); + // Sequence Number + serialize_int_1(buf, self.sequence_number); - // Sequence Numer - buf.push(self.sequence_number); + // Packet body + serialize_int_1(buf, 0xFE); + serialize_string_null(buf, &self.auth_plugin_name); + serialize_byte_eof(buf, &self.auth_plugin_data); - // Authentication Switch Request Header - // int<1> - buf.push(0xFE); - - // string - buf.extend_from_slice(&self.auth_plugin_name); - buf.push(0); - - buf.extend_from_slice(&self.auth_plugin_data); - - // Get length in little endian bytes - // packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16) - buf[0] = buf.len().to_le_bytes()[0]; - buf[1] = buf.len().to_le_bytes()[1]; - buf[2] = buf.len().to_le_bytes()[2]; + // Set packet length + serialize_length(buf); } } diff --git a/mason-mariadb/src/protocol/serialize.rs b/mason-mariadb/src/protocol/serialize.rs index 15e04d08..85bcf06f 100644 --- a/mason-mariadb/src/protocol/serialize.rs +++ b/mason-mariadb/src/protocol/serialize.rs @@ -49,20 +49,20 @@ pub fn serialize_int_1(buf: &mut Vec, value: u8) { } #[inline] -pub fn serialize_int_lenenc(buf: &mut Vec, value: Option) { +pub fn serialize_int_lenenc(buf: &mut Vec, value: Option<&usize>) { if let Some(value) = value { - if value > U24_MAX && value <= std::u64::MAX as usize{ + if *value > U24_MAX && *value <= std::u64::MAX as usize{ buf.write_u8(0xFE); - serialize_int_8(buf, value as u64); - } else if value > std::u16::MAX as usize && value <= U24_MAX { + serialize_int_8(buf, *value as u64); + } else if *value > std::u16::MAX as usize && *value <= U24_MAX { buf.write_u8(0xFD); - serialize_int_3(buf, value as u32); - } else if value > std::u8::MAX as usize && value <= std::u16::MAX as usize{ + serialize_int_3(buf, *value as u32); + } else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize{ buf.write_u8(0xFC); - serialize_int_2(buf, value as u16); - } else if value >= 0 && value <= std::u8::MAX as usize { + serialize_int_2(buf, *value as u16); + } else if *value >= 0 && *value <= std::u8::MAX as usize { buf.write_u8(0xFA); - serialize_int_1(buf, value as u8); + serialize_int_1(buf, *value as u8); } else { panic!("Value is too long"); } @@ -72,35 +72,35 @@ pub fn serialize_int_lenenc(buf: &mut Vec, value: Option) { } #[inline] -pub fn serialize_string_lenenc(buf: &mut Vec, string: &'static str) { +pub fn serialize_string_lenenc(buf: &mut Vec, string: &Bytes) { if string.len() > 0xFFF { panic!("String inside string lenenc serialization is too long"); } serialize_int_3(buf, string.len() as u32); if string.len() > 0 { - buf.extend_from_slice(string.as_bytes()); + buf.extend_from_slice(string); } } #[inline] -pub fn serialize_string_fix(buf: &mut Vec, string: &'static str, size: usize) { - if size != string.len() { - panic!("Sizes do not match"); - } - buf.extend_from_slice(string.as_bytes()); -} - -#[inline] -pub fn serialize_string_null(buf: &mut Vec, string: &'static str) { - buf.extend_from_slice(string.as_bytes()); +pub fn serialize_string_null(buf: &mut Vec, string: &Bytes) { + buf.extend_from_slice(string); buf.write_u8(0); } #[inline] -pub fn serialize_string_eof(buf: &mut Vec, string: &'static str) { - // Ignore the null terminator - buf.extend_from_slice(string.as_bytes()); +pub fn serialize_string_fix(buf: &mut Vec, bytes: &Bytes, size: usize) { + if size != bytes.len() { + panic!("Sizes do not match"); + } + + buf.extend_from_slice(bytes); +} + +#[inline] +pub fn serialize_string_eof(buf: &mut Vec, bytes: &Bytes) { + buf.extend_from_slice(bytes); } #[inline] @@ -175,7 +175,7 @@ mod tests { #[test] fn it_encodes_int_lenenc_u8() { let mut buf: Vec = Vec::new(); - serialize_int_lenenc(&mut buf, Some(std::u8::MAX as usize)); + serialize_int_lenenc(&mut buf, Some(&(std::u8::MAX as usize))); assert_eq!(buf, b"\xFA\xFF".to_vec()); } @@ -183,7 +183,7 @@ mod tests { #[test] fn it_encodes_int_lenenc_u16() { let mut buf: Vec = Vec::new(); - serialize_int_lenenc(&mut buf, Some(std::u16::MAX as usize)); + serialize_int_lenenc(&mut buf, Some(&(std::u16::MAX as usize))); assert_eq!(buf, b"\xFC\xFF\xFF".to_vec()); } @@ -191,7 +191,7 @@ mod tests { #[test] fn it_encodes_int_lenenc_u24() { let mut buf: Vec = Vec::new(); - serialize_int_lenenc(&mut buf, Some(U24_MAX)); + serialize_int_lenenc(&mut buf, Some(&U24_MAX)); assert_eq!(buf, b"\xFD\xFF\xFF\xFF".to_vec()); } @@ -199,7 +199,7 @@ mod tests { #[test] fn it_encodes_int_lenenc_u64() { let mut buf: Vec = Vec::new(); - serialize_int_lenenc(&mut buf, Some(std::u64::MAX as usize)); + serialize_int_lenenc(&mut buf, Some(&(std::u64::MAX as usize))); assert_eq!(buf, b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".to_vec()); } @@ -251,7 +251,7 @@ mod tests { #[test] fn it_encodes_string_lenenc() { let mut buf: Vec = Vec::new(); - serialize_string_lenenc(&mut buf, "random_string"); + serialize_string_lenenc(&mut buf, &Bytes::from_static(b"random_string")); assert_eq!(buf, b"\x0D\x00\x00random_string".to_vec()); } @@ -259,7 +259,7 @@ mod tests { #[test] fn it_encodes_string_fix() { let mut buf: Vec = Vec::new(); - serialize_string_fix(&mut buf, "random_string", 13); + serialize_string_fix(&mut buf, &Bytes::from_static(b"random_string"), 13); assert_eq!(buf, b"random_string".to_vec()); } @@ -267,7 +267,7 @@ mod tests { #[test] fn it_encodes_string_null() { let mut buf: Vec = Vec::new(); - serialize_string_null(&mut buf, "random_string"); + serialize_string_null(&mut buf, &Bytes::from_static(b"random_string")); assert_eq!(buf, b"random_string\0".to_vec()); } @@ -276,7 +276,7 @@ mod tests { #[test] fn it_encodes_string_eof() { let mut buf: Vec = Vec::new(); - serialize_string_eof(&mut buf, "random_string"); + serialize_string_eof(&mut buf, &Bytes::from_static(b"random_string")); assert_eq!(buf, b"random_string".to_vec()); } diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index 9b0c3e73..0009d3e8 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -19,11 +19,11 @@ pub enum Message { bitflags! { pub struct Capabilities: u128 { const CLIENT_MYSQL = 1; - const FOUND_ROWS = 2; - const CONNECT_WITH_DB = 8; - const COMPRESS = 32; - const LOCAL_FILES = 128; - const IGNORE_SPACE = 256; + const FOUND_ROWS = 1 << 1; + const CONNECT_WITH_DB = 1 << 3; + const COMPRESS = 1 << 5; + const LOCAL_FILES = 1 << 7; + const IGNORE_SPACE = 1 << 8; const CLIENT_PROTOCOL_41 = 1 << 9; const CLIENT_INTERACTIVE = 1 << 10; const SSL = 1 << 11;