Use DeContext and remove resultset as a packet

This commit is contained in:
Daniel Akhterov 2019-07-24 17:19:30 -07:00
parent 4cfb1d46a1
commit 6e282ee33b
19 changed files with 561 additions and 200 deletions

View File

@ -1,11 +1,11 @@
use super::{Connection};
use super::Connection;
use crate::protocol::{
deserialize::{Deserialize, DeContext},
deserialize::{DeContext, Deserialize},
packets::{handshake_response::HandshakeResponsePacket, initial::InitialHandshakePacket},
server::Message as ServerMessage,
types::Capabilities,
};
use bytes::Bytes;
use bytes::{BufMut, Bytes};
use failure::{err_msg, Error};
use mason_core::ConnectOptions;
@ -14,7 +14,7 @@ pub async fn establish<'a, 'b: 'a>(
options: ConnectOptions<'b>,
) -> Result<(), Error> {
let buf = &conn.stream.next_bytes().await?;
let mut de_ctx = DeContext::new(conn, &buf);
let mut de_ctx = DeContext::new(&mut conn.context, &buf);
let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?;
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
@ -30,7 +30,7 @@ pub async fn establish<'a, 'b: 'a>(
match conn.next().await? {
Some(ServerMessage::OkPacket(message)) => {
conn.seq_no = message.seq_no;
conn.context.seq_no = message.seq_no;
Ok(())
}
@ -41,7 +41,7 @@ pub async fn establish<'a, 'b: 'a>(
}
None => {
panic!("Did not recieve packet");
panic!("Did not receive packet");
}
}
}
@ -76,9 +76,10 @@ mod test {
database: None,
password: None,
})
.await {
.await
{
Ok(_) => Err(err_msg("Bad username still worked?")),
Err(_) => Ok(())
Err(_) => Ok(()),
}
}
}

View File

