More work towards mariadb protocol refactor

This commit is contained in:
Ryan Leckey 2019-09-05 13:37:43 -07:00
parent b0707bba7b
commit 94fe35c264
22 changed files with 778 additions and 600 deletions

View File

@ -89,7 +89,9 @@ where
}
}
// TODO: Find a nicer way to do this
// Return `Ok(None)` immediately from a function if the wrapped value is `None`
#[allow(unused)]
macro_rules! ret_if_none {
($val:expr) => {
match $val {

View File

@ -71,3 +71,179 @@ impl BufMutExt for Vec<u8> {
self.extend_from_slice(val);
}
}
#[cfg(test)]
mod tests {
use super::BufMutExt;
use crate::io::BufMut;
use byteorder::LittleEndian;
// [X] it_encodes_int_lenenc_u64
// [X] it_encodes_int_lenenc_u32
// [X] it_encodes_int_lenenc_u24
// [X] it_encodes_int_lenenc_u16
// [X] it_encodes_int_lenenc_u8
// [X] it_encodes_int_u64
// [X] it_encodes_int_u32
// [X] it_encodes_int_u24
// [X] it_encodes_int_u16
// [X] it_encodes_int_u8
// [X] it_encodes_string_lenenc
// [X] it_encodes_string_fix
// [X] it_encodes_string_null
// [X] it_encodes_string_eof
// [X] it_encodes_byte_lenenc
// [X] it_encodes_byte_fix
// [X] it_encodes_byte_eof
#[test]
fn it_encodes_int_lenenc_none() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(None);
assert_eq!(&buf[..], b"\xFB");
}
#[test]
fn it_encodes_int_lenenc_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFA as u64);
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn it_encodes_int_lenenc_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(std::u16::MAX as u64);
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFF_FF_FF as u64);
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(std::u64::MAX);
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_fb() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFB as u64);
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn it_encodes_int_lenenc_fc() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFC as u64);
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn it_encodes_int_lenenc_fd() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFD as u64);
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn it_encodes_int_lenenc_fe() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFE as u64);
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
fn it_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian, _>(0xFF as u64);
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn it_encodes_int_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_u64::<LittleEndian>(std::u64::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u32() {
let mut buf = Vec::with_capacity(1024);
buf.put_u32::<LittleEndian>(std::u32::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_u24::<LittleEndian>(0xFF_FF_FF as u32);
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_u16::<LittleEndian>(std::u16::MAX);
assert_eq!(&buf[..], b"\xFF\xFF");
}
#[test]
fn it_encodes_int_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_u8(std::u8::MAX);
assert_eq!(&buf[..], b"\xFF");
}
#[test]
fn it_encodes_string_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_lenenc::<LittleEndian>("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn it_encodes_string_fix() {
let mut buf = Vec::with_capacity(1024);
buf.put_str("random_string");
assert_eq!(&buf[..], b"random_string");
}
#[test]
fn it_encodes_string_null() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_nul("random_string");
assert_eq!(&buf[..], b"random_string\0");
}
#[test]
fn it_encodes_byte_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_bytes_lenenc::<LittleEndian>(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
}

View File

@ -0,0 +1,65 @@
// https://mariadb.com/kb/en/library/connection/#capabilities
bitflags::bitflags! {
pub struct Capabilities: u128 {
const CLIENT_MYSQL = 1;
const FOUND_ROWS = 2;
// One can specify db on connect
const CONNECT_WITH_DB = 8;
// Can use compression protocol
const COMPRESS = 32;
// Can use LOAD DATA LOCAL
const LOCAL_FILES = 128;
// Ignore spaces before '('
const IGNORE_SPACE = 256;
// 4.1+ protocol
const CLIENT_PROTOCOL_41 = 1 << 9;
const CLIENT_INTERACTIVE = 1 << 10;
// Can use SSL
const SSL = 1 << 11;
const TRANSACTIONS = 1 << 12;
// 4.1+ authentication
const SECURE_CONNECTION = 1 << 13;
// Enable/disable multi-stmt support
const MULTI_STATEMENTS = 1 << 16;
// Enable/disable multi-results
const MULTI_RESULTS = 1 << 17;
// Enable/disable multi-results for PrepareStatement
const PS_MULTI_RESULTS = 1 << 18;
// Client supports plugin authentication
const PLUGIN_AUTH = 1 << 19;
// Client send connection attributes
const CONNECT_ATTRS = 1 << 20;
// Enable authentication response packet to be larger than 255 bytes
const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21;
// Enable/disable session tracking in OK_Packet
const CLIENT_SESSION_TRACK = 1 << 23;
// EOF_Packet deprecation
const CLIENT_DEPRECATE_EOF = 1 << 24;
// Client support progress indicator (since 10.2)
const MARIA_DB_CLIENT_PROGRESS = 1 << 32;
// Permit COM_MULTI protocol
const MARIA_DB_CLIENT_COM_MULTI = 1 << 33;
// Permit bulk insert
const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
}
}

View File

