Update client messages to use serialization functions

This commit is contained in:
Daniel Akhterov 2019-06-23 20:57:21 -07:00 committed by Daniel Akhterov
parent eb3cd5f4ee
commit 3af0dc08d5
3 changed files with 85 additions and 123 deletions

View File

@ -55,31 +55,27 @@ impl Serialize for SSLRequestPacket {
fn serialize(&self, buf: &mut Vec<u8>) { fn serialize(&self, buf: &mut Vec<u8>) {
// Temporary storage for length: 3 bytes // Temporary storage for length: 3 bytes
buf.write_u24::<LittleEndian>(0); buf.write_u24::<LittleEndian>(0);
// Sequence Number
serialize_int_1(buf, self.sequence_number);
// Sequence Numer // Packet body
buf.push(self.sequence_number); serialize_int_4(buf, self.capabilities.bits() as u32);
serialize_int_4(buf, self.max_packet_size);
serialize_int_1(buf, self.collation);
LittleEndian::write_u32(buf, self.capabilities.bits() as u32); // Filler
serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19);
LittleEndian::write_u32(buf, self.max_packet_size); if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() &&
!(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
buf.push(self.collation);
buf.extend_from_slice(&[0u8;19]);
if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
if let Some(capabilities) = self.extended_capabilities { if let Some(capabilities) = self.extended_capabilities {
LittleEndian::write_u32(buf, capabilities.bits() as u32); serialize_int_4(buf, capabilities.bits() as u32);
} }
} else { } else {
buf.extend_from_slice(&[0u8;4]); serialize_byte_fix(buf, &Bytes::from_static(&[0u8;4]), 4);
} }
// Get length in little endian bytes // Set packet length
// packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16)
buf[0] = buf.len().to_le_bytes()[0];
buf[1] = buf.len().to_le_bytes()[1];
buf[2] = buf.len().to_le_bytes()[2];
serialize_length(buf); serialize_length(buf);
} }
} }
@ -87,121 +83,87 @@ impl Serialize for SSLRequestPacket {
impl Serialize for HandshakeResponsePacket { impl Serialize for HandshakeResponsePacket {
fn serialize(&self, buf: &mut Vec<u8>) { fn serialize(&self, buf: &mut Vec<u8>) {
// Temporary storage for length: 3 bytes // Temporary storage for length: 3 bytes
buf.push(0); buf.write_u24::<LittleEndian>(0);
buf.push(0); // Sequence Number
buf.push(0); serialize_int_1(buf, self.sequence_number);
// Sequence Numer // Packet body
buf.push(self.sequence_number); serialize_int_4(buf, self.capabilities.bits() as u32);
serialize_int_4(buf, self.max_packet_size);
serialize_int_1(buf, self.collation);
LittleEndian::write_u32(buf, self.capabilities.bits() as u32); // Filler
serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19);
LittleEndian::write_u32(buf, self.max_packet_size); if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() &&
!(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
buf.push(self.collation);
buf.extend_from_slice(&[0u8;19]);
if !(self.server_capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
if let Some(capabilities) = self.extended_capabilities { if let Some(capabilities) = self.extended_capabilities {
LittleEndian::write_u32(buf, capabilities.bits() as u32); serialize_int_4(buf, capabilities.bits() as u32);
} }
} else { } else {
buf.extend_from_slice(&[0u8;4]); serialize_byte_fix(buf, &Bytes::from_static(&[0u8;4]), 4);
} }
// Username: string<NUL> serialize_string_null(buf, &self.username);
buf.extend_from_slice(&self.username);
buf.push(0);
if !(self.server_capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() { if !(self.server_capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() {
if let Some(auth_data) = &self.auth_data { if let Some(auth_data) = &self.auth_data {
// string<lenenc> serialize_string_lenenc(buf, &auth_data);
buf.push(auth_data.len().to_le_bytes()[0]);
buf.push(auth_data.len().to_le_bytes()[1]);
buf.push(auth_data.len().to_le_bytes()[2]);
buf.extend_from_slice(&auth_data);
} }
} else if !(self.server_capabilities & Capabilities::SECURE_CONNECTION).is_empty() { } else if !(self.server_capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
if let Some(auth_response) = &self.auth_response { if let Some(auth_response) = &self.auth_response {
buf.push(self.auth_response_len.unwrap()); serialize_int_1(buf, self.auth_response_len.unwrap());
buf.extend_from_slice(&auth_response); serialize_string_fix(buf, &auth_response, self.auth_response_len.unwrap() as usize);
} }
} else { } else {
buf.push(0); serialize_int_1(buf, 0);
} }
if !(self.server_capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { if !(self.server_capabilities & Capabilities::CONNECT_WITH_DB).is_empty() {
if let Some(database) = &self.database { if let Some(database) = &self.database {
// string<NUL> // string<NUL>
buf.extend_from_slice(&database); serialize_string_null(buf, &database);
buf.push(0);
} }
} }
if !(self.server_capabilities & Capabilities::PLUGIN_AUTH).is_empty() { if !(self.server_capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
if let Some(auth_plugin_name) = &self.auth_plugin_name { if let Some(auth_plugin_name) = &self.auth_plugin_name {
// string<NUL> // string<NUL>
buf.extend_from_slice(&auth_plugin_name); serialize_string_null(buf, &auth_plugin_name);
buf.push(0);
} }
} }
if !(self.server_capabilities & Capabilities::CONNECT_ATTRS).is_empty() { if !(self.server_capabilities & Capabilities::CONNECT_ATTRS).is_empty() {
if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) { if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) {
// int<lenenc> // int<lenenc>
buf.push(conn_attr_len.to_le_bytes().len().to_le_bytes()[0]); serialize_int_lenenc(buf, Some(conn_attr_len));
buf.extend_from_slice(&conn_attr_len.to_le_bytes());
// Loop // Loop
for (key, value) in conn_attr { for (key, value) in conn_attr {
// string<lenenc> serialize_string_lenenc(buf, &key);
buf.push(key.len().to_le_bytes()[0]); serialize_string_lenenc(buf, &value);
buf.push(key.len().to_le_bytes()[1]);
buf.push(key.len().to_le_bytes()[2]);
buf.extend_from_slice(&key);
// string<lenenc>
buf.push(value.len().to_le_bytes()[0]);
buf.push(value.len().to_le_bytes()[1]);
buf.push(value.len().to_le_bytes()[2]);
buf.extend_from_slice(&value);
} }
} }
} }
// Get length in little endian bytes // Set packet length
// packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16) serialize_length(buf);
buf[0] = buf.len().to_le_bytes()[0];
buf[1] = buf.len().to_le_bytes()[1];
buf[2] = buf.len().to_le_bytes()[2];
} }
} }
impl Serialize for AuthenticationSwitchRequestPacket { impl Serialize for AuthenticationSwitchRequestPacket {
fn serialize(&self, buf: &mut Vec<u8>) { fn serialize(&self, buf: &mut Vec<u8>) {
// Temporary storage for length: 3 bytes // Temporary storage for length: 3 bytes
buf.push(0); buf.write_u24::<LittleEndian>(0);
buf.push(0); // Sequence Number
buf.push(0); serialize_int_1(buf, self.sequence_number);
// Sequence Numer // Packet body
buf.push(self.sequence_number); serialize_int_1(buf, 0xFE);
serialize_string_null(buf, &self.auth_plugin_name);
serialize_byte_eof(buf, &self.auth_plugin_data);
// Authentication Switch Request Header // Set packet length
// int<1> serialize_length(buf);
buf.push(0xFE);
// string<NUL>
buf.extend_from_slice(&self.auth_plugin_name);
buf.push(0);
buf.extend_from_slice(&self.auth_plugin_data);
// Get length in little endian bytes
// packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16)
buf[0] = buf.len().to_le_bytes()[0];
buf[1] = buf.len().to_le_bytes()[1];
buf[2] = buf.len().to_le_bytes()[2];
} }
} }

View File

@ -49,20 +49,20 @@ pub fn serialize_int_1(buf: &mut Vec<u8>, value: u8) {
} }
#[inline] #[inline]
pub fn serialize_int_lenenc(buf: &mut Vec<u8>, value: Option<usize>) { pub fn serialize_int_lenenc(buf: &mut Vec<u8>, value: Option<&usize>) {
if let Some(value) = value { if let Some(value) = value {
if value > U24_MAX && value <= std::u64::MAX as usize{ if *value > U24_MAX && *value <= std::u64::MAX as usize{
buf.write_u8(0xFE); buf.write_u8(0xFE);
serialize_int_8(buf, value as u64); serialize_int_8(buf, *value as u64);
} else if value > std::u16::MAX as usize && value <= U24_MAX { } else if *value > std::u16::MAX as usize && *value <= U24_MAX {
buf.write_u8(0xFD); buf.write_u8(0xFD);
serialize_int_3(buf, value as u32); serialize_int_3(buf, *value as u32);
} else if value > std::u8::MAX as usize && value <= std::u16::MAX as usize{ } else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize{
buf.write_u8(0xFC); buf.write_u8(0xFC);
serialize_int_2(buf, value as u16); serialize_int_2(buf, *value as u16);
} else if value >= 0 && value <= std::u8::MAX as usize { } else if *value >= 0 && *value <= std::u8::MAX as usize {
buf.write_u8(0xFA); buf.write_u8(0xFA);
serialize_int_1(buf, value as u8); serialize_int_1(buf, *value as u8);
} else { } else {
panic!("Value is too long"); panic!("Value is too long");
} }
@ -72,35 +72,35 @@ pub fn serialize_int_lenenc(buf: &mut Vec<u8>, value: Option<usize>) {
} }
#[inline] #[inline]
pub fn serialize_string_lenenc(buf: &mut Vec<u8>, string: &'static str) { pub fn serialize_string_lenenc(buf: &mut Vec<u8>, string: &Bytes) {
if string.len() > 0xFFF { if string.len() > 0xFFF {
panic!("String inside string lenenc serialization is too long"); panic!("String inside string lenenc serialization is too long");
} }
serialize_int_3(buf, string.len() as u32); serialize_int_3(buf, string.len() as u32);
if string.len() > 0 { if string.len() > 0 {
buf.extend_from_slice(string.as_bytes()); buf.extend_from_slice(string);
} }
} }
#[inline] #[inline]
pub fn serialize_string_fix(buf: &mut Vec<u8>, string: &'static str, size: usize) { pub fn serialize_string_null(buf: &mut Vec<u8>, string: &Bytes) {
if size != string.len() { buf.extend_from_slice(string);
panic!("Sizes do not match");
}
buf.extend_from_slice(string.as_bytes());
}
#[inline]
pub fn serialize_string_null(buf: &mut Vec<u8>, string: &'static str) {
buf.extend_from_slice(string.as_bytes());
buf.write_u8(0); buf.write_u8(0);
} }
#[inline] #[inline]
pub fn serialize_string_eof(buf: &mut Vec<u8>, string: &'static str) { pub fn serialize_string_fix(buf: &mut Vec<u8>, bytes: &Bytes, size: usize) {
// Ignore the null terminator if size != bytes.len() {
buf.extend_from_slice(string.as_bytes()); panic!("Sizes do not match");
}
buf.extend_from_slice(bytes);
}
#[inline]
pub fn serialize_string_eof(buf: &mut Vec<u8>, bytes: &Bytes) {
buf.extend_from_slice(bytes);
} }
#[inline] #[inline]
@ -175,7 +175,7 @@ mod tests {
#[test] #[test]
fn it_encodes_int_lenenc_u8() { fn it_encodes_int_lenenc_u8() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_int_lenenc(&mut buf, Some(std::u8::MAX as usize)); serialize_int_lenenc(&mut buf, Some(&(std::u8::MAX as usize)));
assert_eq!(buf, b"\xFA\xFF".to_vec()); assert_eq!(buf, b"\xFA\xFF".to_vec());
} }
@ -183,7 +183,7 @@ mod tests {
#[test] #[test]
fn it_encodes_int_lenenc_u16() { fn it_encodes_int_lenenc_u16() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_int_lenenc(&mut buf, Some(std::u16::MAX as usize)); serialize_int_lenenc(&mut buf, Some(&(std::u16::MAX as usize)));
assert_eq!(buf, b"\xFC\xFF\xFF".to_vec()); assert_eq!(buf, b"\xFC\xFF\xFF".to_vec());
} }
@ -191,7 +191,7 @@ mod tests {
#[test] #[test]
fn it_encodes_int_lenenc_u24() { fn it_encodes_int_lenenc_u24() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_int_lenenc(&mut buf, Some(U24_MAX)); serialize_int_lenenc(&mut buf, Some(&U24_MAX));
assert_eq!(buf, b"\xFD\xFF\xFF\xFF".to_vec()); assert_eq!(buf, b"\xFD\xFF\xFF\xFF".to_vec());
} }
@ -199,7 +199,7 @@ mod tests {
#[test] #[test]
fn it_encodes_int_lenenc_u64() { fn it_encodes_int_lenenc_u64() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_int_lenenc(&mut buf, Some(std::u64::MAX as usize)); serialize_int_lenenc(&mut buf, Some(&(std::u64::MAX as usize)));
assert_eq!(buf, b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".to_vec()); assert_eq!(buf, b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF".to_vec());
} }
@ -251,7 +251,7 @@ mod tests {
#[test] #[test]
fn it_encodes_string_lenenc() { fn it_encodes_string_lenenc() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_string_lenenc(&mut buf, "random_string"); serialize_string_lenenc(&mut buf, &Bytes::from_static(b"random_string"));
assert_eq!(buf, b"\x0D\x00\x00random_string".to_vec()); assert_eq!(buf, b"\x0D\x00\x00random_string".to_vec());
} }
@ -259,7 +259,7 @@ mod tests {
#[test] #[test]
fn it_encodes_string_fix() { fn it_encodes_string_fix() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_string_fix(&mut buf, "random_string", 13); serialize_string_fix(&mut buf, &Bytes::from_static(b"random_string"), 13);
assert_eq!(buf, b"random_string".to_vec()); assert_eq!(buf, b"random_string".to_vec());
} }
@ -267,7 +267,7 @@ mod tests {
#[test] #[test]
fn it_encodes_string_null() { fn it_encodes_string_null() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_string_null(&mut buf, "random_string"); serialize_string_null(&mut buf, &Bytes::from_static(b"random_string"));
assert_eq!(buf, b"random_string\0".to_vec()); assert_eq!(buf, b"random_string\0".to_vec());
} }
@ -276,7 +276,7 @@ mod tests {
#[test] #[test]
fn it_encodes_string_eof() { fn it_encodes_string_eof() {
let mut buf: Vec<u8> = Vec::new(); let mut buf: Vec<u8> = Vec::new();
serialize_string_eof(&mut buf, "random_string"); serialize_string_eof(&mut buf, &Bytes::from_static(b"random_string"));
assert_eq!(buf, b"random_string".to_vec()); assert_eq!(buf, b"random_string".to_vec());
} }

View File

@ -19,11 +19,11 @@ pub enum Message {
bitflags! { bitflags! {
pub struct Capabilities: u128 { pub struct Capabilities: u128 {
const CLIENT_MYSQL = 1; const CLIENT_MYSQL = 1;
const FOUND_ROWS = 2; const FOUND_ROWS = 1 << 1;
const CONNECT_WITH_DB = 8; const CONNECT_WITH_DB = 1 << 3;
const COMPRESS = 32; const COMPRESS = 1 << 5;
const LOCAL_FILES = 128; const LOCAL_FILES = 1 << 7;
const IGNORE_SPACE = 256; const IGNORE_SPACE = 1 << 8;
const CLIENT_PROTOCOL_41 = 1 << 9; const CLIENT_PROTOCOL_41 = 1 << 9;
const CLIENT_INTERACTIVE = 1 << 10; const CLIENT_INTERACTIVE = 1 << 10;
const SSL = 1 << 11; const SSL = 1 << 11;