mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
More work towards mariadb protocol refactor
This commit is contained in:
parent
94fe35c264
commit
6949b35eec
@ -3,6 +3,7 @@ use byteorder::ByteOrder;
|
||||
use std::io;
|
||||
|
||||
pub trait BufExt {
|
||||
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64>;
|
||||
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>>;
|
||||
fn get_str_eof(&mut self) -> io::Result<&str>;
|
||||
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<&str>>;
|
||||
@ -10,6 +11,13 @@ pub trait BufExt {
|
||||
}
|
||||
|
||||
impl<'a> BufExt for &'a [u8] {
|
||||
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64> {
|
||||
let val = T::read_uint(*self, n);
|
||||
self.advance(n);
|
||||
|
||||
Ok(val)
|
||||
}
|
||||
|
||||
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<Option<u64>> {
|
||||
Ok(match self.get_u8()? {
|
||||
0xFB => None,
|
||||
|
||||
@ -42,3 +42,24 @@ bitflags::bitflags! {
|
||||
const UNSIGNED = 128;
|
||||
}
|
||||
}
|
||||
|
||||
// https://mariadb.com/kb/en/library/resultset/#field-detail-flag
|
||||
bitflags::bitflags! {
|
||||
pub struct FieldDetailFlag: u16 {
|
||||
const NOT_NULL = 1;
|
||||
const PRIMARY_KEY = 2;
|
||||
const UNIQUE_KEY = 4;
|
||||
const MULTIPLE_KEY = 8;
|
||||
const BLOB = 16;
|
||||
const UNSIGNED = 32;
|
||||
const ZEROFILL_FLAG = 64;
|
||||
const BINARY_COLLATION = 128;
|
||||
const ENUM = 256;
|
||||
const AUTO_INCREMENT = 512;
|
||||
const TIMESTAMP = 1024;
|
||||
const SET = 2048;
|
||||
const NO_DEFAULT_VALUE_FLAG = 4096;
|
||||
const ON_UPDATE_NOW_FLAG = 8192;
|
||||
const NUM_FLAG = 32768;
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,17 +6,17 @@ mod connect;
|
||||
mod encode;
|
||||
mod error_code;
|
||||
mod field;
|
||||
mod server_status;
|
||||
mod response;
|
||||
mod server_status;
|
||||
|
||||
pub use capabilities::Capabilities;
|
||||
pub use connect::{
|
||||
AuthenticationSwitchRequest, HandshakeResponsePacket, InitialHandshakePacket, SslRequest,
|
||||
};
|
||||
pub use response::{
|
||||
OkPacket, EofPacket, ErrPacket, ResultRow,
|
||||
};
|
||||
pub use encode::Encode;
|
||||
pub use error_code::ErrorCode;
|
||||
pub use field::{FieldType, ParameterFlag};
|
||||
pub use field::{FieldDetailFlag, FieldType, ParameterFlag};
|
||||
pub use response::{
|
||||
ColumnCountPacket, ColumnDefinitionPacket, EofPacket, ErrPacket, OkPacket, ResultRow,
|
||||
};
|
||||
pub use server_status::ServerStatusFlag;
|
||||
|
||||
47
src/mariadb/protocol/response/column_count.rs
Normal file
47
src/mariadb/protocol/response/column_count.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use crate::mariadb::io::BufExt;
|
||||
use byteorder::LittleEndian;
|
||||
use std::io;
|
||||
|
||||
// The column packet is the first packet of a result set.
|
||||
// Inside of it it contains the number of columns in the result set
|
||||
// encoded as an int<lenenc>.
|
||||
// https://mariadb.com/kb/en/library/resultset/#column-count-packet
|
||||
#[derive(Debug)]
|
||||
pub struct ColumnCountPacket {
|
||||
pub columns: u64,
|
||||
}
|
||||
|
||||
impl ColumnCountPacket {
|
||||
fn decode(mut buf: &[u8]) -> io::Result<Self> {
|
||||
let columns = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
|
||||
|
||||
Ok(Self { columns })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::__bytes_builder;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_column_packet_0x_fb() -> io::Result<()> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
0u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
0u8,
|
||||
// int<lenenc> tag code: Some(3 bytes)
|
||||
0xFD_u8,
|
||||
// value: 3 bytes
|
||||
0x01_u8, 0x01_u8, 0x01_u8
|
||||
);
|
||||
|
||||
let message = ColumnCountPacket::decode(&buf)?;
|
||||
|
||||
assert_eq!(message.columns, Some(0x010101));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
127
src/mariadb/protocol/response/column_def.rs
Normal file
127
src/mariadb/protocol/response/column_def.rs
Normal file
@ -0,0 +1,127 @@
|
||||
use crate::{
|
||||
io::Buf,
|
||||
mariadb::{
|
||||
io::BufExt,
|
||||
protocol::{FieldDetailFlag, FieldType},
|
||||
},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
// ColumnDefinitionPacket doesn't have a packet header because
|
||||
// it's nested inside a result set packet
|
||||
pub struct ColumnDefinitionPacket {
|
||||
pub schema: Option<String>,
|
||||
pub table_alias: Option<String>,
|
||||
pub table: Option<String>,
|
||||
pub column_alias: Option<String>,
|
||||
pub column: Option<String>,
|
||||
pub char_set: u16,
|
||||
pub max_columns: i32,
|
||||
pub field_type: FieldType,
|
||||
pub field_details: FieldDetailFlag,
|
||||
pub decimals: u8,
|
||||
}
|
||||
|
||||
impl ColumnDefinitionPacket {
|
||||
fn decode(mut buf: &[u8]) -> io::Result<Self> {
|
||||
// string<lenenc> catalog (always 'def')
|
||||
let _catalog = buf.get_str_lenenc::<LittleEndian>()?;
|
||||
// TODO: Assert that this is always DEF
|
||||
|
||||
// string<lenenc> schema
|
||||
let schema = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
|
||||
// string<lenenc> table alias
|
||||
let table_alias = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
|
||||
// string<lenenc> table
|
||||
let table = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
|
||||
// string<lenenc> column alias
|
||||
let column_alias = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
|
||||
// string<lenenc> column
|
||||
let column = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
|
||||
|
||||
// int<lenenc> length of fixed fields (=0xC)
|
||||
let _length_of_fixed_fields = buf.get_uint_lenenc::<LittleEndian>()?;
|
||||
// TODO: Assert that this is always 0xC
|
||||
|
||||
// int<2> character set number
|
||||
let char_set = buf.get_u16::<LittleEndian>()?;
|
||||
// int<4> max. column size
|
||||
let max_columns = buf.get_i32::<LittleEndian>()?;
|
||||
// int<1> Field types
|
||||
let field_type = FieldType(buf.get_u8()?);
|
||||
// int<2> Field detail flag
|
||||
let field_details = FieldDetailFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
|
||||
// int<1> decimals
|
||||
let decimals = buf.get_u8()?;
|
||||
// int<2> - unused -
|
||||
buf.advance(2);
|
||||
|
||||
Ok(Self {
|
||||
schema,
|
||||
table_alias,
|
||||
table,
|
||||
column_alias,
|
||||
column,
|
||||
char_set,
|
||||
max_columns,
|
||||
field_type,
|
||||
field_details,
|
||||
decimals,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::__bytes_builder;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_column_def_packet() -> io::Result<()> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// length
|
||||
1u8, 0u8, 0u8,
|
||||
// seq_no
|
||||
0u8,
|
||||
// string<lenenc> catalog (always 'def')
|
||||
1u8, b'a',
|
||||
// string<lenenc> schema
|
||||
1u8, b'b',
|
||||
// string<lenenc> table alias
|
||||
1u8, b'c',
|
||||
// string<lenenc> table
|
||||
1u8, b'd',
|
||||
// string<lenenc> column alias
|
||||
1u8, b'e',
|
||||
// string<lenenc> column
|
||||
1u8, b'f',
|
||||
// int<lenenc> length of fixed fields (=0xC)
|
||||
0xFC_u8, 1u8, 1u8,
|
||||
// int<2> character set number
|
||||
1u8, 1u8,
|
||||
// int<4> max. column size
|
||||
1u8, 1u8, 1u8, 1u8,
|
||||
// int<1> Field types
|
||||
1u8,
|
||||
// int<2> Field detail flag
|
||||
1u8, 0u8,
|
||||
// int<1> decimals
|
||||
1u8,
|
||||
// int<2> - unused -
|
||||
0u8, 0u8
|
||||
);
|
||||
|
||||
let message = ColumnDefinitionPacket::decode(&buf)?;
|
||||
|
||||
assert_eq!(message.schema, Some(b"b"));
|
||||
assert_eq!(message.table_alias, Some(b"c"));
|
||||
assert_eq!(message.table, Some(b"d"));
|
||||
assert_eq!(message.column_alias, Some(b"e"));
|
||||
assert_eq!(message.column, Some(b"f"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,111 @@
|
||||
use crate::{
|
||||
io::Buf,
|
||||
mariadb::{io::BufExt, protocol::ErrorCode},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ErrPacket {
|
||||
Progress {
|
||||
stage: u8,
|
||||
max_stage: u8,
|
||||
progress: u32,
|
||||
info: Box<str>,
|
||||
},
|
||||
|
||||
Error {
|
||||
code: ErrorCode,
|
||||
sql_state: Option<Box<str>>,
|
||||
message: Box<str>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrPacket {
|
||||
fn decode(mut buf: &[u8]) -> io::Result<Self> {
|
||||
let header = buf.get_u8()?;
|
||||
debug_assert_eq!(header, 0xFF);
|
||||
|
||||
// error code : int<2>
|
||||
let code = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
// if (errorcode == 0xFFFF) /* progress reporting */
|
||||
if code == 0xFF_FF {
|
||||
let stage = buf.get_u8()?;
|
||||
let max_stage = buf.get_u8()?;
|
||||
let progress = buf.get_u24::<LittleEndian>()?;
|
||||
let info = buf
|
||||
.get_str_lenenc::<LittleEndian>()?
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
|
||||
Ok(Self::Progress {
|
||||
stage,
|
||||
max_stage,
|
||||
progress,
|
||||
info,
|
||||
})
|
||||
} else {
|
||||
// if (next byte = '#')
|
||||
let sql_state = if buf[0] == b'#' {
|
||||
// '#' : string<1>
|
||||
buf.advance(1);
|
||||
|
||||
// sql state : string<5>
|
||||
Some(buf.get_str(5)?.into())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let message = buf.get_str_eof()?.into();
|
||||
|
||||
Ok(Self::Error {
|
||||
code: ErrorCode(code),
|
||||
sql_state,
|
||||
message,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::__bytes_builder;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_err_packet() -> Result<(), Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
1u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
1u8,
|
||||
// int<1> 0xfe : EOF header
|
||||
0xFF_u8,
|
||||
// int<2> error code
|
||||
0x84_u8, 0x04_u8,
|
||||
// if (errorcode == 0xFFFF) /* progress reporting */ {
|
||||
// int<1> stage
|
||||
// int<1> max_stage
|
||||
// int<3> progress
|
||||
// string<lenenc> progress_info
|
||||
// } else {
|
||||
// if (next byte = '#') {
|
||||
// string<1> sql state marker '#'
|
||||
b"#",
|
||||
// string<5>sql state
|
||||
b"08S01",
|
||||
// string<EOF> error message
|
||||
b"Got packets out of order"
|
||||
// } else {
|
||||
// string<EOF> error message
|
||||
// }
|
||||
// }
|
||||
);
|
||||
|
||||
let _message = ErrPacket::decode(&buf)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,9 +1,13 @@
|
||||
mod ok;
|
||||
mod err;
|
||||
mod column_count;
|
||||
mod column_def;
|
||||
mod eof;
|
||||
mod err;
|
||||
mod ok;
|
||||
mod row;
|
||||
|
||||
pub use ok::OkPacket;
|
||||
pub use err::ErrPacket;
|
||||
pub use column_count::ColumnCountPacket;
|
||||
pub use column_def::ColumnDefinitionPacket;
|
||||
pub use eof::EofPacket;
|
||||
pub use err::ErrPacket;
|
||||
pub use ok::OkPacket;
|
||||
pub use row::ResultRow;
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
use crate::{
|
||||
io::Buf,
|
||||
mariadb::{
|
||||
io::BufExt,
|
||||
protocol::{ColumnDefinitionPacket, FieldType},
|
||||
},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
use std::{io, pin::Pin, ptr::NonNull};
|
||||
|
||||
/// A resultset row represents a database resultset unit, which is usually generated by
|
||||
/// executing a statement that queries the database.
|
||||
#[derive(Debug)]
|
||||
pub struct ResultRow {
|
||||
#[used]
|
||||
buffer: Pin<Box<[u8]>>,
|
||||
pub values: Box<[Option<NonNull<[u8]>>]>,
|
||||
}
|
||||
|
||||
// SAFE: Raw pointers point to pinned memory inside the struct
|
||||
unsafe impl Send for ResultRow {}
|
||||
unsafe impl Sync for ResultRow {}
|
||||
|
||||
impl ResultRow {
|
||||
pub fn decode(mut buf: &[u8], columns: &[ColumnDefinitionPacket]) -> io::Result<Self> {
|
||||
// 0x00 header : byte<1>
|
||||
let header = buf.get_u8()?;
|
||||
// TODO: Replace with InvalidData err
|
||||
debug_assert_eq!(header, 0);
|
||||
|
||||
// NULL-Bitmap : byte<(number_of_columns + 9) / 8>
|
||||
let null = buf.get_uint::<LittleEndian>((columns.len() + 9) / 8)?;
|
||||
|
||||
let buffer: Pin<Box<[u8]>> = Pin::new(buf.into());
|
||||
let mut buf = &*buffer;
|
||||
|
||||
let mut values = Vec::with_capacity(columns.len());
|
||||
|
||||
for column_idx in 0..columns.len() {
|
||||
if (null & (1 << column_idx)) != 0 {
|
||||
values.push(None);
|
||||
} else {
|
||||
match columns[column_idx].field_type {
|
||||
FieldType::MYSQL_TYPE_LONG => {
|
||||
values.push(Some(buf[..(4 as usize)].into()));
|
||||
buf.advance(4);
|
||||
}
|
||||
|
||||
FieldType::MYSQL_TYPE_VAR_STRING => {
|
||||
let len = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or_default();
|
||||
|
||||
values.push(Some(buf[..(len as usize)].into()));
|
||||
buf.advance(len as usize);
|
||||
}
|
||||
|
||||
type_ => {
|
||||
unimplemented!("encountered unknown field type: {:?}", type_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
values: values.into_boxed_slice(),
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user