@ -0,0 +1,21 @@
use crate::{
io::BufMut,
mariadb::{
io::BufMutExt,
protocol::{Capabilities, Encode},
},
};
#[derive(Default, Debug)]
pub struct AuthenticationSwitchRequest<'a> {
pub auth_plugin_name: &'a str,
pub auth_plugin_data: &'a [u8],
}
impl Encode for AuthenticationSwitchRequest<'_> {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.put_u8(0xFE);
buf.put_str_nul(&self.auth_plugin_name);
buf.put_bytes(&self.auth_plugin_data);
}
}

View File

@ -0,0 +1,164 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{Capabilities, ServerStatusFlag},
},
};
use byteorder::LittleEndian;
use std::io;
#[derive(Debug)]
pub struct InitialHandshakePacket {
pub protocol_version: u8,
pub server_version: String,
pub server_status: ServerStatusFlag,
pub server_default_collation: u8,
pub connection_id: u32,
pub scramble: Box<[u8]>,
pub capabilities: Capabilities,
pub auth_plugin_name: Option<String>,
}
impl InitialHandshakePacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
let protocol_version = buf.get_u8()?;
let server_version = buf.get_str_nul()?.to_owned();
let connection_id = buf.get_u32::<LittleEndian>()?;
let mut scramble = Vec::with_capacity(8);
// scramble 1st part (authentication seed) : string<8>
scramble.extend_from_slice(&buf[..8]);
buf.advance(8);
// reserved : string<1>
buf.advance(1);
// server capabilities (1st part) : int<2>
let capabilities_1 = buf.get_u16::<LittleEndian>()?;
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
// server default collation : int<1>
let server_default_collation = buf.get_u8()?;
// status flags : int<2>
let server_status = buf.get_u16::<LittleEndian>()?;
// server capabilities (2nd part) : int<2>
let capabilities_2 = buf.get_u16::<LittleEndian>()?;
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
// if (server_capabilities & PLUGIN_AUTH)
let plugin_data_length = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// plugin data length : int<1>
buf.get_u8()?
} else {
// 0x00 : int<1>
buf.advance(0);
0
};
// filler : string<6>
buf.advance(6);
// if (server_capabilities & CLIENT_MYSQL)
if capabilities.contains(Capabilities::CLIENT_MYSQL) {
// filler : string<4>
buf.advance(4);
} else {
// server capabilities 3rd part . MariaDB specific flags : int<4>
let capabilities_3 = buf.get_u32::<LittleEndian>()?;
capabilities |= Capabilities::from_bits_truncate((capabilities_2 as u128) << 32);
}
// if (server_capabilities & CLIENT_SECURE_CONNECTION)
if capabilities.contains(Capabilities::SECURE_CONNECTION) {
// scramble 2nd part . Length = max(12, plugin data length - 9) : string<N>
let len = ((plugin_data_length as isize) - 9).max(12) as usize;
scramble.extend_from_slice(&buf[..len]);
buf.advance(len);
// reserved byte : string<1>
buf.advance(1);
}
// if (server_capabilities & PLUGIN_AUTH)
let auth_plugin_name = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
Some(buf.get_str_nul()?.to_owned())
} else {
None
};
Ok(Self {
protocol_version,
server_version,
server_default_collation,
server_status: ServerStatusFlag::from_bits_truncate(server_status),
connection_id,
scramble: scramble.into_boxed_slice(),
capabilities,
auth_plugin_name,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::__bytes_builder;
#[test]
fn it_decodes_initial_handshake_packet() -> io::Result<()> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
0u8,
//int<1> protocol version
10u8,
//string<NUL> server version (MariaDB server version is by default prefixed by "5.5.5-")
b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0",
//int<4> connection id
13u8, 0u8, 0u8, 0u8,
//string<8> scramble 1st part (authentication seed)
b"?~~|vZAu",
//string<1> reserved byte
0u8,
//int<2> server capabilities (1st part)
0xFEu8, 0xF7u8,
//int<1> server default collation
8u8,
//int<2> status flags
2u8, 0u8,
//int<2> server capabilities (2nd part)
0xFF_u8, 0x81_u8,
//if (server_capabilities & PLUGIN_AUTH)
// int<1> plugin data length
15u8,
//else
// int<1> 0x00
//string<6> filler
0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
//if (server_capabilities & CLIENT_MYSQL)
// string<4> filler
//else
// int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */
7u8, 0u8, 0u8, 0u8,
//if (server_capabilities & CLIENT_SECURE_CONNECTION)
// string<n> scramble 2nd part . Length = max(12, plugin data length - 9)
b"JQ8cihP4Q}Dx",
// string<1> reserved byte
0u8,
//if (server_capabilities & PLUGIN_AUTH)
// string<NUL> authentication plugin name
b"mysql_native_password\0"
);
let _message = InitialHandshakePacket::decode(&buf)?;
Ok(())
}
}

View File

