mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
More work towards mariadb protocol refactor
This commit is contained in:
parent
b0707bba7b
commit
94fe35c264
@ -89,7 +89,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Find a nicer way to do this
|
||||
// Return `Ok(None)` immediately from a function if the wrapped value is `None`
|
||||
#[allow(unused)]
|
||||
macro_rules! ret_if_none {
|
||||
($val:expr) => {
|
||||
match $val {
|
||||
|
||||
@ -71,3 +71,179 @@ impl BufMutExt for Vec<u8> {
|
||||
self.extend_from_slice(val);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::BufMutExt;
|
||||
use crate::io::BufMut;
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
// [X] it_encodes_int_lenenc_u64
|
||||
// [X] it_encodes_int_lenenc_u32
|
||||
// [X] it_encodes_int_lenenc_u24
|
||||
// [X] it_encodes_int_lenenc_u16
|
||||
// [X] it_encodes_int_lenenc_u8
|
||||
// [X] it_encodes_int_u64
|
||||
// [X] it_encodes_int_u32
|
||||
// [X] it_encodes_int_u24
|
||||
// [X] it_encodes_int_u16
|
||||
// [X] it_encodes_int_u8
|
||||
// [X] it_encodes_string_lenenc
|
||||
// [X] it_encodes_string_fix
|
||||
// [X] it_encodes_string_null
|
||||
// [X] it_encodes_string_eof
|
||||
// [X] it_encodes_byte_lenenc
|
||||
// [X] it_encodes_byte_fix
|
||||
// [X] it_encodes_byte_eof
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_none() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(None);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFB");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFA as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFA");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(std::u16::MAX as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFF_FF_FF as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(std::u64::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fb() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFB as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFB\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFC as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFC\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fd() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFD as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFD\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fe() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFE as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFE\x00");
|
||||
}
|
||||
|
||||
fn it_encodes_int_lenenc_ff() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian, _>(0xFF as u64);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u64::<LittleEndian>(std::u64::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u32() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u32::<LittleEndian>(std::u32::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u24::<LittleEndian>(0xFF_FF_FF as u32);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u16::<LittleEndian>(std::u16::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u8(std::u8::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_lenenc::<LittleEndian>("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_fix() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_null() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_nul("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string\0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_byte_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_bytes_lenenc::<LittleEndian>(b"random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
}
|
||||
|
||||
65
src/mariadb/protocol/capabilities.rs
Normal file
65
src/mariadb/protocol/capabilities.rs
Normal file
@ -0,0 +1,65 @@
|
||||
// https://mariadb.com/kb/en/library/connection/#capabilities
|
||||
bitflags::bitflags! {
|
||||
pub struct Capabilities: u128 {
|
||||
const CLIENT_MYSQL = 1;
|
||||
const FOUND_ROWS = 2;
|
||||
|
||||
// One can specify db on connect
|
||||
const CONNECT_WITH_DB = 8;
|
||||
|
||||
// Can use compression protocol
|
||||
const COMPRESS = 32;
|
||||
|
||||
// Can use LOAD DATA LOCAL
|
||||
const LOCAL_FILES = 128;
|
||||
|
||||
// Ignore spaces before '('
|
||||
const IGNORE_SPACE = 256;
|
||||
|
||||
// 4.1+ protocol
|
||||
const CLIENT_PROTOCOL_41 = 1 << 9;
|
||||
|
||||
const CLIENT_INTERACTIVE = 1 << 10;
|
||||
|
||||
// Can use SSL
|
||||
const SSL = 1 << 11;
|
||||
|
||||
const TRANSACTIONS = 1 << 12;
|
||||
|
||||
// 4.1+ authentication
|
||||
const SECURE_CONNECTION = 1 << 13;
|
||||
|
||||
// Enable/disable multi-stmt support
|
||||
const MULTI_STATEMENTS = 1 << 16;
|
||||
|
||||
// Enable/disable multi-results
|
||||
const MULTI_RESULTS = 1 << 17;
|
||||
|
||||
// Enable/disable multi-results for PrepareStatement
|
||||
const PS_MULTI_RESULTS = 1 << 18;
|
||||
|
||||
// Client supports plugin authentication
|
||||
const PLUGIN_AUTH = 1 << 19;
|
||||
|
||||
// Client send connection attributes
|
||||
const CONNECT_ATTRS = 1 << 20;
|
||||
|
||||
// Enable authentication response packet to be larger than 255 bytes
|
||||
const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21;
|
||||
|
||||
// Enable/disable session tracking in OK_Packet
|
||||
const CLIENT_SESSION_TRACK = 1 << 23;
|
||||
|
||||
// EOF_Packet deprecation
|
||||
const CLIENT_DEPRECATE_EOF = 1 << 24;
|
||||
|
||||
// Client support progress indicator (since 10.2)
|
||||
const MARIA_DB_CLIENT_PROGRESS = 1 << 32;
|
||||
|
||||
// Permit COM_MULTI protocol
|
||||
const MARIA_DB_CLIENT_COM_MULTI = 1 << 33;
|
||||
|
||||
// Permit bulk insert
|
||||
const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
|
||||
}
|
||||
}
|
||||
21
src/mariadb/protocol/connect/auth_switch_request.rs
Normal file
21
src/mariadb/protocol/connect/auth_switch_request.rs
Normal file
@ -0,0 +1,21 @@
|
||||
use crate::{
|
||||
io::BufMut,
|
||||
mariadb::{
|
||||
io::BufMutExt,
|
||||
protocol::{Capabilities, Encode},
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct AuthenticationSwitchRequest<'a> {
|
||||
pub auth_plugin_name: &'a str,
|
||||
pub auth_plugin_data: &'a [u8],
|
||||
}
|
||||
|
||||
impl Encode for AuthenticationSwitchRequest<'_> {
|
||||
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
|
||||
buf.put_u8(0xFE);
|
||||
buf.put_str_nul(&self.auth_plugin_name);
|
||||
buf.put_bytes(&self.auth_plugin_data);
|
||||
}
|
||||
}
|
||||
164
src/mariadb/protocol/connect/initial.rs
Normal file
164
src/mariadb/protocol/connect/initial.rs
Normal file
@ -0,0 +1,164 @@
|
||||
use crate::{
|
||||
io::Buf,
|
||||
mariadb::{
|
||||
io::BufExt,
|
||||
protocol::{Capabilities, ServerStatusFlag},
|
||||
},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InitialHandshakePacket {
|
||||
pub protocol_version: u8,
|
||||
pub server_version: String,
|
||||
pub server_status: ServerStatusFlag,
|
||||
pub server_default_collation: u8,
|
||||
pub connection_id: u32,
|
||||
pub scramble: Box<[u8]>,
|
||||
pub capabilities: Capabilities,
|
||||
pub auth_plugin_name: Option<String>,
|
||||
}
|
||||
|
||||
impl InitialHandshakePacket {
|
||||
fn decode(mut buf: &[u8]) -> io::Result<Self> {
|
||||
let protocol_version = buf.get_u8()?;
|
||||
let server_version = buf.get_str_nul()?.to_owned();
|
||||
let connection_id = buf.get_u32::<LittleEndian>()?;
|
||||
let mut scramble = Vec::with_capacity(8);
|
||||
|
||||
// scramble 1st part (authentication seed) : string<8>
|
||||
scramble.extend_from_slice(&buf[..8]);
|
||||
buf.advance(8);
|
||||
|
||||
// reserved : string<1>
|
||||
buf.advance(1);
|
||||
|
||||
// server capabilities (1st part) : int<2>
|
||||
let capabilities_1 = buf.get_u16::<LittleEndian>()?;
|
||||
let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into());
|
||||
|
||||
// server default collation : int<1>
|
||||
let server_default_collation = buf.get_u8()?;
|
||||
|
||||
// status flags : int<2>
|
||||
let server_status = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
// server capabilities (2nd part) : int<2>
|
||||
let capabilities_2 = buf.get_u16::<LittleEndian>()?;
|
||||
capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into());
|
||||
|
||||
// if (server_capabilities & PLUGIN_AUTH)
|
||||
let plugin_data_length = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
// plugin data length : int<1>
|
||||
buf.get_u8()?
|
||||
} else {
|
||||
// 0x00 : int<1>
|
||||
buf.advance(0);
|
||||
0
|
||||
};
|
||||
|
||||
// filler : string<6>
|
||||
buf.advance(6);
|
||||
|
||||
// if (server_capabilities & CLIENT_MYSQL)
|
||||
if capabilities.contains(Capabilities::CLIENT_MYSQL) {
|
||||
// filler : string<4>
|
||||
buf.advance(4);
|
||||
} else {
|
||||
// server capabilities 3rd part . MariaDB specific flags : int<4>
|
||||
let capabilities_3 = buf.get_u32::<LittleEndian>()?;
|
||||
capabilities |= Capabilities::from_bits_truncate((capabilities_2 as u128) << 32);
|
||||
}
|
||||
|
||||
// if (server_capabilities & CLIENT_SECURE_CONNECTION)
|
||||
if capabilities.contains(Capabilities::SECURE_CONNECTION) {
|
||||
// scramble 2nd part . Length = max(12, plugin data length - 9) : string<N>
|
||||
let len = ((plugin_data_length as isize) - 9).max(12) as usize;
|
||||
scramble.extend_from_slice(&buf[..len]);
|
||||
buf.advance(len);
|
||||
|
||||
// reserved byte : string<1>
|
||||
buf.advance(1);
|
||||
}
|
||||
|
||||
// if (server_capabilities & PLUGIN_AUTH)
|
||||
let auth_plugin_name = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
|
||||
Some(buf.get_str_nul()?.to_owned())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
protocol_version,
|
||||
server_version,
|
||||
server_default_collation,
|
||||
server_status: ServerStatusFlag::from_bits_truncate(server_status),
|
||||
connection_id,
|
||||
scramble: scramble.into_boxed_slice(),
|
||||
capabilities,
|
||||
auth_plugin_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::__bytes_builder;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_initial_handshake_packet() -> io::Result<()> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
1u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
0u8,
|
||||
//int<1> protocol version
|
||||
10u8,
|
||||
//string<NUL> server version (MariaDB server version is by default prefixed by "5.5.5-")
|
||||
b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0",
|
||||
//int<4> connection id
|
||||
13u8, 0u8, 0u8, 0u8,
|
||||
//string<8> scramble 1st part (authentication seed)
|
||||
b"?~~|vZAu",
|
||||
//string<1> reserved byte
|
||||
0u8,
|
||||
//int<2> server capabilities (1st part)
|
||||
0xFEu8, 0xF7u8,
|
||||
//int<1> server default collation
|
||||
8u8,
|
||||
//int<2> status flags
|
||||
2u8, 0u8,
|
||||
//int<2> server capabilities (2nd part)
|
||||
0xFF_u8, 0x81_u8,
|
||||
|
||||
//if (server_capabilities & PLUGIN_AUTH)
|
||||
// int<1> plugin data length
|
||||
15u8,
|
||||
//else
|
||||
// int<1> 0x00
|
||||
|
||||
//string<6> filler
|
||||
0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
|
||||
//if (server_capabilities & CLIENT_MYSQL)
|
||||
// string<4> filler
|
||||
//else
|
||||
// int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */
|
||||
7u8, 0u8, 0u8, 0u8,
|
||||
//if (server_capabilities & CLIENT_SECURE_CONNECTION)
|
||||
// string<n> scramble 2nd part . Length = max(12, plugin data length - 9)
|
||||
b"JQ8cihP4Q}Dx",
|
||||
// string<1> reserved byte
|
||||
0u8,
|
||||
//if (server_capabilities & PLUGIN_AUTH)
|
||||
// string<NUL> authentication plugin name
|
||||
b"mysql_native_password\0"
|
||||
);
|
||||
|
||||
let _message = InitialHandshakePacket::decode(&buf)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
9
src/mariadb/protocol/connect/mod.rs
Normal file
9
src/mariadb/protocol/connect/mod.rs
Normal file
@ -0,0 +1,9 @@
|
||||
mod auth_switch_request;
|
||||
mod initial;
|
||||
mod response;
|
||||
mod ssl_request;
|
||||
|
||||
pub use auth_switch_request::AuthenticationSwitchRequest;
|
||||
pub use initial::InitialHandshakePacket;
|
||||
pub use response::HandshakeResponsePacket;
|
||||
pub use ssl_request::SslRequest;
|
||||
@ -7,7 +7,7 @@ use crate::{
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct HandshakeResponsePacket<'a> {
|
||||
pub capabilities: Capabilities,
|
||||
pub max_packet_size: u32,
|
||||
40
src/mariadb/protocol/connect/ssl_request.rs
Normal file
40
src/mariadb/protocol/connect/ssl_request.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use crate::{
|
||||
io::BufMut,
|
||||
mariadb::{
|
||||
io::BufMutExt,
|
||||
protocol::{Capabilities, Encode},
|
||||
},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SslRequest {
|
||||
pub capabilities: Capabilities,
|
||||
pub max_packet_size: u32,
|
||||
pub client_collation: u8,
|
||||
}
|
||||
|
||||
impl Encode for SslRequest {
|
||||
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities) {
|
||||
// client capabilities : int<4>
|
||||
buf.put_u32::<LittleEndian>(self.capabilities.bits() as u32);
|
||||
|
||||
// max packet size : int<4>
|
||||
buf.put_u32::<LittleEndian>(self.max_packet_size);
|
||||
|
||||
// client character collation : int<1>
|
||||
buf.put_u8(self.client_collation);
|
||||
|
||||
// reserved : string<19>
|
||||
buf.advance(19);
|
||||
|
||||
// if not (capabilities & CLIENT_MYSQL)
|
||||
if !capabilities.contains(Capabilities::CLIENT_MYSQL) {
|
||||
// extended client capabilities : int<4>
|
||||
buf.put_u32::<LittleEndian>((self.capabilities.bits() >> 32) as u32);
|
||||
} else {
|
||||
// reserved : int<4>
|
||||
buf.advance(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3,219 +3,3 @@ use super::Capabilities;
|
||||
pub trait Encode {
|
||||
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities);
|
||||
}
|
||||
|
||||
pub const U24_MAX: usize = 0xFF_FF_FF;
|
||||
|
||||
// #[inline]
|
||||
// fn put_param(&mut self, bytes: &Bytes, ty: FieldType) {
|
||||
// match ty {
|
||||
// FieldType::MYSQL_TYPE_DECIMAL => self.put_string_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_TINY => self.put_int_1(bytes),
|
||||
// FieldType::MYSQL_TYPE_SHORT => self.put_int_2(bytes),
|
||||
// FieldType::MYSQL_TYPE_LONG => self.put_int_4(bytes),
|
||||
// FieldType::MYSQL_TYPE_FLOAT => self.put_int_4(bytes),
|
||||
// FieldType::MYSQL_TYPE_DOUBLE => self.put_int_8(bytes),
|
||||
// FieldType::MYSQL_TYPE_NULL => panic!("Type cannot be FieldType::MysqlTypeNull"),
|
||||
// FieldType::MYSQL_TYPE_TIMESTAMP => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_LONGLONG => self.put_int_8(bytes),
|
||||
// FieldType::MYSQL_TYPE_INT24 => self.put_int_4(bytes),
|
||||
// FieldType::MYSQL_TYPE_DATE => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_TIME => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_DATETIME => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_YEAR => self.put_int_4(bytes),
|
||||
// FieldType::MYSQL_TYPE_NEWDATE => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_VARCHAR => self.put_string_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_BIT => self.put_string_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_TIMESTAMP2 => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_DATETIME2 => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_TIME2 => unimplemented!(),
|
||||
// FieldType::MYSQL_TYPE_JSON => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_NEWDECIMAL => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_ENUM => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_SET => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_TINY_BLOB => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_MEDIUM_BLOB => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_LONG_BLOB => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_BLOB => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_VAR_STRING => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_STRING => self.put_byte_lenenc(bytes),
|
||||
// FieldType::MYSQL_TYPE_GEOMETRY => self.put_byte_lenenc(bytes),
|
||||
// _ => panic!("Unrecognized field type"),
|
||||
// }
|
||||
// }
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{io::BufMut, mariadb::io::BufMutExt};
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
// [X] it_encodes_int_lenenc_u64
|
||||
// [X] it_encodes_int_lenenc_u32
|
||||
// [X] it_encodes_int_lenenc_u24
|
||||
// [X] it_encodes_int_lenenc_u16
|
||||
// [X] it_encodes_int_lenenc_u8
|
||||
// [X] it_encodes_int_u64
|
||||
// [X] it_encodes_int_u32
|
||||
// [X] it_encodes_int_u24
|
||||
// [X] it_encodes_int_u16
|
||||
// [X] it_encodes_int_u8
|
||||
// [X] it_encodes_string_lenenc
|
||||
// [X] it_encodes_string_fix
|
||||
// [X] it_encodes_string_null
|
||||
// [X] it_encodes_string_eof
|
||||
// [X] it_encodes_byte_lenenc
|
||||
// [X] it_encodes_byte_fix
|
||||
// [X] it_encodes_byte_eof
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_none() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFB");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0xFA as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFA");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(std::u16::MAX as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(U24_MAX as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(std::u64::MAX));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fb() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0xFB as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFB\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0xFC as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFC\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fd() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0xFD as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFD\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_lenenc_fe() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0xFE as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFE\x00");
|
||||
}
|
||||
|
||||
fn it_encodes_int_lenenc_ff() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_uint_lenenc::<LittleEndian>(Some(0xFF as u64));
|
||||
|
||||
assert_eq!(&buf[..], b"\xFC\xFF\x00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u64() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u64::<LittleEndian>(std::u64::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u32() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u32::<LittleEndian>(std::u32::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u24() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u24::<LittleEndian>(U24_MAX as u32);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u16() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u16::<LittleEndian>(std::u16::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_int_u8() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_u8(std::u8::MAX);
|
||||
|
||||
assert_eq!(&buf[..], b"\xFF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_lenenc::<LittleEndian>("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_fix() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_string_null() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_str_nul("random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"random_string\0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_encodes_byte_lenenc() {
|
||||
let mut buf = Vec::with_capacity(1024);
|
||||
buf.put_byte_lenenc::<LittleEndian>(b"random_string");
|
||||
|
||||
assert_eq!(&buf[..], b"\x0Drandom_string");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ErrorCode(pub(crate) u16);
|
||||
|
||||
// TODO: It would be nice to figure out a clean way to go from 1152 to "ER_ABORTING_CONNECTION (1152)" in Debug.
|
||||
|
||||
// Values from https://mariadb.com/kb/en/library/mariadb-error-codes/
|
||||
impl ErrorCode {
|
||||
const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152);
|
||||
const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873);
|
||||
44
src/mariadb/protocol/field.rs
Normal file
44
src/mariadb/protocol/field.rs
Normal file
@ -0,0 +1,44 @@
|
||||
// https://mariadb.com/kb/en/library/resultset/#field-types
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct FieldType(pub u8);
|
||||
|
||||
impl FieldType {
|
||||
pub const MYSQL_TYPE_BIT: FieldType = FieldType(16);
|
||||
pub const MYSQL_TYPE_BLOB: FieldType = FieldType(252);
|
||||
pub const MYSQL_TYPE_DATE: FieldType = FieldType(10);
|
||||
pub const MYSQL_TYPE_DATETIME: FieldType = FieldType(12);
|
||||
pub const MYSQL_TYPE_DATETIME2: FieldType = FieldType(18);
|
||||
pub const MYSQL_TYPE_DECIMAL: FieldType = FieldType(0);
|
||||
pub const MYSQL_TYPE_DOUBLE: FieldType = FieldType(5);
|
||||
pub const MYSQL_TYPE_ENUM: FieldType = FieldType(247);
|
||||
pub const MYSQL_TYPE_FLOAT: FieldType = FieldType(4);
|
||||
pub const MYSQL_TYPE_GEOMETRY: FieldType = FieldType(255);
|
||||
pub const MYSQL_TYPE_INT24: FieldType = FieldType(9);
|
||||
pub const MYSQL_TYPE_JSON: FieldType = FieldType(245);
|
||||
pub const MYSQL_TYPE_LONG: FieldType = FieldType(3);
|
||||
pub const MYSQL_TYPE_LONGLONG: FieldType = FieldType(8);
|
||||
pub const MYSQL_TYPE_LONG_BLOB: FieldType = FieldType(251);
|
||||
pub const MYSQL_TYPE_MEDIUM_BLOB: FieldType = FieldType(250);
|
||||
pub const MYSQL_TYPE_NEWDATE: FieldType = FieldType(14);
|
||||
pub const MYSQL_TYPE_NEWDECIMAL: FieldType = FieldType(246);
|
||||
pub const MYSQL_TYPE_NULL: FieldType = FieldType(6);
|
||||
pub const MYSQL_TYPE_SET: FieldType = FieldType(248);
|
||||
pub const MYSQL_TYPE_SHORT: FieldType = FieldType(2);
|
||||
pub const MYSQL_TYPE_STRING: FieldType = FieldType(254);
|
||||
pub const MYSQL_TYPE_TIME: FieldType = FieldType(11);
|
||||
pub const MYSQL_TYPE_TIME2: FieldType = FieldType(19);
|
||||
pub const MYSQL_TYPE_TIMESTAMP: FieldType = FieldType(7);
|
||||
pub const MYSQL_TYPE_TIMESTAMP2: FieldType = FieldType(17);
|
||||
pub const MYSQL_TYPE_TINY: FieldType = FieldType(1);
|
||||
pub const MYSQL_TYPE_TINY_BLOB: FieldType = FieldType(249);
|
||||
pub const MYSQL_TYPE_VARCHAR: FieldType = FieldType(15);
|
||||
pub const MYSQL_TYPE_VAR_STRING: FieldType = FieldType(253);
|
||||
pub const MYSQL_TYPE_YEAR: FieldType = FieldType(13);
|
||||
}
|
||||
|
||||
// https://mariadb.com/kb/en/library/com_stmt_execute/#parameter-flag
|
||||
bitflags::bitflags! {
|
||||
pub struct ParameterFlag: u8 {
|
||||
const UNSIGNED = 128;
|
||||
}
|
||||
}
|
||||
@ -1,36 +1,22 @@
|
||||
// Reference: https://mariadb.com/kb/en/library/connection
|
||||
// Packets: https://mariadb.com/kb/en/library/0-packet
|
||||
|
||||
// TODO: Handle lengths which are greater than 3 bytes
|
||||
// Either break the packet into several smaller ones, or
|
||||
// return error
|
||||
mod capabilities;
|
||||
mod connect;
|
||||
mod encode;
|
||||
mod error_code;
|
||||
mod field;
|
||||
mod server_status;
|
||||
mod response;
|
||||
|
||||
// TODO: Handle different Capabilities for server and client
|
||||
|
||||
// TODO: Handle when capability is set, but field is None
|
||||
|
||||
pub mod encode;
|
||||
pub mod error_codes;
|
||||
pub mod packets;
|
||||
pub mod types;
|
||||
|
||||
// Re-export all the things
|
||||
// pub use packets::{
|
||||
// AuthenticationSwitchRequestPacket, ColumnDefPacket, ColumnPacket, ComDebug, ComInitDb, ComPing,
|
||||
// ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep,
|
||||
// ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk,
|
||||
// ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket,
|
||||
// InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultRowBinary, ResultRowText,
|
||||
// ResultSet, SSLRequestPacket, SetOptionOptions, ShutdownOptions,
|
||||
// };
|
||||
|
||||
pub use packets::{ColumnCountPacket, ColumnDefinitionPacket};
|
||||
|
||||
pub use encode::Encode;
|
||||
|
||||
pub use error_codes::ErrorCode;
|
||||
|
||||
pub use types::{
|
||||
Capabilities, FieldDetailFlag, FieldType, ProtocolType, ServerStatusFlag, SessionChangeType,
|
||||
StmtExecFlag,
|
||||
pub use capabilities::Capabilities;
|
||||
pub use connect::{
|
||||
AuthenticationSwitchRequest, HandshakeResponsePacket, InitialHandshakePacket, SslRequest,
|
||||
};
|
||||
pub use response::{
|
||||
OkPacket, EofPacket, ErrPacket, ResultRow,
|
||||
};
|
||||
pub use encode::Encode;
|
||||
pub use error_code::ErrorCode;
|
||||
pub use field::{FieldType, ParameterFlag};
|
||||
pub use server_status::ServerStatusFlag;
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
use crate::mariadb::{BufMut, ConnContext, Encode, MariaDbRawConnection};
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct AuthenticationSwitchRequestPacket {
|
||||
pub auth_plugin_name: Bytes,
|
||||
pub auth_plugin_data: Bytes,
|
||||
}
|
||||
|
||||
impl Encode for AuthenticationSwitchRequestPacket {
|
||||
fn encode(&self, buf: &mut Vec<u8>, ctx: &mut ConnContext) -> Result<(), Error> {
|
||||
buf.put_int_u8(0xFE);
|
||||
buf.put_string_null(&self.auth_plugin_name);
|
||||
buf.put_byte_eof(&self.auth_plugin_data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,166 +0,0 @@
|
||||
use crate::mariadb::{Capabilities, DeContext, Decode, ServerStatusFlag};
|
||||
use bytes::Bytes;
|
||||
use failure::{err_msg, Error};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct InitialHandshakePacket {
|
||||
pub length: u32,
|
||||
pub seq_no: u8,
|
||||
pub protocol_version: u8,
|
||||
pub server_version: Bytes,
|
||||
pub connection_id: i32,
|
||||
pub auth_seed: Bytes,
|
||||
pub capabilities: Capabilities,
|
||||
pub collation: u8,
|
||||
pub status: ServerStatusFlag,
|
||||
pub plugin_data_length: u8,
|
||||
pub scramble: Option<Bytes>,
|
||||
pub auth_plugin_name: Option<Bytes>,
|
||||
}
|
||||
|
||||
impl Decode for InitialHandshakePacket {
|
||||
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_u8();
|
||||
|
||||
if seq_no != 0 {
|
||||
return Err(err_msg(
|
||||
"Sequence Number of Initial Handshake Packet is not 0",
|
||||
));
|
||||
}
|
||||
|
||||
let protocol_version = decoder.decode_int_u8();
|
||||
let server_version = decoder.decode_string_null()?;
|
||||
let connection_id = decoder.decode_int_i32();
|
||||
let auth_seed = decoder.decode_string_fix(8);
|
||||
|
||||
// Skip reserved byte
|
||||
decoder.skip_bytes(1);
|
||||
|
||||
let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_u16().into());
|
||||
|
||||
let collation = decoder.decode_int_u8();
|
||||
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_u16().into());
|
||||
|
||||
capabilities |=
|
||||
Capabilities::from_bits_truncate(((decoder.decode_int_i16() as u32) << 16).into());
|
||||
|
||||
let mut plugin_data_length = 0;
|
||||
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
|
||||
plugin_data_length = decoder.decode_int_u8();
|
||||
} else {
|
||||
// Skip reserve byte
|
||||
decoder.skip_bytes(1);
|
||||
}
|
||||
|
||||
// Skip filler
|
||||
decoder.skip_bytes(6);
|
||||
|
||||
if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
|
||||
capabilities |=
|
||||
Capabilities::from_bits_truncate(((decoder.decode_int_u32() as u128) << 32).into());
|
||||
} else {
|
||||
// Skip filler
|
||||
decoder.skip_bytes(4);
|
||||
}
|
||||
|
||||
let mut scramble: Option<Bytes> = None;
|
||||
if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
|
||||
let len = std::cmp::max(12, plugin_data_length as usize - 9);
|
||||
scramble = Some(decoder.decode_string_fix(len as usize));
|
||||
// Skip reserve byte
|
||||
decoder.skip_bytes(1);
|
||||
}
|
||||
|
||||
let mut auth_plugin_name: Option<Bytes> = None;
|
||||
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
|
||||
auth_plugin_name = Some(decoder.decode_string_null()?);
|
||||
}
|
||||
|
||||
ctx.ctx.last_seq_no = seq_no;
|
||||
|
||||
Ok(InitialHandshakePacket {
|
||||
length,
|
||||
seq_no,
|
||||
protocol_version,
|
||||
server_version,
|
||||
connection_id,
|
||||
auth_seed,
|
||||
capabilities,
|
||||
collation,
|
||||
status,
|
||||
plugin_data_length,
|
||||
scramble,
|
||||
auth_plugin_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::{
|
||||
__bytes_builder,
|
||||
mariadb::{ConnContext, Decoder},
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_initial_handshake_packet() -> Result<(), Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
1u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
0u8,
|
||||
//int<1> protocol version
|
||||
10u8,
|
||||
//string<NUL> server version (MariaDB server version is by default prefixed by "5.5.5-")
|
||||
b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0",
|
||||
//int<4> connection id
|
||||
13u8, 0u8, 0u8, 0u8,
|
||||
//string<8> scramble 1st part (authentication seed)
|
||||
b"?~~|vZAu",
|
||||
//string<1> reserved byte
|
||||
0u8,
|
||||
//int<2> server capabilities (1st part)
|
||||
0xFEu8, 0xF7u8,
|
||||
//int<1> server default collation
|
||||
8u8,
|
||||
//int<2> status flags
|
||||
2u8, 0u8,
|
||||
//int<2> server capabilities (2nd part)
|
||||
0xFF_u8, 0x81_u8,
|
||||
|
||||
//if (server_capabilities & PLUGIN_AUTH)
|
||||
// int<1> plugin data length
|
||||
15u8,
|
||||
//else
|
||||
// int<1> 0x00
|
||||
|
||||
//string<6> filler
|
||||
0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
|
||||
//if (server_capabilities & CLIENT_MYSQL)
|
||||
// string<4> filler
|
||||
//else
|
||||
// int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */
|
||||
7u8, 0u8, 0u8, 0u8,
|
||||
//if (server_capabilities & CLIENT_SECURE_CONNECTION)
|
||||
// string<n> scramble 2nd part . Length = max(12, plugin data length - 9)
|
||||
b"JQ8cihP4Q}Dx",
|
||||
// string<1> reserved byte
|
||||
0u8,
|
||||
//if (server_capabilities & PLUGIN_AUTH)
|
||||
// string<NUL> authentication plugin name
|
||||
b"mysql_native_password\0"
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let _message = InitialHandshakePacket::decode(&mut ctx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1,40 +0,0 @@
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
|
||||
use crate::mariadb::{BufMut, Capabilities, ConnContext, Encode, MariaDbRawConnection};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct SSLRequestPacket {
|
||||
pub capabilities: Capabilities,
|
||||
pub max_packet_size: u32,
|
||||
pub collation: u8,
|
||||
pub extended_capabilities: Option<Capabilities>,
|
||||
}
|
||||
|
||||
impl Encode for SSLRequestPacket {
|
||||
fn encode(&self, buf: &mut Vec<u8>, ctx: &mut ConnContext) -> Result<(), Error> {
|
||||
buf.alloc_packet_header();
|
||||
buf.seq_no(0);
|
||||
|
||||
buf.put_int_u32(self.capabilities.bits() as u32);
|
||||
buf.put_int_u32(self.max_packet_size);
|
||||
buf.put_int_u8(self.collation);
|
||||
|
||||
// Filler
|
||||
buf.put_byte_fix(&Bytes::from_static(&[0u8; 19]), 19);
|
||||
|
||||
if !(ctx.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
|
||||
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
|
||||
{
|
||||
if let Some(capabilities) = self.extended_capabilities {
|
||||
buf.put_int_u32(capabilities.bits() as u32);
|
||||
}
|
||||
} else {
|
||||
buf.put_byte_fix(&Bytes::from_static(&[0u8; 4]), 4);
|
||||
}
|
||||
|
||||
buf.put_length();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
65
src/mariadb/protocol/response/eof.rs
Normal file
65
src/mariadb/protocol/response/eof.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use crate::{
|
||||
io::Buf,
|
||||
mariadb::{
|
||||
io::BufExt,
|
||||
protocol::{ErrorCode, ServerStatusFlag},
|
||||
},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EofPacket {
|
||||
pub warning_count: u16,
|
||||
pub status: ServerStatusFlag,
|
||||
}
|
||||
|
||||
impl EofPacket {
|
||||
fn decode(mut buf: &[u8]) -> io::Result<Self> {
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0xFE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("expected 0xFE; received {}", header),
|
||||
));
|
||||
}
|
||||
|
||||
let warning_count = buf.get_u16::<LittleEndian>()?;
|
||||
let status = ServerStatusFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
|
||||
|
||||
Ok(Self {
|
||||
warning_count,
|
||||
status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::{__bytes_builder, mariadb::ConnContext};
|
||||
use bytes::Bytes;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_eof_packet() -> Result<(), Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
1u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
1u8,
|
||||
// int<1> 0xfe : EOF header
|
||||
0xFE_u8,
|
||||
// int<2> warning count
|
||||
0u8, 0u8,
|
||||
// int<2> server status
|
||||
1u8, 1u8
|
||||
);
|
||||
|
||||
let _message = EofPacket::decode(&buf)?;
|
||||
|
||||
// TODO: Assert fields?
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
0
src/mariadb/protocol/response/err.rs
Normal file
0
src/mariadb/protocol/response/err.rs
Normal file
9
src/mariadb/protocol/response/mod.rs
Normal file
9
src/mariadb/protocol/response/mod.rs
Normal file
@ -0,0 +1,9 @@
|
||||
mod ok;
|
||||
mod err;
|
||||
mod eof;
|
||||
mod row;
|
||||
|
||||
pub use ok::OkPacket;
|
||||
pub use err::ErrPacket;
|
||||
pub use eof::EofPacket;
|
||||
pub use row::ResultRow;
|
||||
117
src/mariadb/protocol/response/ok.rs
Normal file
117
src/mariadb/protocol/response/ok.rs
Normal file
@ -0,0 +1,117 @@
|
||||
use crate::{
|
||||
io::Buf,
|
||||
mariadb::{
|
||||
io::BufExt,
|
||||
protocol::{Capabilities, ServerStatusFlag},
|
||||
},
|
||||
};
|
||||
use byteorder::LittleEndian;
|
||||
use std::io;
|
||||
|
||||
// https://mariadb.com/kb/en/library/ok_packet/
|
||||
#[derive(Debug)]
|
||||
pub struct OkPacket {
|
||||
pub affected_rows: u64,
|
||||
pub last_insert_id: u64,
|
||||
pub server_status: ServerStatusFlag,
|
||||
pub warning_count: u16,
|
||||
pub info: Box<str>,
|
||||
pub session_state_info: Option<Box<[u8]>>,
|
||||
pub value_of_variable: Option<Box<str>>,
|
||||
}
|
||||
|
||||
impl OkPacket {
|
||||
fn decode(mut buf: &[u8], capabilities: Capabilities) -> io::Result<Self> {
|
||||
let header = buf.get_u8()?;
|
||||
if header != 0 && header != 0xFE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("expected 0x00 or 0xFE; received 0x{:X}", header),
|
||||
));
|
||||
}
|
||||
|
||||
let affected_rows = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
|
||||
let last_insert_id = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
|
||||
let server_status = ServerStatusFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
|
||||
let warning_count = buf.get_u16::<LittleEndian>()?;
|
||||
|
||||
let info;
|
||||
let mut session_state_info = None;
|
||||
let mut value_of_variable = None;
|
||||
|
||||
if capabilities.contains(Capabilities::CLIENT_SESSION_TRACK) {
|
||||
info = buf
|
||||
.get_str_lenenc::<LittleEndian>()?
|
||||
.unwrap_or_default()
|
||||
.to_owned()
|
||||
.into();
|
||||
session_state_info = buf.get_byte_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
value_of_variable = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
|
||||
} else {
|
||||
info = buf.get_str_eof()?.to_owned().into();
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
affected_rows,
|
||||
last_insert_id,
|
||||
server_status,
|
||||
warning_count,
|
||||
info,
|
||||
session_state_info,
|
||||
value_of_variable,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::{
|
||||
__bytes_builder,
|
||||
mariadb::{ConnContext, Decoder},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn it_decodes_ok_packet() -> Result<(), Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// int<3> length
|
||||
0u8, 0u8, 0u8,
|
||||
// // int<1> seq_no
|
||||
1u8,
|
||||
// 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set)
|
||||
0u8,
|
||||
// int<lenenc> affected rows
|
||||
0xFB_u8,
|
||||
// int<lenenc> last insert id
|
||||
0xFB_u8,
|
||||
// int<2> server status
|
||||
1u8, 1u8,
|
||||
// int<2> warning count
|
||||
0u8, 0u8,
|
||||
// if session_tracking_supported (see CLIENT_SESSION_TRACK) {
|
||||
// string<lenenc> info
|
||||
// if (status flags & SERVER_SESSION_STATE_CHANGED) {
|
||||
// string<lenenc> session state info
|
||||
// string<lenenc> value of variable
|
||||
// }
|
||||
// } else {
|
||||
// string<EOF> info
|
||||
b"info"
|
||||
// }
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = OkPacket::decode(&mut ctx)?;
|
||||
|
||||
assert_eq!(message.affected_rows, None);
|
||||
assert_eq!(message.last_insert_id, None);
|
||||
assert!(!(message.server_status & ServerStatusFlag::SERVER_STATUS_IN_TRANS).is_empty());
|
||||
assert_eq!(message.warning_count, 0);
|
||||
assert_eq!(message.info, b"info".to_vec());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
0
src/mariadb/protocol/response/row.rs
Normal file
0
src/mariadb/protocol/response/row.rs
Normal file
45
src/mariadb/protocol/server_status.rs
Normal file
45
src/mariadb/protocol/server_status.rs
Normal file
@ -0,0 +1,45 @@
|
||||
// https://mariadb.com/kb/en/library/mariadb-connectorc-types-and-definitions/#server-status
|
||||
bitflags::bitflags! {
|
||||
pub struct ServerStatusFlag: u16 {
|
||||
// A transaction is currently active
|
||||
const SERVER_STATUS_IN_TRANS = 1;
|
||||
|
||||
// Autocommit mode is set
|
||||
const SERVER_STATUS_AUTOCOMMIT = 2;
|
||||
|
||||
// more results exists (more packet follow)
|
||||
const SERVER_MORE_RESULTS_EXISTS = 8;
|
||||
|
||||
const SERVER_QUERY_NO_GOOD_INDEX_USED = 16;
|
||||
const SERVER_QUERY_NO_INDEX_USED = 32;
|
||||
|
||||
// when using COM_STMT_FETCH, indicate that current cursor still has result
|
||||
const SERVER_STATUS_CURSOR_EXISTS = 64;
|
||||
|
||||
// when using COM_STMT_FETCH, indicate that current cursor has finished to send results
|
||||
const SERVER_STATUS_LAST_ROW_SENT = 128;
|
||||
|
||||
// database has been dropped
|
||||
const SERVER_STATUS_DB_DROPPED = 1 << 8;
|
||||
|
||||
// current escape mode is "no backslash escape"
|
||||
const SERVER_STATUS_NO_BACKSLASH_ESAPES = 1 << 9;
|
||||
|
||||
// A DDL change did have an impact on an existing PREPARE (an
|
||||
// automatic reprepare has been executed)
|
||||
const SERVER_STATUS_METADATA_CHANGED = 1 << 10;
|
||||
|
||||
// Last statement took more than the time value specified in
|
||||
// server variable long_query_time.
|
||||
const SERVER_QUERY_WAS_SLOW = 1 << 11;
|
||||
|
||||
// this resultset contain stored procedure output parameter
|
||||
const SERVER_PS_OUT_PARAMS = 1 << 12;
|
||||
|
||||
// current transaction is a read-only transaction
|
||||
const SERVER_STATUS_IN_TRANS_READONLY = 1 << 13;
|
||||
|
||||
// session state change. see Session change type for more information
|
||||
const SERVER_SESSION_STATE_CHANGED = 1 << 14;
|
||||
}
|
||||
}
|
||||
@ -1,34 +1,3 @@
|
||||
pub enum ProtocolType {
|
||||
Text,
|
||||
Binary,
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
pub struct Capabilities: u128 {
|
||||
const CLIENT_MYSQL = 1;
|
||||
const FOUND_ROWS = 1 << 1;
|
||||
const CONNECT_WITH_DB = 1 << 3;
|
||||
const COMPRESS = 1 << 5;
|
||||
const LOCAL_FILES = 1 << 7;
|
||||
const IGNORE_SPACE = 1 << 8;
|
||||
const CLIENT_PROTOCOL_41 = 1 << 9;
|
||||
const CLIENT_INTERACTIVE = 1 << 10;
|
||||
const SSL = 1 << 11;
|
||||
const TRANSACTIONS = 1 << 12;
|
||||
const SECURE_CONNECTION = 1 << 13;
|
||||
const MULTI_STATEMENTS = 1 << 16;
|
||||
const MULTI_RESULTS = 1 << 17;
|
||||
const PS_MULTI_RESULTS = 1 << 18;
|
||||
const PLUGIN_AUTH = 1 << 19;
|
||||
const CONNECT_ATTRS = 1 << 20;
|
||||
const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21;
|
||||
const CLIENT_SESSION_TRACK = 1 << 23;
|
||||
const CLIENT_DEPRECATE_EOF = 1 << 24;
|
||||
const MARIA_DB_CLIENT_PROGRESS = 1 << 32;
|
||||
const MARIA_DB_CLIENT_COM_MULTI = 1 << 33;
|
||||
const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
pub struct FieldDetailFlag: u16 {
|
||||
@ -78,42 +47,6 @@ pub enum SessionChangeType {
|
||||
SessionTrackTransactionState = 5,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct FieldType(pub u8);
|
||||
impl FieldType {
|
||||
pub const MYSQL_TYPE_BIT: FieldType = FieldType(16);
|
||||
pub const MYSQL_TYPE_BLOB: FieldType = FieldType(252);
|
||||
pub const MYSQL_TYPE_DATE: FieldType = FieldType(10);
|
||||
pub const MYSQL_TYPE_DATETIME: FieldType = FieldType(12);
|
||||
pub const MYSQL_TYPE_DATETIME2: FieldType = FieldType(18);
|
||||
pub const MYSQL_TYPE_DECIMAL: FieldType = FieldType(0);
|
||||
pub const MYSQL_TYPE_DOUBLE: FieldType = FieldType(5);
|
||||
pub const MYSQL_TYPE_ENUM: FieldType = FieldType(247);
|
||||
pub const MYSQL_TYPE_FLOAT: FieldType = FieldType(4);
|
||||
pub const MYSQL_TYPE_GEOMETRY: FieldType = FieldType(255);
|
||||
pub const MYSQL_TYPE_INT24: FieldType = FieldType(9);
|
||||
pub const MYSQL_TYPE_JSON: FieldType = FieldType(245);
|
||||
pub const MYSQL_TYPE_LONG: FieldType = FieldType(3);
|
||||
pub const MYSQL_TYPE_LONGLONG: FieldType = FieldType(8);
|
||||
pub const MYSQL_TYPE_LONG_BLOB: FieldType = FieldType(251);
|
||||
pub const MYSQL_TYPE_MEDIUM_BLOB: FieldType = FieldType(250);
|
||||
pub const MYSQL_TYPE_NEWDATE: FieldType = FieldType(14);
|
||||
pub const MYSQL_TYPE_NEWDECIMAL: FieldType = FieldType(246);
|
||||
pub const MYSQL_TYPE_NULL: FieldType = FieldType(6);
|
||||
pub const MYSQL_TYPE_SET: FieldType = FieldType(248);
|
||||
pub const MYSQL_TYPE_SHORT: FieldType = FieldType(2);
|
||||
pub const MYSQL_TYPE_STRING: FieldType = FieldType(254);
|
||||
pub const MYSQL_TYPE_TIME: FieldType = FieldType(11);
|
||||
pub const MYSQL_TYPE_TIME2: FieldType = FieldType(19);
|
||||
pub const MYSQL_TYPE_TIMESTAMP: FieldType = FieldType(7);
|
||||
pub const MYSQL_TYPE_TIMESTAMP2: FieldType = FieldType(17);
|
||||
pub const MYSQL_TYPE_TINY: FieldType = FieldType(1);
|
||||
pub const MYSQL_TYPE_TINY_BLOB: FieldType = FieldType(249);
|
||||
pub const MYSQL_TYPE_VARCHAR: FieldType = FieldType(15);
|
||||
pub const MYSQL_TYPE_VAR_STRING: FieldType = FieldType(253);
|
||||
pub const MYSQL_TYPE_YEAR: FieldType = FieldType(13);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct StmtExecFlag(pub u8);
|
||||
impl StmtExecFlag {
|
||||
@ -122,63 +55,3 @@ impl StmtExecFlag {
|
||||
pub const READ_ONLY: StmtExecFlag = StmtExecFlag(1);
|
||||
pub const SCROLLABLE_CURSOR: StmtExecFlag = StmtExecFlag(3);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct ParamFlag(pub u8);
|
||||
impl ParamFlag {
|
||||
pub const NONE: ParamFlag = ParamFlag(0);
|
||||
pub const UNSIGNED: ParamFlag = ParamFlag(128);
|
||||
}
|
||||
|
||||
// TODO: Remove these Default impls
|
||||
|
||||
impl Default for Capabilities {
|
||||
fn default() -> Self {
|
||||
Capabilities::CLIENT_PROTOCOL_41
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ServerStatusFlag {
|
||||
fn default() -> Self {
|
||||
ServerStatusFlag::SERVER_STATUS_IN_TRANS
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FieldDetailFlag {
|
||||
fn default() -> Self {
|
||||
FieldDetailFlag::NOT_NULL
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FieldType {
|
||||
fn default() -> Self {
|
||||
FieldType::MYSQL_TYPE_DECIMAL
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StmtExecFlag {
|
||||
fn default() -> Self {
|
||||
StmtExecFlag::NO_CURSOR
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ParamFlag {
|
||||
fn default() -> Self {
|
||||
ParamFlag::UNSIGNED
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::Capabilities;
|
||||
use crate::{__bytes_builder, io::Buf};
|
||||
use byteorder::LittleEndian;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_capabilities() -> std::io::Result<()> {
|
||||
let buf = &__bytes_builder!(b"\xfe\xf7")[..];
|
||||
Capabilities::from_bits_truncate(buf.get_u16::<LittleEndian>()? as u128);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user