diff --git a/mason-mariadb/src/protocol/decode.rs b/mason-mariadb/src/protocol/decode.rs index 16466eaf..19acd61c 100644 --- a/mason-mariadb/src/protocol/decode.rs +++ b/mason-mariadb/src/protocol/decode.rs @@ -144,9 +144,13 @@ impl<'a> Decoder<'a> { } #[inline] - pub fn decode_string_eof(&mut self, length: usize) -> Bytes { - let value = self.buf.slice(self.index, if length >= self.index { - length + pub fn decode_string_eof(&mut self, length: Option) -> Bytes { + let value = self.buf.slice(self.index, if let Some(len) = length { + if len >= self.index { + len + } else { + self.buf.len() + } } else { self.buf.len() }); @@ -181,9 +185,13 @@ impl<'a> Decoder<'a> { } #[inline] - pub fn decode_byte_eof(&mut self, length: usize) -> Bytes { - let value = self.buf.slice(self.index, if length >= self.index { - length + pub fn decode_byte_eof(&mut self, length: Option) -> Bytes { + let value = self.buf.slice(self.index, if let Some(len) = length { + if len >= self.index { + len + } else { + self.buf.len() + } } else { self.buf.len() }); @@ -199,7 +207,7 @@ mod tests { use super::*; -// [X] it_decodes_int_lenenc + // [X] it_decodes_int_lenenc // [X] it_decodes_int_8 // [X] it_decodes_int_4 // [X] it_decodes_int_3 @@ -338,7 +346,7 @@ mod tests { fn it_decodes_string_eof() { let buf = Bytes::from(b"\x01".to_vec()); let mut decoder = Decoder::new(&buf); - let string: Bytes = decoder.decode_string_eof(0); + let string: Bytes = decoder.decode_string_eof(None); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -375,9 +383,9 @@ mod tests { fn it_decodes_byte_eof() { let buf = Bytes::from(b"\x01".to_vec()); let mut decoder = Decoder::new(&buf); - let string: Bytes = decoder.decode_byte_eof(0); + let string: Bytes = decoder.decode_byte_eof(None); - assert_eq!(string[0], b'\x01'); + assert_eq!(&string[..], b"\x01"); assert_eq!(string.len(), 1); assert_eq!(decoder.index, 1); } diff --git a/mason-mariadb/src/protocol/packets/column.rs b/mason-mariadb/src/protocol/packets/column.rs index 06901c91..6bb40d0c 100644 --- a/mason-mariadb/src/protocol/packets/column.rs +++ b/mason-mariadb/src/protocol/packets/column.rs @@ -45,9 +45,14 @@ mod test { }) .await?; + #[rustfmt::skip] let buf = __bytes_builder!( - // int tag code: None - 0xFB_u8 + // int<3> length + 0u8, 0u8, 0u8, + // int<1> seq_no + 0u8, + // int tag code: None + 0xFB_u8 ); let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; @@ -68,10 +73,16 @@ mod test { }) .await?; + #[rustfmt::skip] let buf = __bytes_builder!( - // int tag code: Some(3 bytes) - 0xFD_u8, // value: 3 bytes - 0x01_u8, 0x01_u8, 0x01_u8 + // int<3> length + 0u8, 0u8, 0u8, + // int<1> seq_no + 0u8, + // int tag code: Some(3 bytes) + 0xFD_u8, + // value: 3 bytes + 0x01_u8, 0x01_u8, 0x01_u8 ); let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; @@ -94,10 +105,14 @@ mod test { #[rustfmt::skip] let buf = __bytes_builder!( - // int tag code: Some(3 bytes) - 0xFC_u8, - // value: 2 bytes - 0x01_u8, 0x01_u8 + // int<3> length + 0u8, 0u8, 0u8, + // int<1> seq_no + 0u8, + // int tag code: Some(3 bytes) + 0xFC_u8, + // value: 2 bytes + 0x01_u8, 0x01_u8 ); let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; diff --git a/mason-mariadb/src/protocol/packets/eof.rs b/mason-mariadb/src/protocol/packets/eof.rs index 08de8ce9..f8c9c813 100644 --- a/mason-mariadb/src/protocol/packets/eof.rs +++ b/mason-mariadb/src/protocol/packets/eof.rs @@ -25,9 +25,9 @@ impl Deserialize for EofPacket { let packet_header = decoder.decode_int_1(); -// if packet_header != 0xFE { -// panic!("Packet header is not 0xFE for ErrPacket"); -// } + if packet_header != 0xFE { + panic!("Packet header is not 0xFE for ErrPacket"); + } let warning_count = decoder.decode_int_2(); let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2()); @@ -56,16 +56,16 @@ mod test { #[rustfmt::skip] let buf = __bytes_builder!( - // int<3> length - 1u8, 0u8, 0u8, - // int<1> seq_no - 1u8, - // int<1> 0xfe : EOF header - 0xFE_u8, - // int<2> warning count - 0u8, 0u8, - // int<2> server status - 1u8, 1u8 + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 1u8, + // int<1> 0xfe : EOF header + 0xFE_u8, + // int<2> warning count + 0u8, 0u8, + // int<2> server status + 1u8, 1u8 ); let buf = Bytes::from_static(b"\x01\0\0\x01\xFE\x00\x00\x01\x00"); diff --git a/mason-mariadb/src/protocol/packets/err.rs b/mason-mariadb/src/protocol/packets/err.rs index 1e61b03e..b6bfce4a 100644 --- a/mason-mariadb/src/protocol/packets/err.rs +++ b/mason-mariadb/src/protocol/packets/err.rs @@ -54,9 +54,9 @@ impl Deserialize for ErrPacket { if decoder.buf[decoder.index] == b'#' { sql_state_marker = Some(decoder.decode_string_fix(1)); sql_state = Some(decoder.decode_string_fix(5)); - error_message = Some(decoder.decode_string_eof(length as usize)); + error_message = Some(decoder.decode_string_eof(Some(length as usize))); } else { - error_message = Some(decoder.decode_string_eof(length as usize)); + error_message = Some(decoder.decode_string_eof(Some(length as usize))); } } diff --git a/mason-mariadb/src/protocol/packets/ok.rs b/mason-mariadb/src/protocol/packets/ok.rs index eeb213a6..224d1bad 100644 --- a/mason-mariadb/src/protocol/packets/ok.rs +++ b/mason-mariadb/src/protocol/packets/ok.rs @@ -45,7 +45,7 @@ impl Deserialize for OkPacket { let session_state_info = None; let value = None; - let info = decoder.decode_byte_eof(index + length as usize); + let info = decoder.decode_byte_eof(Some(index + length as usize)); Ok(OkPacket { length, @@ -82,30 +82,30 @@ mod test { #[rustfmt::skip] let buf = __bytes_builder!( - // int<3> length - 0u8, 0u8, 0u8, - // // int<1> seq_no - 1u8, - // 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set) - 0u8, - // int affected rows - 0xFB_u8, - // int last insert id - 0xFB_u8, - // int<2> server status - 1u8, 1u8, - // int<2> warning count - 0u8, 0u8, - // if session_tracking_supported (see CLIENT_SESSION_TRACK) { - // string info - // if (status flags & SERVER_SESSION_STATE_CHANGED) { - // string session state info - // string value of variable - // } - // } else { - // string info - b"info" - // } + // int<3> length + 0u8, 0u8, 0u8, + // // int<1> seq_no + 1u8, + // 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set) + 0u8, + // int affected rows + 0xFB_u8, + // int last insert id + 0xFB_u8, + // int<2> server status + 1u8, 1u8, + // int<2> warning count + 0u8, 0u8, + // if session_tracking_supported (see CLIENT_SESSION_TRACK) { + // string info + // if (status flags & SERVER_SESSION_STATE_CHANGED) { + // string session state info + // string value of variable + // } + // } else { + // string info + b"info" + // } ); let message = OkPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; diff --git a/mason-mariadb/src/protocol/packets/result_set.rs b/mason-mariadb/src/protocol/packets/result_set.rs index 7f8acf01..0a7c74ab 100644 --- a/mason-mariadb/src/protocol/packets/result_set.rs +++ b/mason-mariadb/src/protocol/packets/result_set.rs @@ -98,7 +98,7 @@ mod test { // TODO: Use byte string as input for test; this is a valid return from a mariadb. #[rustfmt::skip] - let buf = __bytes_builder!( + let buf = __bytes_builder!( // ------------------- // // Column Count packet // // ------------------- //