@ -0,0 +1,9 @@
mod auth_switch_request;
mod initial;
mod response;
mod ssl_request;
pub use auth_switch_request::AuthenticationSwitchRequest;
pub use initial::InitialHandshakePacket;
pub use response::HandshakeResponsePacket;
pub use ssl_request::SslRequest;

View File

@ -7,7 +7,7 @@ use crate::{
};
use byteorder::LittleEndian;
#[derive(Default, Debug)]
#[derive(Debug)]
pub struct HandshakeResponsePacket<'a> {
pub capabilities: Capabilities,
pub max_packet_size: u32,

View File

@ -0,0 +1,40 @@
use crate::{
io::BufMut,
mariadb::{
io::BufMutExt,
protocol::{Capabilities, Encode},
},
};
use byteorder::LittleEndian;
#[derive(Debug)]
pub struct SslRequest {
pub capabilities: Capabilities,
pub max_packet_size: u32,
pub client_collation: u8,
}
impl Encode for SslRequest {
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
// client capabilities : int<4>
buf.put_u32::<LittleEndian>(self.capabilities.bits() as u32);
// max packet size : int<4>
buf.put_u32::<LittleEndian>(self.max_packet_size);
// client character collation : int<1>
buf.put_u8(self.client_collation);
// reserved : string<19>
buf.advance(19);
// if not (capabilities & CLIENT_MYSQL)
if !capabilities.contains(Capabilities::CLIENT_MYSQL) {
// extended client capabilities : int<4>
buf.put_u32::<LittleEndian>((self.capabilities.bits() >> 32) as u32);
} else {
// reserved : int<4>
buf.advance(4);
}
}
}

View File