@ -1,7 +1,7 @@
use crate::protocol::{
deserialize::{Deserialize, DeContext},
deserialize::{DeContext, Deserialize},
encode::Encoder,
packets::{com_ping::ComPing, com_quit::ComQuit, ok::OkPacket},
packets::{com_ping::ComPing, com_query::ComQuery, com_quit::ComQuit, ok::OkPacket},
serialize::Serialize,
server::Message as ServerMessage,
types::{Capabilities, ServerStatusFlag},
@ -24,6 +24,12 @@ pub struct Connection {
// Buffer used when serializing outgoing messages
pub encoder: Encoder,
// Context for the connection
// Explicitly declared to easily send to deserializers
pub context: ConnContext,
}
pub struct ConnContext {
// MariaDB Connection ID
pub connection_id: i32,
@ -41,16 +47,18 @@ pub struct Connection {
}
impl Connection {
pub async fn establish(options: ConnectOptions<'static>) -> Result<Self, Error> {
pub(crate) async fn establish(options: ConnectOptions<'static>) -> Result<Self, Error> {
let stream: Framed = Framed::new(TcpStream::connect((options.host, options.port)).await?);
let mut conn: Connection = Self {
stream,
encoder: Encoder::new(1024),
connection_id: -1,
seq_no: 1,
last_seq_no: 0,
capabilities: Capabilities::default(),
status: ServerStatusFlag::default(),
context: ConnContext {
connection_id: -1,
seq_no: 1,
last_seq_no: 0,
capabilities: Capabilities::default(),
status: ServerStatusFlag::default(),
},
};
establish::establish(&mut conn, options).await?;
@ -58,13 +66,13 @@ impl Connection {
Ok(conn)
}
async fn send<S>(&mut self, message: S) -> Result<(), Error>
pub(crate) async fn send<S>(&mut self, message: S) -> Result<(), Error>
where
S: Serialize,
{
self.encoder.clear();
self.encoder.alloc_packet_header();
self.encoder.seq_no(self.seq_no);
self.encoder.seq_no(self.context.seq_no);
message.serialize(self)?;
self.encoder.encode_length();
@ -74,24 +82,30 @@ impl Connection {
Ok(())
}
async fn quit(&mut self) -> Result<(), Error> {
pub(crate) async fn quit(&mut self) -> Result<(), Error> {
self.send(ComQuit()).await?;
Ok(())
}
async fn ping(&mut self) -> Result<(), Error> {
self.seq_no = 0;
self.send(ComPing()).await?;
// Ping response must be an OkPacket
let buf = self.stream.next_bytes().await?;
OkPacket::deserialize(&mut DeContext::new(self, &buf))?;
pub(crate) async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> {
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
Ok(())
}
async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
pub(crate) async fn ping(&mut self) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComPing()).await?;
// Ping response must be an OkPacket
let buf = self.stream.next_bytes().await?;
OkPacket::deserialize(&mut DeContext::new(&mut self.context, &buf))?;
Ok(())
}
pub(crate) async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0;
@ -118,7 +132,10 @@ impl Connection {
while len > 0 {
let size = rbuf.len();
let message = ServerMessage::deserialize(&mut DeContext::new(self, &rbuf.as_ref().into()))?;
let message = ServerMessage::deserialize(&mut DeContext::new(
&mut self.context,
&rbuf.as_ref().into(),
))?;
len -= size - rbuf.len();
match message {

View File

@ -1,7 +1,7 @@
#![feature(non_exhaustive, async_await)]
#![allow(clippy::needless_lifetimes)]
// TODO: Remove this once API has matured
#![allow(dead_code)]
#![allow(dead_code, unused_imports, unused_variables)]
#[macro_use]
extern crate bitflags;
@ -10,3 +10,6 @@ extern crate enum_tryfrom_derive;
pub mod connection;
pub mod protocol;
#[macro_use]
pub mod macros;

View File

@ -0,0 +1,20 @@
#[cfg(test)]
#[doc(hidden)]
#[macro_export]
macro_rules! __bytes_builder (
($($b: expr), *) => {{
use bytes::Buf;
use bytes::IntoBuf;
use bytes::BufMut;
let mut bytes = bytes::BytesMut::new();
$(
{
let buf = $b.into_buf();
bytes.reserve(buf.remaining());
bytes.put(buf);
}
)*
bytes.freeze()
}}
);

View File

@ -1,4 +1,5 @@
// Deserializing bytes and string do the same thing. Except that string also has a null terminated deserialzer
use super::packets::packet_header::PacketHeader;
use byteorder::{ByteOrder, LittleEndian};
use bytes::Bytes;
use failure::{err_msg, Error};
@ -24,6 +25,23 @@ impl<'a> Decoder<'a> {
Ok(length)
}
#[inline]
pub fn peek_tag(&self) -> Option<&u8> {
self.buf.get(4)
}
#[inline]
pub fn peek_packet_header(&self) -> Result<PacketHeader, Error> {
let length = LittleEndian::read_u24(&self.buf[self.index..]);
let seq_no = self.buf[3];
if self.buf.len() < length as usize {
return Err(err_msg("Lengths to do not match"));
}
Ok(PacketHeader { length, seq_no })
}
#[inline]
pub fn skip_bytes(&mut self, amount: usize) {
self.index += amount;

View File

@ -1,19 +1,16 @@
use super::decode::Decoder;
use failure::Error;
use crate::connection::Connection;
use crate::connection::{ConnContext, Connection};
use bytes::Bytes;
use failure::Error;
pub struct DeContext<'a> {
pub conn: &'a mut Connection,
pub conn: &'a mut ConnContext,
pub decoder: Decoder<'a>,
}
impl<'a> DeContext<'a> {
pub fn new(conn: &'a mut Connection, buf: &'a Bytes) -> Self {
DeContext {
conn,
decoder: Decoder::new(&buf),
}
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
DeContext { conn, decoder: Decoder::new(&buf) }
}
}

View File

@ -56,7 +56,7 @@ impl Encoder {
}
#[inline]
pub fn encode_int_3(&mut self, value: u32) {
pub fn encode_int_3(&mut self, value: u32) {
self.buf.extend_from_slice(&value.to_le_bytes()[0..3]);
}
@ -74,22 +74,22 @@ impl Encoder {
pub fn encode_int_lenenc(&mut self, value: Option<&usize>) {
if let Some(value) = value {
if *value > U24_MAX && *value <= std::u64::MAX as usize {
self.buf.push(0xFE);
self.buf.put_u8(0xFE);
self.encode_int_8(*value as u64);
} else if *value > std::u16::MAX as usize && *value <= U24_MAX {
self.buf.push(0xFD);
self.buf.put_u8(0xFD);
self.encode_int_3(*value as u32);
} else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize {
self.buf.push(0xFC);
self.buf.put_u8(0xFC);
self.encode_int_2(*value as u16);
} else if *value <= std::u8::MAX as usize {
self.buf.push(0xFA);
self.buf.put_u8(0xFA);
self.encode_int_1(*value as u8);
} else {
panic!("Value is too long");
}
} else {
self.buf.push(0xFB);
self.buf.put_u8(0xFB);
}
}

View File

@ -1,28 +1,27 @@
use super::super::deserialize::{Deserialize, DeContext};
use super::super::deserialize::{DeContext, Deserialize};
use failure::Error;
#[derive(Default, Debug)]
// ColumnPacket doesn't have a packet header because
// it's nested inside a result set packet
pub struct ColumnPacket {
pub length: u32,
pub seq_no: u8,
pub columns: Option<usize>,
}
impl Deserialize for ColumnPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let columns = decoder.decode_int_lenenc();
Ok(ColumnPacket { length, seq_no, columns })
Ok(ColumnPacket { columns })
}
}
#[cfg(test)]
mod test {
use bytes::Bytes;
use super::*;
use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder};
use bytes::Bytes;
use mason_core::ConnectOptions;
#[runtime::test]
@ -33,10 +32,15 @@ mod test {
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = Bytes::from(b"\x01\0\0\x01\xFB".to_vec());
let message = ColumnPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?;
let buf = __bytes_builder!(
// int<lenenc> tag code: None
0xFB_u8
);
let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
assert_eq!(message.columns, None);
@ -51,10 +55,16 @@ mod test {
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = Bytes::from(b"\x04\0\0\x01\xFD\x01\x01\x01".to_vec());
let message = ColumnPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?;
let buf = __bytes_builder!(
// int<lenenc> tag code: Some(3 bytes)
0xFD_u8, // value: 3 bytes
0x01_u8, 0x01_u8, 0x01_u8
);
let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
assert_eq!(message.columns, Some(0x010101));
@ -69,14 +79,21 @@ mod test {
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = Bytes::from(b"\x03\0\0\x01\xFC\x01\x01".to_vec());
let message = ColumnPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?;
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<lenenc> tag code: Some(3 bytes)
0xFC_u8,
// value: 2 bytes
0x01_u8, 0x01_u8
);
let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
assert_ne!(message.columns, Some(0x0100));
Ok(())
}
}

View File

@ -1,15 +1,15 @@
use std::convert::TryFrom;
use bytes::Bytes;
use failure::Error;
use super::super::{
deserialize::{Deserialize, DeContext},
deserialize::{DeContext, Deserialize},
types::{FieldDetailFlag, FieldType},
};
use bytes::Bytes;
use failure::Error;
use std::convert::TryFrom;
#[derive(Debug, Default)]
// ColumnDefPacket doesn't have a packet header because
// it's nested inside a result set packet
pub struct ColumnDefPacket {
pub length: u32,
pub seq_no: u8,
pub catalog: Bytes,
pub schema: Bytes,
pub table_alias: Bytes,
@ -27,8 +27,6 @@ pub struct ColumnDefPacket {
impl Deserialize for ColumnDefPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let catalog = decoder.decode_string_lenenc();
let schema = decoder.decode_string_lenenc();
@ -47,8 +45,6 @@ impl Deserialize for ColumnDefPacket {
decoder.skip_bytes(2);
Ok(ColumnDefPacket {
length,
seq_no,
catalog,
schema,
table_alias,
@ -67,8 +63,9 @@ impl Deserialize for ColumnDefPacket {
#[cfg(test)]
mod test {
use bytes::Bytes;
use super::*;
use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder};
use bytes::Bytes;
use mason_core::ConnectOptions;
#[runtime::test]
@ -79,26 +76,40 @@ mod test {
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = Bytes::from(b"\
\0\0\0\
\x01\
\x01\0\0a\
\x01\0\0b\
\x01\0\0c\
\x01\0\0d\
\x01\0\0e\
\x01\0\0f\
\xfc\x01\x01\
\x01\x01\
\x01\x01\x01\x01\
\x00\
\x00\x00\
\x01\
\0\0
".to_vec());
let message = ColumnDefPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?;
#[rustfmt::skip]
let buf = __bytes_builder!(
// string<lenenc> catalog (always 'def')
1u8, 0u8, 0u8, b'a',
// string<lenenc> schema
1u8, 0u8, 0u8, b'b',
// string<lenenc> table alias
1u8, 0u8, 0u8, b'c',
// string<lenenc> table
1u8, 0u8, 0u8, b'd',
// string<lenenc> column alias
1u8, 0u8, 0u8, b'e',
// string<lenenc> column
1u8, 0u8, 0u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xFC_u8, 1u8, 1u8,
// int<2> character set number
1u8, 1u8,
// int<4> max. column size
1u8, 1u8, 1u8, 1u8,
// int<1> Field types
1u8,
// int<2> Field detail flag
1u8, 0u8,
// int<1> decimals
1u8,
// int<2> - unused -
0u8, 0u8
);
let message = ColumnDefPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
assert_eq!(&message.catalog[..], b"a");
assert_eq!(&message.schema[..], b"b");

View File

@ -0,0 +1,76 @@
use super::super::{
decode::Decoder,
deserialize::{DeContext, Deserialize},
error_codes::ErrorCode,
types::ServerStatusFlag,
};
use bytes::Bytes;
use failure::Error;
use std::convert::TryFrom;
#[derive(Default, Debug)]
pub struct EofPacket {
pub length: u32,
pub seq_no: u8,
pub warning_count: u16,
pub status: ServerStatusFlag,
}
impl Deserialize for EofPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let packet_header = decoder.decode_int_1();
if packet_header != 0xFE {
panic!("Packet header is not 0xFE for ErrPacket");
}
let warning_count = decoder.decode_int_2();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2());
Ok(EofPacket { length, seq_no, warning_count, status })
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{__bytes_builder, connection::Connection};
use bytes::Bytes;
use mason_core::ConnectOptions;
#[runtime::test]
async fn it_decodes_eof_packet() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
})
.await?;
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
1u8, 1u8
);
let buf = Bytes::from_static(b"\x01\0\0\x01\xFE\x00\x00\x01\x00");
let _message = EofPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
Ok(())
}
}

View File

@ -1,7 +1,10 @@
use std::convert::TryFrom;
use super::super::{
deserialize::{DeContext, Deserialize},
error_codes::ErrorCode,
};
use bytes::Bytes;
use failure::Error;
use super::super::{deserialize::Deserialize, deserialize::DeContext, error_codes::ErrorCode};
use std::convert::TryFrom;
#[derive(Default, Debug)]
pub struct ErrPacket {
@ -72,8 +75,9 @@ impl Deserialize for ErrPacket {
#[cfg(test)]
mod test {
use bytes::Bytes;
use super::*;
use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder};
use bytes::Bytes;
use mason_core::ConnectOptions;
#[runtime::test]
@ -84,10 +88,39 @@ mod test {
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = Bytes::from(b"!\0\0\x01\xff\x84\x04#08S01Got packets out of order".to_vec());
let _message = ErrPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?;
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<1> 0xfe : EOF header
0xFF_u8,
// int<2> error code
0x84_u8, 0x04_u8,
// if (errorcode == 0xFFFF) /* progress reporting */ {
// int<1> stage
// int<1> max_stage
// int<3> progress
// string<lenenc> progress_info
// } else {
// if (next byte = '#') {
// string<1> sql state marker '#'
b"#",
// string<5>sql state
b"08S01",
// string<EOF> error message
b"Got packets out of order"
// } else {
// string<EOF> error message
// }
// }
);
let _message = ErrPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
Ok(())
}

View File

@ -28,7 +28,7 @@ impl Serialize for HandshakeResponsePacket {
// Filler
conn.encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19);
if !(conn.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
if !(conn.context.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {
@ -40,11 +40,11 @@ impl Serialize for HandshakeResponsePacket {
conn.encoder.encode_string_null(&self.username);
if !(conn.capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() {
if !(conn.context.capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() {
if let Some(auth_data) = &self.auth_data {
conn.encoder.encode_string_lenenc(&auth_data);
}
} else if !(conn.capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
} else if !(conn.context.capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
if let Some(auth_response) = &self.auth_response {
conn.encoder.encode_int_1(self.auth_response_len.unwrap());
conn.encoder
@ -54,21 +54,21 @@ impl Serialize for HandshakeResponsePacket {
conn.encoder.encode_int_1(0);
}
if !(conn.capabilities & Capabilities::CONNECT_WITH_DB).is_empty() {
if !(conn.context.capabilities & Capabilities::CONNECT_WITH_DB).is_empty() {
if let Some(database) = &self.database {
// string<NUL>
conn.encoder.encode_string_null(&database);
}
}
if !(conn.capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
if !(conn.context.capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
if let Some(auth_plugin_name) = &self.auth_plugin_name {
// string<NUL>
conn.encoder.encode_string_null(&auth_plugin_name);
}
}
if !(conn.capabilities & Capabilities::CONNECT_ATTRS).is_empty() {
if !(conn.context.capabilities & Capabilities::CONNECT_ATTRS).is_empty() {
if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) {
// int<lenenc>
conn.encoder.encode_int_lenenc(Some(conn_attr_len));

View File

@ -1,5 +1,5 @@
use super::super::{
deserialize::{Deserialize, DeContext},
deserialize::{DeContext, Deserialize},
types::{Capabilities, ServerStatusFlag},
};
use bytes::Bytes;
@ -102,41 +102,84 @@ impl Deserialize for InitialHandshakePacket {
#[cfg(test)]
mod test {
use super::*;
use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder};
use bytes::BytesMut;
use mason_core::ConnectOptions;
#[runtime::test]
async fn it_decodes_initial_handshake_packet() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
let mut conn = crate::connection::Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = BytesMut::from(b"\
n\0\0\
\0\
\n\
5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0\
\x13\0\0\0\
?~~|vZAu\
\0\
\xfe\xf7\
\x08\
\x02\0\
\xff\x81\
\x15\
\0\0\0\0\0\0\
\x07\0\0\0\
JQ8cihP4Q}Dx\
\0\
mysql_native_password\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
.to_vec(),
let buf = __bytes_builder!(
// int<3> length
1u8,
0u8,
0u8,
// int<1> seq_no
0u8,
//int<1> protocol version
10u8,
//string<NUL> server version (MariaDB server version is by default prefixed by "5.5.5-")
b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0",
//int<4> connection id
13u8,
0u8,
0u8,
0u8,
//string<8> scramble 1st part (authentication seed)
b"?~~|vZAu",
//string<1> reserved byte
0u8,
//int<2> server capabilities (1st part)
0xFEu8,
0xF7u8,
//int<1> server default collation
8u8,
//int<2> status flags
2u8,
0u8,
//int<2> server capabilities (2nd part)
0xFF_u8,
0x81_u8,
//if (server_capabilities & PLUGIN_AUTH)
// int<1> plugin data length
//else
// int<1> 0x00
15u8,
//string<6> filler
0u8,
0u8,
0u8,
0u8,
0u8,
0u8,
//if (server_capabilities & CLIENT_MYSQL)
// string<4> filler
//else
// int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */
7u8,
0u8,
0u8,
0u8,
//if (server_capabilities & CLIENT_SECURE_CONNECTION)
// string<n> scramble 2nd part . Length = max(12, plugin data length - 9)
b"JQ8cihP4Q}Dx",
// string<1> reserved byte
0u8,
//if (server_capabilities & PLUGIN_AUTH)
// string<NUL> authentication plugin name
b"mysql_native_password\0"
);
let _message = InitialHandshakePacket::deserialize(&mut conn, &mut Decoder::new(&buf.freeze()))?;
let _message =
InitialHandshakePacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
Ok(())
}

View File

@ -12,9 +12,11 @@ pub mod com_set_option;
pub mod com_shutdown;
pub mod com_sleep;
pub mod com_statistics;
pub mod eof;
pub mod err;
pub mod handshake_response;
pub mod initial;
pub mod ok;
pub mod packet_header;
pub mod result_set;
pub mod ssl_request;

View File

@ -1,7 +1,9 @@
use super::super::{deserialize::Deserialize, deserialize::DeContext, types::ServerStatusFlag};
use super::super::{
deserialize::{DeContext, Deserialize},
types::ServerStatusFlag,
};
use bytes::Bytes;
use failure::Error;
use failure::err_msg;
use failure::{err_msg, Error};
#[derive(Default, Debug)]
pub struct OkPacket {
@ -57,7 +59,7 @@ impl Deserialize for OkPacket {
#[cfg(test)]
mod test {
use super::*;
use bytes::BytesMut;
use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder};
use mason_core::ConnectOptions;
#[runtime::test]
@ -68,22 +70,38 @@ mod test {
user: Some("root"),
database: None,
password: None,
}).await?;
})
.await?;
let buf = BytesMut::from(b"\
\x0F\x00\x00\
\x01\
\x00\
\xFB\
\xFB\
\x01\x01\
\x00\x00\
info\
"
.to_vec(),
#[rustfmt::skip]
let buf = __bytes_builder!(
// length
0x0F_u8, 0x0_u8, 0x0_u8,
// seq_no
0x01_u8,
// 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set)
0x00_u8,
// int<lenenc> affected rows
0xFB_u8,
// int<lenenc> last insert id
0xFB_u8,
// int<2> server status
0x01_u8, 0x01_u8,
// int<2> warning count
0x0_u8, 0x0_u8,
// if session_tracking_supported (see CLIENT_SESSION_TRACK) {
// string<lenenc> info
// if (status flags & SERVER_SESSION_STATE_CHANGED) {
// string<lenenc> session state info
// string<lenenc> value of variable
// }
// } else {
// string<EOF> info
b"info"
// }
);
let message = OkPacket::deserialize(&mut conn, &mut Decoder::new(&buf.freeze()))?;
let message = OkPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?;
assert_eq!(message.affected_rows, None);
assert_eq!(message.last_insert_id, None);

View File

@ -0,0 +1,4 @@
pub struct PacketHeader {
pub length: u32,
pub seq_no: u8,
}

View File

@ -1,9 +1,9 @@
use bytes::Bytes;
use failure::Error;
use super::super::{
deserialize::{Deserialize, DeContext},
deserialize::{DeContext, Deserialize},
packets::{column::ColumnPacket, column_def::ColumnDefPacket},
};
use bytes::Bytes;
use failure::Error;
#[derive(Debug, Default)]
pub struct ResultSet {
@ -21,6 +21,18 @@ impl Deserialize for ResultSet {
let column_packet = ColumnPacket::deserialize(ctx)?;
match ctx.decoder.decode_int_1() {
// 0x00 -> PACKET_OK
0x00 => {}
// 0xFF -> PACKET_ERR
0xFF => {}
_ => {
panic!("Didn't receive 0x00 nor 0xFF");
}
}
let columns = if let Some(columns) = column_packet.columns {
(0..columns)
.map(|_| ColumnDefPacket::deserialize(ctx))
@ -33,68 +45,164 @@ impl Deserialize for ResultSet {
let mut rows = Vec::new();
for _ in 0.. {
loop {
// if end of buffer stop
if ctx.decoder.eof() {
break;
}
// Decode each column as string<lenenc>
rows.push(
(0..column_packet.columns.unwrap_or(0))
.map(|_| ctx.decoder.decode_string_lenenc())
.collect::<Vec<Bytes>>(),
)
let columns = if let Some(columns) = column_packet.columns {
(0..columns).map(|_| ctx.decoder.decode_string_lenenc()).collect::<Vec<Bytes>>()
} else {
Vec::new()
};
}
Ok(ResultSet {
length,
seq_no,
column_packet,
columns,
rows,
})
Ok(ResultSet { length, seq_no, column_packet, columns, rows })
}
}
#[cfg(test)]
mod test {
use bytes::Bytes;
use super::*;
use crate::{__bytes_builder, connection::Connection};
use bytes::{BufMut, Bytes};
#[runtime::test]
async fn it_decodes_result_set_packet() -> Result<(), Error> {
let buf = Bytes::from(b"\
\0\0\0\x01\
\x02\0\0\x02\xff\x02
\x01\0\0a\
\x01\0\0b\
\x01\0\0c\
\x01\0\0d\
\x01\0\0e\
\x01\0\0f\
\xfc\x01\x01\
\x01\x01\
\x01\x01\x01\x01\
\x00\
\x00\x00\
\x01\
\0\0\
\x01\0\0g\
\x01\0\0h\
\x01\0\0i\
\x01\0\0j\
\x01\0\0k\
\x01\0\0l\
\xfc\x01\x01\
\x01\x01\
\x01\x01\x01\x01\
\x00\
\x00\x00\
\x01\
\0\0
".to_vec());
// let message = ColumnDefPacket::deserialize(&mut Connection::mock().await, &mut Decoder::new(&buf))?;
let mut conn = Connection::establish(mason_core::ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
})
.await?;
// conn.query("SELECT * FROM users");
#[rustfmt::skip]
let buf = __bytes_builder!(
// ------------------- //
// Column Count packet //
// ------------------- //
// length
0x02_u8, 0x0_u8, 0x0_u8,
// seq_no
0x02_u8,
// int<lenenc> Column count packet
0x02_u8, 0x00_u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// length
0x02_u8, 0x0_u8, 0x0_u8,
// seq_no
0x02_u8,
// string<lenenc> catalog (always 'def')
0x03_u8, 0x0_u8, 0x0_u8, b"def",
// string<lenenc> schema
0x01_u8, 0x0_u8, 0x0_u8, b'b',
// string<lenenc> table alias
0x01_u8, 0x0_u8, 0x0_u8, b'c',
// string<lenenc> table
0x01_u8, 0x0_u8, 0x0_u8, b'd',
// string<lenenc> column alias
0x01_u8, 0x0_u8, 0x0_u8, b'e',
// string<lenenc> column
0x01_u8, 0x0_u8, 0x0_u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xfc_u8, 0x01_u8, 0x01_u8,
// int<2> character set number
0x01_u8, 0x01_u8,
// int<4> max. column size
0x01_u8, 0x01_u8, 0x01_u8, 0x01_u8,
// int<1> Field types
0x00_u8,
// int<2> Field detail flag
0x00_u8, 0x00_u8,
// int<1> decimals
0x01_u8,
// int<2> - unused -
0x0_u8, 0x0_u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// length
0x02_u8, 0x0_u8, 0x0_u8,
// seq_no
0x02_u8,
// string<lenenc> catalog (always 'def')
0x03_u8, 0x0_u8, 0x0_u8, b"def",
// string<lenenc> schema
0x01_u8, 0x0_u8, 0x0_u8, b'b',
// string<lenenc> table alias
0x01_u8, 0x0_u8, 0x0_u8, b'c',
// string<lenenc> table
0x01_u8, 0x0_u8, 0x0_u8, b'd',
// string<lenenc> column alias
0x01_u8, 0x0_u8, 0x0_u8, b'e',
// string<lenenc> column
0x01_u8, 0x0_u8, 0x0_u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xfc_u8, 0x01_u8, 0x01_u8,
// int<2> character set number
0x01_u8, 0x01_u8,
// int<4> max. column size
0x01_u8, 0x01_u8, 0x01_u8, 0x01_u8,
// int<1> Field types
0x00_u8,
// int<2> Field detail flag
0x00_u8, 0x00_u8,
// int<1> decimals
0x01_u8,
// int<2> - unused -
0x0_u8, 0x00_u8,
// ---------- //
// EOF Packet //
// ---------- //
// length
0x02_u8, 0x0_u8, 0x0_u8,
// seq_no
0x02_u8,
// int<1> 0xfe : EOF header
0xfe_u8,
// int<2> warning count
0x0_u8, 0x0_u8,
// int<2> server status
0x01_u8, 0x00_u8,
// ------------------- //
// N Result Row Packet //
// ------------------- //
// string<lenenc> column data
0x01_u8, 0x0_u8, 0x0_u8, b'h',
// string<lenenc> column data
0x01_u8, 0x0_u8, 0x0_u8, b'i',
// ---------- //
// EOF Packet //
// ---------- //
// length
0x02_u8, 0x0_u8, 0x0_u8,
// seq_no
0x02_u8,
// int<1> 0xfe : EOF header
0xfe_u8,
// int<2> warning count
0x0_u8, 0x0_u8,
// int<2> server status
0x01_u8, 0x00_u8
);
Ok(())
}

View File

@ -20,7 +20,7 @@ impl Serialize for SSLRequestPacket {
// Filler
conn.encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19);
if !(conn.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
if !(conn.context.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
&& !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty()
{
if let Some(capabilities) = self.extended_capabilities {

View File

@ -18,14 +18,7 @@ pub enum Message {
impl Message {
pub fn deserialize(ctx: &mut DeContext) -> Result<Option<Self>, Error> {
let decoder = &mut ctx.decoder;
if decoder.buf.len() < 4 {
return Ok(None);
}
let length = decoder.decode_length()?;
if decoder.buf.len() < (length + 4) as usize {
return Ok(None);
}
let _packet_header = decoder.peek_packet_header()?;
let tag = decoder.buf[4];