WIP: ResultSet

This commit is contained in:
Daniel Akhterov 2019-07-08 19:19:43 -07:00
parent 163e154073
commit 8a1b9a89fd
9 changed files with 343 additions and 173 deletions

View File

@ -17,6 +17,8 @@ hex = "0.3.2"
bytes = "0.4.12"
memchr = "2.2.0"
bitflags = "1.1.0"
enum-tryfrom = "0.2.1"
enum-tryfrom-derive = "0.2.1"
[dev-dependencies]

View File

@ -1,6 +1,6 @@
use crate::protocol::{
client::{ComPing, ComQuit, Serialize},
serialize::serialize_length,
encode::encode_length,
server::{
Capabilities, Deserialize, Message as ServerMessage,
ServerStatusFlag, OkPacket
@ -73,7 +73,7 @@ impl Connection {
self.wbuf[3] = self.seq_no;
message.serialize(&mut self.wbuf, &self.capabilities)?;
serialize_length(&mut self.wbuf);
encode_length(&mut self.wbuf);
self.stream.inner.write_all(&self.wbuf).await?;
self.stream.inner.flush().await?;
@ -116,7 +116,7 @@ impl Framed {
}
async fn next_bytes(&mut self) -> Result<Bytes, Error> {
let mut rbuf = BytesMut::with_capacity(0);
let mut rbuf = BytesMut::new();
let mut len = 0;
let mut packet_len: u32 = 0;
@ -155,7 +155,7 @@ impl Framed {
}
async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
let mut rbuf = BytesMut::with_capacity(0);
let mut rbuf = BytesMut::new();
let mut len = 0;
loop {

View File

@ -5,6 +5,8 @@
#[macro_use]
extern crate bitflags;
#[macro_use]
extern crate enum_tryfrom_derive;
pub mod connection;
pub mod protocol;

View File

@ -8,7 +8,7 @@
// TODO: Handle when capability is set, but field is None
use super::server::Capabilities;
use crate::protocol::serialize::*;
use crate::protocol::encode::*;
use bytes::{Bytes, BytesMut};
use failure::Error;
@ -127,7 +127,7 @@ impl Serialize for ComQuit {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComQuit.into());
encode_int_1(buf, TextProtocol::ComQuit.into());
Ok(())
}
@ -139,8 +139,8 @@ impl Serialize for ComInitDb {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComInitDb.into());
serialize_string_null(buf, &self.schema_name);
encode_int_1(buf, TextProtocol::ComInitDb.into());
encode_string_null(buf, &self.schema_name);
Ok(())
}
@ -152,7 +152,7 @@ impl Serialize for ComDebug {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComDebug.into());
encode_int_1(buf, TextProtocol::ComDebug.into());
Ok(())
}
@ -164,7 +164,7 @@ impl Serialize for ComPing {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComPing.into());
encode_int_1(buf, TextProtocol::ComPing.into());
Ok(())
}
@ -176,8 +176,8 @@ impl Serialize for ComProcessKill {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComProcessKill.into());
serialize_int_4(buf, self.process_id);
encode_int_1(buf, TextProtocol::ComProcessKill.into());
encode_int_4(buf, self.process_id);
Ok(())
}
@ -189,8 +189,8 @@ impl Serialize for ComQuery {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComQuery.into());
serialize_string_eof(buf, &self.sql_statement);
encode_int_1(buf, TextProtocol::ComQuery.into());
encode_string_eof(buf, &self.sql_statement);
Ok(())
}
@ -202,7 +202,7 @@ impl Serialize for ComResetConnection {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComResetConnection.into());
encode_int_1(buf, TextProtocol::ComResetConnection.into());
Ok(())
}
@ -214,8 +214,8 @@ impl Serialize for ComSetOption {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComSetOption.into());
serialize_int_2(buf, self.option.into());
encode_int_1(buf, TextProtocol::ComSetOption.into());
encode_int_2(buf, self.option.into());
Ok(())
}
@ -227,8 +227,8 @@ impl Serialize for ComShutdown {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComShutdown.into());
serialize_int_1(buf, self.option.into());
encode_int_1(buf, TextProtocol::ComShutdown.into());
encode_int_1(buf, self.option.into());
Ok(())
}
@ -240,7 +240,7 @@ impl Serialize for ComSleep {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComSleep.into());
encode_int_1(buf, TextProtocol::ComSleep.into());
Ok(())
}
@ -252,7 +252,7 @@ impl Serialize for ComStatistics {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, TextProtocol::ComStatistics.into());
encode_int_1(buf, TextProtocol::ComStatistics.into());
Ok(())
}
@ -264,21 +264,21 @@ impl Serialize for SSLRequestPacket {
buf: &mut BytesMut,
server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_4(buf, self.capabilities.bits() as u32);
serialize_int_4(buf, self.max_packet_size);
serialize_int_1(buf, self.collation);
encode_int_4(buf, self.capabilities.bits() as u32);
encode_int_4(buf, self.max_packet_size);
encode_int_1(buf, self.collation);
// Filler
serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19);
encode_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19);
if !(*server_capabilities & Capabilities::CLIENT_MYSQL).is_empty()
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {
serialize_int_4(buf, capabilities.bits() as u32);
encode_int_4(buf, capabilities.bits() as u32);
}
} else {
serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 4]), 4);
encode_byte_fix(buf, &Bytes::from_static(&[0u8; 4]), 4);
}
Ok(())
@ -291,61 +291,61 @@ impl Serialize for HandshakeResponsePacket {
buf: &mut BytesMut,
server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_4(buf, self.capabilities.bits() as u32);
serialize_int_4(buf, self.max_packet_size);
serialize_int_1(buf, self.collation);
encode_int_4(buf, self.capabilities.bits() as u32);
encode_int_4(buf, self.max_packet_size);
encode_int_1(buf, self.collation);
// Filler
serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19);
encode_byte_fix(buf, &Bytes::from_static(&[0u8; 19]), 19);
if !(*server_capabilities & Capabilities::CLIENT_MYSQL).is_empty()
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {
serialize_int_4(buf, capabilities.bits() as u32);
encode_int_4(buf, capabilities.bits() as u32);
}
} else {
serialize_byte_fix(buf, &Bytes::from_static(&[0u8; 4]), 4);
encode_byte_fix(buf, &Bytes::from_static(&[0u8; 4]), 4);
}
serialize_string_null(buf, &self.username);
encode_string_null(buf, &self.username);
if !(*server_capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() {
if let Some(auth_data) = &self.auth_data {
serialize_string_lenenc(buf, &auth_data);
encode_string_lenenc(buf, &auth_data);
}
} else if !(*server_capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
if let Some(auth_response) = &self.auth_response {
serialize_int_1(buf, self.auth_response_len.unwrap());
serialize_string_fix(buf, &auth_response, self.auth_response_len.unwrap() as usize);
encode_int_1(buf, self.auth_response_len.unwrap());
encode_string_fix(buf, &auth_response, self.auth_response_len.unwrap() as usize);
}
} else {
serialize_int_1(buf, 0);
encode_int_1(buf, 0);
}
if !(*server_capabilities & Capabilities::CONNECT_WITH_DB).is_empty() {
if let Some(database) = &self.database {
// string<NUL>
serialize_string_null(buf, &database);
encode_string_null(buf, &database);
}
}
if !(*server_capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
if let Some(auth_plugin_name) = &self.auth_plugin_name {
// string<NUL>
serialize_string_null(buf, &auth_plugin_name);
encode_string_null(buf, &auth_plugin_name);
}
}
if !(*server_capabilities & Capabilities::CONNECT_ATTRS).is_empty() {
if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) {
// int<lenenc>
serialize_int_lenenc(buf, Some(conn_attr_len));
encode_int_lenenc(buf, Some(conn_attr_len));
// Loop
for (key, value) in conn_attr {
serialize_string_lenenc(buf, &key);
serialize_string_lenenc(buf, &value);
encode_string_lenenc(buf, &key);
encode_string_lenenc(buf, &value);
}
}
}
@ -360,9 +360,9 @@ impl Serialize for AuthenticationSwitchRequestPacket {
buf: &mut BytesMut,
_server_capabilities: &Capabilities,
) -> Result<(), Error> {
serialize_int_1(buf, 0xFE);
serialize_string_null(buf, &self.auth_plugin_name);
serialize_byte_eof(buf, &self.auth_plugin_data);
encode_int_1(buf, 0xFE);
encode_string_null(buf, &self.auth_plugin_name);
encode_byte_eof(buf, &self.auth_plugin_data);
Ok(())
}