@ -3,219 +3,3 @@ use super::Capabilities;
pub trait Encode {
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities);
}
pub const U24_MAX: usize = 0xFF_FF_FF;
// #[inline]
// fn put_param(&mut self, bytes: &Bytes, ty: FieldType) {
// match ty {
// FieldType::MYSQL_TYPE_DECIMAL => self.put_string_lenenc(bytes),
// FieldType::MYSQL_TYPE_TINY => self.put_int_1(bytes),
// FieldType::MYSQL_TYPE_SHORT => self.put_int_2(bytes),
// FieldType::MYSQL_TYPE_LONG => self.put_int_4(bytes),
// FieldType::MYSQL_TYPE_FLOAT => self.put_int_4(bytes),
// FieldType::MYSQL_TYPE_DOUBLE => self.put_int_8(bytes),
// FieldType::MYSQL_TYPE_NULL => panic!("Type cannot be FieldType::MysqlTypeNull"),
// FieldType::MYSQL_TYPE_TIMESTAMP => unimplemented!(),
// FieldType::MYSQL_TYPE_LONGLONG => self.put_int_8(bytes),
// FieldType::MYSQL_TYPE_INT24 => self.put_int_4(bytes),
// FieldType::MYSQL_TYPE_DATE => unimplemented!(),
// FieldType::MYSQL_TYPE_TIME => unimplemented!(),
// FieldType::MYSQL_TYPE_DATETIME => unimplemented!(),
// FieldType::MYSQL_TYPE_YEAR => self.put_int_4(bytes),
// FieldType::MYSQL_TYPE_NEWDATE => unimplemented!(),
// FieldType::MYSQL_TYPE_VARCHAR => self.put_string_lenenc(bytes),
// FieldType::MYSQL_TYPE_BIT => self.put_string_lenenc(bytes),
// FieldType::MYSQL_TYPE_TIMESTAMP2 => unimplemented!(),
// FieldType::MYSQL_TYPE_DATETIME2 => unimplemented!(),
// FieldType::MYSQL_TYPE_TIME2 => unimplemented!(),
// FieldType::MYSQL_TYPE_JSON => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_NEWDECIMAL => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_ENUM => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_SET => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_TINY_BLOB => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_MEDIUM_BLOB => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_LONG_BLOB => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_BLOB => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_VAR_STRING => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_STRING => self.put_byte_lenenc(bytes),
// FieldType::MYSQL_TYPE_GEOMETRY => self.put_byte_lenenc(bytes),
// _ => panic!("Unrecognized field type"),
// }
// }
#[cfg(test)]
mod tests {
use super::*;
use crate::{io::BufMut, mariadb::io::BufMutExt};
use byteorder::LittleEndian;
// [X] it_encodes_int_lenenc_u64
// [X] it_encodes_int_lenenc_u32
// [X] it_encodes_int_lenenc_u24
// [X] it_encodes_int_lenenc_u16
// [X] it_encodes_int_lenenc_u8
// [X] it_encodes_int_u64
// [X] it_encodes_int_u32
// [X] it_encodes_int_u24
// [X] it_encodes_int_u16
// [X] it_encodes_int_u8
// [X] it_encodes_string_lenenc
// [X] it_encodes_string_fix
// [X] it_encodes_string_null
// [X] it_encodes_string_eof
// [X] it_encodes_byte_lenenc
// [X] it_encodes_byte_fix
// [X] it_encodes_byte_eof
#[test]
fn it_encodes_int_lenenc_none() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0u64));
assert_eq!(&buf[..], b"\xFB");
}
#[test]
fn it_encodes_int_lenenc_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0xFA as u64));
assert_eq!(&buf[..], b"\xFA");
}
#[test]
fn it_encodes_int_lenenc_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(std::u16::MAX as u64));
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(U24_MAX as u64));
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(std::u64::MAX));
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_lenenc_fb() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0xFB as u64));
assert_eq!(&buf[..], b"\xFC\xFB\x00");
}
#[test]
fn it_encodes_int_lenenc_fc() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0xFC as u64));
assert_eq!(&buf[..], b"\xFC\xFC\x00");
}
#[test]
fn it_encodes_int_lenenc_fd() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0xFD as u64));
assert_eq!(&buf[..], b"\xFC\xFD\x00");
}
#[test]
fn it_encodes_int_lenenc_fe() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0xFE as u64));
assert_eq!(&buf[..], b"\xFC\xFE\x00");
}
fn it_encodes_int_lenenc_ff() {
let mut buf = Vec::with_capacity(1024);
buf.put_uint_lenenc::<LittleEndian>(Some(0xFF as u64));
assert_eq!(&buf[..], b"\xFC\xFF\x00");
}
#[test]
fn it_encodes_int_u64() {
let mut buf = Vec::with_capacity(1024);
buf.put_u64::<LittleEndian>(std::u64::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u32() {
let mut buf = Vec::with_capacity(1024);
buf.put_u32::<LittleEndian>(std::u32::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u24() {
let mut buf = Vec::with_capacity(1024);
buf.put_u24::<LittleEndian>(U24_MAX as u32);
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
}
#[test]
fn it_encodes_int_u16() {
let mut buf = Vec::with_capacity(1024);
buf.put_u16::<LittleEndian>(std::u16::MAX);
assert_eq!(&buf[..], b"\xFF\xFF");
}
#[test]
fn it_encodes_int_u8() {
let mut buf = Vec::with_capacity(1024);
buf.put_u8(std::u8::MAX);
assert_eq!(&buf[..], b"\xFF");
}
#[test]
fn it_encodes_string_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_lenenc::<LittleEndian>("random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
#[test]
fn it_encodes_string_fix() {
let mut buf = Vec::with_capacity(1024);
buf.put_str("random_string");
assert_eq!(&buf[..], b"random_string");
}
#[test]
fn it_encodes_string_null() {
let mut buf = Vec::with_capacity(1024);
buf.put_str_nul("random_string");
assert_eq!(&buf[..], b"random_string\0");
}
#[test]
fn it_encodes_byte_lenenc() {
let mut buf = Vec::with_capacity(1024);
buf.put_byte_lenenc::<LittleEndian>(b"random_string");
assert_eq!(&buf[..], b"\x0Drandom_string");
}
}

View File

@ -1,6 +1,9 @@
#[derive(Default, Debug)]
pub struct ErrorCode(pub(crate) u16);
// TODO: It would be nice to figure out a clean way to go from 1152 to "ER_ABORTING_CONNECTION (1152)" in Debug.
// Values from https://mariadb.com/kb/en/library/mariadb-error-codes/
impl ErrorCode {
const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152);
const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873);

View File

@ -0,0 +1,44 @@
// https://mariadb.com/kb/en/library/resultset/#field-types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FieldType(pub u8);
impl FieldType {
pub const MYSQL_TYPE_BIT: FieldType = FieldType(16);
pub const MYSQL_TYPE_BLOB: FieldType = FieldType(252);
pub const MYSQL_TYPE_DATE: FieldType = FieldType(10);
pub const MYSQL_TYPE_DATETIME: FieldType = FieldType(12);
pub const MYSQL_TYPE_DATETIME2: FieldType = FieldType(18);
pub const MYSQL_TYPE_DECIMAL: FieldType = FieldType(0);
pub const MYSQL_TYPE_DOUBLE: FieldType = FieldType(5);
pub const MYSQL_TYPE_ENUM: FieldType = FieldType(247);
pub const MYSQL_TYPE_FLOAT: FieldType = FieldType(4);
pub const MYSQL_TYPE_GEOMETRY: FieldType = FieldType(255);
pub const MYSQL_TYPE_INT24: FieldType = FieldType(9);
pub const MYSQL_TYPE_JSON: FieldType = FieldType(245);
pub const MYSQL_TYPE_LONG: FieldType = FieldType(3);
pub const MYSQL_TYPE_LONGLONG: FieldType = FieldType(8);
pub const MYSQL_TYPE_LONG_BLOB: FieldType = FieldType(251);
pub const MYSQL_TYPE_MEDIUM_BLOB: FieldType = FieldType(250);
pub const MYSQL_TYPE_NEWDATE: FieldType = FieldType(14);
pub const MYSQL_TYPE_NEWDECIMAL: FieldType = FieldType(246);
pub const MYSQL_TYPE_NULL: FieldType = FieldType(6);
pub const MYSQL_TYPE_SET: FieldType = FieldType(248);
pub const MYSQL_TYPE_SHORT: FieldType = FieldType(2);
pub const MYSQL_TYPE_STRING: FieldType = FieldType(254);
pub const MYSQL_TYPE_TIME: FieldType = FieldType(11);
pub const MYSQL_TYPE_TIME2: FieldType = FieldType(19);
pub const MYSQL_TYPE_TIMESTAMP: FieldType = FieldType(7);
pub const MYSQL_TYPE_TIMESTAMP2: FieldType = FieldType(17);
pub const MYSQL_TYPE_TINY: FieldType = FieldType(1);
pub const MYSQL_TYPE_TINY_BLOB: FieldType = FieldType(249);
pub const MYSQL_TYPE_VARCHAR: FieldType = FieldType(15);
pub const MYSQL_TYPE_VAR_STRING: FieldType = FieldType(253);
pub const MYSQL_TYPE_YEAR: FieldType = FieldType(13);
}
// https://mariadb.com/kb/en/library/com_stmt_execute/#parameter-flag
bitflags::bitflags! {
pub struct ParameterFlag: u8 {
const UNSIGNED = 128;
}
}

View File

@ -1,36 +1,22 @@
// Reference: https://mariadb.com/kb/en/library/connection
// Packets: https://mariadb.com/kb/en/library/0-packet
// TODO: Handle lengths which are greater than 3 bytes
// Either break the packet into several smaller ones, or
// return error
mod capabilities;
mod connect;
mod encode;
mod error_code;
mod field;
mod server_status;
mod response;
// TODO: Handle different Capabilities for server and client
// TODO: Handle when capability is set, but field is None
pub mod encode;
pub mod error_codes;
pub mod packets;
pub mod types;
// Re-export all the things
// pub use packets::{
// AuthenticationSwitchRequestPacket, ColumnDefPacket, ColumnPacket, ComDebug, ComInitDb, ComPing,
// ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep,
// ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk,
// ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket,
// InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultRowBinary, ResultRowText,
// ResultSet, SSLRequestPacket, SetOptionOptions, ShutdownOptions,
// };
pub use packets::{ColumnCountPacket, ColumnDefinitionPacket};
pub use encode::Encode;
pub use error_codes::ErrorCode;
pub use types::{
Capabilities, FieldDetailFlag, FieldType, ProtocolType, ServerStatusFlag, SessionChangeType,
StmtExecFlag,
pub use capabilities::Capabilities;
pub use connect::{
AuthenticationSwitchRequest, HandshakeResponsePacket, InitialHandshakePacket, SslRequest,
};
pub use response::{
OkPacket, EofPacket, ErrPacket, ResultRow,
};
pub use encode::Encode;
pub use error_code::ErrorCode;
pub use field::{FieldType, ParameterFlag};
pub use server_status::ServerStatusFlag;

View File

@ -1,19 +0,0 @@
use crate::mariadb::{BufMut, ConnContext, Encode, MariaDbRawConnection};
use bytes::Bytes;
use failure::Error;
#[derive(Default, Debug)]
pub struct AuthenticationSwitchRequestPacket {
pub auth_plugin_name: Bytes,
pub auth_plugin_data: Bytes,
}
impl Encode for AuthenticationSwitchRequestPacket {
fn encode(&self, buf: &mut Vec<u8>, ctx: &mut ConnContext) -> Result<(), Error> {
buf.put_int_u8(0xFE);
buf.put_string_null(&self.auth_plugin_name);
buf.put_byte_eof(&self.auth_plugin_data);
Ok(())
}
}

View File

@ -1,166 +0,0 @@
use crate::mariadb::{Capabilities, DeContext, Decode, ServerStatusFlag};
use bytes::Bytes;
use failure::{err_msg, Error};
#[derive(Default, Debug)]
pub struct InitialHandshakePacket {
pub length: u32,
pub seq_no: u8,
pub protocol_version: u8,
pub server_version: Bytes,
pub connection_id: i32,
pub auth_seed: Bytes,
pub capabilities: Capabilities,
pub collation: u8,
pub status: ServerStatusFlag,
pub plugin_data_length: u8,
pub scramble: Option<Bytes>,
pub auth_plugin_name: Option<Bytes>,
}
impl Decode for InitialHandshakePacket {
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_u8();
if seq_no != 0 {
return Err(err_msg(
"Sequence Number of Initial Handshake Packet is not 0",
));
}
let protocol_version = decoder.decode_int_u8();
let server_version = decoder.decode_string_null()?;
let connection_id = decoder.decode_int_i32();
let auth_seed = decoder.decode_string_fix(8);
// Skip reserved byte
decoder.skip_bytes(1);
let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_u16().into());
let collation = decoder.decode_int_u8();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into());
capabilities |=
Capabilities::from_bits_truncate(((decoder.decode_int_i16() as u32) << 16).into());
let mut plugin_data_length = 0;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
plugin_data_length = decoder.decode_int_u8();
} else {
// Skip reserve byte
decoder.skip_bytes(1);
}
// Skip filler
decoder.skip_bytes(6);
if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
capabilities |=
Capabilities::from_bits_truncate(((decoder.decode_int_u32() as u128) << 32).into());
} else {
// Skip filler
decoder.skip_bytes(4);
}
let mut scramble: Option<Bytes> = None;
if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
let len = std::cmp::max(12, plugin_data_length as usize - 9);
scramble = Some(decoder.decode_string_fix(len as usize));
// Skip reserve byte
decoder.skip_bytes(1);
}
let mut auth_plugin_name: Option<Bytes> = None;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
auth_plugin_name = Some(decoder.decode_string_null()?);
}
ctx.ctx.last_seq_no = seq_no;
Ok(InitialHandshakePacket {
length,
seq_no,
protocol_version,
server_version,
connection_id,
auth_seed,
capabilities,
collation,
status,
plugin_data_length,
scramble,
auth_plugin_name,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
__bytes_builder,
mariadb::{ConnContext, Decoder},
};
use bytes::BytesMut;
#[test]
fn it_decodes_initial_handshake_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
0u8,
//int<1> protocol version
10u8,
//string<NUL> server version (MariaDB server version is by default prefixed by "5.5.5-")
b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0",
//int<4> connection id
13u8, 0u8, 0u8, 0u8,
//string<8> scramble 1st part (authentication seed)
b"?~~|vZAu",
//string<1> reserved byte
0u8,
//int<2> server capabilities (1st part)
0xFEu8, 0xF7u8,
//int<1> server default collation
8u8,
//int<2> status flags
2u8, 0u8,
//int<2> server capabilities (2nd part)
0xFF_u8, 0x81_u8,
//if (server_capabilities & PLUGIN_AUTH)
// int<1> plugin data length
15u8,
//else
// int<1> 0x00
//string<6> filler
0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
//if (server_capabilities & CLIENT_MYSQL)
// string<4> filler
//else
// int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */
7u8, 0u8, 0u8, 0u8,
//if (server_capabilities & CLIENT_SECURE_CONNECTION)
// string<n> scramble 2nd part . Length = max(12, plugin data length - 9)
b"JQ8cihP4Q}Dx",
// string<1> reserved byte
0u8,
//if (server_capabilities & PLUGIN_AUTH)
// string<NUL> authentication plugin name
b"mysql_native_password\0"
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
let _message = InitialHandshakePacket::decode(&mut ctx)?;
Ok(())
}
}

