From 3ddf4508af027ec58d5a8fcc73175dc57a198053 Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Mon, 29 Jul 2019 19:16:46 -0700 Subject: [PATCH] Update decode to decode into signed types --- src/mariadb/protocol/decode.rs | 77 ++++++++++++++-------- src/mariadb/protocol/deserialize.rs | 2 +- src/mariadb/protocol/error_codes.rs | 2 +- src/mariadb/protocol/packets/column.rs | 2 +- src/mariadb/protocol/packets/column_def.rs | 8 +-- src/mariadb/protocol/packets/eof.rs | 4 +- src/mariadb/protocol/packets/err.rs | 2 +- src/mariadb/protocol/packets/initial.rs | 8 +-- src/mariadb/protocol/packets/ok.rs | 8 +-- src/mariadb/protocol/types.rs | 2 +- 10 files changed, 67 insertions(+), 48 deletions(-) diff --git a/src/mariadb/protocol/decode.rs b/src/mariadb/protocol/decode.rs index be8a630e..c8751bcd 100644 --- a/src/mariadb/protocol/decode.rs +++ b/src/mariadb/protocol/decode.rs @@ -22,7 +22,8 @@ impl<'a> Decoder<'a> { // Length is the first 3 bytes of the packet in little endian format #[inline] pub fn decode_length(&mut self) -> Result { - let length = self.decode_int_3(); + let length: u32 = (self.buf[self.index] as u32) + ((self.buf[self.index + 1] as u32) << 8) + ((self.buf[self.index + 2] as u32) << 16); + self.index += 3; if self.buf.len() - self.index < length as usize { return Err(err_msg("Lengths to do not match when decoding length")); @@ -69,69 +70,87 @@ impl<'a> Decoder<'a> { // The first byte of the int determines the length of the int. // If the first byte is // 0xFB then the int is "NULL" or None in Rust terms. - // 0xFC then the following 2 bytes are the int value u16. - // 0xFD then the following 3 bytes are the int value u24. - // 0xFE then the following 8 bytes are teh int value u64. + // 0xFC then the following 2 bytes are the int value i16. + // 0xFD then the following 3 bytes are the int value i24. + // 0xFE then the following 8 bytes are teh int value i64. // 0xFF then there was an error. // If the first byte is not in the previous list then that byte is the int value. #[inline] - pub fn decode_int_lenenc(&mut self) -> Option { + pub fn decode_int_lenenc(&mut self) -> Option { match self.buf[self.index] { 0xFB => { self.index += 1; None } 0xFC => { - let value = Some(LittleEndian::read_u16(&self.buf[self.index + 1..]) as usize); + let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as i64); self.index += 3; value } 0xFD => { - let value = Some(LittleEndian::read_u24(&self.buf[self.index + 1..]) as usize); + let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as i64); self.index += 4; value } 0xFE => { - let value = Some(LittleEndian::read_u64(&self.buf[self.index + 1..]) as usize); + let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as i64); self.index += 9; value } 0xFF => panic!("int unprocessable first byte 0xFF"), _ => { - let value = Some(self.buf[self.index] as usize); + let value = Some(self.buf[self.index] as i64); self.index += 1; value } } } - // Decode an int<8> which is a u64 + // Decode an int<8> which is a i64 #[inline] - pub fn decode_int_8(&mut self) -> u64 { - let value = LittleEndian::read_u64(&self.buf[self.index..]); + pub fn decode_int_8(&mut self) -> i64 { + let value = LittleEndian::read_i64(&self.buf[self.index..]); self.index += 8; value } - // Decode an int<4> which is a u32 + // Decode an int<4> which is a i32 #[inline] - pub fn decode_int_4(&mut self) -> u32 { + pub fn decode_int_4(&mut self) -> i32 { + let value = LittleEndian::read_i32(&self.buf[self.index..]); + self.index += 4; + value + } + + // Decode an int<4> which is a i32 + // This is a helper method for decoding flags. + #[inline] + pub fn decode_int_4_unsigned(&mut self) -> u32 { let value = LittleEndian::read_u32(&self.buf[self.index..]); self.index += 4; value } - // Decode an int<3> which is a u24 + // Decode an int<3> which is a i24 #[inline] - pub fn decode_int_3(&mut self) -> u32 { - let value = LittleEndian::read_u24(&self.buf[self.index..]); + pub fn decode_int_3(&mut self) -> i32 { + let value = LittleEndian::read_i24(&self.buf[self.index..]); self.index += 3; value } - // Decode an int<2> which is a u16 + // Decode an int<2> which is a i16 #[inline] - pub fn decode_int_2(&mut self) -> u16 { + pub fn decode_int_2(&mut self) -> i16 { + let value = LittleEndian::read_i16(&self.buf[self.index..]); + self.index += 2; + value + } + + // Decode an int<2> as an u16 + // This is a helper method for decoding flags. + #[inline] + pub fn decode_int_2_unsigned(&mut self) -> u16 { let value = LittleEndian::read_u16(&self.buf[self.index..]); self.index += 2; value @@ -149,7 +168,7 @@ impl<'a> Decoder<'a> { // the length of the string, and the the following n bytes are the contents. #[inline] pub fn decode_string_lenenc(&mut self) -> Bytes { - let length = self.decode_int_lenenc().unwrap_or(0usize); + let length = self.decode_int_lenenc().unwrap_or(0); let value = self.buf.slice(self.index, self.index + length as usize); self.index = self.index + length as usize; value @@ -250,7 +269,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fb() { let buf = __bytes_builder!(0xFB_u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, None); assert_eq!(decoder.index, 1); @@ -260,7 +279,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fc() { let buf =__bytes_builder!(0xFCu8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(0x0101)); assert_eq!(decoder.index, 3); @@ -270,7 +289,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fd() { let buf = __bytes_builder!(0xFDu8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(0x010101)); assert_eq!(decoder.index, 4); @@ -280,7 +299,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fe() { let buf = __bytes_builder!(0xFE_u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(0x0101010101010101)); assert_eq!(decoder.index, 9); @@ -290,7 +309,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fa() { let buf = __bytes_builder!(0xFA_u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(0xFA)); assert_eq!(decoder.index, 1); @@ -300,7 +319,7 @@ mod tests { fn it_decodes_int_8() { let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: u64 = decoder.decode_int_8(); + let int: i64 = decoder.decode_int_8(); assert_eq!(int, 0x0101010101010101); assert_eq!(decoder.index, 8); @@ -310,7 +329,7 @@ mod tests { fn it_decodes_int_4() { let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: u32 = decoder.decode_int_4(); + let int: i32 = decoder.decode_int_4(); assert_eq!(int, 0x01010101); assert_eq!(decoder.index, 4); @@ -320,7 +339,7 @@ mod tests { fn it_decodes_int_3() { let buf = __bytes_builder!(1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: u32 = decoder.decode_int_3(); + let int: i32 = decoder.decode_int_3(); assert_eq!(int, 0x010101); assert_eq!(decoder.index, 3); @@ -330,7 +349,7 @@ mod tests { fn it_decodes_int_2() { let buf = __bytes_builder!(1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: u16 = decoder.decode_int_2(); + let int: i16 = decoder.decode_int_2(); assert_eq!(int, 0x0101); assert_eq!(decoder.index, 2); diff --git a/src/mariadb/protocol/deserialize.rs b/src/mariadb/protocol/deserialize.rs index 22a48177..1ddb0bd6 100644 --- a/src/mariadb/protocol/deserialize.rs +++ b/src/mariadb/protocol/deserialize.rs @@ -10,7 +10,7 @@ use failure::Error; pub struct DeContext<'a> { pub conn: &'a mut ConnContext, pub decoder: Decoder<'a>, - pub columns: Option, + pub columns: Option, } impl<'a> DeContext<'a> { diff --git a/src/mariadb/protocol/error_codes.rs b/src/mariadb/protocol/error_codes.rs index a135130e..819ffd78 100644 --- a/src/mariadb/protocol/error_codes.rs +++ b/src/mariadb/protocol/error_codes.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; #[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)] -#[TryFromPrimitiveType = "u16"] +#[TryFromPrimitiveType = "i16"] pub enum ErrorCode { ErDefault = 0, ErHashchk = 1000, diff --git a/src/mariadb/protocol/packets/column.rs b/src/mariadb/protocol/packets/column.rs index 3efa698b..85f32c67 100644 --- a/src/mariadb/protocol/packets/column.rs +++ b/src/mariadb/protocol/packets/column.rs @@ -9,7 +9,7 @@ use crate::mariadb::{DeContext, Deserialize}; pub struct ColumnPacket { pub length: u32, pub seq_no: u8, - pub columns: Option, + pub columns: Option, } impl Deserialize for ColumnPacket { diff --git a/src/mariadb/protocol/packets/column_def.rs b/src/mariadb/protocol/packets/column_def.rs index a8524a5e..5283be6a 100644 --- a/src/mariadb/protocol/packets/column_def.rs +++ b/src/mariadb/protocol/packets/column_def.rs @@ -13,9 +13,9 @@ pub struct ColumnDefPacket { pub table: Bytes, pub column_alias: Bytes, pub column: Bytes, - pub length_of_fixed_fields: Option, - pub char_set: u16, - pub max_columns: u32, + pub length_of_fixed_fields: Option, + pub char_set: i16, + pub max_columns: i32, pub field_type: FieldType, pub field_details: FieldDetailFlag, pub decimals: u8, @@ -48,7 +48,7 @@ impl Deserialize for ColumnDefPacket { // int<1> Field types let field_type = FieldType::try_from(decoder.decode_int_1())?; // int<2> Field detail flag - let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_2()); + let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_2_unsigned()); // int<1> decimals let decimals = decoder.decode_int_1(); // int<2> - unused - diff --git a/src/mariadb/protocol/packets/eof.rs b/src/mariadb/protocol/packets/eof.rs index 94c6840f..2a7330e6 100644 --- a/src/mariadb/protocol/packets/eof.rs +++ b/src/mariadb/protocol/packets/eof.rs @@ -9,7 +9,7 @@ use std::convert::TryFrom; pub struct EofPacket { pub length: u32, pub seq_no: u8, - pub warning_count: u16, + pub warning_count: i16, pub status: ServerStatusFlag, } @@ -27,7 +27,7 @@ impl Deserialize for EofPacket { } let warning_count = decoder.decode_int_2(); - let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2()); + let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned()); Ok(EofPacket { length, seq_no, warning_count, status }) } diff --git a/src/mariadb/protocol/packets/err.rs b/src/mariadb/protocol/packets/err.rs index 424d1d67..24eab35c 100644 --- a/src/mariadb/protocol/packets/err.rs +++ b/src/mariadb/protocol/packets/err.rs @@ -13,7 +13,7 @@ pub struct ErrPacket { pub error_code: ErrorCode, pub stage: Option, pub max_stage: Option, - pub progress: Option, + pub progress: Option, pub progress_info: Option, pub sql_state_marker: Option, pub sql_state: Option, diff --git a/src/mariadb/protocol/packets/initial.rs b/src/mariadb/protocol/packets/initial.rs index 953140ec..4d0b5bd6 100644 --- a/src/mariadb/protocol/packets/initial.rs +++ b/src/mariadb/protocol/packets/initial.rs @@ -9,7 +9,7 @@ pub struct InitialHandshakePacket { pub seq_no: u8, pub protocol_version: u8, pub server_version: Bytes, - pub connection_id: u32, + pub connection_id: i32, pub auth_seed: Bytes, pub capabilities: Capabilities, pub collation: u8, @@ -37,10 +37,10 @@ impl Deserialize for InitialHandshakePacket { // Skip reserved byte decoder.skip_bytes(1); - let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_2().into()); + let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_2_unsigned().into()); let collation = decoder.decode_int_1(); - let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2().into()); + let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned().into()); capabilities |= Capabilities::from_bits_truncate(((decoder.decode_int_2() as u32) << 16).into()); @@ -58,7 +58,7 @@ impl Deserialize for InitialHandshakePacket { if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() { capabilities |= - Capabilities::from_bits_truncate(((decoder.decode_int_4() as u128) << 32).into()); + Capabilities::from_bits_truncate(((decoder.decode_int_4_unsigned() as u128) << 32).into()); } else { // Skip filler decoder.skip_bytes(4); diff --git a/src/mariadb/protocol/packets/ok.rs b/src/mariadb/protocol/packets/ok.rs index 15cfcea2..26341397 100644 --- a/src/mariadb/protocol/packets/ok.rs +++ b/src/mariadb/protocol/packets/ok.rs @@ -8,10 +8,10 @@ use crate::mariadb::{DeContext, Deserialize, ServerStatusFlag, pub struct OkPacket { pub length: u32, pub seq_no: u8, - pub affected_rows: Option, - pub last_insert_id: Option, + pub affected_rows: Option, + pub last_insert_id: Option, pub server_status: ServerStatusFlag, - pub warning_count: u16, + pub warning_count: i16, pub info: Bytes, pub session_state_info: Option, pub value: Option, @@ -36,7 +36,7 @@ impl Deserialize for OkPacket { let affected_rows = decoder.decode_int_lenenc(); let last_insert_id = decoder.decode_int_lenenc(); - let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2().into()); + let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2_unsigned().into()); let warning_count = decoder.decode_int_2(); // Assuming CLIENT_SESSION_TRACK is unsupported diff --git a/src/mariadb/protocol/types.rs b/src/mariadb/protocol/types.rs index c5d51f31..6cb4fd8d 100644 --- a/src/mariadb/protocol/types.rs +++ b/src/mariadb/protocol/types.rs @@ -144,6 +144,6 @@ mod test { fn it_decodes_capabilities() { let buf = Bytes::from(b"\xfe\xf7".to_vec()); let mut decoder = Decoder::new(&buf); - Capabilities::from_bits_truncate(decoder.decode_int_2().into()); + Capabilities::from_bits_truncate(decoder.decode_int_2_unsigned().into()); } }