View File

@ -3,9 +3,14 @@ use byteorder::{ByteOrder, LittleEndian};
use bytes::Bytes;
use failure::{err_msg, Error};
//pub struct Decoder<'a> {
// pub buf: &'a Bytes,
// pub index: usize,
//}
#[inline]
pub fn deserialize_length(buf: &Bytes, index: &mut usize) -> Result<u32, Error> {
let length = deserialize_int_3(&buf, index);
pub fn decode_length(buf: &Bytes, index: &mut usize) -> Result<u32, Error> {
let length = decode_int_3(&buf, index);
if buf.len() < length as usize {
return Err(err_msg("Lengths to do not match"));
@ -15,7 +20,7 @@ pub fn deserialize_length(buf: &Bytes, index: &mut usize) -> Result<u32, Error>
}
#[inline]
pub fn deserialize_int_lenenc(buf: &Bytes, index: &mut usize) -> Option<usize> {
pub fn decode_int_lenenc(buf: &Bytes, index: &mut usize) -> Option<usize> {
match buf[*index] {
0xFB => {
*index += 1;
@ -46,64 +51,64 @@ pub fn deserialize_int_lenenc(buf: &Bytes, index: &mut usize) -> Option<usize> {
}
#[inline]
pub fn deserialize_int_8(buf: &Bytes, index: &mut usize) -> u64 {
pub fn decode_int_8(buf: &Bytes, index: &mut usize) -> u64 {
let value = LittleEndian::read_u64(&buf[*index..]);
*index += 8;
value
}
#[inline]
pub fn deserialize_int_4(buf: &Bytes, index: &mut usize) -> u32 {
pub fn decode_int_4(buf: &Bytes, index: &mut usize) -> u32 {
let value = LittleEndian::read_u32(&buf[*index..]);
*index += 4;
value
}
#[inline]
pub fn deserialize_int_3(buf: &Bytes, index: &mut usize) -> u32 {
pub fn decode_int_3(buf: &Bytes, index: &mut usize) -> u32 {
let value = LittleEndian::read_u24(&buf[*index..]);
*index += 3;
value
}
#[inline]
pub fn deserialize_int_2(buf: &Bytes, index: &mut usize) -> u16 {
pub fn decode_int_2(buf: &Bytes, index: &mut usize) -> u16 {
let value = LittleEndian::read_u16(&buf[*index..]);
*index += 2;
value
}
#[inline]
pub fn deserialize_int_1(buf: &Bytes, index: &mut usize) -> u8 {
pub fn decode_int_1(buf: &Bytes, index: &mut usize) -> u8 {
let value = buf[*index];
*index += 1;
value
}
#[inline]
pub fn deserialize_string_lenenc(buf: &Bytes, index: &mut usize) -> Bytes {
let length = deserialize_int_3(&buf, &mut *index);
pub fn decode_string_lenenc(buf: &Bytes, index: &mut usize) -> Bytes {
let length = decode_int_3(&buf, &mut *index);
let value = Bytes::from(&buf[*index..*index + length as usize]);
*index = *index + length as usize;
value
}
#[inline]
pub fn deserialize_string_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes {
pub fn decode_string_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes {
let value = Bytes::from(&buf[*index..*index + length as usize]);
*index = *index + length as usize;
value
}
#[inline]
pub fn deserialize_string_eof(buf: &Bytes, index: &mut usize) -> Bytes {
pub fn decode_string_eof(buf: &Bytes, index: &mut usize) -> Bytes {
let value = Bytes::from(&buf[*index..]);
*index = buf.len();
value
}
#[inline]
pub fn deserialize_string_null(buf: &Bytes, index: &mut usize) -> Result<Bytes, Error> {
pub fn decode_string_null(buf: &Bytes, index: &mut usize) -> Result<Bytes, Error> {
if let Some(null_index) = memchr::memchr(0, &buf[*index..]) {
let value = Bytes::from(&buf[*index..*index + null_index]);
*index = *index + null_index + 1;
@ -114,22 +119,22 @@ pub fn deserialize_string_null(buf: &Bytes, index: &mut usize) -> Result<Bytes,
}
#[inline]
pub fn deserialize_byte_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes {
pub fn decode_byte_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes {
let value = Bytes::from(&buf[*index..*index + length as usize]);
*index = *index + length as usize;
value
}
#[inline]
pub fn deserialize_byte_lenenc(buf: &Bytes, index: &mut usize) -> Bytes {
let length = deserialize_int_3(&buf, &mut *index);
pub fn decode_byte_lenenc(buf: &Bytes, index: &mut usize) -> Bytes {
let length = decode_int_3(&buf, &mut *index);
let value = Bytes::from(&buf[*index..*index + length as usize]);
*index = *index + length as usize;
value
}
#[inline]
pub fn deserialize_byte_eof(buf: &Bytes, index: &mut usize) -> Bytes {
pub fn decode_byte_eof(buf: &Bytes, index: &mut usize) -> Bytes {
let value = Bytes::from(&buf[*index..]);
*index = buf.len();
value
@ -158,7 +163,7 @@ mod tests {
fn it_decodes_int_lenenc_0x_fb() {
let buf: BytesMut = BytesMut::from(b"\xFB".to_vec());
let mut index = 0;
let int: Option<usize> = deserialize_int_lenenc(&buf.freeze(), &mut index);
let int: Option<usize> = decode_int_lenenc(&buf.freeze(), &mut index);
assert_eq!(int, None);
assert_eq!(index, 1);
@ -168,7 +173,7 @@ mod tests {
fn it_decodes_int_lenenc_0x_fc() {
let buf = BytesMut::from(b"\xFC\x01\x01".to_vec());
let mut index = 0;
let int: Option<usize> = deserialize_int_lenenc(&buf.freeze(), &mut index);
let int: Option<usize> = decode_int_lenenc(&buf.freeze(), &mut index);
assert_eq!(int, Some(257));
assert_eq!(index, 3);
@ -178,7 +183,7 @@ mod tests {
fn it_decodes_int_lenenc_0x_fd() {
let buf = BytesMut::from(b"\xFD\x01\x01\x01".to_vec());
let mut index = 0;
let int: Option<usize> = deserialize_int_lenenc(&buf.freeze(), &mut index);
let int: Option<usize> = decode_int_lenenc(&buf.freeze(), &mut index);
assert_eq!(int, Some(65793));
assert_eq!(index, 4);
@ -188,7 +193,7 @@ mod tests {
fn it_decodes_int_lenenc_0x_fe() {
let buf = BytesMut::from(b"\xFE\x01\x01\x01\x01\x01\x01\x01\x01".to_vec());
let mut index = 0;
let int: Option<usize> = deserialize_int_lenenc(&buf.freeze(), &mut index);
let int: Option<usize> = decode_int_lenenc(&buf.freeze(), &mut index);
assert_eq!(int, Some(72340172838076673));
assert_eq!(index, 9);
@ -198,7 +203,7 @@ mod tests {
fn it_decodes_int_lenenc_0x_fa() {
let buf = BytesMut::from(b"\xFA".to_vec());
let mut index = 0;
let int: Option<usize> = deserialize_int_lenenc(&buf.freeze(), &mut index);
let int: Option<usize> = decode_int_lenenc(&buf.freeze(), &mut index);
assert_eq!(int, Some(0xfA));
assert_eq!(index, 1);
@ -208,7 +213,7 @@ mod tests {
fn it_decodes_int_8() {
let buf = BytesMut::from(b"\x01\x01\x01\x01\x01\x01\x01\x01".to_vec());
let mut index = 0;
let int: u64 = deserialize_int_8(&buf.freeze(), &mut index);
let int: u64 = decode_int_8(&buf.freeze(), &mut index);
assert_eq!(int, 72340172838076673);
assert_eq!(index, 8);
@ -218,7 +223,7 @@ mod tests {
fn it_decodes_int_4() {
let buf = BytesMut::from(b"\x01\x01\x01\x01".to_vec());
let mut index = 0;
let int: u32 = deserialize_int_4(&buf.freeze(), &mut index);
let int: u32 = decode_int_4(&buf.freeze(), &mut index);
assert_eq!(int, 16843009);
assert_eq!(index, 4);
@ -228,7 +233,7 @@ mod tests {
fn it_decodes_int_3() {
let buf = BytesMut::from(b"\x01\x01\x01".to_vec());
let mut index = 0;
let int: u32 = deserialize_int_3(&buf.freeze(), &mut index);
let int: u32 = decode_int_3(&buf.freeze(), &mut index);
assert_eq!(int, 65793);
assert_eq!(index, 3);
@ -238,7 +243,7 @@ mod tests {
fn it_decodes_int_2() {
let buf = BytesMut::from(b"\x01\x01".to_vec());
let mut index = 0;
let int: u16 = deserialize_int_2(&buf.freeze(), &mut index);
let int: u16 = decode_int_2(&buf.freeze(), &mut index);
assert_eq!(int, 257);
assert_eq!(index, 2);
@ -248,7 +253,7 @@ mod tests {
fn it_decodes_int_1() {
let buf = BytesMut::from(b"\x01".to_vec());
let mut index = 0;
let int: u8 = deserialize_int_1(&buf.freeze(), &mut index);
let int: u8 = decode_int_1(&buf.freeze(), &mut index);
assert_eq!(int, 1);
assert_eq!(index, 1);
@ -258,7 +263,7 @@ mod tests {
fn it_decodes_string_lenenc() {
let buf = BytesMut::from(b"\x01\x00\x00\x01".to_vec());
let mut index = 0;
let string: Bytes = deserialize_string_lenenc(&buf.freeze(), &mut index);
let string: Bytes = decode_string_lenenc(&buf.freeze(), &mut index);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);
@ -269,7 +274,7 @@ mod tests {
fn it_decodes_string_fix() {
let buf = BytesMut::from(b"\x01".to_vec());
let mut index = 0;
let string: Bytes = deserialize_string_fix(&buf.freeze(), &mut index, 1);
let string: Bytes = decode_string_fix(&buf.freeze(), &mut index, 1);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);
@ -280,7 +285,7 @@ mod tests {
fn it_decodes_string_eof() {
let buf = BytesMut::from(b"\x01".to_vec());
let mut index = 0;
let string: Bytes = deserialize_string_eof(&buf.freeze(), &mut index);
let string: Bytes = decode_string_eof(&buf.freeze(), &mut index);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);
@ -291,7 +296,7 @@ mod tests {
fn it_decodes_string_null() -> Result<(), Error> {
let buf = BytesMut::from(b"random\x00\x01".to_vec());
let mut index = 0;
let string: Bytes = deserialize_string_null(&buf.freeze(), &mut index)?;
let string: Bytes = decode_string_null(&buf.freeze(), &mut index)?;
assert_eq!(&string[..], b"random");
@ -306,7 +311,7 @@ mod tests {
fn it_decodes_byte_fix() {
let buf = BytesMut::from(b"\x01".to_vec());
let mut index = 0;
let string: Bytes = deserialize_byte_fix(&buf.freeze(), &mut index, 1);
let string: Bytes = decode_byte_fix(&buf.freeze(), &mut index, 1);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);
@ -317,7 +322,7 @@ mod tests {
fn it_decodes_byte_eof() {
let buf = BytesMut::from(b"\x01".to_vec());
let mut index = 0;
let string: Bytes = deserialize_byte_eof(&buf.freeze(), &mut index);
let string: Bytes = decode_byte_eof(&buf.freeze(), &mut index);
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);

View File

@ -4,7 +4,7 @@ use bytes::{BufMut, Bytes, BytesMut};
const U24_MAX: usize = 0xFF_FF_FF;
#[inline]
pub fn serialize_length(buf: &mut BytesMut) {
pub fn encode_length(buf: &mut BytesMut) {
let mut length = [0; 3];
if buf.len() > U24_MAX {
panic!("Buffer too long");
@ -22,46 +22,46 @@ pub fn serialize_length(buf: &mut BytesMut) {
}
#[inline]
pub fn serialize_int_8(buf: &mut BytesMut, value: u64) {
pub fn encode_int_8(buf: &mut BytesMut, value: u64) {
buf.put_u64_le(value);
}
#[inline]
pub fn serialize_int_4(buf: &mut BytesMut, value: u32) {
pub fn encode_int_4(buf: &mut BytesMut, value: u32) {
buf.put_u32_le(value);
}
#[inline]
pub fn serialize_int_3(buf: &mut BytesMut, value: u32) {
pub fn encode_int_3(buf: &mut BytesMut, value: u32) {
let length = value.to_le_bytes();
buf.extend_from_slice(&length[0..3]);
}
#[inline]
pub fn serialize_int_2(buf: &mut BytesMut, value: u16) {
pub fn encode_int_2(buf: &mut BytesMut, value: u16) {
buf.put_u16_le(value);
}
#[inline]
pub fn serialize_int_1(buf: &mut BytesMut, value: u8) {
pub fn encode_int_1(buf: &mut BytesMut, value: u8) {
buf.put_u8(value);
}
#[inline]
pub fn serialize_int_lenenc(buf: &mut BytesMut, value: Option<&usize>) {
pub fn encode_int_lenenc(buf: &mut BytesMut, value: Option<&usize>) {
if let Some(value) = value {
if *value > U24_MAX && *value <= std::u64::MAX as usize {
buf.put_u8(0xFE);
serialize_int_8(buf, *value as u64);
encode_int_8(buf, *value as u64);
} else if *value > std::u16::MAX as usize && *value <= U24_MAX {
buf.put_u8(0xFD);
serialize_int_3(buf, *value as u32);
encode_int_3(buf, *value as u32);
} else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize {
buf.put_u8(0xFC);
serialize_int_2(buf, *value as u16);
encode_int_2(buf, *value as u16);
} else if *value <= std::u8::MAX as usize {
buf.put_u8(0xFA);
serialize_int_1(buf, *value as u8);
encode_int_1(buf, *value as u8);
} else {
panic!("Value is too long");
}
@ -71,25 +71,25 @@ pub fn serialize_int_lenenc(buf: &mut BytesMut, value: Option<&usize>) {
}
#[inline]
pub fn serialize_string_lenenc(buf: &mut BytesMut, string: &Bytes) {
pub fn encode_string_lenenc(buf: &mut BytesMut, string: &Bytes) {
if string.len() > 0xFFF {
panic!("String inside string lenenc serialization is too long");
}
serialize_int_3(buf, string.len() as u32);
encode_int_3(buf, string.len() as u32);
if string.len() > 0 {
buf.extend_from_slice(string);
}
}
#[inline]
pub fn serialize_string_null(buf: &mut BytesMut, string: &Bytes) {
pub fn encode_string_null(buf: &mut BytesMut, string: &Bytes) {
buf.extend_from_slice(string);
buf.put(0_u8);
}
#[inline]
pub fn serialize_string_fix(buf: &mut BytesMut, bytes: &Bytes, size: usize) {
pub fn encode_string_fix(buf: &mut BytesMut, bytes: &Bytes, size: usize) {
if size != bytes.len() {
panic!("Sizes do not match");
}
@ -98,22 +98,22 @@ pub fn serialize_string_fix(buf: &mut BytesMut, bytes: &Bytes, size: usize) {
}
#[inline]
pub fn serialize_string_eof(buf: &mut BytesMut, bytes: &Bytes) {
pub fn encode_string_eof(buf: &mut BytesMut, bytes: &Bytes) {
buf.extend_from_slice(bytes);
}
#[inline]
pub fn serialize_byte_lenenc(buf: &mut BytesMut, bytes: &Bytes) {
pub fn encode_byte_lenenc(buf: &mut BytesMut, bytes: &Bytes) {
if bytes.len() > 0xFFF {
panic!("String inside string lenenc serialization is too long");
}
serialize_int_3(buf, bytes.len() as u32);
encode_int_3(buf, bytes.len() as u32);
buf.extend_from_slice(bytes);
}
#[inline]
pub fn serialize_byte_fix(buf: &mut BytesMut, bytes: &Bytes, size: usize) {
pub fn encode_byte_fix(buf: &mut BytesMut, bytes: &Bytes, size: usize) {
if size != bytes.len() {
panic!("Sizes do not match");
}
@ -122,7 +122,7 @@ pub fn serialize_byte_fix(buf: &mut BytesMut, bytes: &Bytes, size: usize) {
}
#[inline]
pub fn serialize_byte_eof(buf: &mut BytesMut, bytes: &Bytes) {
pub fn encode_byte_eof(buf: &mut BytesMut, bytes: &Bytes) {
buf.extend_from_slice(bytes);
}
@ -130,28 +130,28 @@ pub fn serialize_byte_eof(buf: &mut BytesMut, bytes: &Bytes) {
mod tests {
use super::*;
// [X] serialize_int_lenenc_u64
// [X] serialize_int_lenenc_u32
// [X] serialize_int_lenenc_u24
// [X] serialize_int_lenenc_u16
// [X] serialize_int_lenenc_u8
// [X] serialize_int_u64
// [X] serialize_int_u32
// [X] serialize_int_u24
// [X] serialize_int_u16
// [X] serialize_int_u8
// [X] serialize_string_lenenc
// [X] serialize_string_fix
// [X] serialize_string_null
// [X] serialize_string_eof
// [X] serialize_byte_lenenc
// [X] serialize_byte_fix
// [X] serialize_byte_eof
// [X] encode_int_lenenc_u64
// [X] encode_int_lenenc_u32
// [X] encode_int_lenenc_u24
// [X] encode_int_lenenc_u16
// [X] encode_int_lenenc_u8
// [X] encode_int_u64
// [X] encode_int_u32
// [X] encode_int_u24
// [X] encode_int_u16
// [X] encode_int_u8
// [X] encode_string_lenenc
// [X] encode_string_fix
// [X] encode_string_null
// [X] encode_string_eof
// [X] encode_byte_lenenc
// [X] encode_byte_fix
// [X] encode_byte_eof
#[test]
fn it_encodes_int_lenenc_none() {
let mut buf = BytesMut::new();
serialize_int_lenenc(&mut buf, None);
encode_int_lenenc(&mut buf, None);
assert_eq!(&buf[..], b"\xFB");
}
@ -159,7 +159,7 @@ mod tests {
#[test]
fn it_encodes_int_lenenc_u8() {
let mut buf = BytesMut::new();
serialize_int_lenenc(&mut buf, Some(&(std::u8::MAX as usize)));
encode_int_lenenc(&mut buf, Some(&(std::u8::MAX as usize)));
assert_eq!(&buf[..], b"\xFA\xFF");
}
@ -167,7 +167,7 @@ mod tests {
#[test]
fn it_encodes_int_lenenc_u16() {
let mut buf = BytesMut::new();
serialize_int_lenenc(&mut buf, Some(&(std::u16::MAX as usize)));
encode_int_lenenc(&mut buf, Some(&(std::u16::MAX as usize)));
assert_eq!(&buf[..], b"\xFC\xFF\xFF");
}
@ -175,7 +175,7 @@ mod tests {
#[test]
fn it_encodes_int_lenenc_u24() {
let mut buf = BytesMut::new();
serialize_int_lenenc(&mut buf, Some(&U24_MAX));
encode_int_lenenc(&mut buf, Some(&U24_MAX));
assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF");
}
@ -183,7 +183,7 @@ mod tests {
#[test]
fn it_encodes_int_lenenc_u64() {
let mut buf = BytesMut::new();
serialize_int_lenenc(&mut buf, Some(&(std::u64::MAX as usize)));
encode_int_lenenc(&mut buf, Some(&(std::u64::MAX as usize)));
assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
@ -191,7 +191,7 @@ mod tests {
#[test]
fn it_encodes_int_u64() {
let mut buf = BytesMut::new();
serialize_int_8(&mut buf, std::u64::MAX);
encode_int_8(&mut buf, std::u64::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF");
}
@ -199,7 +199,7 @@ mod tests {
#[test]
fn it_encodes_int_u32() {
let mut buf = BytesMut::new();
serialize_int_4(&mut buf, std::u32::MAX);
encode_int_4(&mut buf, std::u32::MAX);
assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF");
}
@ -207,7 +207,7 @@ mod tests {
#[test]
fn it_encodes_int_u24() {
let mut buf = BytesMut::new();
serialize_int_3(&mut buf, U24_MAX as u32);
encode_int_3(&mut buf, U24_MAX as u32);
assert_eq!(&buf[..], b"\xFF\xFF\xFF");
}
@ -215,7 +215,7 @@ mod tests {
#[test]
fn it_encodes_int_u16() {
let mut buf = BytesMut::new();
serialize_int_2(&mut buf, std::u16::MAX);
encode_int_2(&mut buf, std::u16::MAX);
assert_eq!(&buf[..], b"\xFF\xFF");
}
@ -223,7 +223,7 @@ mod tests {
#[test]
fn it_encodes_int_u8() {
let mut buf = BytesMut::new();
serialize_int_1(&mut buf, std::u8::MAX);
encode_int_1(&mut buf, std::u8::MAX);
assert_eq!(&buf[..], b"\xFF");
}
@ -231,7 +231,7 @@ mod tests {
#[test]
fn it_encodes_string_lenenc() {
let mut buf = BytesMut::new();
serialize_string_lenenc(&mut buf, &Bytes::from_static(b"random_string"));
encode_string_lenenc(&mut buf, &Bytes::from_static(b"random_string"));
assert_eq!(&buf[..], b"\x0D\x00\x00random_string");
}
@ -239,7 +239,7 @@ mod tests {
#[test]
fn it_encodes_string_fix() {
let mut buf = BytesMut::new();
serialize_string_fix(&mut buf, &Bytes::from_static(b"random_string"), 13);
encode_string_fix(&mut buf, &Bytes::from_static(b"random_string"), 13);
assert_eq!(&buf[..], b"random_string");
}
@ -247,7 +247,7 @@ mod tests {
#[test]
fn it_encodes_string_null() {
let mut buf = BytesMut::new();
serialize_string_null(&mut buf, &Bytes::from_static(b"random_string"));
encode_string_null(&mut buf, &Bytes::from_static(b"random_string"));
assert_eq!(&buf[..], b"random_string\0");
}
@ -255,7 +255,7 @@ mod tests {
#[test]
fn it_encodes_string_eof() {
let mut buf = BytesMut::new();
serialize_string_eof(&mut buf, &Bytes::from_static(b"random_string"));
encode_string_eof(&mut buf, &Bytes::from_static(b"random_string"));
assert_eq!(&buf[..], b"random_string");
}
@ -263,7 +263,7 @@ mod tests {
#[test]
fn it_encodes_byte_lenenc() {
let mut buf = BytesMut::new();
serialize_byte_lenenc(&mut buf, &Bytes::from("random_string"));
encode_byte_lenenc(&mut buf, &Bytes::from("random_string"));
assert_eq!(&buf[..], b"\x0D\x00\x00random_string");
}
@ -271,7 +271,7 @@ mod tests {
#[test]
fn it_encodes_byte_fix() {
let mut buf = BytesMut::new();
serialize_byte_fix(&mut buf, &Bytes::from("random_string"), 13);
encode_byte_fix(&mut buf, &Bytes::from("random_string"), 13);
assert_eq!(&buf[..], b"random_string");
}
@ -279,7 +279,7 @@ mod tests {
#[test]
fn it_encodes_byte_eof() {
let mut buf = BytesMut::new();
serialize_byte_eof(&mut buf, &Bytes::from("random_string"));
encode_byte_eof(&mut buf, &Bytes::from("random_string"));
assert_eq!(&buf[..], b"random_string");
}

View File

@ -1,4 +1,9 @@
pub enum ErrorCodes {
use std::convert::TryFrom;
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)]
#[TryFromPrimitiveType="u16"]
pub enum ErrorCode {
ErDefault = 0,
ErHashchk = 1000,
ErNisamchk = 1001,
ErNo = 1002,
@ -973,3 +978,9 @@ pub enum ErrorCodes {
ErNoEisForField = 1980,
ErWarnAggfuncDependence = 1981,
}
impl Default for ErrorCode {
fn default() -> Self {
ErrorCode::ErDefault
}
}

View File

@ -1,5 +1,5 @@
pub mod client;
pub mod deserialize;
pub mod decode;
pub mod error_codes;
pub mod serialize;
pub mod encode;
pub mod server;

View File

@ -1,9 +1,10 @@
// Reference: https://mariadb.com/kb/en/library/connection
use crate::protocol::deserialize::*;
use crate::protocol::{decode::*, error_codes::ErrorCode};
use byteorder::{ByteOrder, LittleEndian};
use bytes::{Bytes, BytesMut};
use failure::{err_msg, Error};
use std::convert::TryFrom;
pub trait Deserialize: Sized {
fn deserialize(buf: &Bytes) -> Result<Self, Error>;
@ -92,6 +93,42 @@ pub enum SessionChangeType {
SessionTrackTransactionState = 5,
}
#[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive)]
#[TryFromPrimitiveType="u8"]
pub enum FieldType {
MysqlTypeDecimal = 0,
MysqlTypeTiny = 1,
MysqlTypeShort = 2,
MysqlTypeLong = 3,
MysqlTypeFloat = 4,
MysqlTypeDouble = 5,
MysqlTypeNull = 6,
MysqlTypeTimestamp = 7,
MysqlTypeLonglong = 8,
MysqlTypeInt24 = 9,
MysqlTypeDate = 10,
MysqlTypeTime = 11,
MysqlTypeDatetime = 12,
MysqlTypeYear = 13,
MysqlTypeNewdate = 14,
MysqlTypeVarchar = 15,
MysqlTypeBit = 16,
MysqlTypeTimestamp2 = 17,
MysqlTypeDatetime2 = 18,
MysqlTypeTime2 = 19,
MysqlTypeJson = 245,
MysqlTypeNewdecimal = 246,
MysqlTypeEnum = 247,
MysqlTypeSet = 248,
MysqlTypeTinyBlob = 249,
MysqlTypeMediumBlob = 250,
MysqlTypeLongBlob = 251,
MysqlTypeBlob = 252,
MysqlTypeVarString = 253,
MysqlTypeString = 254,
MysqlTypeGeometry = 255,
}
impl Default for Capabilities {
fn default() -> Self {
Capabilities::CLIENT_MYSQL
@ -104,6 +141,12 @@ impl Default for ServerStatusFlag {
}
}
impl Default for FieldType {
fn default() -> Self {
FieldType::MysqlTypeDecimal
}
}
#[derive(Default, Debug)]
pub struct InitialHandshakePacket {
pub length: u32,
@ -137,7 +180,7 @@ pub struct OkPacket {
pub struct ErrPacket {
pub length: u32,
pub seq_no: u8,
pub error_code: u16,
pub error_code: ErrorCode,
pub stage: Option<u8>,
pub max_stage: Option<u8>,
pub progress: Option<u32>,
@ -147,6 +190,36 @@ pub struct ErrPacket {
pub error_message: Option<Bytes>,
}
#[derive(Default, Debug)]
pub struct ColumnPacket {
pub length: u32,
pub seq_no: u8,
pub columns: Option<usize>,
}
pub struct ColumnDefPacket {
pub length: u32,
pub seq_no: u8,
pub catalog: Bytes,
pub schema: Bytes,
pub table_alias: Bytes,
pub table: Bytes,
pub column_alias: Bytes,
pub column: Bytes,
pub length_of_fixed_fields: Option<usize>,
pub char_set: u16,
pub max_columns: u32,
pub field_type: FieldType,
pub field_details: FieldDetailFlag,
pub decimals: u8,
}
#[derive(Debug, Default)]
pub struct ResultSet {
pub columns: Vec<(ColumnPacket, ColumnDefPacket)>,
pub rows: Vec<Bytes>,
}
impl Message {
pub fn deserialize(buf: &mut BytesMut) -> Result<Option<Self>, Error> {
if buf.len() < 4 {
@ -174,35 +247,35 @@ impl Deserialize for InitialHandshakePacket {
fn deserialize(buf: &Bytes) -> Result<Self, Error> {
let mut index = 0;
let length = deserialize_length(&buf, &mut index)?;
let seq_no = deserialize_int_1(&buf, &mut index);
let length = decode_length(&buf, &mut index)?;
let seq_no = decode_int_1(&buf, &mut index);
if seq_no != 0 {
return Err(err_msg("Squence Number of Initial Handshake Packet is not 0"));
}
let protocol_version = deserialize_int_1(&buf, &mut index);
let server_version = deserialize_string_null(&buf, &mut index)?;
let connection_id = deserialize_int_4(&buf, &mut index);
let auth_seed = deserialize_string_fix(&buf, &mut index, 8);
let protocol_version = decode_int_1(&buf, &mut index);
let server_version = decode_string_null(&buf, &mut index)?;
let connection_id = decode_int_4(&buf, &mut index);
let auth_seed = decode_string_fix(&buf, &mut index, 8);
// Skip reserved byte
index += 1;
let mut capabilities =
Capabilities::from_bits_truncate(deserialize_int_2(&buf, &mut index).into());
Capabilities::from_bits_truncate(decode_int_2(&buf, &mut index).into());
let collation = deserialize_int_1(&buf, &mut index);
let collation = decode_int_1(&buf, &mut index);
let status =
ServerStatusFlag::from_bits_truncate(deserialize_int_2(&buf, &mut index).into());
ServerStatusFlag::from_bits_truncate(decode_int_2(&buf, &mut index).into());
capabilities |= Capabilities::from_bits_truncate(
((deserialize_int_2(&buf, &mut index) as u32) << 16).into(),
((decode_int_2(&buf, &mut index) as u32) << 16).into(),
);
let mut plugin_data_length = 0;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
plugin_data_length = deserialize_int_1(&buf, &mut index);
plugin_data_length = decode_int_1(&buf, &mut index);
} else {
// Skip reserve byte
index += 1;
@ -213,7 +286,7 @@ impl Deserialize for InitialHandshakePacket {
if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
capabilities |= Capabilities::from_bits_truncate(
((deserialize_int_4(&buf, &mut index) as u128) << 32).into(),
((decode_int_4(&buf, &mut index) as u128) << 32).into(),
);
} else {
// Skip filler
@ -223,14 +296,14 @@ impl Deserialize for InitialHandshakePacket {
let mut scramble: Option<Bytes> = None;
if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
let len = std::cmp::max(12, plugin_data_length as usize - 9);
scramble = Some(deserialize_string_fix(&buf, &mut index, len));
scramble = Some(decode_string_fix(&buf, &mut index, len));
// Skip reserve byte
index += 1;
}
let mut auth_plugin_name: Option<Bytes> = None;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
auth_plugin_name = Some(deserialize_string_null(&buf, &mut index)?);
auth_plugin_name = Some(decode_string_null(&buf, &mut index)?);
}
Ok(InitialHandshakePacket {
@ -255,20 +328,20 @@ impl Deserialize for OkPacket {
let mut index = 0;
// Packet header
let length = deserialize_length(&buf, &mut index)?;
let seq_no = deserialize_int_1(&buf, &mut index);
let length = decode_length(&buf, &mut index)?;
let seq_no = decode_int_1(&buf, &mut index);
// Packet body
let packet_header = deserialize_int_1(&buf, &mut index);
let packet_header = decode_int_1(&buf, &mut index);
if packet_header != 0 && packet_header != 0xFE {
panic!("Packet header is not 0 or 0xFE for OkPacket");
}
let affected_rows = deserialize_int_lenenc(&buf, &mut index);
let last_insert_id = deserialize_int_lenenc(&buf, &mut index);
let affected_rows = decode_int_lenenc(&buf, &mut index);
let last_insert_id = decode_int_lenenc(&buf, &mut index);
let server_status =
ServerStatusFlag::from_bits_truncate(deserialize_int_2(&buf, &mut index).into());
let warning_count = deserialize_int_2(&buf, &mut index);
ServerStatusFlag::from_bits_truncate(decode_int_2(&buf, &mut index).into());
let warning_count = decode_int_2(&buf, &mut index);
// Assuming CLIENT_SESSION_TRACK is unsupported
let session_state_info = None;
@ -294,15 +367,15 @@ impl Deserialize for ErrPacket {
fn deserialize(buf: &Bytes) -> Result<Self, Error> {
let mut index = 0;
let length = deserialize_length(&buf, &mut index)?;
let seq_no = deserialize_int_1(&buf, &mut index);
let length = decode_length(&buf, &mut index)?;
let seq_no = decode_int_1(&buf, &mut index);
let packet_header = deserialize_int_1(&buf, &mut index);
let packet_header = decode_int_1(&buf, &mut index);
if packet_header != 0xFF {
panic!("Packet header is not 0xFF for ErrPacket");
}
let error_code = deserialize_int_2(&buf, &mut index);
let error_code = ErrorCode::try_from(decode_int_2(&buf, &mut index))?;
let mut stage = None;
let mut max_stage = None;
@ -314,18 +387,18 @@ impl Deserialize for ErrPacket {
let mut error_message = None;
// Progress Reporting
if error_code == 0xFFFF {
stage = Some(deserialize_int_1(buf, &mut index));
max_stage = Some(deserialize_int_1(buf, &mut index));
progress = Some(deserialize_int_3(buf, &mut index));
progress_info = Some(deserialize_string_lenenc(&buf, &mut index));
if error_code as u16 == 0xFFFF {
stage = Some(decode_int_1(buf, &mut index));
max_stage = Some(decode_int_1(buf, &mut index));
progress = Some(decode_int_3(buf, &mut index));
progress_info = Some(decode_string_lenenc(&buf, &mut index));
} else {
if buf[index] == b'#' {
sql_state_marker = Some(deserialize_string_fix(buf, &mut index, 1));
sql_state = Some(deserialize_string_fix(buf, &mut index, 5));
error_message = Some(deserialize_string_eof(buf, &mut index));
sql_state_marker = Some(decode_string_fix(buf, &mut index, 1));
sql_state = Some(decode_string_fix(buf, &mut index, 5));
error_message = Some(decode_string_eof(buf, &mut index));
} else {
error_message = Some(deserialize_string_eof(buf, &mut index));
error_message = Some(decode_string_eof(buf, &mut index));
}
}
@ -344,6 +417,83 @@ impl Deserialize for ErrPacket {
}
}
impl Deserialize for ColumnPacket {
fn deserialize(buf: &Bytes) -> Result<Self, Error> {
let mut index = 0;
let length = decode_length(&buf, &mut index)?;
let seq_no = decode_int_1(&buf, &mut index);
let columns = decode_int_lenenc(&buf, &mut index);
Ok(ColumnPacket {
length,
seq_no,
columns,
})
}
}
impl Deserialize for ColumnDefPacket {
fn deserialize(buf: &Bytes) -> Result<Self, Error> {
let mut index = 0;
let length = decode_length(&buf, &mut index)?;
let seq_no = decode_int_1(&buf, &mut index);
let catalog = decode_string_lenenc(&buf, &mut index);
let schema = decode_string_lenenc(&buf, &mut index);
let table_alias = decode_string_lenenc(&buf, &mut index);
let table = decode_string_lenenc(&buf, &mut index);
let column_alias = decode_string_lenenc(&buf, &mut index);
let column = decode_string_lenenc(&buf, &mut index);
let length_of_fixed_fields = decode_int_lenenc(&buf, &mut index);
let char_set = decode_int_2(&buf, &mut index);
let max_columns = decode_int_4(&buf, &mut index);
let field_type = FieldType::try_from(decode_int_1(&buf, &mut index))?;
let field_details = FieldDetailFlag::from_bits_truncate(decode_int_2(&buf, &mut index));
let decimals = decode_int_1(&buf, &mut index);
// Skip last two unused bytes
// index += 2;
Ok(ColumnDefPacket {
length,
seq_no,
catalog,
schema,
table_alias,
table,
column_alias,
column,
length_of_fixed_fields,
char_set,
max_columns,
field_type,
field_details,
decimals,
})
}
}
//impl Deserialize for ResultSet {
// fn deserialize(buf: &Bytes) -> Result<Self, Error> {
// let mut index = 0;
//
// let length = decode_length(&buf, &mut index)?;
// let seq_no = decode_int_1(&buf, &mut index);
//
// let column_packet = ColumnPacket::deserialize(&but)?;
//
// let column_definitions = if let Some(columns) = column_packet.columns {
// (0..columns).map(|_| {
// ColumnDefPacket::deserialize()
// })
// };
//
// Ok(ResultSet::default())
// }
//}
#[cfg(test)]
mod test {
use super::*;
@ -352,7 +502,7 @@ mod test {
fn it_decodes_capabilities() {
let buf = BytesMut::from(b"\xfe\xf7".to_vec());
let mut index = 0;
Capabilities::from_bits_truncate(deserialize_int_2(&buf.freeze(), &mut index).into());
Capabilities::from_bits_truncate(decode_int_2(&buf.freeze(), &mut index).into());
}
#[test]