View File

@ -1,40 +0,0 @@
use bytes::Bytes;
use failure::Error;
use crate::mariadb::{BufMut, Capabilities, ConnContext, Encode, MariaDbRawConnection};
#[derive(Default, Debug)]
pub struct SSLRequestPacket {
pub capabilities: Capabilities,
pub max_packet_size: u32,
pub collation: u8,
pub extended_capabilities: Option<Capabilities>,
}
impl Encode for SSLRequestPacket {
fn encode(&self, buf: &mut Vec<u8>, ctx: &mut ConnContext) -> Result<(), Error> {
buf.alloc_packet_header();
buf.seq_no(0);
buf.put_int_u32(self.capabilities.bits() as u32);
buf.put_int_u32(self.max_packet_size);
buf.put_int_u8(self.collation);
// Filler
buf.put_byte_fix(&Bytes::from_static(&[0u8; 19]), 19);
if !(ctx.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {
buf.put_int_u32(capabilities.bits() as u32);
}
} else {
buf.put_byte_fix(&Bytes::from_static(&[0u8; 4]), 4);
}
buf.put_length();
Ok(())
}
}

View File

@ -0,0 +1,65 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{ErrorCode, ServerStatusFlag},
},
};
use byteorder::LittleEndian;
use std::io;
#[derive(Debug)]
pub struct EofPacket {
pub warning_count: u16,
pub status: ServerStatusFlag,
}
impl EofPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
let header = buf.get_u8()?;
if header != 0xFE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 0xFE; received {}", header),
));
}
let warning_count = buf.get_u16::<LittleEndian>()?;
let status = ServerStatusFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
Ok(Self {
warning_count,
status,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{__bytes_builder, mariadb::ConnContext};
use bytes::Bytes;
#[test]
fn it_decodes_eof_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
1u8, 1u8
);
let _message = EofPacket::decode(&buf)?;
// TODO: Assert fields?
Ok(())
}
}

