More work towards mariadb protocol refactor

This commit is contained in:
Ryan Leckey 2019-09-05 15:20:41 -07:00
parent 94fe35c264
commit 6949b35eec
8 changed files with 395 additions and 9 deletions

View File

@ -3,6 +3,7 @@ use byteorder::ByteOrder;
use std::io;
pub trait BufExt {
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64>;
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>>;
fn get_str_eof(&mut self) -> io::Result<&str>;
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&str>>;
@ -10,6 +11,13 @@ pub trait BufExt {
}
impl<'a> BufExt for &'a [u8] {
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64> {
let val = T::read_uint(*self, n);
self.advance(n);
Ok(val)
}
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>> {
Ok(match self.get_u8()? {
0xFB => None,

View File

@ -42,3 +42,24 @@ bitflags::bitflags! {
const UNSIGNED = 128;
}
}
// https://mariadb.com/kb/en/library/resultset/#field-detail-flag
bitflags::bitflags! {
pub struct FieldDetailFlag: u16 {
const NOT_NULL = 1;
const PRIMARY_KEY = 2;
const UNIQUE_KEY = 4;
const MULTIPLE_KEY = 8;
const BLOB = 16;
const UNSIGNED = 32;
const ZEROFILL_FLAG = 64;
const BINARY_COLLATION = 128;
const ENUM = 256;
const AUTO_INCREMENT = 512;
const TIMESTAMP = 1024;
const SET = 2048;
const NO_DEFAULT_VALUE_FLAG = 4096;
const ON_UPDATE_NOW_FLAG = 8192;
const NUM_FLAG = 32768;
}
}

View File

@ -6,17 +6,17 @@ mod connect;
mod encode;
mod error_code;
mod field;
mod server_status;
mod response;
mod server_status;
pub use capabilities::Capabilities;
pub use connect::{
AuthenticationSwitchRequest, HandshakeResponsePacket, InitialHandshakePacket, SslRequest,
};
pub use response::{
OkPacket, EofPacket, ErrPacket, ResultRow,
};
pub use encode::Encode;
pub use error_code::ErrorCode;
pub use field::{FieldType, ParameterFlag};
pub use field::{FieldDetailFlag, FieldType, ParameterFlag};
pub use response::{
ColumnCountPacket, ColumnDefinitionPacket, EofPacket, ErrPacket, OkPacket, ResultRow,
};
pub use server_status::ServerStatusFlag;

View File

@ -0,0 +1,47 @@
use crate::mariadb::io::BufExt;
use byteorder::LittleEndian;
use std::io;
// The column packet is the first packet of a result set.
// Inside of it it contains the number of columns in the result set
// encoded as an int<lenenc>.
// https://mariadb.com/kb/en/library/resultset/#column-count-packet
#[derive(Debug)]
pub struct ColumnCountPacket {
pub columns: u64,
}
impl ColumnCountPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
let columns = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
Ok(Self { columns })
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::__bytes_builder;
#[test]
fn it_decodes_column_packet_0x_fb() -> io::Result<()> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
0u8, 0u8, 0u8,
// int<1> seq_no
0u8,
// int<lenenc> tag code: Some(3 bytes)
0xFD_u8,
// value: 3 bytes
0x01_u8, 0x01_u8, 0x01_u8
);
let message = ColumnCountPacket::decode(&buf)?;
assert_eq!(message.columns, Some(0x010101));
Ok(())
}
}

View File

@ -0,0 +1,127 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{FieldDetailFlag, FieldType},
},
};
use byteorder::LittleEndian;
use std::io;
#[derive(Debug)]
// ColumnDefinitionPacket doesn't have a packet header because
// it's nested inside a result set packet
pub struct ColumnDefinitionPacket {
pub schema: Option<String>,
pub table_alias: Option<String>,
pub table: Option<String>,
pub column_alias: Option<String>,
pub column: Option<String>,
pub char_set: u16,
pub max_columns: i32,
pub field_type: FieldType,
pub field_details: FieldDetailFlag,
pub decimals: u8,
}
impl ColumnDefinitionPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
// string<lenenc> catalog (always 'def')
let _catalog = buf.get_str_lenenc::<LittleEndian>()?;
// TODO: Assert that this is always DEF
// string<lenenc> schema
let schema = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> table alias
let table_alias = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> table
let table = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> column alias
let column_alias = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> column
let column = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// int<lenenc> length of fixed fields (=0xC)
let _length_of_fixed_fields = buf.get_uint_lenenc::<LittleEndian>()?;
// TODO: Assert that this is always 0xC
// int<2> character set number
let char_set = buf.get_u16::<LittleEndian>()?;
// int<4> max. column size
let max_columns = buf.get_i32::<LittleEndian>()?;
// int<1> Field types
let field_type = FieldType(buf.get_u8()?);
// int<2> Field detail flag
let field_details = FieldDetailFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
// int<1> decimals
let decimals = buf.get_u8()?;
// int<2> - unused -
buf.advance(2);
Ok(Self {
schema,
table_alias,
table,
column_alias,
column,
char_set,
max_columns,
field_type,
field_details,
decimals,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::__bytes_builder;
#[test]
fn it_decodes_column_def_packet() -> io::Result<()> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// length
1u8, 0u8, 0u8,
// seq_no
0u8,
// string<lenenc> catalog (always 'def')
1u8, b'a',
// string<lenenc> schema
1u8, b'b',
// string<lenenc> table alias
1u8, b'c',
// string<lenenc> table
1u8, b'd',
// string<lenenc> column alias
1u8, b'e',
// string<lenenc> column
1u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xFC_u8, 1u8, 1u8,
// int<2> character set number
1u8, 1u8,
// int<4> max. column size
1u8, 1u8, 1u8, 1u8,
// int<1> Field types
1u8,
// int<2> Field detail flag
1u8, 0u8,
// int<1> decimals
1u8,
// int<2> - unused -
0u8, 0u8
);
let message = ColumnDefinitionPacket::decode(&buf)?;
assert_eq!(message.schema, Some(b"b"));
assert_eq!(message.table_alias, Some(b"c"));
assert_eq!(message.table, Some(b"d"));
assert_eq!(message.column_alias, Some(b"e"));
assert_eq!(message.column, Some(b"f"));
Ok(())
}
}

