From 05024dbd0a18cd91db7b69f996c632d091814f0b Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Fri, 26 Jul 2019 13:28:12 -0700 Subject: [PATCH] Fix result set decoding OkPacket instead of EofPacket at the end of the transaction --- mason-mariadb/src/protocol/decode.rs | 20 ++- mason-mariadb/src/protocol/packets/err.rs | 4 +- mason-mariadb/src/protocol/packets/ok.rs | 10 +- .../src/protocol/packets/result_set.rs | 133 +++++++++++++----- 4 files changed, 121 insertions(+), 46 deletions(-) diff --git a/mason-mariadb/src/protocol/decode.rs b/mason-mariadb/src/protocol/decode.rs index 8352d1ed2..16466eaf7 100644 --- a/mason-mariadb/src/protocol/decode.rs +++ b/mason-mariadb/src/protocol/decode.rs @@ -144,8 +144,12 @@ impl<'a> Decoder<'a> { } #[inline] - pub fn decode_string_eof(&mut self) -> Bytes { - let value = self.buf.slice(self.index, self.buf.len()); + pub fn decode_string_eof(&mut self, length: usize) -> Bytes { + let value = self.buf.slice(self.index, if length >= self.index { + length + } else { + self.buf.len() + }); self.index = self.buf.len(); value } @@ -177,8 +181,12 @@ impl<'a> Decoder<'a> { } #[inline] - pub fn decode_byte_eof(&mut self) -> Bytes { - let value = self.buf.slice(self.index, self.buf.len()); + pub fn decode_byte_eof(&mut self, length: usize) -> Bytes { + let value = self.buf.slice(self.index, if length >= self.index { + length + } else { + self.buf.len() + }); self.index = self.buf.len(); value } @@ -330,7 +338,7 @@ mod tests { fn it_decodes_string_eof() { let buf = Bytes::from(b"\x01".to_vec()); let mut decoder = Decoder::new(&buf); - let string: Bytes = decoder.decode_string_eof(); + let string: Bytes = decoder.decode_string_eof(0); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -367,7 +375,7 @@ mod tests { fn it_decodes_byte_eof() { let buf = Bytes::from(b"\x01".to_vec()); let mut decoder = Decoder::new(&buf); - let string: Bytes = decoder.decode_byte_eof(); + let string: Bytes = decoder.decode_byte_eof(0); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); diff --git a/mason-mariadb/src/protocol/packets/err.rs b/mason-mariadb/src/protocol/packets/err.rs index 4fe12b05d..1e61b03e5 100644 --- a/mason-mariadb/src/protocol/packets/err.rs +++ b/mason-mariadb/src/protocol/packets/err.rs @@ -54,9 +54,9 @@ impl Deserialize for ErrPacket { if decoder.buf[decoder.index] == b'#' { sql_state_marker = Some(decoder.decode_string_fix(1)); sql_state = Some(decoder.decode_string_fix(5)); - error_message = Some(decoder.decode_string_eof()); + error_message = Some(decoder.decode_string_eof(length as usize)); } else { - error_message = Some(decoder.decode_string_eof()); + error_message = Some(decoder.decode_string_eof(length as usize)); } } diff --git a/mason-mariadb/src/protocol/packets/ok.rs b/mason-mariadb/src/protocol/packets/ok.rs index 0558a8b68..eeb213a6f 100644 --- a/mason-mariadb/src/protocol/packets/ok.rs +++ b/mason-mariadb/src/protocol/packets/ok.rs @@ -22,10 +22,14 @@ pub struct OkPacket { impl Deserialize for OkPacket { fn deserialize(ctx: &mut DeContext) -> Result { let decoder = &mut ctx.decoder; + // Packet header let length = decoder.decode_length()?; let seq_no = decoder.decode_int_1(); + // Used later for the byte_eof decoding + let index = decoder.index; + // Packet body let packet_header = decoder.decode_int_1(); if packet_header != 0 && packet_header != 0xFE { @@ -41,7 +45,7 @@ impl Deserialize for OkPacket { let session_state_info = None; let value = None; - let info = decoder.decode_byte_eof(); + let info = decoder.decode_byte_eof(index + length as usize); Ok(OkPacket { length, @@ -78,9 +82,9 @@ mod test { #[rustfmt::skip] let buf = __bytes_builder!( - // length + // int<3> length 0u8, 0u8, 0u8, - // seq_no + // // int<1> seq_no 1u8, // 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set) 0u8, diff --git a/mason-mariadb/src/protocol/packets/result_set.rs b/mason-mariadb/src/protocol/packets/result_set.rs index 3502e1ac0..7f8acf015 100644 --- a/mason-mariadb/src/protocol/packets/result_set.rs +++ b/mason-mariadb/src/protocol/packets/result_set.rs @@ -63,9 +63,9 @@ impl Deserialize for ResultSet { } if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() { - EofPacket::deserialize(ctx)?; - } else { OkPacket::deserialize(ctx)?; + } else { + EofPacket::deserialize(ctx)?; } Ok(ResultSet { @@ -97,156 +97,219 @@ mod test { .await?; // TODO: Use byte string as input for test; this is a valid return from a mariadb. - // Reference: b"\x01\0\0\x01\x04(\0\0\x02\x03def\x04test\x05users\x05users\x02id\x02id\x0c\x08\0\x80\0\0\0\xfd\x03@\0\0\04\0\0\x03\x03def\x04test\x05users\x05users\x08username\x08username\x0c\x08\0\xff\xff\0\0\xfc\x11\x10\0\0\04\0\0\x04\x03def\x04test\x05users\x05users\x08password\x08password\x0c\x08\0\xff\xff\0\0\xfc\x11\x10\0\0\0<\0\0\x05\x03def\x04test\x05users\x05users\x0caccess_level\x0caccess_level\x0c\x08\0\x07\0\0\0\xfe\x01\x11\0\0\0\x05\0\0\x06\xfe\0\0\"\0>\0\0\x07$044d3f34-af65-11e9-a2e5-0242ac110003\x04josh\x0bpassword123\x07regular4\0\0\x08$d83dd1c4-ada9-11e9-96bc-0242ac110003\x06daniel\x01f\x05admin\x05\0\0\t\xfe\0\0\"\0\0 #[rustfmt::skip] let buf = __bytes_builder!( // ------------------- // // Column Count packet // // ------------------- // + // int<3> length 1u8, 0u8, 0u8, + // int<1> seq_no 1u8, + // int tag code or length 4u8, // ------------------------ // // Column Definition packet // // ------------------------ // + // int<3> length 40u8, 0u8, 0u8, + // int<1> seq_no 2u8, + // string catalog (always 'def') 3u8, b"def", + // string schema 4u8, b"test", + // string table alias 5u8, b"users", + // string table 5u8, b"users", + // string column alias 2u8, b"id", + // string column 2u8, b"id", + // int length of fixed fields (=0xC) 0x0C_u8, + // int<2> character set number 8u8, 0u8, + // int<4> max. column size 0x80_u8, 0u8, 0u8, 0u8, + // int<1> Field types 0xFD_u8, + // int<2> Field detail flag 3u8, 64u8, + // int<1> decimals 0u8, + // int<2> - unused - 0u8, 0u8, // ------------------------ // // Column Definition packet // // ------------------------ // + // int<3> length 52u8, 0u8, 0u8, + // int<1> seq_no 3u8, + // string catalog (always 'def') 3u8, b"def", + // string schema 4u8, b"test", + // string table alias 5u8, b"users", + // string table 5u8, b"users", + // string column alias 8u8, b"username", + // string column 8u8, b"username", + // int length of fixed fields (=0xC) 0x0C_u8, + // int<2> character set number 8u8, 0u8, + // int<4> max. column size 0xFF_u8, 0xFF_u8, 0u8, 0u8, + // int<1> Field types 0xFC_u8, + // int<2> Field detail flag 0x11_u8, 0x10_u8, + // int<1> decimals 0u8, + // int<2> - unused - 0u8, 0u8, // ------------------------ // // Column Definition packet // // ------------------------ // + // int<3> length 52u8, 0u8, 0u8, + // int<1> seq_no 4u8, + // string catalog (always 'def') 3u8, b"def", + // string schema 4u8, b"test", + // string table alias 5u8, b"users", + // string table 5u8, b"users", + // string column alias 8u8, b"password", + // string column 8u8, b"password", + // int length of fixed fields (=0xC) 0x0C_u8, + // int<2> character set number 8u8, 0u8, + // int<4> max. column size 0xFF_u8, 0xFF_u8, 0u8, 0u8, + // int<1> Field types 0xFC_u8, + // int<2> Field detail flag 0x11_u8, 0x10_u8, + // int<1> decimals 0u8, + // int<2> - unused - 0u8, 0u8, // ------------------------ // // Column Definition packet // // ------------------------ // + // int<3> length 60u8, 0u8, 0u8, + // int<1> seq_no 5u8, + // string catalog (always 'def') 3u8, b"def", + // string schema 4u8, b"test", + // string table alias 5u8, b"users", + // string table 5u8, b"users", - 12u8, b"access_level", - 12u8, b"access_level", - 12u8, + // string column alias + 0x0C_u8, b"access_level", + // string column + 0x0C_u8, b"access_level", + // int length of fixed fields (=0xC) + 0x0C_u8, + // int<2> character set number 8u8, 0u8, + // int<4> max. column size 7u8, 0u8, 0u8, 0u8, + // int<1> Field types 0xFE_u8, + // int<2> Field detail flag 1u8, 0x11_u8, + // int<1> decimals 0u8, + // int<2> - unused - 0u8, 0u8, - // ---------- // // EOF Packet // // ---------- // + // int<3> length 5u8, 0u8, 0u8, + // int<1> seq_no 6u8, + // int<1> 0xfe : EOF header 0xFE_u8, + // int<2> warning count 0u8, 0u8, + // int<2> server status 34u8, 0u8, // ----------------- // // Result Row Packet // // ----------------- // + // int<3> length 62u8, 0u8, 0u8, + // int<1> seq_no 7u8, - 32u8, b"044d3f34-af65-11e9-a2e5-0242ac110003", + // string column data + 36u8, b"044d3f34-af65-11e9-a2e5-0242ac110003", + // string column data 4u8, b"josh", - 11u8, b"password123", + // string column data + 0x0B_u8, b"password123", + // string column data 7u8, b"regular", // ----------------- // // Result Row Packet // // ----------------- // + // int<3> length 52u8, 0u8, 0u8, + // int<1> seq_no 8u8, - 32u8, b"d83dd1c4-ada9-11e9-96bc-0242ac110003", + // string column data + 36u8, b"d83dd1c4-ada9-11e9-96bc-0242ac110003", + // string column data 6u8, b"daniel", + // string column data 1u8, b"f", + // string column data 5u8, b"admin", + + // ------------- // + // OK/EOF Packet // + // ------------- // + // int<3> length 5u8, 0u8, 0u8, - 9u8, + // int<1> seq_no + 1u8, + // 0xFE: Required header for last packet of result set 0xFE_u8, + // int<2> warning count 0u8, 0u8, + // int<2> server status 34u8, 0u8 ); - conn.select_db("test").await?; - - conn.query("SELECT * FROM users").await?; - - let buf = conn.stream.next_bytes().await?; - println!("{:?}", buf); let mut ctx = DeContext::new(&mut conn.context, &buf); ResultSet::deserialize(&mut ctx)?; - - - // ------------------------ // - // Column Definition packet // - // ------------------------ // - - // ---------- // - // EOF Packet // - // ---------- // - - // ------------------- // - // N Result Row Packet // - // ------------------- // - - // ---------- // - // EOF Packet // - // ---------- // - Ok(()) } }