diff --git a/src/io/buf_stream.rs b/src/io/buf_stream.rs index 5d7775e8..fb8f3b83 100644 --- a/src/io/buf_stream.rs +++ b/src/io/buf_stream.rs @@ -89,7 +89,9 @@ where } } +// TODO: Find a nicer way to do this // Return `Ok(None)` immediately from a function if the wrapped value is `None` +#[allow(unused)] macro_rules! ret_if_none { ($val:expr) => { match $val { diff --git a/src/mariadb/io/buf_mut_ext.rs b/src/mariadb/io/buf_mut_ext.rs index 5f9f0a6f..654c426c 100644 --- a/src/mariadb/io/buf_mut_ext.rs +++ b/src/mariadb/io/buf_mut_ext.rs @@ -71,3 +71,179 @@ impl BufMutExt for Vec { self.extend_from_slice(val); } } + +#[cfg(test)] +mod tests { + use super::BufMutExt; + use crate::io::BufMut; + use byteorder::LittleEndian; + + // [X] it_encodes_int_lenenc_u64 + // [X] it_encodes_int_lenenc_u32 + // [X] it_encodes_int_lenenc_u24 + // [X] it_encodes_int_lenenc_u16 + // [X] it_encodes_int_lenenc_u8 + // [X] it_encodes_int_u64 + // [X] it_encodes_int_u32 + // [X] it_encodes_int_u24 + // [X] it_encodes_int_u16 + // [X] it_encodes_int_u8 + // [X] it_encodes_string_lenenc + // [X] it_encodes_string_fix + // [X] it_encodes_string_null + // [X] it_encodes_string_eof + // [X] it_encodes_byte_lenenc + // [X] it_encodes_byte_fix + // [X] it_encodes_byte_eof + + #[test] + fn it_encodes_int_lenenc_none() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(None); + + assert_eq!(&buf[..], b"\xFB"); + } + + #[test] + fn it_encodes_int_lenenc_u8() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFA as u64); + + assert_eq!(&buf[..], b"\xFA"); + } + + #[test] + fn it_encodes_int_lenenc_u16() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(std::u16::MAX as u64); + + assert_eq!(&buf[..], b"\xFC\xFF\xFF"); + } + + #[test] + fn it_encodes_int_lenenc_u24() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFF_FF_FF as u64); + + assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); + } + + #[test] + fn it_encodes_int_lenenc_u64() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(std::u64::MAX); + + assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); + } + + #[test] + fn it_encodes_int_lenenc_fb() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFB as u64); + + assert_eq!(&buf[..], b"\xFC\xFB\x00"); + } + + #[test] + fn it_encodes_int_lenenc_fc() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFC as u64); + + assert_eq!(&buf[..], b"\xFC\xFC\x00"); + } + + #[test] + fn it_encodes_int_lenenc_fd() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFD as u64); + + assert_eq!(&buf[..], b"\xFC\xFD\x00"); + } + + #[test] + fn it_encodes_int_lenenc_fe() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFE as u64); + + assert_eq!(&buf[..], b"\xFC\xFE\x00"); + } + + fn it_encodes_int_lenenc_ff() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc::(0xFF as u64); + + assert_eq!(&buf[..], b"\xFC\xFF\x00"); + } + + #[test] + fn it_encodes_int_u64() { + let mut buf = Vec::with_capacity(1024); + buf.put_u64::(std::u64::MAX); + + assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); + } + + #[test] + fn it_encodes_int_u32() { + let mut buf = Vec::with_capacity(1024); + buf.put_u32::(std::u32::MAX); + + assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF"); + } + + #[test] + fn it_encodes_int_u24() { + let mut buf = Vec::with_capacity(1024); + buf.put_u24::(0xFF_FF_FF as u32); + + assert_eq!(&buf[..], b"\xFF\xFF\xFF"); + } + + #[test] + fn it_encodes_int_u16() { + let mut buf = Vec::with_capacity(1024); + buf.put_u16::(std::u16::MAX); + + assert_eq!(&buf[..], b"\xFF\xFF"); + } + + #[test] + fn it_encodes_int_u8() { + let mut buf = Vec::with_capacity(1024); + buf.put_u8(std::u8::MAX); + + assert_eq!(&buf[..], b"\xFF"); + } + + #[test] + fn it_encodes_string_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_str_lenenc::("random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); + } + + #[test] + fn it_encodes_string_fix() { + let mut buf = Vec::with_capacity(1024); + buf.put_str("random_string"); + + assert_eq!(&buf[..], b"random_string"); + } + + #[test] + fn it_encodes_string_null() { + let mut buf = Vec::with_capacity(1024); + buf.put_str_nul("random_string"); + + assert_eq!(&buf[..], b"random_string\0"); + } + + #[test] + fn it_encodes_byte_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_bytes_lenenc::(b"random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); + } +} diff --git a/src/mariadb/protocol/capabilities.rs b/src/mariadb/protocol/capabilities.rs new file mode 100644 index 00000000..e40da3de --- /dev/null +++ b/src/mariadb/protocol/capabilities.rs @@ -0,0 +1,65 @@ +// https://mariadb.com/kb/en/library/connection/#capabilities +bitflags::bitflags! { + pub struct Capabilities: u128 { + const CLIENT_MYSQL = 1; + const FOUND_ROWS = 2; + + // One can specify db on connect + const CONNECT_WITH_DB = 8; + + // Can use compression protocol + const COMPRESS = 32; + + // Can use LOAD DATA LOCAL + const LOCAL_FILES = 128; + + // Ignore spaces before '(' + const IGNORE_SPACE = 256; + + // 4.1+ protocol + const CLIENT_PROTOCOL_41 = 1 << 9; + + const CLIENT_INTERACTIVE = 1 << 10; + + // Can use SSL + const SSL = 1 << 11; + + const TRANSACTIONS = 1 << 12; + + // 4.1+ authentication + const SECURE_CONNECTION = 1 << 13; + + // Enable/disable multi-stmt support + const MULTI_STATEMENTS = 1 << 16; + + // Enable/disable multi-results + const MULTI_RESULTS = 1 << 17; + + // Enable/disable multi-results for PrepareStatement + const PS_MULTI_RESULTS = 1 << 18; + + // Client supports plugin authentication + const PLUGIN_AUTH = 1 << 19; + + // Client send connection attributes + const CONNECT_ATTRS = 1 << 20; + + // Enable authentication response packet to be larger than 255 bytes + const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21; + + // Enable/disable session tracking in OK_Packet + const CLIENT_SESSION_TRACK = 1 << 23; + + // EOF_Packet deprecation + const CLIENT_DEPRECATE_EOF = 1 << 24; + + // Client support progress indicator (since 10.2) + const MARIA_DB_CLIENT_PROGRESS = 1 << 32; + + // Permit COM_MULTI protocol + const MARIA_DB_CLIENT_COM_MULTI = 1 << 33; + + // Permit bulk insert + const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34; + } +} diff --git a/src/mariadb/protocol/connect/auth_switch_request.rs b/src/mariadb/protocol/connect/auth_switch_request.rs new file mode 100644 index 00000000..a386c4ed --- /dev/null +++ b/src/mariadb/protocol/connect/auth_switch_request.rs @@ -0,0 +1,21 @@ +use crate::{ + io::BufMut, + mariadb::{ + io::BufMutExt, + protocol::{Capabilities, Encode}, + }, +}; + +#[derive(Default, Debug)] +pub struct AuthenticationSwitchRequest<'a> { + pub auth_plugin_name: &'a str, + pub auth_plugin_data: &'a [u8], +} + +impl Encode for AuthenticationSwitchRequest<'_> { + fn encode(&self, buf: &mut Vec, _: Capabilities) { + buf.put_u8(0xFE); + buf.put_str_nul(&self.auth_plugin_name); + buf.put_bytes(&self.auth_plugin_data); + } +} diff --git a/src/mariadb/protocol/connect/initial.rs b/src/mariadb/protocol/connect/initial.rs new file mode 100644 index 00000000..9d78f7d6 --- /dev/null +++ b/src/mariadb/protocol/connect/initial.rs @@ -0,0 +1,164 @@ +use crate::{ + io::Buf, + mariadb::{ + io::BufExt, + protocol::{Capabilities, ServerStatusFlag}, + }, +}; +use byteorder::LittleEndian; +use std::io; + +#[derive(Debug)] +pub struct InitialHandshakePacket { + pub protocol_version: u8, + pub server_version: String, + pub server_status: ServerStatusFlag, + pub server_default_collation: u8, + pub connection_id: u32, + pub scramble: Box<[u8]>, + pub capabilities: Capabilities, + pub auth_plugin_name: Option, +} + +impl InitialHandshakePacket { + fn decode(mut buf: &[u8]) -> io::Result { + let protocol_version = buf.get_u8()?; + let server_version = buf.get_str_nul()?.to_owned(); + let connection_id = buf.get_u32::()?; + let mut scramble = Vec::with_capacity(8); + + // scramble 1st part (authentication seed) : string<8> + scramble.extend_from_slice(&buf[..8]); + buf.advance(8); + + // reserved : string<1> + buf.advance(1); + + // server capabilities (1st part) : int<2> + let capabilities_1 = buf.get_u16::()?; + let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into()); + + // server default collation : int<1> + let server_default_collation = buf.get_u8()?; + + // status flags : int<2> + let server_status = buf.get_u16::()?; + + // server capabilities (2nd part) : int<2> + let capabilities_2 = buf.get_u16::()?; + capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into()); + + // if (server_capabilities & PLUGIN_AUTH) + let plugin_data_length = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + // plugin data length : int<1> + buf.get_u8()? + } else { + // 0x00 : int<1> + buf.advance(0); + 0 + }; + + // filler : string<6> + buf.advance(6); + + // if (server_capabilities & CLIENT_MYSQL) + if capabilities.contains(Capabilities::CLIENT_MYSQL) { + // filler : string<4> + buf.advance(4); + } else { + // server capabilities 3rd part . MariaDB specific flags : int<4> + let capabilities_3 = buf.get_u32::()?; + capabilities |= Capabilities::from_bits_truncate((capabilities_2 as u128) << 32); + } + + // if (server_capabilities & CLIENT_SECURE_CONNECTION) + if capabilities.contains(Capabilities::SECURE_CONNECTION) { + // scramble 2nd part . Length = max(12, plugin data length - 9) : string + let len = ((plugin_data_length as isize) - 9).max(12) as usize; + scramble.extend_from_slice(&buf[..len]); + buf.advance(len); + + // reserved byte : string<1> + buf.advance(1); + } + + // if (server_capabilities & PLUGIN_AUTH) + let auth_plugin_name = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + Some(buf.get_str_nul()?.to_owned()) + } else { + None + }; + + Ok(Self { + protocol_version, + server_version, + server_default_collation, + server_status: ServerStatusFlag::from_bits_truncate(server_status), + connection_id, + scramble: scramble.into_boxed_slice(), + capabilities, + auth_plugin_name, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::__bytes_builder; + + #[test] + fn it_decodes_initial_handshake_packet() -> io::Result<()> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 0u8, + //int<1> protocol version + 10u8, + //string server version (MariaDB server version is by default prefixed by "5.5.5-") + b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0", + //int<4> connection id + 13u8, 0u8, 0u8, 0u8, + //string<8> scramble 1st part (authentication seed) + b"?~~|vZAu", + //string<1> reserved byte + 0u8, + //int<2> server capabilities (1st part) + 0xFEu8, 0xF7u8, + //int<1> server default collation + 8u8, + //int<2> status flags + 2u8, 0u8, + //int<2> server capabilities (2nd part) + 0xFF_u8, 0x81_u8, + + //if (server_capabilities & PLUGIN_AUTH) + // int<1> plugin data length + 15u8, + //else + // int<1> 0x00 + + //string<6> filler + 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, + //if (server_capabilities & CLIENT_MYSQL) + // string<4> filler + //else + // int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */ + 7u8, 0u8, 0u8, 0u8, + //if (server_capabilities & CLIENT_SECURE_CONNECTION) + // string scramble 2nd part . Length = max(12, plugin data length - 9) + b"JQ8cihP4Q}Dx", + // string<1> reserved byte + 0u8, + //if (server_capabilities & PLUGIN_AUTH) + // string authentication plugin name + b"mysql_native_password\0" + ); + + let _message = InitialHandshakePacket::decode(&buf)?; + + Ok(()) + } +} diff --git a/src/mariadb/protocol/connect/mod.rs b/src/mariadb/protocol/connect/mod.rs new file mode 100644 index 00000000..80bd6e92 --- /dev/null +++ b/src/mariadb/protocol/connect/mod.rs @@ -0,0 +1,9 @@ +mod auth_switch_request; +mod initial; +mod response; +mod ssl_request; + +pub use auth_switch_request::AuthenticationSwitchRequest; +pub use initial::InitialHandshakePacket; +pub use response::HandshakeResponsePacket; +pub use ssl_request::SslRequest; diff --git a/src/mariadb/protocol/packets/handshake_response_packet.rs b/src/mariadb/protocol/connect/response.rs similarity index 99% rename from src/mariadb/protocol/packets/handshake_response_packet.rs rename to src/mariadb/protocol/connect/response.rs index e7752a16..360c242e 100644 --- a/src/mariadb/protocol/packets/handshake_response_packet.rs +++ b/src/mariadb/protocol/connect/response.rs @@ -7,7 +7,7 @@ use crate::{ }; use byteorder::LittleEndian; -#[derive(Default, Debug)] +#[derive(Debug)] pub struct HandshakeResponsePacket<'a> { pub capabilities: Capabilities, pub max_packet_size: u32, diff --git a/src/mariadb/protocol/connect/ssl_request.rs b/src/mariadb/protocol/connect/ssl_request.rs new file mode 100644 index 00000000..3cbc9c6d --- /dev/null +++ b/src/mariadb/protocol/connect/ssl_request.rs @@ -0,0 +1,40 @@ +use crate::{ + io::BufMut, + mariadb::{ + io::BufMutExt, + protocol::{Capabilities, Encode}, + }, +}; +use byteorder::LittleEndian; + +#[derive(Debug)] +pub struct SslRequest { + pub capabilities: Capabilities, + pub max_packet_size: u32, + pub client_collation: u8, +} + +impl Encode for SslRequest { + fn encode(&self, buf: &mut Vec, capabilities: Capabilities) { + // client capabilities : int<4> + buf.put_u32::(self.capabilities.bits() as u32); + + // max packet size : int<4> + buf.put_u32::(self.max_packet_size); + + // client character collation : int<1> + buf.put_u8(self.client_collation); + + // reserved : string<19> + buf.advance(19); + + // if not (capabilities & CLIENT_MYSQL) + if !capabilities.contains(Capabilities::CLIENT_MYSQL) { + // extended client capabilities : int<4> + buf.put_u32::((self.capabilities.bits() >> 32) as u32); + } else { + // reserved : int<4> + buf.advance(4); + } + } +} diff --git a/src/mariadb/protocol/encode.rs b/src/mariadb/protocol/encode.rs index 87835476..deb888cf 100644 --- a/src/mariadb/protocol/encode.rs +++ b/src/mariadb/protocol/encode.rs @@ -3,219 +3,3 @@ use super::Capabilities; pub trait Encode { fn encode(&self, buf: &mut Vec, capabilities: Capabilities); } - -pub const U24_MAX: usize = 0xFF_FF_FF; - -// #[inline] -// fn put_param(&mut self, bytes: &Bytes, ty: FieldType) { -// match ty { -// FieldType::MYSQL_TYPE_DECIMAL => self.put_string_lenenc(bytes), -// FieldType::MYSQL_TYPE_TINY => self.put_int_1(bytes), -// FieldType::MYSQL_TYPE_SHORT => self.put_int_2(bytes), -// FieldType::MYSQL_TYPE_LONG => self.put_int_4(bytes), -// FieldType::MYSQL_TYPE_FLOAT => self.put_int_4(bytes), -// FieldType::MYSQL_TYPE_DOUBLE => self.put_int_8(bytes), -// FieldType::MYSQL_TYPE_NULL => panic!("Type cannot be FieldType::MysqlTypeNull"), -// FieldType::MYSQL_TYPE_TIMESTAMP => unimplemented!(), -// FieldType::MYSQL_TYPE_LONGLONG => self.put_int_8(bytes), -// FieldType::MYSQL_TYPE_INT24 => self.put_int_4(bytes), -// FieldType::MYSQL_TYPE_DATE => unimplemented!(), -// FieldType::MYSQL_TYPE_TIME => unimplemented!(), -// FieldType::MYSQL_TYPE_DATETIME => unimplemented!(), -// FieldType::MYSQL_TYPE_YEAR => self.put_int_4(bytes), -// FieldType::MYSQL_TYPE_NEWDATE => unimplemented!(), -// FieldType::MYSQL_TYPE_VARCHAR => self.put_string_lenenc(bytes), -// FieldType::MYSQL_TYPE_BIT => self.put_string_lenenc(bytes), -// FieldType::MYSQL_TYPE_TIMESTAMP2 => unimplemented!(), -// FieldType::MYSQL_TYPE_DATETIME2 => unimplemented!(), -// FieldType::MYSQL_TYPE_TIME2 => unimplemented!(), -// FieldType::MYSQL_TYPE_JSON => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_NEWDECIMAL => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_ENUM => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_SET => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_TINY_BLOB => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_MEDIUM_BLOB => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_LONG_BLOB => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_BLOB => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_VAR_STRING => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_STRING => self.put_byte_lenenc(bytes), -// FieldType::MYSQL_TYPE_GEOMETRY => self.put_byte_lenenc(bytes), -// _ => panic!("Unrecognized field type"), -// } -// } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{io::BufMut, mariadb::io::BufMutExt}; - use byteorder::LittleEndian; - - // [X] it_encodes_int_lenenc_u64 - // [X] it_encodes_int_lenenc_u32 - // [X] it_encodes_int_lenenc_u24 - // [X] it_encodes_int_lenenc_u16 - // [X] it_encodes_int_lenenc_u8 - // [X] it_encodes_int_u64 - // [X] it_encodes_int_u32 - // [X] it_encodes_int_u24 - // [X] it_encodes_int_u16 - // [X] it_encodes_int_u8 - // [X] it_encodes_string_lenenc - // [X] it_encodes_string_fix - // [X] it_encodes_string_null - // [X] it_encodes_string_eof - // [X] it_encodes_byte_lenenc - // [X] it_encodes_byte_fix - // [X] it_encodes_byte_eof - - #[test] - fn it_encodes_int_lenenc_none() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0u64)); - - assert_eq!(&buf[..], b"\xFB"); - } - - #[test] - fn it_encodes_int_lenenc_u8() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0xFA as u64)); - - assert_eq!(&buf[..], b"\xFA"); - } - - #[test] - fn it_encodes_int_lenenc_u16() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(std::u16::MAX as u64)); - - assert_eq!(&buf[..], b"\xFC\xFF\xFF"); - } - - #[test] - fn it_encodes_int_lenenc_u24() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(U24_MAX as u64)); - - assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_lenenc_u64() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(std::u64::MAX)); - - assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_lenenc_fb() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0xFB as u64)); - - assert_eq!(&buf[..], b"\xFC\xFB\x00"); - } - - #[test] - fn it_encodes_int_lenenc_fc() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0xFC as u64)); - - assert_eq!(&buf[..], b"\xFC\xFC\x00"); - } - - #[test] - fn it_encodes_int_lenenc_fd() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0xFD as u64)); - - assert_eq!(&buf[..], b"\xFC\xFD\x00"); - } - - #[test] - fn it_encodes_int_lenenc_fe() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0xFE as u64)); - - assert_eq!(&buf[..], b"\xFC\xFE\x00"); - } - - fn it_encodes_int_lenenc_ff() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(Some(0xFF as u64)); - - assert_eq!(&buf[..], b"\xFC\xFF\x00"); - } - - #[test] - fn it_encodes_int_u64() { - let mut buf = Vec::with_capacity(1024); - buf.put_u64::(std::u64::MAX); - - assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u32() { - let mut buf = Vec::with_capacity(1024); - buf.put_u32::(std::u32::MAX); - - assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u24() { - let mut buf = Vec::with_capacity(1024); - buf.put_u24::(U24_MAX as u32); - - assert_eq!(&buf[..], b"\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u16() { - let mut buf = Vec::with_capacity(1024); - buf.put_u16::(std::u16::MAX); - - assert_eq!(&buf[..], b"\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u8() { - let mut buf = Vec::with_capacity(1024); - buf.put_u8(std::u8::MAX); - - assert_eq!(&buf[..], b"\xFF"); - } - - #[test] - fn it_encodes_string_lenenc() { - let mut buf = Vec::with_capacity(1024); - buf.put_str_lenenc::("random_string"); - - assert_eq!(&buf[..], b"\x0Drandom_string"); - } - - #[test] - fn it_encodes_string_fix() { - let mut buf = Vec::with_capacity(1024); - buf.put_str("random_string"); - - assert_eq!(&buf[..], b"random_string"); - } - - #[test] - fn it_encodes_string_null() { - let mut buf = Vec::with_capacity(1024); - buf.put_str_nul("random_string"); - - assert_eq!(&buf[..], b"random_string\0"); - } - - #[test] - fn it_encodes_byte_lenenc() { - let mut buf = Vec::with_capacity(1024); - buf.put_byte_lenenc::(b"random_string"); - - assert_eq!(&buf[..], b"\x0Drandom_string"); - } -} diff --git a/src/mariadb/protocol/error_codes.rs b/src/mariadb/protocol/error_code.rs similarity index 99% rename from src/mariadb/protocol/error_codes.rs rename to src/mariadb/protocol/error_code.rs index 0a1080dc..f6096a6e 100644 --- a/src/mariadb/protocol/error_codes.rs +++ b/src/mariadb/protocol/error_code.rs @@ -1,6 +1,9 @@ #[derive(Default, Debug)] pub struct ErrorCode(pub(crate) u16); +// TODO: It would be nice to figure out a clean way to go from 1152 to "ER_ABORTING_CONNECTION (1152)" in Debug. + +// Values from https://mariadb.com/kb/en/library/mariadb-error-codes/ impl ErrorCode { const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152); const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873); diff --git a/src/mariadb/protocol/field.rs b/src/mariadb/protocol/field.rs new file mode 100644 index 00000000..680ca60f --- /dev/null +++ b/src/mariadb/protocol/field.rs @@ -0,0 +1,44 @@ +// https://mariadb.com/kb/en/library/resultset/#field-types +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FieldType(pub u8); + +impl FieldType { + pub const MYSQL_TYPE_BIT: FieldType = FieldType(16); + pub const MYSQL_TYPE_BLOB: FieldType = FieldType(252); + pub const MYSQL_TYPE_DATE: FieldType = FieldType(10); + pub const MYSQL_TYPE_DATETIME: FieldType = FieldType(12); + pub const MYSQL_TYPE_DATETIME2: FieldType = FieldType(18); + pub const MYSQL_TYPE_DECIMAL: FieldType = FieldType(0); + pub const MYSQL_TYPE_DOUBLE: FieldType = FieldType(5); + pub const MYSQL_TYPE_ENUM: FieldType = FieldType(247); + pub const MYSQL_TYPE_FLOAT: FieldType = FieldType(4); + pub const MYSQL_TYPE_GEOMETRY: FieldType = FieldType(255); + pub const MYSQL_TYPE_INT24: FieldType = FieldType(9); + pub const MYSQL_TYPE_JSON: FieldType = FieldType(245); + pub const MYSQL_TYPE_LONG: FieldType = FieldType(3); + pub const MYSQL_TYPE_LONGLONG: FieldType = FieldType(8); + pub const MYSQL_TYPE_LONG_BLOB: FieldType = FieldType(251); + pub const MYSQL_TYPE_MEDIUM_BLOB: FieldType = FieldType(250); + pub const MYSQL_TYPE_NEWDATE: FieldType = FieldType(14); + pub const MYSQL_TYPE_NEWDECIMAL: FieldType = FieldType(246); + pub const MYSQL_TYPE_NULL: FieldType = FieldType(6); + pub const MYSQL_TYPE_SET: FieldType = FieldType(248); + pub const MYSQL_TYPE_SHORT: FieldType = FieldType(2); + pub const MYSQL_TYPE_STRING: FieldType = FieldType(254); + pub const MYSQL_TYPE_TIME: FieldType = FieldType(11); + pub const MYSQL_TYPE_TIME2: FieldType = FieldType(19); + pub const MYSQL_TYPE_TIMESTAMP: FieldType = FieldType(7); + pub const MYSQL_TYPE_TIMESTAMP2: FieldType = FieldType(17); + pub const MYSQL_TYPE_TINY: FieldType = FieldType(1); + pub const MYSQL_TYPE_TINY_BLOB: FieldType = FieldType(249); + pub const MYSQL_TYPE_VARCHAR: FieldType = FieldType(15); + pub const MYSQL_TYPE_VAR_STRING: FieldType = FieldType(253); + pub const MYSQL_TYPE_YEAR: FieldType = FieldType(13); +} + +// https://mariadb.com/kb/en/library/com_stmt_execute/#parameter-flag +bitflags::bitflags! { + pub struct ParameterFlag: u8 { + const UNSIGNED = 128; + } +} diff --git a/src/mariadb/protocol/mod.rs b/src/mariadb/protocol/mod.rs index 2c3a23d5..5703ad1a 100644 --- a/src/mariadb/protocol/mod.rs +++ b/src/mariadb/protocol/mod.rs @@ -1,36 +1,22 @@ // Reference: https://mariadb.com/kb/en/library/connection // Packets: https://mariadb.com/kb/en/library/0-packet -// TODO: Handle lengths which are greater than 3 bytes -// Either break the packet into several smaller ones, or -// return error +mod capabilities; +mod connect; +mod encode; +mod error_code; +mod field; +mod server_status; +mod response; -// TODO: Handle different Capabilities for server and client - -// TODO: Handle when capability is set, but field is None - -pub mod encode; -pub mod error_codes; -pub mod packets; -pub mod types; - -// Re-export all the things -// pub use packets::{ -// AuthenticationSwitchRequestPacket, ColumnDefPacket, ColumnPacket, ComDebug, ComInitDb, ComPing, -// ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep, -// ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, -// ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket, -// InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultRowBinary, ResultRowText, -// ResultSet, SSLRequestPacket, SetOptionOptions, ShutdownOptions, -// }; - -pub use packets::{ColumnCountPacket, ColumnDefinitionPacket}; - -pub use encode::Encode; - -pub use error_codes::ErrorCode; - -pub use types::{ - Capabilities, FieldDetailFlag, FieldType, ProtocolType, ServerStatusFlag, SessionChangeType, - StmtExecFlag, +pub use capabilities::Capabilities; +pub use connect::{ + AuthenticationSwitchRequest, HandshakeResponsePacket, InitialHandshakePacket, SslRequest, }; +pub use response::{ + OkPacket, EofPacket, ErrPacket, ResultRow, +}; +pub use encode::Encode; +pub use error_code::ErrorCode; +pub use field::{FieldType, ParameterFlag}; +pub use server_status::ServerStatusFlag; diff --git a/src/mariadb/protocol/packets/auth_switch_request.rs b/src/mariadb/protocol/packets/auth_switch_request.rs deleted file mode 100644 index d8d41a2c..00000000 --- a/src/mariadb/protocol/packets/auth_switch_request.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::mariadb::{BufMut, ConnContext, Encode, MariaDbRawConnection}; -use bytes::Bytes; -use failure::Error; - -#[derive(Default, Debug)] -pub struct AuthenticationSwitchRequestPacket { - pub auth_plugin_name: Bytes, - pub auth_plugin_data: Bytes, -} - -impl Encode for AuthenticationSwitchRequestPacket { - fn encode(&self, buf: &mut Vec, ctx: &mut ConnContext) -> Result<(), Error> { - buf.put_int_u8(0xFE); - buf.put_string_null(&self.auth_plugin_name); - buf.put_byte_eof(&self.auth_plugin_data); - - Ok(()) - } -} diff --git a/src/mariadb/protocol/packets/initial.rs b/src/mariadb/protocol/packets/initial.rs deleted file mode 100644 index afe1f6a5..00000000 --- a/src/mariadb/protocol/packets/initial.rs +++ /dev/null @@ -1,166 +0,0 @@ -use crate::mariadb::{Capabilities, DeContext, Decode, ServerStatusFlag}; -use bytes::Bytes; -use failure::{err_msg, Error}; - -#[derive(Default, Debug)] -pub struct InitialHandshakePacket { - pub length: u32, - pub seq_no: u8, - pub protocol_version: u8, - pub server_version: Bytes, - pub connection_id: i32, - pub auth_seed: Bytes, - pub capabilities: Capabilities, - pub collation: u8, - pub status: ServerStatusFlag, - pub plugin_data_length: u8, - pub scramble: Option, - pub auth_plugin_name: Option, -} - -impl Decode for InitialHandshakePacket { - fn decode(ctx: &mut DeContext) -> Result { - let decoder = &mut ctx.decoder; - let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_u8(); - - if seq_no != 0 { - return Err(err_msg( - "Sequence Number of Initial Handshake Packet is not 0", - )); - } - - let protocol_version = decoder.decode_int_u8(); - let server_version = decoder.decode_string_null()?; - let connection_id = decoder.decode_int_i32(); - let auth_seed = decoder.decode_string_fix(8); - - // Skip reserved byte - decoder.skip_bytes(1); - - let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_u16().into()); - - let collation = decoder.decode_int_u8(); - let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into()); - - capabilities |= - Capabilities::from_bits_truncate(((decoder.decode_int_i16() as u32) << 16).into()); - - let mut plugin_data_length = 0; - if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() { - plugin_data_length = decoder.decode_int_u8(); - } else { - // Skip reserve byte - decoder.skip_bytes(1); - } - - // Skip filler - decoder.skip_bytes(6); - - if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() { - capabilities |= - Capabilities::from_bits_truncate(((decoder.decode_int_u32() as u128) << 32).into()); - } else { - // Skip filler - decoder.skip_bytes(4); - } - - let mut scramble: Option = None; - if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() { - let len = std::cmp::max(12, plugin_data_length as usize - 9); - scramble = Some(decoder.decode_string_fix(len as usize)); - // Skip reserve byte - decoder.skip_bytes(1); - } - - let mut auth_plugin_name: Option = None; - if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() { - auth_plugin_name = Some(decoder.decode_string_null()?); - } - - ctx.ctx.last_seq_no = seq_no; - - Ok(InitialHandshakePacket { - length, - seq_no, - protocol_version, - server_version, - connection_id, - auth_seed, - capabilities, - collation, - status, - plugin_data_length, - scramble, - auth_plugin_name, - }) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::{ - __bytes_builder, - mariadb::{ConnContext, Decoder}, - }; - use bytes::BytesMut; - - #[test] - fn it_decodes_initial_handshake_packet() -> Result<(), Error> { - #[rustfmt::skip] - let buf = __bytes_builder!( - // int<3> length - 1u8, 0u8, 0u8, - // int<1> seq_no - 0u8, - //int<1> protocol version - 10u8, - //string server version (MariaDB server version is by default prefixed by "5.5.5-") - b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0", - //int<4> connection id - 13u8, 0u8, 0u8, 0u8, - //string<8> scramble 1st part (authentication seed) - b"?~~|vZAu", - //string<1> reserved byte - 0u8, - //int<2> server capabilities (1st part) - 0xFEu8, 0xF7u8, - //int<1> server default collation - 8u8, - //int<2> status flags - 2u8, 0u8, - //int<2> server capabilities (2nd part) - 0xFF_u8, 0x81_u8, - - //if (server_capabilities & PLUGIN_AUTH) - // int<1> plugin data length - 15u8, - //else - // int<1> 0x00 - - //string<6> filler - 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, - //if (server_capabilities & CLIENT_MYSQL) - // string<4> filler - //else - // int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */ - 7u8, 0u8, 0u8, 0u8, - //if (server_capabilities & CLIENT_SECURE_CONNECTION) - // string scramble 2nd part . Length = max(12, plugin data length - 9) - b"JQ8cihP4Q}Dx", - // string<1> reserved byte - 0u8, - //if (server_capabilities & PLUGIN_AUTH) - // string authentication plugin name - b"mysql_native_password\0" - ); - - let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, buf); - - let _message = InitialHandshakePacket::decode(&mut ctx)?; - - Ok(()) - } -} diff --git a/src/mariadb/protocol/packets/ssl_request.rs b/src/mariadb/protocol/packets/ssl_request.rs deleted file mode 100644 index 08a2dbbf..00000000 --- a/src/mariadb/protocol/packets/ssl_request.rs +++ /dev/null @@ -1,40 +0,0 @@ -use bytes::Bytes; -use failure::Error; - -use crate::mariadb::{BufMut, Capabilities, ConnContext, Encode, MariaDbRawConnection}; - -#[derive(Default, Debug)] -pub struct SSLRequestPacket { - pub capabilities: Capabilities, - pub max_packet_size: u32, - pub collation: u8, - pub extended_capabilities: Option, -} - -impl Encode for SSLRequestPacket { - fn encode(&self, buf: &mut Vec, ctx: &mut ConnContext) -> Result<(), Error> { - buf.alloc_packet_header(); - buf.seq_no(0); - - buf.put_int_u32(self.capabilities.bits() as u32); - buf.put_int_u32(self.max_packet_size); - buf.put_int_u8(self.collation); - - // Filler - buf.put_byte_fix(&Bytes::from_static(&[0u8; 19]), 19); - - if !(ctx.capabilities & Capabilities::CLIENT_MYSQL).is_empty() - && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() - { - if let Some(capabilities) = self.extended_capabilities { - buf.put_int_u32(capabilities.bits() as u32); - } - } else { - buf.put_byte_fix(&Bytes::from_static(&[0u8; 4]), 4); - } - - buf.put_length(); - - Ok(()) - } -} diff --git a/src/mariadb/protocol/response/eof.rs b/src/mariadb/protocol/response/eof.rs new file mode 100644 index 00000000..b5c47996 --- /dev/null +++ b/src/mariadb/protocol/response/eof.rs @@ -0,0 +1,65 @@ +use crate::{ + io::Buf, + mariadb::{ + io::BufExt, + protocol::{ErrorCode, ServerStatusFlag}, + }, +}; +use byteorder::LittleEndian; +use std::io; + +#[derive(Debug)] +pub struct EofPacket { + pub warning_count: u16, + pub status: ServerStatusFlag, +} + +impl EofPacket { + fn decode(mut buf: &[u8]) -> io::Result { + let header = buf.get_u8()?; + if header != 0xFE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("expected 0xFE; received {}", header), + )); + } + + let warning_count = buf.get_u16::()?; + let status = ServerStatusFlag::from_bits_truncate(buf.get_u16::()?); + + Ok(Self { + warning_count, + status, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{__bytes_builder, mariadb::ConnContext}; + use bytes::Bytes; + + #[test] + fn it_decodes_eof_packet() -> Result<(), Error> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 1u8, + // int<1> 0xfe : EOF header + 0xFE_u8, + // int<2> warning count + 0u8, 0u8, + // int<2> server status + 1u8, 1u8 + ); + + let _message = EofPacket::decode(&buf)?; + + // TODO: Assert fields? + + Ok(()) + } +} diff --git a/src/mariadb/protocol/response/err.rs b/src/mariadb/protocol/response/err.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/mariadb/protocol/response/mod.rs b/src/mariadb/protocol/response/mod.rs new file mode 100644 index 00000000..ade72c2a --- /dev/null +++ b/src/mariadb/protocol/response/mod.rs @@ -0,0 +1,9 @@ +mod ok; +mod err; +mod eof; +mod row; + +pub use ok::OkPacket; +pub use err::ErrPacket; +pub use eof::EofPacket; +pub use row::ResultRow; diff --git a/src/mariadb/protocol/response/ok.rs b/src/mariadb/protocol/response/ok.rs new file mode 100644 index 00000000..d99248b8 --- /dev/null +++ b/src/mariadb/protocol/response/ok.rs @@ -0,0 +1,117 @@ +use crate::{ + io::Buf, + mariadb::{ + io::BufExt, + protocol::{Capabilities, ServerStatusFlag}, + }, +}; +use byteorder::LittleEndian; +use std::io; + +// https://mariadb.com/kb/en/library/ok_packet/ +#[derive(Debug)] +pub struct OkPacket { + pub affected_rows: u64, + pub last_insert_id: u64, + pub server_status: ServerStatusFlag, + pub warning_count: u16, + pub info: Box, + pub session_state_info: Option>, + pub value_of_variable: Option>, +} + +impl OkPacket { + fn decode(mut buf: &[u8], capabilities: Capabilities) -> io::Result { + let header = buf.get_u8()?; + if header != 0 && header != 0xFE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("expected 0x00 or 0xFE; received 0x{:X}", header), + )); + } + + let affected_rows = buf.get_uint_lenenc::()?.unwrap_or(0); + let last_insert_id = buf.get_uint_lenenc::()?.unwrap_or(0); + let server_status = ServerStatusFlag::from_bits_truncate(buf.get_u16::()?); + let warning_count = buf.get_u16::()?; + + let info; + let mut session_state_info = None; + let mut value_of_variable = None; + + if capabilities.contains(Capabilities::CLIENT_SESSION_TRACK) { + info = buf + .get_str_lenenc::()? + .unwrap_or_default() + .to_owned() + .into(); + session_state_info = buf.get_byte_lenenc::()?.map(Into::into); + value_of_variable = buf.get_str_lenenc::()?.map(Into::into); + } else { + info = buf.get_str_eof()?.to_owned().into(); + } + + Ok(Self { + affected_rows, + last_insert_id, + server_status, + warning_count, + info, + session_state_info, + value_of_variable, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + __bytes_builder, + mariadb::{ConnContext, Decoder}, + }; + + #[test] + fn it_decodes_ok_packet() -> Result<(), Error> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 0u8, 0u8, 0u8, + // // int<1> seq_no + 1u8, + // 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set) + 0u8, + // int affected rows + 0xFB_u8, + // int last insert id + 0xFB_u8, + // int<2> server status + 1u8, 1u8, + // int<2> warning count + 0u8, 0u8, + // if session_tracking_supported (see CLIENT_SESSION_TRACK) { + // string info + // if (status flags & SERVER_SESSION_STATE_CHANGED) { + // string session state info + // string value of variable + // } + // } else { + // string info + b"info" + // } + ); + + let mut context = ConnContext::new(); + let mut ctx = DeContext::new(&mut context, buf); + + let message = OkPacket::decode(&mut ctx)?; + + assert_eq!(message.affected_rows, None); + assert_eq!(message.last_insert_id, None); + assert!(!(message.server_status & ServerStatusFlag::SERVER_STATUS_IN_TRANS).is_empty()); + assert_eq!(message.warning_count, 0); + assert_eq!(message.info, b"info".to_vec()); + + Ok(()) + } +} diff --git a/src/mariadb/protocol/response/row.rs b/src/mariadb/protocol/response/row.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/mariadb/protocol/server_status.rs b/src/mariadb/protocol/server_status.rs new file mode 100644 index 00000000..5b973928 --- /dev/null +++ b/src/mariadb/protocol/server_status.rs @@ -0,0 +1,45 @@ +// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status +bitflags::bitflags! { + pub struct ServerStatusFlag: u16 { + // A transaction is currently active + const SERVER_STATUS_IN_TRANS = 1; + + // Autocommit mode is set + const SERVER_STATUS_AUTOCOMMIT = 2; + + // more results exists (more packet follow) + const SERVER_MORE_RESULTS_EXISTS = 8; + + const SERVER_QUERY_NO_GOOD_INDEX_USED = 16; + const SERVER_QUERY_NO_INDEX_USED = 32; + + // when using COM_STMT_FETCH, indicate that current cursor still has result + const SERVER_STATUS_CURSOR_EXISTS = 64; + + // when using COM_STMT_FETCH, indicate that current cursor has finished to send results + const SERVER_STATUS_LAST_ROW_SENT = 128; + + // database has been dropped + const SERVER_STATUS_DB_DROPPED = 1 << 8; + + // current escape mode is "no backslash escape" + const SERVER_STATUS_NO_BACKSLASH_ESAPES = 1 << 9; + + // A DDL change did have an impact on an existing PREPARE (an + // automatic reprepare has been executed) + const SERVER_STATUS_METADATA_CHANGED = 1 << 10; + + // Last statement took more than the time value specified in + // server variable long_query_time. + const SERVER_QUERY_WAS_SLOW = 1 << 11; + + // this resultset contain stored procedure output parameter + const SERVER_PS_OUT_PARAMS = 1 << 12; + + // current transaction is a read-only transaction + const SERVER_STATUS_IN_TRANS_READONLY = 1 << 13; + + // session state change. see Session change type for more information + const SERVER_SESSION_STATE_CHANGED = 1 << 14; + } +} diff --git a/src/mariadb/protocol/types.rs b/src/mariadb/protocol/types.rs index a40c3b33..b6d6899c 100644 --- a/src/mariadb/protocol/types.rs +++ b/src/mariadb/protocol/types.rs @@ -1,34 +1,3 @@ -pub enum ProtocolType { - Text, - Binary, -} - -bitflags::bitflags! { - pub struct Capabilities: u128 { - const CLIENT_MYSQL = 1; - const FOUND_ROWS = 1 << 1; - const CONNECT_WITH_DB = 1 << 3; - const COMPRESS = 1 << 5; - const LOCAL_FILES = 1 << 7; - const IGNORE_SPACE = 1 << 8; - const CLIENT_PROTOCOL_41 = 1 << 9; - const CLIENT_INTERACTIVE = 1 << 10; - const SSL = 1 << 11; - const TRANSACTIONS = 1 << 12; - const SECURE_CONNECTION = 1 << 13; - const MULTI_STATEMENTS = 1 << 16; - const MULTI_RESULTS = 1 << 17; - const PS_MULTI_RESULTS = 1 << 18; - const PLUGIN_AUTH = 1 << 19; - const CONNECT_ATTRS = 1 << 20; - const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21; - const CLIENT_SESSION_TRACK = 1 << 23; - const CLIENT_DEPRECATE_EOF = 1 << 24; - const MARIA_DB_CLIENT_PROGRESS = 1 << 32; - const MARIA_DB_CLIENT_COM_MULTI = 1 << 33; - const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34; - } -} bitflags::bitflags! { pub struct FieldDetailFlag: u16 { @@ -78,42 +47,6 @@ pub enum SessionChangeType { SessionTrackTransactionState = 5, } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct FieldType(pub u8); -impl FieldType { - pub const MYSQL_TYPE_BIT: FieldType = FieldType(16); - pub const MYSQL_TYPE_BLOB: FieldType = FieldType(252); - pub const MYSQL_TYPE_DATE: FieldType = FieldType(10); - pub const MYSQL_TYPE_DATETIME: FieldType = FieldType(12); - pub const MYSQL_TYPE_DATETIME2: FieldType = FieldType(18); - pub const MYSQL_TYPE_DECIMAL: FieldType = FieldType(0); - pub const MYSQL_TYPE_DOUBLE: FieldType = FieldType(5); - pub const MYSQL_TYPE_ENUM: FieldType = FieldType(247); - pub const MYSQL_TYPE_FLOAT: FieldType = FieldType(4); - pub const MYSQL_TYPE_GEOMETRY: FieldType = FieldType(255); - pub const MYSQL_TYPE_INT24: FieldType = FieldType(9); - pub const MYSQL_TYPE_JSON: FieldType = FieldType(245); - pub const MYSQL_TYPE_LONG: FieldType = FieldType(3); - pub const MYSQL_TYPE_LONGLONG: FieldType = FieldType(8); - pub const MYSQL_TYPE_LONG_BLOB: FieldType = FieldType(251); - pub const MYSQL_TYPE_MEDIUM_BLOB: FieldType = FieldType(250); - pub const MYSQL_TYPE_NEWDATE: FieldType = FieldType(14); - pub const MYSQL_TYPE_NEWDECIMAL: FieldType = FieldType(246); - pub const MYSQL_TYPE_NULL: FieldType = FieldType(6); - pub const MYSQL_TYPE_SET: FieldType = FieldType(248); - pub const MYSQL_TYPE_SHORT: FieldType = FieldType(2); - pub const MYSQL_TYPE_STRING: FieldType = FieldType(254); - pub const MYSQL_TYPE_TIME: FieldType = FieldType(11); - pub const MYSQL_TYPE_TIME2: FieldType = FieldType(19); - pub const MYSQL_TYPE_TIMESTAMP: FieldType = FieldType(7); - pub const MYSQL_TYPE_TIMESTAMP2: FieldType = FieldType(17); - pub const MYSQL_TYPE_TINY: FieldType = FieldType(1); - pub const MYSQL_TYPE_TINY_BLOB: FieldType = FieldType(249); - pub const MYSQL_TYPE_VARCHAR: FieldType = FieldType(15); - pub const MYSQL_TYPE_VAR_STRING: FieldType = FieldType(253); - pub const MYSQL_TYPE_YEAR: FieldType = FieldType(13); -} - #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct StmtExecFlag(pub u8); impl StmtExecFlag { @@ -122,63 +55,3 @@ impl StmtExecFlag { pub const READ_ONLY: StmtExecFlag = StmtExecFlag(1); pub const SCROLLABLE_CURSOR: StmtExecFlag = StmtExecFlag(3); } - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct ParamFlag(pub u8); -impl ParamFlag { - pub const NONE: ParamFlag = ParamFlag(0); - pub const UNSIGNED: ParamFlag = ParamFlag(128); -} - -// TODO: Remove these Default impls - -impl Default for Capabilities { - fn default() -> Self { - Capabilities::CLIENT_PROTOCOL_41 - } -} - -impl Default for ServerStatusFlag { - fn default() -> Self { - ServerStatusFlag::SERVER_STATUS_IN_TRANS - } -} - -impl Default for FieldDetailFlag { - fn default() -> Self { - FieldDetailFlag::NOT_NULL - } -} - -impl Default for FieldType { - fn default() -> Self { - FieldType::MYSQL_TYPE_DECIMAL - } -} - -impl Default for StmtExecFlag { - fn default() -> Self { - StmtExecFlag::NO_CURSOR - } -} - -impl Default for ParamFlag { - fn default() -> Self { - ParamFlag::UNSIGNED - } -} - -#[cfg(test)] -mod test { - use super::Capabilities; - use crate::{__bytes_builder, io::Buf}; - use byteorder::LittleEndian; - - #[test] - fn it_decodes_capabilities() -> std::io::Result<()> { - let buf = &__bytes_builder!(b"\xfe\xf7")[..]; - Capabilities::from_bits_truncate(buf.get_u16::()? as u128); - - Ok(()) - } -}