View File

@ -0,0 +1,111 @@
use crate::{
io::Buf,
mariadb::{io::BufExt, protocol::ErrorCode},
};
use byteorder::LittleEndian;
use std::io;
#[derive(Debug)]
pub enum ErrPacket {
Progress {
stage: u8,
max_stage: u8,
progress: u32,
info: Box<str>,
},
Error {
code: ErrorCode,
sql_state: Option<Box<str>>,
message: Box<str>,
},
}
impl ErrPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
let header = buf.get_u8()?;
debug_assert_eq!(header, 0xFF);
// error code : int<2>
let code = buf.get_u16::<LittleEndian>()?;
// if (errorcode == 0xFFFF) /* progress reporting */
if code == 0xFF_FF {
let stage = buf.get_u8()?;
let max_stage = buf.get_u8()?;
let progress = buf.get_u24::<LittleEndian>()?;
let info = buf
.get_str_lenenc::<LittleEndian>()?
.unwrap_or_default()
.into();
Ok(Self::Progress {
stage,
max_stage,
progress,
info,
})
} else {
// if (next byte = '#')
let sql_state = if buf[0] == b'#' {
// '#' : string<1>
buf.advance(1);
// sql state : string<5>
Some(buf.get_str(5)?.into())
} else {
None
};
let message = buf.get_str_eof()?.into();
Ok(Self::Error {
code: ErrorCode(code),
sql_state,
message,
})
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::__bytes_builder;
#[test]
fn it_decodes_err_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<1> 0xfe : EOF header
0xFF_u8,
// int<2> error code
0x84_u8, 0x04_u8,
// if (errorcode == 0xFFFF) /* progress reporting */ {
// int<1> stage
// int<1> max_stage
// int<3> progress
// string<lenenc> progress_info
// } else {
// if (next byte = '#') {
// string<1> sql state marker '#'
b"#",
// string<5>sql state
b"08S01",
// string<EOF> error message
b"Got packets out of order"
// } else {
// string<EOF> error message
// }
// }
);
let _message = ErrPacket::decode(&buf)?;
Ok(())
}
}

View File

@ -1,9 +1,13 @@
mod ok;
mod err;
mod column_count;
mod column_def;
mod eof;
mod err;
mod ok;
mod row;
pub use ok::OkPacket;
pub use err::ErrPacket;
pub use column_count::ColumnCountPacket;
pub use column_def::ColumnDefinitionPacket;
pub use eof::EofPacket;
pub use err::ErrPacket;
pub use ok::OkPacket;
pub use row::ResultRow;

View File

@ -0,0 +1,68 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{ColumnDefinitionPacket, FieldType},
},
};
use byteorder::LittleEndian;
use std::{io, pin::Pin, ptr::NonNull};
/// A resultset row represents a database resultset unit, which is usually generated by
/// executing a statement that queries the database.
#[derive(Debug)]
pub struct ResultRow {
#[used]
buffer: Pin<Box<[u8]>>,
pub values: Box<[Option<NonNull<[u8]>>]>,
}
// SAFE: Raw pointers point to pinned memory inside the struct
unsafe impl Send for ResultRow {}
unsafe impl Sync for ResultRow {}
impl ResultRow {
pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> io::Result<Self> {
// 0x00 header : byte<1>
let header = buf.get_u8()?;
// TODO: Replace with InvalidData err
debug_assert_eq!(header, 0);
// NULL-Bitmap : byte<(number_of_columns + 9) / 8>
let null = buf.get_uint::<LittleEndian>((columns.len() + 9) / 8)?;
let buffer: Pin<Box<[u8]>> = Pin::new(buf.into());
let mut buf = &*buffer;
let mut values = Vec::with_capacity(columns.len());
for column_idx in 0..columns.len() {
if (null & (1 << column_idx)) != 0 {
values.push(None);
} else {
match columns[column_idx].field_type {
FieldType::MYSQL_TYPE_LONG => {
values.push(Some(buf[..(4 as usize)].into()));
buf.advance(4);
}
FieldType::MYSQL_TYPE_VAR_STRING => {
let len = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or_default();
values.push(Some(buf[..(len as usize)].into()));
buf.advance(len as usize);
}
type_ => {
unimplemented!("encountered unknown field type: {:?}", type_);
}
}
}
}
Ok(Self {
buffer,
values: values.into_boxed_slice(),
})
}
}