View File

View File

@ -0,0 +1,9 @@
mod ok;
mod err;
mod eof;
mod row;
pub use ok::OkPacket;
pub use err::ErrPacket;
pub use eof::EofPacket;
pub use row::ResultRow;

View File

@ -0,0 +1,117 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{Capabilities, ServerStatusFlag},
},
};
use byteorder::LittleEndian;
use std::io;
// https://mariadb.com/kb/en/library/ok_packet/
#[derive(Debug)]
pub struct OkPacket {
pub affected_rows: u64,
pub last_insert_id: u64,
pub server_status: ServerStatusFlag,
pub warning_count: u16,
pub info: Box<str>,
pub session_state_info: Option<Box<[u8]>>,
pub value_of_variable: Option<Box<str>>,
}
impl OkPacket {
fn decode(mut buf: &[u8], capabilities: Capabilities) -> io::Result<Self> {
let header = buf.get_u8()?;
if header != 0 && header != 0xFE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 0x00 or 0xFE; received 0x{:X}", header),
));
}
let affected_rows = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
let last_insert_id = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
let server_status = ServerStatusFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
let warning_count = buf.get_u16::<LittleEndian>()?;
let info;
let mut session_state_info = None;
let mut value_of_variable = None;
if capabilities.contains(Capabilities::CLIENT_SESSION_TRACK) {
info = buf
.get_str_lenenc::<LittleEndian>()?
.unwrap_or_default()
.to_owned()
.into();
session_state_info = buf.get_byte_lenenc::<LittleEndian>()?.map(Into::into);
value_of_variable = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
} else {
info = buf.get_str_eof()?.to_owned().into();
}
Ok(Self {
affected_rows,
last_insert_id,
server_status,
warning_count,
info,
session_state_info,
value_of_variable,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
__bytes_builder,
mariadb::{ConnContext, Decoder},
};
#[test]
fn it_decodes_ok_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
0u8, 0u8, 0u8,
// // int<1> seq_no
1u8,
// 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set)
0u8,
// int<lenenc> affected rows
0xFB_u8,
// int<lenenc> last insert id
0xFB_u8,
// int<2> server status
1u8, 1u8,
// int<2> warning count
0u8, 0u8,
// if session_tracking_supported (see CLIENT_SESSION_TRACK) {
// string<lenenc> info
// if (status flags & SERVER_SESSION_STATE_CHANGED) {
// string<lenenc> session state info
// string<lenenc> value of variable
// }
// } else {
// string<EOF> info
b"info"
// }
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
let message = OkPacket::decode(&mut ctx)?;
assert_eq!(message.affected_rows, None);
assert_eq!(message.last_insert_id, None);
assert!(!(message.server_status & ServerStatusFlag::SERVER_STATUS_IN_TRANS).is_empty());
assert_eq!(message.warning_count, 0);
assert_eq!(message.info, b"info".to_vec());
Ok(())
}
}

