diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index 2c2644f6..b79c8b02 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -102,15 +102,10 @@ impl Connection { Ok(conn) } - pub async fn send(&mut self, message: S) -> Result<(), Error> - where - S: Serialize, - { + pub async fn send(&mut self, message: S) -> Result<(), Error> where S: Serialize { self.encoder.clear(); - self.encoder.alloc_packet_header(); - self.encoder.seq_no(self.context.seq_no); + message.serialize(&mut self.context, &mut self.encoder)?; - self.encoder.encode_length(); self.stream.inner.write_all(&self.encoder.buf).await?; self.stream.inner.flush().await?; diff --git a/src/mariadb/mod.rs b/src/mariadb/mod.rs index fb81ba1f..08b2d699 100644 --- a/src/mariadb/mod.rs +++ b/src/mariadb/mod.rs @@ -29,6 +29,11 @@ pub use protocol::PacketHeader; pub use protocol::ResultSet; pub use protocol::ResultRow; pub use protocol::SSLRequestPacket; +pub use protocol::ComStmtPrepare; +pub use protocol::ComStmtPrepareOk; +pub use protocol::ComStmtPrepareResp; +pub use protocol::ComStmtClose; +pub use protocol::ComStmtExec; pub use protocol::Decoder; pub use protocol::DeContext; pub use protocol::Deserialize; @@ -41,4 +46,6 @@ pub use protocol::ServerStatusFlag; pub use protocol::FieldType; pub use protocol::FieldDetailFlag; pub use protocol::SessionChangeType; +pub use protocol::StmtExecFlag; pub use protocol::TextProtocol; +pub use protocol::BinaryProtocol; diff --git a/src/mariadb/protocol/client.rs b/src/mariadb/protocol/client.rs index adfcada3..8f5fd34f 100644 --- a/src/mariadb/protocol/client.rs +++ b/src/mariadb/protocol/client.rs @@ -35,3 +35,16 @@ impl Into for TextProtocol { self as u8 } } + +pub enum BinaryProtocol { + ComStmtPrepare = 0x16, + ComStmtClose = 0x19, + ComStmtExec = 0x17, +} + +// Helper method to easily transform into u8 +impl Into for BinaryProtocol { + fn into(self) -> u8 { + self as u8 + } +} diff --git a/src/mariadb/protocol/decode.rs b/src/mariadb/protocol/decode.rs index c8751bcd..5a6709cc 100644 --- a/src/mariadb/protocol/decode.rs +++ b/src/mariadb/protocol/decode.rs @@ -108,7 +108,7 @@ impl<'a> Decoder<'a> { // Decode an int<8> which is a i64 #[inline] - pub fn decode_int_8(&mut self) -> i64 { + pub fn decode_int_i64(&mut self) -> i64 { let value = LittleEndian::read_i64(&self.buf[self.index..]); self.index += 8; value @@ -116,16 +116,16 @@ impl<'a> Decoder<'a> { // Decode an int<4> which is a i32 #[inline] - pub fn decode_int_4(&mut self) -> i32 { + pub fn decode_int_i32(&mut self) -> i32 { let value = LittleEndian::read_i32(&self.buf[self.index..]); self.index += 4; value } - // Decode an int<4> which is a i32 + // Decode an int<4> which is a u32 // This is a helper method for decoding flags. #[inline] - pub fn decode_int_4_unsigned(&mut self) -> u32 { + pub fn decode_int_u32(&mut self) -> u32 { let value = LittleEndian::read_u32(&self.buf[self.index..]); self.index += 4; value @@ -133,7 +133,7 @@ impl<'a> Decoder<'a> { // Decode an int<3> which is a i24 #[inline] - pub fn decode_int_3(&mut self) -> i32 { + pub fn decode_int_i24(&mut self) -> i32 { let value = LittleEndian::read_i24(&self.buf[self.index..]); self.index += 3; value @@ -141,7 +141,7 @@ impl<'a> Decoder<'a> { // Decode an int<2> which is a i16 #[inline] - pub fn decode_int_2(&mut self) -> i16 { + pub fn decode_int_i16(&mut self) -> i16 { let value = LittleEndian::read_i16(&self.buf[self.index..]); self.index += 2; value @@ -150,7 +150,7 @@ impl<'a> Decoder<'a> { // Decode an int<2> as an u16 // This is a helper method for decoding flags. #[inline] - pub fn decode_int_2_unsigned(&mut self) -> u16 { + pub fn decode_int_u16(&mut self) -> u16 { let value = LittleEndian::read_u16(&self.buf[self.index..]); self.index += 2; value @@ -158,7 +158,7 @@ impl<'a> Decoder<'a> { // Decode an int<1> which is a u8 #[inline] - pub fn decode_int_1(&mut self) -> u8 { + pub fn decode_int_u8(&mut self) -> u8 { let value = self.buf[self.index]; self.index += 1; value @@ -221,7 +221,7 @@ impl<'a> Decoder<'a> { // Same as the string counter part, but copied to maintain consistency with the spec. #[inline] pub fn decode_byte_lenenc(&mut self) -> Bytes { - let length = self.decode_int_1(); + let length = self.decode_int_u8(); let value = self.buf.slice(self.index, self.index + length as usize); self.index = self.index + length as usize; value @@ -319,7 +319,7 @@ mod tests { fn it_decodes_int_8() { let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: i64 = decoder.decode_int_8(); + let int: i64 = decoder.decode_int_i64(); assert_eq!(int, 0x0101010101010101); assert_eq!(decoder.index, 8); @@ -329,7 +329,7 @@ mod tests { fn it_decodes_int_4() { let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: i32 = decoder.decode_int_4(); + let int: i32 = decoder.decode_int_i32(); assert_eq!(int, 0x01010101); assert_eq!(decoder.index, 4); @@ -339,7 +339,7 @@ mod tests { fn it_decodes_int_3() { let buf = __bytes_builder!(1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: i32 = decoder.decode_int_3(); + let int: i32 = decoder.decode_int_i24(); assert_eq!(int, 0x010101); assert_eq!(decoder.index, 3); @@ -349,7 +349,7 @@ mod tests { fn it_decodes_int_2() { let buf = __bytes_builder!(1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: i16 = decoder.decode_int_2(); + let int: i16 = decoder.decode_int_i16(); assert_eq!(int, 0x0101); assert_eq!(decoder.index, 2); @@ -359,7 +359,7 @@ mod tests { fn it_decodes_int_1() { let buf = __bytes_builder!(1u8); let mut decoder = Decoder::new(&buf); - let int: u8 = decoder.decode_int_1(); + let int: u8 = decoder.decode_int_u8(); assert_eq!(int, 1u8); assert_eq!(decoder.index, 1); diff --git a/src/mariadb/protocol/encode.rs b/src/mariadb/protocol/encode.rs index a96346d6..ea455cdf 100644 --- a/src/mariadb/protocol/encode.rs +++ b/src/mariadb/protocol/encode.rs @@ -1,5 +1,6 @@ use byteorder::{ByteOrder, LittleEndian}; use bytes::{BufMut, Bytes, BytesMut}; +use crate::mariadb::FieldType; const U24_MAX: usize = 0xFF_FF_FF; @@ -37,7 +38,7 @@ impl Encoder { let mut length = [0; 3]; if self.buf.len() > U24_MAX { panic!("Buffer too long"); - } else if self.buf.len() <= 4 { + } else if self.buf.len() < 4 { panic!("Buffer too short. Only contains packet length and sequence number") } @@ -52,31 +53,80 @@ impl Encoder { // Encode a u64 as an int<8> #[inline] - pub fn encode_int_8(&mut self, value: u64) { + pub fn encode_int_u64(&mut self, value: u64) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode a i64 as an int<8> + #[inline] + pub fn encode_int_i64(&mut self, value: i64) { + self.buf.extend_from_slice(&value.to_le_bytes()); + } + + #[inline] + pub fn encode_int_8(&mut self, bytes: &Bytes) { + self.buf.extend_from_slice(bytes); + } + // Encode a u32 as an int<4> #[inline] - pub fn encode_int_4(&mut self, value: u32) { + pub fn encode_int_u32(&mut self, value: u32) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode a i32 as an int<4> + #[inline] + pub fn encode_int_i32(&mut self, value: i32) { + self.buf.extend_from_slice(&value.to_le_bytes()); + } + + #[inline] + pub fn encode_int_4(&mut self, bytes: &Bytes) { + self.buf.extend_from_slice(bytes); + } + // Encode a u32 (truncated to u24) as an int<3> #[inline] - pub fn encode_int_3(&mut self, value: u32) { + pub fn encode_int_u24(&mut self, value: u32) { + self.buf.extend_from_slice(&value.to_le_bytes()[0..3]); + } + // Encode a i32 (truncated to i24) as an int<3> + #[inline] + pub fn encode_int_i24(&mut self, value: i32) { self.buf.extend_from_slice(&value.to_le_bytes()[0..3]); } // Encode a u16 as an int<2> #[inline] - pub fn encode_int_2(&mut self, value: u16) { + pub fn encode_int_u16(&mut self, value: u16) { self.buf.extend_from_slice(&value.to_le_bytes()); } + // Encode a i16 as an int<2> + #[inline] + pub fn encode_int_i16(&mut self, value: i16) { + self.buf.extend_from_slice(&value.to_le_bytes()); + } + + #[inline] + pub fn encode_int_2(&mut self, bytes: &Bytes) { + self.buf.extend_from_slice(bytes); + } + // Encode a u8 as an int<1> #[inline] - pub fn encode_int_1(&mut self, value: u8) { + pub fn encode_int_u8(&mut self, value: u8) { + self.buf.extend_from_slice(&value.to_le_bytes()); + } + + #[inline] + pub fn encode_int_1(&mut self, bytes: &Bytes) { + self.buf.extend_from_slice(bytes); + } + + // Encode a i8 as an int<1> + #[inline] + pub fn encode_int_i8(&mut self, value: i8) { self.buf.extend_from_slice(&value.to_le_bytes()); } @@ -87,15 +137,15 @@ impl Encoder { if let Some(value) = value { if *value > U24_MAX && *value <= std::u64::MAX as usize { self.buf.put_u8(0xFE); - self.encode_int_8(*value as u64); + self.encode_int_u64(*value as u64); } else if *value > std::u16::MAX as usize && *value <= U24_MAX { self.buf.put_u8(0xFD); - self.encode_int_3(*value as u32); + self.encode_int_u24(*value as u32); } else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize { self.buf.put_u8(0xFC); - self.encode_int_2(*value as u16); + self.encode_int_u16(*value as u16); } else if *value <= std::u8::MAX as usize { match *value { @@ -160,7 +210,7 @@ impl Encoder { panic!("String inside string lenenc serialization is too long"); } - self.encode_int_3(bytes.len() as u32); + self.encode_int_u24(bytes.len() as u32); self.buf.extend_from_slice(bytes); } @@ -179,6 +229,43 @@ impl Encoder { pub fn encode_byte_eof(&mut self, bytes: &Bytes) { self.buf.extend_from_slice(bytes); } + + #[inline] + pub fn encode_binary(&mut self, bytes: &Bytes, ty: &FieldType) { + match ty { + FieldType::MysqlTypeDecimal => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeTiny => self.encode_int_1(bytes), + FieldType::MysqlTypeShort => self.encode_int_2(bytes), + FieldType::MysqlTypeLong => self.encode_int_4(bytes), + FieldType::MysqlTypeFloat => self.encode_int_4(bytes), + FieldType::MysqlTypeDouble => self.encode_int_8(bytes), + FieldType::MysqlTypeNull => panic!("Type cannot be FieldType::MysqlTypeNull"), + FieldType::MysqlTypeTimestamp => unimplemented!(), + FieldType::MysqlTypeLonglong => self.encode_int_8(bytes), + FieldType::MysqlTypeInt24 => self.encode_int_4(bytes), + FieldType::MysqlTypeDate => unimplemented!(), + FieldType::MysqlTypeTime => unimplemented!(), + FieldType::MysqlTypeDatetime => unimplemented!(), + FieldType::MysqlTypeYear => self.encode_int_4(bytes), + FieldType::MysqlTypeNewdate => unimplemented!(), + FieldType::MysqlTypeVarchar => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeBit => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeTimestamp2 => unimplemented!(), + FieldType::MysqlTypeDatetime2 => unimplemented!(), + FieldType::MysqlTypeTime2 =>unimplemented!(), + FieldType::MysqlTypeJson => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeNewdecimal => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeEnum => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeSet => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeTinyBlob => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeMediumBlob => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeLongBlob => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeBlob => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeVarString => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeString => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeGeometry => self.encode_string_lenenc(bytes), + } + } } impl From for Encoder { @@ -292,7 +379,7 @@ mod tests { #[test] fn it_encodes_int_u64() { let mut encoder = Encoder::new(128); - encoder.encode_int_8(std::u64::MAX); + encoder.encode_int_u64(std::u64::MAX); assert_eq!(&encoder.buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); } @@ -300,7 +387,7 @@ mod tests { #[test] fn it_encodes_int_u32() { let mut encoder = Encoder::new(128); - encoder.encode_int_4(std::u32::MAX); + encoder.encode_int_u32(std::u32::MAX); assert_eq!(&encoder.buf[..], b"\xFF\xFF\xFF\xFF"); } @@ -308,7 +395,7 @@ mod tests { #[test] fn it_encodes_int_u24() { let mut encoder = Encoder::new(128); - encoder.encode_int_3(U24_MAX as u32); + encoder.encode_int_u24(U24_MAX as u32); assert_eq!(&encoder.buf[..], b"\xFF\xFF\xFF"); } @@ -316,7 +403,7 @@ mod tests { #[test] fn it_encodes_int_u16() { let mut encoder = Encoder::new(128); - encoder.encode_int_2(std::u16::MAX); + encoder.encode_int_u16(std::u16::MAX); assert_eq!(&encoder.buf[..], b"\xFF\xFF"); } @@ -324,7 +411,7 @@ mod tests { #[test] fn it_encodes_int_u8() { let mut encoder = Encoder::new(128); - encoder.encode_int_1(std::u8::MAX); + encoder.encode_int_u8(std::u8::MAX); assert_eq!(&encoder.buf[..], b"\xFF"); } diff --git a/src/mariadb/protocol/mod.rs b/src/mariadb/protocol/mod.rs index 8264e67b..7bc32159 100644 --- a/src/mariadb/protocol/mod.rs +++ b/src/mariadb/protocol/mod.rs @@ -34,6 +34,11 @@ pub use packets::PacketHeader; pub use packets::ResultSet; pub use packets::ResultRow; pub use packets::SSLRequestPacket; +pub use packets::ComStmtPrepare; +pub use packets::ComStmtPrepareOk; +pub use packets::ComStmtPrepareResp; +pub use packets::ComStmtClose; +pub use packets::ComStmtExec; pub use decode::Decoder; @@ -53,5 +58,7 @@ pub use types::ServerStatusFlag; pub use types::FieldType; pub use types::FieldDetailFlag; pub use types::SessionChangeType; +pub use types::StmtExecFlag; pub use client::TextProtocol; +pub use client::BinaryProtocol; diff --git a/src/mariadb/protocol/packets/auth_switch_request.rs b/src/mariadb/protocol/packets/auth_switch_request.rs index d16beafa..a364c216 100644 --- a/src/mariadb/protocol/packets/auth_switch_request.rs +++ b/src/mariadb/protocol/packets/auth_switch_request.rs @@ -11,7 +11,7 @@ pub struct AuthenticationSwitchRequestPacket { impl Serialize for AuthenticationSwitchRequestPacket { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(0xFE); + encoder.encode_int_u8(0xFE); encoder.encode_string_null(&self.auth_plugin_name); encoder.encode_byte_eof(&self.auth_plugin_data); diff --git a/src/mariadb/protocol/packets/column.rs b/src/mariadb/protocol/packets/column.rs index 85f32c67..e728b3e9 100644 --- a/src/mariadb/protocol/packets/column.rs +++ b/src/mariadb/protocol/packets/column.rs @@ -17,7 +17,7 @@ impl Deserialize for ColumnPacket { let decoder = &mut ctx.decoder; let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); let columns = decoder.decode_int_lenenc(); diff --git a/src/mariadb/protocol/packets/column_def.rs b/src/mariadb/protocol/packets/column_def.rs index 5283be6a..68890ed6 100644 --- a/src/mariadb/protocol/packets/column_def.rs +++ b/src/mariadb/protocol/packets/column_def.rs @@ -25,7 +25,7 @@ impl Deserialize for ColumnDefPacket { fn deserialize(ctx: &mut DeContext) -> Result { let decoder = &mut ctx.decoder; let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); // string catalog (always 'def') let catalog = decoder.decode_string_lenenc(); @@ -42,15 +42,15 @@ impl Deserialize for ColumnDefPacket { // int length of fixed fields (=0xC) let length_of_fixed_fields = decoder.decode_int_lenenc(); // int<2> character set number - let char_set = decoder.decode_int_2(); + let char_set = decoder.decode_int_i16(); // int<4> max. column size - let max_columns = decoder.decode_int_4(); + let max_columns = decoder.decode_int_i32(); // int<1> Field types - let field_type = FieldType::try_from(decoder.decode_int_1())?; + let field_type = FieldType::try_from(decoder.decode_int_u8())?; // int<2> Field detail flag - let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_2_unsigned()); + let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_u16()); // int<1> decimals - let decimals = decoder.decode_int_1(); + let decimals = decoder.decode_int_u8(); // int<2> - unused - decoder.skip_bytes(2); diff --git a/src/mariadb/protocol/packets/com_debug.rs b/src/mariadb/protocol/packets/com_debug.rs index 81632bd9..022e863d 100644 --- a/src/mariadb/protocol/packets/com_debug.rs +++ b/src/mariadb/protocol/packets/com_debug.rs @@ -5,7 +5,12 @@ pub struct ComDebug(); impl Serialize for ComDebug { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComDebug.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComDebug.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_init_db.rs b/src/mariadb/protocol/packets/com_init_db.rs index 01bed29e..0bfafded 100644 --- a/src/mariadb/protocol/packets/com_init_db.rs +++ b/src/mariadb/protocol/packets/com_init_db.rs @@ -8,9 +8,14 @@ pub struct ComInitDb { impl Serialize for ComInitDb { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComInitDb.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComInitDb.into()); encoder.encode_string_null(&self.schema_name); + encoder.encode_length(); + Ok(()) } } diff --git a/src/mariadb/protocol/packets/com_ping.rs b/src/mariadb/protocol/packets/com_ping.rs index 159dcc08..7f834183 100644 --- a/src/mariadb/protocol/packets/com_ping.rs +++ b/src/mariadb/protocol/packets/com_ping.rs @@ -5,7 +5,12 @@ pub struct ComPing(); impl Serialize for ComPing { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComPing.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComPing.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_process_kill.rs b/src/mariadb/protocol/packets/com_process_kill.rs index 466c523d..ab68918f 100644 --- a/src/mariadb/protocol/packets/com_process_kill.rs +++ b/src/mariadb/protocol/packets/com_process_kill.rs @@ -7,8 +7,13 @@ pub struct ComProcessKill { impl Serialize for ComProcessKill { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComProcessKill.into()); - encoder.encode_int_4(self.process_id); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComProcessKill.into()); + encoder.encode_int_u32(self.process_id); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_query.rs b/src/mariadb/protocol/packets/com_query.rs index ec404373..512d2a69 100644 --- a/src/mariadb/protocol/packets/com_query.rs +++ b/src/mariadb/protocol/packets/com_query.rs @@ -8,9 +8,14 @@ pub struct ComQuery { impl Serialize for ComQuery { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComQuery.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComQuery.into()); encoder.encode_string_eof(&self.sql_statement); + encoder.encode_length(); + Ok(()) } } diff --git a/src/mariadb/protocol/packets/com_quit.rs b/src/mariadb/protocol/packets/com_quit.rs index 7ea97c08..2c8c8184 100644 --- a/src/mariadb/protocol/packets/com_quit.rs +++ b/src/mariadb/protocol/packets/com_quit.rs @@ -5,7 +5,12 @@ pub struct ComQuit(); impl Serialize for ComQuit { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComQuit.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComQuit.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_reset_conn.rs b/src/mariadb/protocol/packets/com_reset_conn.rs index 355af64e..f6692b84 100644 --- a/src/mariadb/protocol/packets/com_reset_conn.rs +++ b/src/mariadb/protocol/packets/com_reset_conn.rs @@ -5,7 +5,12 @@ pub struct ComResetConnection(); impl Serialize for ComResetConnection { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComResetConnection.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComResetConnection.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_set_option.rs b/src/mariadb/protocol/packets/com_set_option.rs index 0c4d4ecd..cc23ca81 100644 --- a/src/mariadb/protocol/packets/com_set_option.rs +++ b/src/mariadb/protocol/packets/com_set_option.rs @@ -13,8 +13,13 @@ pub struct ComSetOption { impl Serialize for ComSetOption { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComSetOption.into()); - encoder.encode_int_2(self.option.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComSetOption.into()); + encoder.encode_int_u16(self.option.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_shutdown.rs b/src/mariadb/protocol/packets/com_shutdown.rs index ac6bb3e8..8ef0e203 100644 --- a/src/mariadb/protocol/packets/com_shutdown.rs +++ b/src/mariadb/protocol/packets/com_shutdown.rs @@ -12,8 +12,13 @@ pub struct ComShutdown { impl Serialize for ComShutdown { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComShutdown.into()); - encoder.encode_int_1(self.option.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComShutdown.into()); + encoder.encode_int_u8(self.option.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_sleep.rs b/src/mariadb/protocol/packets/com_sleep.rs index 284feb68..30326816 100644 --- a/src/mariadb/protocol/packets/com_sleep.rs +++ b/src/mariadb/protocol/packets/com_sleep.rs @@ -5,7 +5,12 @@ pub struct ComSleep(); impl Serialize for ComSleep { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComSleep.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComSleep.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_statistics.rs b/src/mariadb/protocol/packets/com_statistics.rs index 009ccbee..722a8240 100644 --- a/src/mariadb/protocol/packets/com_statistics.rs +++ b/src/mariadb/protocol/packets/com_statistics.rs @@ -5,7 +5,12 @@ pub struct ComStatistics(); impl Serialize for ComStatistics { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_1(TextProtocol::ComStatistics.into()); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(TextProtocol::ComStatistics.into()); + + encoder.encode_length(); Ok(()) } diff --git a/src/mariadb/protocol/packets/com_stmt_close.rs b/src/mariadb/protocol/packets/com_stmt_close.rs new file mode 100644 index 00000000..b816beee --- /dev/null +++ b/src/mariadb/protocol/packets/com_stmt_close.rs @@ -0,0 +1,20 @@ +use std::convert::TryInto; + +#[derive(Debug)] +pub struct ComStmtClose { + stmt_id: i32 +} + +impl crate::mariadb::Serialize for ComStmtClose { + fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), failure::Error> { + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(crate::mariadb::BinaryProtocol::ComStmtClose.into()); + encoder.encode_int_i32(self.stmt_id); + + encoder.encode_length(); + + Ok(()) + } +} diff --git a/src/mariadb/protocol/packets/com_stmt_exec.rs b/src/mariadb/protocol/packets/com_stmt_exec.rs new file mode 100644 index 00000000..797aee95 --- /dev/null +++ b/src/mariadb/protocol/packets/com_stmt_exec.rs @@ -0,0 +1,68 @@ +use crate::mariadb::{StmtExecFlag, ColumnDefPacket, FieldDetailFlag}; +use bytes::Bytes; + +#[derive(Debug)] +pub struct ComStmtExec { + pub stmt_id: i32, + pub flags: StmtExecFlag, + pub params: Option>>, + pub param_defs: Option>, +} + +impl crate::mariadb::Serialize for ComStmtExec { + fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), failure::Error> { + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(crate::mariadb::BinaryProtocol::ComStmtExec.into()); + encoder.encode_int_i32(self.stmt_id); + encoder.encode_int_u8(self.flags as u8); + encoder.encode_int_u8(0); + + if let Some(params) = &self.params { + let null_bitmap_size = (params.len() + 7) / 8; + let mut shift_amount = 0u8; + let mut bitmap = vec![0u8]; + + // Generate NULL-bitmap from params + for param in params { + if param.is_none() { + bitmap.push(bitmap.last().unwrap() & (1 << shift_amount)); + } + + shift_amount = (shift_amount + 1) % 8; + + if shift_amount % 8 == 0 { + bitmap.push(0u8); + } + } + + // Do not send the param types + encoder.encode_int_u8(if self.param_defs.is_some() { + 1u8 + } else { + 0u8 + }); + + if let Some(params) = &self.param_defs { + for param in params { + encoder.encode_int_u8(param.field_type as u8); + encoder.encode_int_u8(if (param.field_details & FieldDetailFlag::UNSIGNED).is_empty() { + 1u8 + } else { + 0u8 + }); + } + } + + // Encode params + for param in params { + + } + } + + encoder.encode_length(); + + Ok(()) + } +} diff --git a/src/mariadb/protocol/packets/com_stmt_prepare.rs b/src/mariadb/protocol/packets/com_stmt_prepare.rs new file mode 100644 index 00000000..56297830 --- /dev/null +++ b/src/mariadb/protocol/packets/com_stmt_prepare.rs @@ -0,0 +1,20 @@ +use bytes::Bytes; + +#[derive(Debug)] +pub struct ComStmtPrepare { + statement: Bytes +} + +impl crate::mariadb::Serialize for ComStmtPrepare { + fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), failure::Error> { + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u8(crate::mariadb::BinaryProtocol::ComStmtPrepare.into()); + encoder.encode_string_eof(&self.statement); + + encoder.encode_length(); + + Ok(()) + } +} diff --git a/src/mariadb/protocol/packets/com_stmt_prepare_ok.rs b/src/mariadb/protocol/packets/com_stmt_prepare_ok.rs new file mode 100644 index 00000000..f74a4b57 --- /dev/null +++ b/src/mariadb/protocol/packets/com_stmt_prepare_ok.rs @@ -0,0 +1,34 @@ +use std::convert::TryFrom; + +#[derive(Debug)] +pub struct ComStmtPrepareOk { + pub stmt_id: i32, + pub columns: i16, + pub params: i16, + pub warnings: i16, +} + +impl crate::mariadb::Deserialize for ComStmtPrepareOk { + fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result { + let decoder = &mut ctx.decoder; + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_u8(); + + let stmt_id = decoder.decode_int_i32(); + + let columns = decoder.decode_int_i16(); + let params = decoder.decode_int_i16(); + + // Skip 1 unused byte; + decoder.skip_bytes(1); + + let warnings = decoder.decode_int_i16(); + + Ok(ComStmtPrepareOk { + stmt_id, + columns, + params, + warnings + }) + } +} diff --git a/src/mariadb/protocol/packets/com_stmt_prepare_resp.rs b/src/mariadb/protocol/packets/com_stmt_prepare_resp.rs new file mode 100644 index 00000000..b02bbb5f --- /dev/null +++ b/src/mariadb/protocol/packets/com_stmt_prepare_resp.rs @@ -0,0 +1,60 @@ +use crate::mariadb::{ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket}; + +#[derive(Debug)] +pub struct ComStmtPrepareResp { + pub ok: ComStmtPrepareOk, + pub param_defs: Option>, + pub res_columns: Option>, +} + +//int<1> 0x00 COM_STMT_PREPARE_OK header +//int<4> statement id +//int<2> number of columns in the returned result set (or 0 if statement does not return result set) +//int<2> number of prepared statement parameters ('?' placeholders) +//string<1> -not used- +//int<2> number of warnings + +impl crate::mariadb::Deserialize for ComStmtPrepareResp { + fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result { + let decoder = &mut ctx.decoder; + let length = decoder.decode_length()?; + + let ok = ComStmtPrepareOk::deserialize(ctx)?; + + let param_defs = if ok.params > 0 { + let param_defs = (0..ok.params).map(|_| ColumnDefPacket::deserialize(ctx)) + .filter(Result::is_ok) + .map(Result::unwrap) + .collect::>(); + + if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() { + EofPacket::deserialize(ctx)?; + } + + Some(param_defs) + } else { + None + }; + + let res_columns = if ok.columns > 0 { + let param_defs = (0..ok.columns).map(|_| ColumnDefPacket::deserialize(ctx)) + .filter(Result::is_ok) + .map(Result::unwrap) + .collect::>(); + + if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() { + EofPacket::deserialize(ctx)?; + } + + Some(param_defs) + } else { + None + }; + + Ok(ComStmtPrepareResp { + ok, + param_defs, + res_columns + }) + } +} diff --git a/src/mariadb/protocol/packets/eof.rs b/src/mariadb/protocol/packets/eof.rs index 2a7330e6..c43ad6bf 100644 --- a/src/mariadb/protocol/packets/eof.rs +++ b/src/mariadb/protocol/packets/eof.rs @@ -18,16 +18,16 @@ impl Deserialize for EofPacket { let decoder = &mut ctx.decoder; let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); - let packet_header = decoder.decode_int_1(); + let packet_header = decoder.decode_int_u8(); if packet_header != 0xFE { panic!("Packet header is not 0xFE for ErrPacket"); } - let warning_count = decoder.decode_int_2(); - let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned()); + let warning_count = decoder.decode_int_i16(); + let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16()); Ok(EofPacket { length, seq_no, warning_count, status }) } diff --git a/src/mariadb/protocol/packets/err.rs b/src/mariadb/protocol/packets/err.rs index 24eab35c..138305eb 100644 --- a/src/mariadb/protocol/packets/err.rs +++ b/src/mariadb/protocol/packets/err.rs @@ -24,14 +24,14 @@ impl Deserialize for ErrPacket { fn deserialize(ctx: &mut DeContext) -> Result { let decoder = &mut ctx.decoder; let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); - let packet_header = decoder.decode_int_1(); + let packet_header = decoder.decode_int_u8(); if packet_header != 0xFF { panic!("Packet header is not 0xFF for ErrPacket"); } - let error_code = ErrorCode::try_from(decoder.decode_int_2())?; + let error_code = ErrorCode::try_from(decoder.decode_int_i16())?; let mut stage = None; let mut max_stage = None; @@ -44,9 +44,9 @@ impl Deserialize for ErrPacket { // Progress Reporting if error_code as u16 == 0xFFFF { - stage = Some(decoder.decode_int_1()); - max_stage = Some(decoder.decode_int_1()); - progress = Some(decoder.decode_int_3()); + stage = Some(decoder.decode_int_u8()); + max_stage = Some(decoder.decode_int_u8()); + progress = Some(decoder.decode_int_i24()); progress_info = Some(decoder.decode_string_lenenc()); } else { if decoder.buf[decoder.index] == b'#' { diff --git a/src/mariadb/protocol/packets/handshake_response.rs b/src/mariadb/protocol/packets/handshake_response.rs index 5fcbe8c8..96db1851 100644 --- a/src/mariadb/protocol/packets/handshake_response.rs +++ b/src/mariadb/protocol/packets/handshake_response.rs @@ -20,9 +20,12 @@ pub struct HandshakeResponsePacket { impl Serialize for HandshakeResponsePacket { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_4(self.capabilities.bits() as u32); - encoder.encode_int_4(self.max_packet_size); - encoder.encode_int_1(self.collation); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u32(self.capabilities.bits() as u32); + encoder.encode_int_u32(self.max_packet_size); + encoder.encode_int_u8(self.collation); // Filler encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19); @@ -31,7 +34,7 @@ impl Serialize for HandshakeResponsePacket { && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { - encoder.encode_int_4(capabilities.bits() as u32); + encoder.encode_int_u32(capabilities.bits() as u32); } } else { encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 4]), 4); @@ -45,12 +48,12 @@ impl Serialize for HandshakeResponsePacket { } } else if !(ctx.capabilities & Capabilities::SECURE_CONNECTION).is_empty() { if let Some(auth_response) = &self.auth_response { - encoder.encode_int_1(self.auth_response_len.unwrap()); + encoder.encode_int_u8(self.auth_response_len.unwrap()); encoder .encode_string_fix(&auth_response, self.auth_response_len.unwrap() as usize); } } else { - encoder.encode_int_1(0); + encoder.encode_int_u8(0); } if !(ctx.capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { @@ -80,6 +83,8 @@ impl Serialize for HandshakeResponsePacket { } } + encoder.encode_length(); + Ok(()) } } diff --git a/src/mariadb/protocol/packets/initial.rs b/src/mariadb/protocol/packets/initial.rs index 4d0b5bd6..95aade5a 100644 --- a/src/mariadb/protocol/packets/initial.rs +++ b/src/mariadb/protocol/packets/initial.rs @@ -23,31 +23,31 @@ impl Deserialize for InitialHandshakePacket { fn deserialize(ctx: &mut DeContext) -> Result { let decoder = &mut ctx.decoder; let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); if seq_no != 0 { return Err(err_msg("Sequence Number of Initial Handshake Packet is not 0")); } - let protocol_version = decoder.decode_int_1(); + let protocol_version = decoder.decode_int_u8(); let server_version = decoder.decode_string_null()?; - let connection_id = decoder.decode_int_4(); + let connection_id = decoder.decode_int_i32(); let auth_seed = decoder.decode_string_fix(8); // Skip reserved byte decoder.skip_bytes(1); - let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_2_unsigned().into()); + let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_u16().into()); - let collation = decoder.decode_int_1(); - let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned().into()); + let collation = decoder.decode_int_u8(); + let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into()); capabilities |= - Capabilities::from_bits_truncate(((decoder.decode_int_2() as u32) << 16).into()); + Capabilities::from_bits_truncate(((decoder.decode_int_i16() as u32) << 16).into()); let mut plugin_data_length = 0; if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() { - plugin_data_length = decoder.decode_int_1(); + plugin_data_length = decoder.decode_int_u8(); } else { // Skip reserve byte decoder.skip_bytes(1); @@ -58,7 +58,7 @@ impl Deserialize for InitialHandshakePacket { if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() { capabilities |= - Capabilities::from_bits_truncate(((decoder.decode_int_4_unsigned() as u128) << 32).into()); + Capabilities::from_bits_truncate(((decoder.decode_int_u32() as u128) << 32).into()); } else { // Skip filler decoder.skip_bytes(4); diff --git a/src/mariadb/protocol/packets/mod.rs b/src/mariadb/protocol/packets/mod.rs index e1998923..34ca06c9 100644 --- a/src/mariadb/protocol/packets/mod.rs +++ b/src/mariadb/protocol/packets/mod.rs @@ -21,6 +21,11 @@ pub mod packet_header; pub mod result_set; pub mod ssl_request; pub mod result_row; +pub mod com_stmt_prepare; +pub mod com_stmt_prepare_ok; +pub mod com_stmt_prepare_resp; +pub mod com_stmt_close; +pub mod com_stmt_exec; pub use auth_switch_request::AuthenticationSwitchRequestPacket; pub use column::ColumnPacket; @@ -47,3 +52,8 @@ pub use packet_header::PacketHeader; pub use result_set::ResultSet; pub use result_row::ResultRow; pub use ssl_request::SSLRequestPacket; +pub use com_stmt_prepare::ComStmtPrepare; +pub use com_stmt_prepare_ok::ComStmtPrepareOk; +pub use com_stmt_prepare_resp::ComStmtPrepareResp; +pub use com_stmt_close::ComStmtClose; +pub use com_stmt_exec::ComStmtExec; diff --git a/src/mariadb/protocol/packets/ok.rs b/src/mariadb/protocol/packets/ok.rs index 26341397..cdd47a66 100644 --- a/src/mariadb/protocol/packets/ok.rs +++ b/src/mariadb/protocol/packets/ok.rs @@ -23,21 +23,21 @@ impl Deserialize for OkPacket { // Packet header let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); // Used later for the byte_eof decoding let index = decoder.index; // Packet body - let packet_header = decoder.decode_int_1(); + let packet_header = decoder.decode_int_u8(); if packet_header != 0 && packet_header != 0xFE { return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket")); } let affected_rows = decoder.decode_int_lenenc(); let last_insert_id = decoder.decode_int_lenenc(); - let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned().into()); - let warning_count = decoder.decode_int_2(); + let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into()); + let warning_count = decoder.decode_int_i16(); // Assuming CLIENT_SESSION_TRACK is unsupported let session_state_info = None; diff --git a/src/mariadb/protocol/packets/result_row.rs b/src/mariadb/protocol/packets/result_row.rs index 07d545cc..cccfdce2 100644 --- a/src/mariadb/protocol/packets/result_row.rs +++ b/src/mariadb/protocol/packets/result_row.rs @@ -17,7 +17,7 @@ impl Deserialize for ResultRow { let decoder = &mut ctx.decoder; let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); + let seq_no = decoder.decode_int_u8(); let row = if let Some(columns) = ctx.columns { (0..columns).map(|_| decoder.decode_string_lenenc()).collect::>() diff --git a/src/mariadb/protocol/packets/ssl_request.rs b/src/mariadb/protocol/packets/ssl_request.rs index 09e16feb..13eea486 100644 --- a/src/mariadb/protocol/packets/ssl_request.rs +++ b/src/mariadb/protocol/packets/ssl_request.rs @@ -13,9 +13,12 @@ pub struct SSLRequestPacket { impl Serialize for SSLRequestPacket { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { - encoder.encode_int_4(self.capabilities.bits() as u32); - encoder.encode_int_4(self.max_packet_size); - encoder.encode_int_1(self.collation); + encoder.alloc_packet_header(); + encoder.seq_no(0); + + encoder.encode_int_u32(self.capabilities.bits() as u32); + encoder.encode_int_u32(self.max_packet_size); + encoder.encode_int_u8(self.collation); // Filler encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19); @@ -24,12 +27,14 @@ impl Serialize for SSLRequestPacket { && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { - encoder.encode_int_4(capabilities.bits() as u32); + encoder.encode_int_u32(capabilities.bits() as u32); } } else { encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 4]), 4); } + encoder.encode_length(); + Ok(()) } } diff --git a/src/mariadb/protocol/types.rs b/src/mariadb/protocol/types.rs index 6cb4fd8d..896dd84e 100644 --- a/src/mariadb/protocol/types.rs +++ b/src/mariadb/protocol/types.rs @@ -111,6 +111,21 @@ pub enum FieldType { MysqlTypeGeometry = 255, } +#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)] +#[TryFromPrimitiveType = "u8"] +pub enum StmtExecFlag { + NoCursor = 0, + ReadOnly = 1, + CursorForUpdate = 2, + ScrollableCursor = 3, +} + +#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)] +#[TryFromPrimitiveType = "u8"] +pub enum ParamFlag { + Unsigned = 128, +} + impl Default for Capabilities { fn default() -> Self { Capabilities::CLIENT_MYSQL @@ -135,6 +150,18 @@ impl Default for FieldType { } } +impl Default for StmtExecFlag { + fn default() -> Self { + StmtExecFlag::NoCursor + } +} + +impl Default for ParamFlag { + fn default() -> Self { + ParamFlag::Unsigned + } +} + #[cfg(test)] mod test { use super::super::{decode::Decoder, types::Capabilities}; @@ -144,6 +171,6 @@ mod test { fn it_decodes_capabilities() { let buf = Bytes::from(b"\xfe\xf7".to_vec()); let mut decoder = Decoder::new(&buf); - Capabilities::from_bits_truncate(decoder.decode_int_2_unsigned().into()); + Capabilities::from_bits_truncate(decoder.decode_int_u16().into()); } }