From 6949b35eeccba3dc60ebc3588f17396b6087061c Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Thu, 5 Sep 2019 15:20:41 -0700 Subject: [PATCH] More work towards mariadb protocol refactor --- src/mariadb/io/buf_ext.rs | 8 ++ src/mariadb/protocol/field.rs | 21 +++ src/mariadb/protocol/mod.rs | 10 +- src/mariadb/protocol/response/column_count.rs | 47 +++++++ src/mariadb/protocol/response/column_def.rs | 127 ++++++++++++++++++ src/mariadb/protocol/response/err.rs | 111 +++++++++++++++ src/mariadb/protocol/response/mod.rs | 12 +- src/mariadb/protocol/response/row.rs | 68 ++++++++++ 8 files changed, 395 insertions(+), 9 deletions(-) create mode 100644 src/mariadb/protocol/response/column_count.rs create mode 100644 src/mariadb/protocol/response/column_def.rs diff --git a/src/mariadb/io/buf_ext.rs b/src/mariadb/io/buf_ext.rs index f3364826..c0b33d86 100644 --- a/src/mariadb/io/buf_ext.rs +++ b/src/mariadb/io/buf_ext.rs @@ -3,6 +3,7 @@ use byteorder::ByteOrder; use std::io; pub trait BufExt { + fn get_uint(&mut self, n: usize) -> io::Result; fn get_uint_lenenc(&mut self) -> io::Result>; fn get_str_eof(&mut self) -> io::Result<&str>; fn get_str_lenenc(&mut self) -> io::Result>; @@ -10,6 +11,13 @@ pub trait BufExt { } impl<'a> BufExt for &'a [u8] { + fn get_uint(&mut self, n: usize) -> io::Result { + let val = T::read_uint(*self, n); + self.advance(n); + + Ok(val) + } + fn get_uint_lenenc(&mut self) -> io::Result> { Ok(match self.get_u8()? { 0xFB => None, diff --git a/src/mariadb/protocol/field.rs b/src/mariadb/protocol/field.rs index 680ca60f..f5311e54 100644 --- a/src/mariadb/protocol/field.rs +++ b/src/mariadb/protocol/field.rs @@ -42,3 +42,24 @@ bitflags::bitflags! { const UNSIGNED = 128; } } + +// https://mariadb.com/kb/en/library/resultset/#field-detail-flag +bitflags::bitflags! { + pub struct FieldDetailFlag: u16 { + const NOT_NULL = 1; + const PRIMARY_KEY = 2; + const UNIQUE_KEY = 4; + const MULTIPLE_KEY = 8; + const BLOB = 16; + const UNSIGNED = 32; + const ZEROFILL_FLAG = 64; + const BINARY_COLLATION = 128; + const ENUM = 256; + const AUTO_INCREMENT = 512; + const TIMESTAMP = 1024; + const SET = 2048; + const NO_DEFAULT_VALUE_FLAG = 4096; + const ON_UPDATE_NOW_FLAG = 8192; + const NUM_FLAG = 32768; + } +} diff --git a/src/mariadb/protocol/mod.rs b/src/mariadb/protocol/mod.rs index 5703ad1a..24126a9e 100644 --- a/src/mariadb/protocol/mod.rs +++ b/src/mariadb/protocol/mod.rs @@ -6,17 +6,17 @@ mod connect; mod encode; mod error_code; mod field; -mod server_status; mod response; +mod server_status; pub use capabilities::Capabilities; pub use connect::{ AuthenticationSwitchRequest, HandshakeResponsePacket, InitialHandshakePacket, SslRequest, }; -pub use response::{ - OkPacket, EofPacket, ErrPacket, ResultRow, -}; pub use encode::Encode; pub use error_code::ErrorCode; -pub use field::{FieldType, ParameterFlag}; +pub use field::{FieldDetailFlag, FieldType, ParameterFlag}; +pub use response::{ + ColumnCountPacket, ColumnDefinitionPacket, EofPacket, ErrPacket, OkPacket, ResultRow, +}; pub use server_status::ServerStatusFlag; diff --git a/src/mariadb/protocol/response/column_count.rs b/src/mariadb/protocol/response/column_count.rs new file mode 100644 index 00000000..54535783 --- /dev/null +++ b/src/mariadb/protocol/response/column_count.rs @@ -0,0 +1,47 @@ +use crate::mariadb::io::BufExt; +use byteorder::LittleEndian; +use std::io; + +// The column packet is the first packet of a result set. +// Inside of it it contains the number of columns in the result set +// encoded as an int. +// https://mariadb.com/kb/en/library/resultset/#column-count-packet +#[derive(Debug)] +pub struct ColumnCountPacket { + pub columns: u64, +} + +impl ColumnCountPacket { + fn decode(mut buf: &[u8]) -> io::Result { + let columns = buf.get_uint_lenenc::()?.unwrap_or(0); + + Ok(Self { columns }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::__bytes_builder; + + #[test] + fn it_decodes_column_packet_0x_fb() -> io::Result<()> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 0u8, 0u8, 0u8, + // int<1> seq_no + 0u8, + // int tag code: Some(3 bytes) + 0xFD_u8, + // value: 3 bytes + 0x01_u8, 0x01_u8, 0x01_u8 + ); + + let message = ColumnCountPacket::decode(&buf)?; + + assert_eq!(message.columns, Some(0x010101)); + + Ok(()) + } +} diff --git a/src/mariadb/protocol/response/column_def.rs b/src/mariadb/protocol/response/column_def.rs new file mode 100644 index 00000000..0b415c9d --- /dev/null +++ b/src/mariadb/protocol/response/column_def.rs @@ -0,0 +1,127 @@ +use crate::{ + io::Buf, + mariadb::{ + io::BufExt, + protocol::{FieldDetailFlag, FieldType}, + }, +}; +use byteorder::LittleEndian; +use std::io; + +#[derive(Debug)] +// ColumnDefinitionPacket doesn't have a packet header because +// it's nested inside a result set packet +pub struct ColumnDefinitionPacket { + pub schema: Option, + pub table_alias: Option, + pub table: Option, + pub column_alias: Option, + pub column: Option, + pub char_set: u16, + pub max_columns: i32, + pub field_type: FieldType, + pub field_details: FieldDetailFlag, + pub decimals: u8, +} + +impl ColumnDefinitionPacket { + fn decode(mut buf: &[u8]) -> io::Result { + // string catalog (always 'def') + let _catalog = buf.get_str_lenenc::()?; + // TODO: Assert that this is always DEF + + // string schema + let schema = buf.get_str_lenenc::()?.map(ToOwned::to_owned); + // string table alias + let table_alias = buf.get_str_lenenc::()?.map(ToOwned::to_owned); + // string table + let table = buf.get_str_lenenc::()?.map(ToOwned::to_owned); + // string column alias + let column_alias = buf.get_str_lenenc::()?.map(ToOwned::to_owned); + // string column + let column = buf.get_str_lenenc::()?.map(ToOwned::to_owned); + + // int length of fixed fields (=0xC) + let _length_of_fixed_fields = buf.get_uint_lenenc::()?; + // TODO: Assert that this is always 0xC + + // int<2> character set number + let char_set = buf.get_u16::()?; + // int<4> max. column size + let max_columns = buf.get_i32::()?; + // int<1> Field types + let field_type = FieldType(buf.get_u8()?); + // int<2> Field detail flag + let field_details = FieldDetailFlag::from_bits_truncate(buf.get_u16::()?); + // int<1> decimals + let decimals = buf.get_u8()?; + // int<2> - unused - + buf.advance(2); + + Ok(Self { + schema, + table_alias, + table, + column_alias, + column, + char_set, + max_columns, + field_type, + field_details, + decimals, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::__bytes_builder; + + #[test] + fn it_decodes_column_def_packet() -> io::Result<()> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // length + 1u8, 0u8, 0u8, + // seq_no + 0u8, + // string catalog (always 'def') + 1u8, b'a', + // string schema + 1u8, b'b', + // string table alias + 1u8, b'c', + // string table + 1u8, b'd', + // string column alias + 1u8, b'e', + // string column + 1u8, b'f', + // int length of fixed fields (=0xC) + 0xFC_u8, 1u8, 1u8, + // int<2> character set number + 1u8, 1u8, + // int<4> max. column size + 1u8, 1u8, 1u8, 1u8, + // int<1> Field types + 1u8, + // int<2> Field detail flag + 1u8, 0u8, + // int<1> decimals + 1u8, + // int<2> - unused - + 0u8, 0u8 + ); + + let message = ColumnDefinitionPacket::decode(&buf)?; + + assert_eq!(message.schema, Some(b"b")); + assert_eq!(message.table_alias, Some(b"c")); + assert_eq!(message.table, Some(b"d")); + assert_eq!(message.column_alias, Some(b"e")); + assert_eq!(message.column, Some(b"f")); + + Ok(()) + } +} diff --git a/src/mariadb/protocol/response/err.rs b/src/mariadb/protocol/response/err.rs index e69de29b..607a5b2b 100644 --- a/src/mariadb/protocol/response/err.rs +++ b/src/mariadb/protocol/response/err.rs @@ -0,0 +1,111 @@ +use crate::{ + io::Buf, + mariadb::{io::BufExt, protocol::ErrorCode}, +}; +use byteorder::LittleEndian; +use std::io; + +#[derive(Debug)] +pub enum ErrPacket { + Progress { + stage: u8, + max_stage: u8, + progress: u32, + info: Box, + }, + + Error { + code: ErrorCode, + sql_state: Option>, + message: Box, + }, +} + +impl ErrPacket { + fn decode(mut buf: &[u8]) -> io::Result { + let header = buf.get_u8()?; + debug_assert_eq!(header, 0xFF); + + // error code : int<2> + let code = buf.get_u16::()?; + + // if (errorcode == 0xFFFF) /* progress reporting */ + if code == 0xFF_FF { + let stage = buf.get_u8()?; + let max_stage = buf.get_u8()?; + let progress = buf.get_u24::()?; + let info = buf + .get_str_lenenc::()? + .unwrap_or_default() + .into(); + + Ok(Self::Progress { + stage, + max_stage, + progress, + info, + }) + } else { + // if (next byte = '#') + let sql_state = if buf[0] == b'#' { + // '#' : string<1> + buf.advance(1); + + // sql state : string<5> + Some(buf.get_str(5)?.into()) + } else { + None + }; + + let message = buf.get_str_eof()?.into(); + + Ok(Self::Error { + code: ErrorCode(code), + sql_state, + message, + }) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::__bytes_builder; + + #[test] + fn it_decodes_err_packet() -> Result<(), Error> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 1u8, + // int<1> 0xfe : EOF header + 0xFF_u8, + // int<2> error code + 0x84_u8, 0x04_u8, + // if (errorcode == 0xFFFF) /* progress reporting */ { + // int<1> stage + // int<1> max_stage + // int<3> progress + // string progress_info + // } else { + // if (next byte = '#') { + // string<1> sql state marker '#' + b"#", + // string<5>sql state + b"08S01", + // string error message + b"Got packets out of order" + // } else { + // string error message + // } + // } + ); + + let _message = ErrPacket::decode(&buf)?; + + Ok(()) + } +} diff --git a/src/mariadb/protocol/response/mod.rs b/src/mariadb/protocol/response/mod.rs index ade72c2a..fd2be74c 100644 --- a/src/mariadb/protocol/response/mod.rs +++ b/src/mariadb/protocol/response/mod.rs @@ -1,9 +1,13 @@ -mod ok; -mod err; +mod column_count; +mod column_def; mod eof; +mod err; +mod ok; mod row; -pub use ok::OkPacket; -pub use err::ErrPacket; +pub use column_count::ColumnCountPacket; +pub use column_def::ColumnDefinitionPacket; pub use eof::EofPacket; +pub use err::ErrPacket; +pub use ok::OkPacket; pub use row::ResultRow; diff --git a/src/mariadb/protocol/response/row.rs b/src/mariadb/protocol/response/row.rs index e69de29b..a8a7c5df 100644 --- a/src/mariadb/protocol/response/row.rs +++ b/src/mariadb/protocol/response/row.rs @@ -0,0 +1,68 @@ +use crate::{ + io::Buf, + mariadb::{ + io::BufExt, + protocol::{ColumnDefinitionPacket, FieldType}, + }, +}; +use byteorder::LittleEndian; +use std::{io, pin::Pin, ptr::NonNull}; + +/// A resultset row represents a database resultset unit, which is usually generated by +/// executing a statement that queries the database. +#[derive(Debug)] +pub struct ResultRow { + #[used] + buffer: Pin>, + pub values: Box<[Option>]>, +} + +// SAFE: Raw pointers point to pinned memory inside the struct +unsafe impl Send for ResultRow {} +unsafe impl Sync for ResultRow {} + +impl ResultRow { + pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> io::Result { + // 0x00 header : byte<1> + let header = buf.get_u8()?; + // TODO: Replace with InvalidData err + debug_assert_eq!(header, 0); + + // NULL-Bitmap : byte<(number_of_columns + 9) / 8> + let null = buf.get_uint::((columns.len() + 9) / 8)?; + + let buffer: Pin> = Pin::new(buf.into()); + let mut buf = &*buffer; + + let mut values = Vec::with_capacity(columns.len()); + + for column_idx in 0..columns.len() { + if (null & (1 << column_idx)) != 0 { + values.push(None); + } else { + match columns[column_idx].field_type { + FieldType::MYSQL_TYPE_LONG => { + values.push(Some(buf[..(4 as usize)].into())); + buf.advance(4); + } + + FieldType::MYSQL_TYPE_VAR_STRING => { + let len = buf.get_uint_lenenc::()?.unwrap_or_default(); + + values.push(Some(buf[..(len as usize)].into())); + buf.advance(len as usize); + } + + type_ => { + unimplemented!("encountered unknown field type: {:?}", type_); + } + } + } + } + + Ok(Self { + buffer, + values: values.into_boxed_slice(), + }) + } +}