View File

View File

@ -0,0 +1,45 @@
// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status
bitflags::bitflags! {
pub struct ServerStatusFlag: u16 {
// A transaction is currently active
const SERVER_STATUS_IN_TRANS = 1;
// Autocommit mode is set
const SERVER_STATUS_AUTOCOMMIT = 2;
// more results exists (more packet follow)
const SERVER_MORE_RESULTS_EXISTS = 8;
const SERVER_QUERY_NO_GOOD_INDEX_USED = 16;
const SERVER_QUERY_NO_INDEX_USED = 32;
// when using COM_STMT_FETCH, indicate that current cursor still has result
const SERVER_STATUS_CURSOR_EXISTS = 64;
// when using COM_STMT_FETCH, indicate that current cursor has finished to send results
const SERVER_STATUS_LAST_ROW_SENT = 128;
// database has been dropped
const SERVER_STATUS_DB_DROPPED = 1 << 8;
// current escape mode is "no backslash escape"
const SERVER_STATUS_NO_BACKSLASH_ESAPES = 1 << 9;
// A DDL change did have an impact on an existing PREPARE (an
// automatic reprepare has been executed)
const SERVER_STATUS_METADATA_CHANGED = 1 << 10;
// Last statement took more than the time value specified in
// server variable long_query_time.
const SERVER_QUERY_WAS_SLOW = 1 << 11;
// this resultset contain stored procedure output parameter
const SERVER_PS_OUT_PARAMS = 1 << 12;
// current transaction is a read-only transaction
const SERVER_STATUS_IN_TRANS_READONLY = 1 << 13;
// session state change. see Session change type for more information
const SERVER_SESSION_STATE_CHANGED = 1 << 14;
}
}

View File

