diff --git a/src/mariadb/protocol/decode.rs b/src/mariadb/protocol/decode.rs index 7ea7c1d9..f82df65c 100644 --- a/src/mariadb/protocol/decode.rs +++ b/src/mariadb/protocol/decode.rs @@ -7,6 +7,57 @@ use super::packets::packet_header::PacketHeader; // This is a simple wrapper around Bytes to make decoding easier // since the index is always tracked +// The decoder is used to decode mysql protocol data-types +// into the appropriate Rust type or bytes::Bytes otherwise +// There are two types of protocols: Text and Binary. +// Text protocol is used for most things, and binary is used +// only for the results of prepared statements. +// MySql Text protocol data-types: +// - byte : Fixed-length bytes +// - byte : Length-encoded bytes +// - byte : End-of-file length bytes +// - int : Fixed-length integers +// - int : Length-encoded integers +// - string : Fixed-length strings +// - string : Null-terminated strings +// - string : Length-encoded strings +// - string : End-of-file length strings +// The decoder will decode all of the Text Protocol types, and if the data-type +// is of type int<*> then the decoder will convert that into the +// appropriate Rust type. +// The second protocol (Binary) protocol data-types (these rely on knowing the type from the column definition packet): +// - DECIMAL : DECIMAL has no fixed size, so will be encoded as string. +// - DOUBLE : DOUBLE is the IEEE 754 floating-point value in Little-endian format on 8 bytes. +// - BIGINT : BIGINT is the value in Little-endian format on 8 bytes. Signed is defined by the Column field detail flag. +// - INTEGER: INTEGER is the value in Little-endian format on 4 bytes. Signed is defined by the Column field detail flag. +// - MEDIUMINT : MEDIUMINT is similar to INTEGER binary encoding, even if MEDIUM int is 3-bytes encoded server side. (Last byte will always be 0x00). +// - FLOAT : FLOAT is the IEEE 754 floating-point value in Little-endian format on 4 bytes. +// - SMALLINT : SMALLINT is the value in Little-endian format on 2 bytes. Signed is defined by the Column field detail flag. +// - YEAR : YEAR uses the same format as SMALLINT. +// - TINYINT : TINYINT is the value of 1 byte. Signed is defined by the Column field detail flag. +// - DATE : Data is encoded in 5 bytes. +// - First byte is the date length which must be 4 +// - Bytes 2-3 are the year on 2 bytes little-endian format +// - Byte 4 is the month (1=january - 12=december) +// - Byte 5 is the day of the month (0 - 31) +// - TIMESTAMP: Data is encoded in 8 bytes without fractional seconds, 12 bytes with fractional seconds. +// - Byte 1 is data length; 7 without fractional seconds, 11 with fractional seconds +// - Bytes 2-3 are the year on 2 bytes little-endian format +// - Byte 4 is the month (1=january - 12=december) +// - Byte 5 is the day of the month (0 - 31) +// - Byte 6 is the hour of day (0 if DATE type) (0-23) +// - Byte 7 is the minutes (0 if DATE type) (0-59) +// - Byte 8 is the seconds (0 if DATE type) (0-59) +// - Bytes 9-12 is the micro-second on 4 bytes little-endian format (only if data-length is > 7) (0-9999) +// - DATETIME : DATETIME uses the same format as TIMESTAMP binary encoding +// - TIME : Data is encoded in 9 bytes without fractional seconds, 13 bytes with fractional seconds. +// - Byte 1 is the data length; 8 without fractional seconds, 12 with fractional seconds +// - Byte 2 determines negativity +// - Bytes 3-6 are the date on 4 bytes little-endian format +// - Byte 6 is the hour of day (0 if DATE type) (0-23) +// - Byte 7 is the minutes (0 if DATE type) (0-59) +// - Byte 8 is the seconds (0 if DATE type) (0-59) +// - Bytes 10-13 are the micro-seconds on 4 bytes little-endian format (only if data-length is > 7) pub struct Decoder<'a> { pub buf: &'a Bytes, pub index: usize, @@ -76,24 +127,56 @@ impl<'a> Decoder<'a> { // 0xFF then there was an error. // If the first byte is not in the previous list then that byte is the int value. #[inline] - pub fn decode_int_lenenc(&mut self) -> Option { + pub fn decode_int_lenenc_signed(&mut self) -> Option { match self.buf[self.index] { 0xFB => { self.index += 1; None } 0xFC => { - let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as u64); + let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as i64); self.index += 3; value } 0xFD => { - let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as u64); + let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as i64); self.index += 4; value } 0xFE => { - let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as u64); + let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as i64); + self.index += 9; + value + } + 0xFF => panic!("int unprocessable first byte 0xFF"), + _ => { + let value = Some(self.buf[self.index] as i64); + self.index += 1; + value + } + } + } + + // This is functionally identical to the previous method, but this one returns an u64 instead + #[inline] + pub fn decode_int_lenenc_unsigned(&mut self) -> Option { + match self.buf[self.index] { + 0xFB => { + self.index += 1; + None + } + 0xFC => { + let value = Some(LittleEndian::read_u16(&self.buf[self.index + 1..]) as u64); + self.index += 3; + value + } + 0xFD => { + let value = Some(LittleEndian::read_u24(&self.buf[self.index + 1..]) as u64); + self.index += 4; + value + } + 0xFE => { + let value = Some(LittleEndian::read_u64(&self.buf[self.index + 1..]) as u64); self.index += 9; value } @@ -168,7 +251,7 @@ impl<'a> Decoder<'a> { // the length of the string, and the the following n bytes are the contents. #[inline] pub fn decode_string_lenenc(&mut self) -> Bytes { - let length = self.decode_int_lenenc().unwrap_or(0); + let length = self.decode_int_lenenc_unsigned().unwrap_or(0); let value = self.buf.slice(self.index, self.index + length as usize); self.index = self.index + length as usize; value @@ -327,12 +410,8 @@ impl<'a> Decoder<'a> { #[inline] pub fn decode_binary_time(&mut self) -> Bytes { - let length = self.decode_int_u8(); - if length != 8 && length != 12 { - panic!("Date length is not 8 or 12 (the only two possible values)"); - } - let value = self.buf.slice(self.index, self.index + length as usize); - self.index += length as usize; + let value = self.buf.slice(self.index, self.index + 13); + self.index += 13; value } @@ -363,7 +442,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fb() { let buf = __bytes_builder!(0xFB_u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, None); assert_eq!(decoder.index, 1); @@ -373,7 +452,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fc() { let buf =__bytes_builder!(0xFCu8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0x0101)); assert_eq!(decoder.index, 3); @@ -383,7 +462,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fd() { let buf = __bytes_builder!(0xFDu8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0x010101)); assert_eq!(decoder.index, 4); @@ -393,7 +472,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fe() { let buf = __bytes_builder!(0xFE_u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0x0101010101010101)); assert_eq!(decoder.index, 9); @@ -403,7 +482,7 @@ mod tests { fn it_decodes_int_lenenc_0x_fa() { let buf = __bytes_builder!(0xFA_u8); let mut decoder = Decoder::new(&buf); - let int: Option = decoder.decode_int_lenenc(); + let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0xFA)); assert_eq!(decoder.index, 1); diff --git a/src/mariadb/protocol/packets/binary/com_stmt_exec.rs b/src/mariadb/protocol/packets/binary/com_stmt_exec.rs index 99f33573..98b721ad 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_exec.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_exec.rs @@ -96,7 +96,7 @@ mod tests { table: Bytes::from_static(b"users"), column_alias: Bytes::from_static(b"username"), column: Bytes::from_static(b"username"), - length_of_fixed_fields: Some(0x0Ci64), + length_of_fixed_fields: Some(0x0Cu64), char_set: 1, max_columns: 1, field_type: FieldType::MysqlTypeString, diff --git a/src/mariadb/protocol/packets/binary/result_row.rs b/src/mariadb/protocol/packets/binary/result_row.rs index 5cad8b85..759ce720 100644 --- a/src/mariadb/protocol/packets/binary/result_row.rs +++ b/src/mariadb/protocol/packets/binary/result_row.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use crate::mariadb::{FieldType}; #[derive(Debug, Default)] pub struct ResultRow { @@ -16,24 +17,72 @@ impl crate::mariadb::Deserialize for ResultRow { let bitmap = if let Some(columns) = ctx.columns { let size = (columns + 9) / 8; - decoder.decode_byte_fix(size as usize) + Ok(decoder.decode_byte_fix(size as usize)) } else { - Bytes::new() - }; + Err(failure::err_msg("Columns were not provided; cannot deserialize binary result row")) + }?; - let row = if let Some(columns) = ctx.columns { - (0..columns).map(|index| { - if (1 << index) & (bitmap[index/8] << (index % 8)) == 0 { - None - } else { - match ctx.column_defs[index] { + let row = match (&ctx.columns, &ctx.column_defs) { + (Some(columns), Some(column_defs)) => { + (0..*columns as usize).map(|index| { + if (1 << (index % 8)) & bitmap[index / 8] as usize == 0 { + None + } else { + match column_defs[index].field_type { + // Ordered by https://mariadb.com/kb/en/library/resultset-row/#binary-resultset-row + FieldType::MysqlTypeDouble => Some(decoder.decode_binary_double()), + FieldType::MysqlTypeLonglong => Some(decoder.decode_binary_bigint()), + // Is this MYSQL_TYPE_INTEGER? + FieldType::MysqlTypeLong => Some(decoder.decode_binary_int()), + + // Is this MYSQL_TYPE_MEDIUMINTEGER? + FieldType::MysqlTypeInt24 => Some(decoder.decode_binary_mediumint()), + + FieldType::MysqlTypeFloat => Some(decoder.decode_binary_float()), + + // Is this MYSQL_TYPE_SMALLINT? + FieldType::MysqlTypeShort => Some(decoder.decode_binary_smallint()), + + FieldType::MysqlTypeYear => Some(decoder.decode_binary_year()), + FieldType::MysqlTypeTiny => Some(decoder.decode_binary_tinyint()), + FieldType::MysqlTypeDate => Some(decoder.decode_binary_date()), + FieldType::MysqlTypeTimestamp => Some(decoder.decode_binary_timestamp()), + FieldType::MysqlTypeDatetime => Some(decoder.decode_binary_datetime()), + FieldType::MysqlTypeTime => Some(decoder.decode_binary_time()), + FieldType::MysqlTypeNewdecimal => Some(decoder.decode_binary_decimal()), + + // This group of types are all encoded as byte + FieldType::MysqlTypeTinyBlob => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeMediumBlob => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeLongBlob => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeBlob => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeVarchar => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeVarString => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeString => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeGeometry => Some(decoder.decode_byte_lenenc()), + + // The following did not have defined binary encoding, so I guessed. + // Perhaps you cannot get these types back from the server if you're using + // prepared statements? In that case we should error out here instead of + // proceeding to decode. + FieldType::MysqlTypeDecimal => Some(decoder.decode_binary_decimal()), + FieldType::MysqlTypeNull => panic!("Cannot decode MysqlTypeNull"), + FieldType::MysqlTypeNewdate => Some(decoder.decode_binary_date()), + FieldType::MysqlTypeBit => Some(decoder.decode_byte_fix(1)), + FieldType::MysqlTypeTimestamp2 => Some(decoder.decode_binary_timestamp()), + FieldType::MysqlTypeDatetime2 => Some(decoder.decode_binary_datetime()), + FieldType::MysqlTypeTime2 => Some(decoder.decode_binary_time()), + FieldType::MysqlTypeJson => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeEnum => Some(decoder.decode_byte_lenenc()), + FieldType::MysqlTypeSet => Some(decoder.decode_byte_lenenc()), + + + } } - decoder.decode_binary_column(&ctx.column_defs) - } - }).collect::>() - } else { - Vec::new() + }).collect::>>() + }, + _ => Vec::new(), }; Ok(ResultRow::default()) diff --git a/src/mariadb/protocol/packets/column.rs b/src/mariadb/protocol/packets/column.rs index 763b9617..fedd1e98 100644 --- a/src/mariadb/protocol/packets/column.rs +++ b/src/mariadb/protocol/packets/column.rs @@ -19,7 +19,7 @@ impl Deserialize for ColumnPacket { let length = decoder.decode_length()?; let seq_no = decoder.decode_int_u8(); - let columns = decoder.decode_int_lenenc(); + let columns = decoder.decode_int_lenenc_unsigned(); Ok(ColumnPacket { length, seq_no, columns }) } diff --git a/src/mariadb/protocol/packets/column_def.rs b/src/mariadb/protocol/packets/column_def.rs index 695f287f..095bb300 100644 --- a/src/mariadb/protocol/packets/column_def.rs +++ b/src/mariadb/protocol/packets/column_def.rs @@ -40,7 +40,7 @@ impl Deserialize for ColumnDefPacket { // string column let column = decoder.decode_string_lenenc(); // int length of fixed fields (=0xC) - let length_of_fixed_fields = decoder.decode_int_lenenc(); + let length_of_fixed_fields = decoder.decode_int_lenenc_unsigned(); // int<2> character set number let char_set = decoder.decode_int_i16(); // int<4> max. column size diff --git a/src/mariadb/protocol/packets/ok.rs b/src/mariadb/protocol/packets/ok.rs index f804fa0e..b9b35e2a 100644 --- a/src/mariadb/protocol/packets/ok.rs +++ b/src/mariadb/protocol/packets/ok.rs @@ -9,7 +9,7 @@ pub struct OkPacket { pub length: u32, pub seq_no: u8, pub affected_rows: Option, - pub last_insert_id: Option, + pub last_insert_id: Option, pub server_status: ServerStatusFlag, pub warning_count: i16, pub info: Bytes, @@ -34,8 +34,8 @@ impl Deserialize for OkPacket { return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket")); } - let affected_rows = decoder.decode_int_lenenc(); - let last_insert_id = decoder.decode_int_lenenc(); + let affected_rows = decoder.decode_int_lenenc_unsigned(); + let last_insert_id = decoder.decode_int_lenenc_signed(); let server_status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into()); let warning_count = decoder.decode_int_i16();