@ -1,34 +1,3 @@
pub enum ProtocolType {
Text,
Binary,
}
bitflags::bitflags! {
pub struct Capabilities: u128 {
const CLIENT_MYSQL = 1;
const FOUND_ROWS = 1 << 1;
const CONNECT_WITH_DB = 1 << 3;
const COMPRESS = 1 << 5;
const LOCAL_FILES = 1 << 7;
const IGNORE_SPACE = 1 << 8;
const CLIENT_PROTOCOL_41 = 1 << 9;
const CLIENT_INTERACTIVE = 1 << 10;
const SSL = 1 << 11;
const TRANSACTIONS = 1 << 12;
const SECURE_CONNECTION = 1 << 13;
const MULTI_STATEMENTS = 1 << 16;
const MULTI_RESULTS = 1 << 17;
const PS_MULTI_RESULTS = 1 << 18;
const PLUGIN_AUTH = 1 << 19;
const CONNECT_ATTRS = 1 << 20;
const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21;
const CLIENT_SESSION_TRACK = 1 << 23;
const CLIENT_DEPRECATE_EOF = 1 << 24;
const MARIA_DB_CLIENT_PROGRESS = 1 << 32;
const MARIA_DB_CLIENT_COM_MULTI = 1 << 33;
const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
}
}
bitflags::bitflags! {
pub struct FieldDetailFlag: u16 {
@ -78,42 +47,6 @@ pub enum SessionChangeType {
SessionTrackTransactionState = 5,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FieldType(pub u8);
impl FieldType {
pub const MYSQL_TYPE_BIT: FieldType = FieldType(16);
pub const MYSQL_TYPE_BLOB: FieldType = FieldType(252);
pub const MYSQL_TYPE_DATE: FieldType = FieldType(10);
pub const MYSQL_TYPE_DATETIME: FieldType = FieldType(12);
pub const MYSQL_TYPE_DATETIME2: FieldType = FieldType(18);
pub const MYSQL_TYPE_DECIMAL: FieldType = FieldType(0);
pub const MYSQL_TYPE_DOUBLE: FieldType = FieldType(5);
pub const MYSQL_TYPE_ENUM: FieldType = FieldType(247);
pub const MYSQL_TYPE_FLOAT: FieldType = FieldType(4);
pub const MYSQL_TYPE_GEOMETRY: FieldType = FieldType(255);
pub const MYSQL_TYPE_INT24: FieldType = FieldType(9);
pub const MYSQL_TYPE_JSON: FieldType = FieldType(245);
pub const MYSQL_TYPE_LONG: FieldType = FieldType(3);
pub const MYSQL_TYPE_LONGLONG: FieldType = FieldType(8);
pub const MYSQL_TYPE_LONG_BLOB: FieldType = FieldType(251);
pub const MYSQL_TYPE_MEDIUM_BLOB: FieldType = FieldType(250);
pub const MYSQL_TYPE_NEWDATE: FieldType = FieldType(14);
pub const MYSQL_TYPE_NEWDECIMAL: FieldType = FieldType(246);
pub const MYSQL_TYPE_NULL: FieldType = FieldType(6);
pub const MYSQL_TYPE_SET: FieldType = FieldType(248);
pub const MYSQL_TYPE_SHORT: FieldType = FieldType(2);
pub const MYSQL_TYPE_STRING: FieldType = FieldType(254);
pub const MYSQL_TYPE_TIME: FieldType = FieldType(11);
pub const MYSQL_TYPE_TIME2: FieldType = FieldType(19);
pub const MYSQL_TYPE_TIMESTAMP: FieldType = FieldType(7);
pub const MYSQL_TYPE_TIMESTAMP2: FieldType = FieldType(17);
pub const MYSQL_TYPE_TINY: FieldType = FieldType(1);
pub const MYSQL_TYPE_TINY_BLOB: FieldType = FieldType(249);
pub const MYSQL_TYPE_VARCHAR: FieldType = FieldType(15);
pub const MYSQL_TYPE_VAR_STRING: FieldType = FieldType(253);
pub const MYSQL_TYPE_YEAR: FieldType = FieldType(13);
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct StmtExecFlag(pub u8);
impl StmtExecFlag {
@ -122,63 +55,3 @@ impl StmtExecFlag {
pub const READ_ONLY: StmtExecFlag = StmtExecFlag(1);
pub const SCROLLABLE_CURSOR: StmtExecFlag = StmtExecFlag(3);
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ParamFlag(pub u8);
impl ParamFlag {
pub const NONE: ParamFlag = ParamFlag(0);
pub const UNSIGNED: ParamFlag = ParamFlag(128);
}
// TODO: Remove these Default impls
impl Default for Capabilities {
fn default() -> Self {
Capabilities::CLIENT_PROTOCOL_41
}
}
impl Default for ServerStatusFlag {
fn default() -> Self {
ServerStatusFlag::SERVER_STATUS_IN_TRANS
}
}
impl Default for FieldDetailFlag {
fn default() -> Self {
FieldDetailFlag::NOT_NULL
}
}
impl Default for FieldType {
fn default() -> Self {
FieldType::MYSQL_TYPE_DECIMAL
}
}
impl Default for StmtExecFlag {
fn default() -> Self {
StmtExecFlag::NO_CURSOR
}
}
impl Default for ParamFlag {
fn default() -> Self {
ParamFlag::UNSIGNED
}
}
#[cfg(test)]
mod test {
use super::Capabilities;
use crate::{__bytes_builder, io::Buf};
use byteorder::LittleEndian;
#[test]
fn it_decodes_capabilities() -> std::io::Result<()> {
let buf = &__bytes_builder!(b"\xfe\xf7")[..];
Capabilities::from_bits_truncate(buf.get_u16::<LittleEndian>()? as u128);
Ok(())
}
}