From 6e282ee33bfb1e3482c7bbf8bb27940e278941c0 Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Wed, 24 Jul 2019 17:19:30 -0700 Subject: [PATCH] Use DeContext and remove resultset as a packet --- mason-mariadb/src/connection/establish.rs | 17 +- mason-mariadb/src/connection/mod.rs | 57 +++-- mason-mariadb/src/lib.rs | 5 +- mason-mariadb/src/macros/mod.rs | 20 ++ mason-mariadb/src/protocol/decode.rs | 18 ++ mason-mariadb/src/protocol/deserialize.rs | 13 +- mason-mariadb/src/protocol/encode.rs | 12 +- mason-mariadb/src/protocol/packets/column.rs | 51 +++-- .../src/protocol/packets/column_def.rs | 71 +++--- mason-mariadb/src/protocol/packets/eof.rs | 76 +++++++ mason-mariadb/src/protocol/packets/err.rs | 45 +++- .../protocol/packets/handshake_response.rs | 12 +- mason-mariadb/src/protocol/packets/initial.rs | 89 ++++++-- mason-mariadb/src/protocol/packets/mod.rs | 2 + mason-mariadb/src/protocol/packets/ok.rs | 52 +++-- .../src/protocol/packets/packet_header.rs | 4 + .../src/protocol/packets/result_set.rs | 206 +++++++++++++----- .../src/protocol/packets/ssl_request.rs | 2 +- mason-mariadb/src/protocol/server.rs | 9 +- 19 files changed, 561 insertions(+), 200 deletions(-) create mode 100644 mason-mariadb/src/macros/mod.rs create mode 100644 mason-mariadb/src/protocol/packets/eof.rs create mode 100644 mason-mariadb/src/protocol/packets/packet_header.rs diff --git a/mason-mariadb/src/connection/establish.rs b/mason-mariadb/src/connection/establish.rs index 0b5d69e5..18cd2173 100644 --- a/mason-mariadb/src/connection/establish.rs +++ b/mason-mariadb/src/connection/establish.rs @@ -1,11 +1,11 @@ -use super::{Connection}; +use super::Connection; use crate::protocol::{ - deserialize::{Deserialize, DeContext}, + deserialize::{DeContext, Deserialize}, packets::{handshake_response::HandshakeResponsePacket, initial::InitialHandshakePacket}, server::Message as ServerMessage, types::Capabilities, }; -use bytes::Bytes; +use bytes::{BufMut, Bytes}; use failure::{err_msg, Error}; use mason_core::ConnectOptions; @@ -14,7 +14,7 @@ pub async fn establish<'a, 'b: 'a>( options: ConnectOptions<'b>, ) -> Result<(), Error> { let buf = &conn.stream.next_bytes().await?; - let mut de_ctx = DeContext::new(conn, &buf); + let mut de_ctx = DeContext::new(&mut conn.context, &buf); let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?; let handshake: HandshakeResponsePacket = HandshakeResponsePacket { @@ -30,7 +30,7 @@ pub async fn establish<'a, 'b: 'a>( match conn.next().await? { Some(ServerMessage::OkPacket(message)) => { - conn.seq_no = message.seq_no; + conn.context.seq_no = message.seq_no; Ok(()) } @@ -41,7 +41,7 @@ pub async fn establish<'a, 'b: 'a>( } None => { - panic!("Did not recieve packet"); + panic!("Did not receive packet"); } } } @@ -76,9 +76,10 @@ mod test { database: None, password: None, }) - .await { + .await + { Ok(_) => Err(err_msg("Bad username still worked?")), - Err(_) => Ok(()) + Err(_) => Ok(()), } } } diff --git a/mason-mariadb/src/connection/mod.rs b/mason-mariadb/src/connection/mod.rs index 91db28d8..c2f6b8f6 100644 --- a/mason-mariadb/src/connection/mod.rs +++ b/mason-mariadb/src/connection/mod.rs @@ -1,7 +1,7 @@ use crate::protocol::{ - deserialize::{Deserialize, DeContext}, + deserialize::{DeContext, Deserialize}, encode::Encoder, - packets::{com_ping::ComPing, com_quit::ComQuit, ok::OkPacket}, + packets::{com_ping::ComPing, com_query::ComQuery, com_quit::ComQuit, ok::OkPacket}, serialize::Serialize, server::Message as ServerMessage, types::{Capabilities, ServerStatusFlag}, @@ -24,6 +24,12 @@ pub struct Connection { // Buffer used when serializing outgoing messages pub encoder: Encoder, + // Context for the connection + // Explicitly declared to easily send to deserializers + pub context: ConnContext, +} + +pub struct ConnContext { // MariaDB Connection ID pub connection_id: i32, @@ -41,16 +47,18 @@ pub struct Connection { } impl Connection { - pub async fn establish(options: ConnectOptions<'static>) -> Result { + pub(crate) async fn establish(options: ConnectOptions<'static>) -> Result { let stream: Framed = Framed::new(TcpStream::connect((options.host, options.port)).await?); let mut conn: Connection = Self { stream, encoder: Encoder::new(1024), - connection_id: -1, - seq_no: 1, - last_seq_no: 0, - capabilities: Capabilities::default(), - status: ServerStatusFlag::default(), + context: ConnContext { + connection_id: -1, + seq_no: 1, + last_seq_no: 0, + capabilities: Capabilities::default(), + status: ServerStatusFlag::default(), + }, }; establish::establish(&mut conn, options).await?; @@ -58,13 +66,13 @@ impl Connection { Ok(conn) } - async fn send(&mut self, message: S) -> Result<(), Error> + pub(crate) async fn send(&mut self, message: S) -> Result<(), Error> where S: Serialize, { self.encoder.clear(); self.encoder.alloc_packet_header(); - self.encoder.seq_no(self.seq_no); + self.encoder.seq_no(self.context.seq_no); message.serialize(self)?; self.encoder.encode_length(); @@ -74,24 +82,30 @@ impl Connection { Ok(()) } - async fn quit(&mut self) -> Result<(), Error> { + pub(crate) async fn quit(&mut self) -> Result<(), Error> { self.send(ComQuit()).await?; Ok(()) } - async fn ping(&mut self) -> Result<(), Error> { - self.seq_no = 0; - self.send(ComPing()).await?; - - // Ping response must be an OkPacket - let buf = self.stream.next_bytes().await?; - OkPacket::deserialize(&mut DeContext::new(self, &buf))?; + pub(crate) async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> { + self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?; Ok(()) } - async fn next(&mut self) -> Result, Error> { + pub(crate) async fn ping(&mut self) -> Result<(), Error> { + self.context.seq_no = 0; + self.send(ComPing()).await?; + + // Ping response must be an OkPacket + let buf = self.stream.next_bytes().await?; + OkPacket::deserialize(&mut DeContext::new(&mut self.context, &buf))?; + + Ok(()) + } + + pub(crate) async fn next(&mut self) -> Result, Error> { let mut rbuf = BytesMut::new(); let mut len = 0; @@ -118,7 +132,10 @@ impl Connection { while len > 0 { let size = rbuf.len(); - let message = ServerMessage::deserialize(&mut DeContext::new(self, &rbuf.as_ref().into()))?; + let message = ServerMessage::deserialize(&mut DeContext::new( + &mut self.context, + &rbuf.as_ref().into(), + ))?; len -= size - rbuf.len(); match message { diff --git a/mason-mariadb/src/lib.rs b/mason-mariadb/src/lib.rs index 35a10532..e4fcccde 100644 --- a/mason-mariadb/src/lib.rs +++ b/mason-mariadb/src/lib.rs @@ -1,7 +1,7 @@ #![feature(non_exhaustive, async_await)] #![allow(clippy::needless_lifetimes)] // TODO: Remove this once API has matured -#![allow(dead_code)] +#![allow(dead_code, unused_imports, unused_variables)] #[macro_use] extern crate bitflags; @@ -10,3 +10,6 @@ extern crate enum_tryfrom_derive; pub mod connection; pub mod protocol; + +#[macro_use] +pub mod macros; diff --git a/mason-mariadb/src/macros/mod.rs b/mason-mariadb/src/macros/mod.rs new file mode 100644 index 00000000..296cc9a7 --- /dev/null +++ b/mason-mariadb/src/macros/mod.rs @@ -0,0 +1,20 @@ +#[cfg(test)] +#[doc(hidden)] +#[macro_export] +macro_rules! __bytes_builder ( + ($($b: expr), *) => {{ + use bytes::Buf; + use bytes::IntoBuf; + use bytes::BufMut; + + let mut bytes = bytes::BytesMut::new(); + $( + { + let buf = $b.into_buf(); + bytes.reserve(buf.remaining()); + bytes.put(buf); + } + )* + bytes.freeze() + }} +); diff --git a/mason-mariadb/src/protocol/decode.rs b/mason-mariadb/src/protocol/decode.rs index f77ee70b..14619c15 100644 --- a/mason-mariadb/src/protocol/decode.rs +++ b/mason-mariadb/src/protocol/decode.rs @@ -1,4 +1,5 @@ // Deserializing bytes and string do the same thing. Except that string also has a null terminated deserialzer +use super::packets::packet_header::PacketHeader; use byteorder::{ByteOrder, LittleEndian}; use bytes::Bytes; use failure::{err_msg, Error}; @@ -24,6 +25,23 @@ impl<'a> Decoder<'a> { Ok(length) } + #[inline] + pub fn peek_tag(&self) -> Option<&u8> { + self.buf.get(4) + } + + #[inline] + pub fn peek_packet_header(&self) -> Result { + let length = LittleEndian::read_u24(&self.buf[self.index..]); + let seq_no = self.buf[3]; + + if self.buf.len() < length as usize { + return Err(err_msg("Lengths to do not match")); + } + + Ok(PacketHeader { length, seq_no }) + } + #[inline] pub fn skip_bytes(&mut self, amount: usize) { self.index += amount; diff --git a/mason-mariadb/src/protocol/deserialize.rs b/mason-mariadb/src/protocol/deserialize.rs index 683cc851..802f5dd5 100644 --- a/mason-mariadb/src/protocol/deserialize.rs +++ b/mason-mariadb/src/protocol/deserialize.rs @@ -1,19 +1,16 @@ use super::decode::Decoder; -use failure::Error; -use crate::connection::Connection; +use crate::connection::{ConnContext, Connection}; use bytes::Bytes; +use failure::Error; pub struct DeContext<'a> { - pub conn: &'a mut Connection, + pub conn: &'a mut ConnContext, pub decoder: Decoder<'a>, } impl<'a> DeContext<'a> { - pub fn new(conn: &'a mut Connection, buf: &'a Bytes) -> Self { - DeContext { - conn, - decoder: Decoder::new(&buf), - } + pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self { + DeContext { conn, decoder: Decoder::new(&buf) } } } diff --git a/mason-mariadb/src/protocol/encode.rs b/mason-mariadb/src/protocol/encode.rs index 74b7bef1..0810c1b5 100644 --- a/mason-mariadb/src/protocol/encode.rs +++ b/mason-mariadb/src/protocol/encode.rs @@ -56,7 +56,7 @@ impl Encoder { } #[inline] - pub fn encode_int_3(&mut self, value: u32) { + pub fn encode_int_3(&mut self, value: u32) { self.buf.extend_from_slice(&value.to_le_bytes()[0..3]); } @@ -74,22 +74,22 @@ impl Encoder { pub fn encode_int_lenenc(&mut self, value: Option<&usize>) { if let Some(value) = value { if *value > U24_MAX && *value <= std::u64::MAX as usize { - self.buf.push(0xFE); + self.buf.put_u8(0xFE); self.encode_int_8(*value as u64); } else if *value > std::u16::MAX as usize && *value <= U24_MAX { - self.buf.push(0xFD); + self.buf.put_u8(0xFD); self.encode_int_3(*value as u32); } else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize { - self.buf.push(0xFC); + self.buf.put_u8(0xFC); self.encode_int_2(*value as u16); } else if *value <= std::u8::MAX as usize { - self.buf.push(0xFA); + self.buf.put_u8(0xFA); self.encode_int_1(*value as u8); } else { panic!("Value is too long"); } } else { - self.buf.push(0xFB); + self.buf.put_u8(0xFB); } } diff --git a/mason-mariadb/src/protocol/packets/column.rs b/mason-mariadb/src/protocol/packets/column.rs index 146bfc7d..d6c8cc6b 100644 --- a/mason-mariadb/src/protocol/packets/column.rs +++ b/mason-mariadb/src/protocol/packets/column.rs @@ -1,28 +1,27 @@ -use super::super::deserialize::{Deserialize, DeContext}; +use super::super::deserialize::{DeContext, Deserialize}; use failure::Error; #[derive(Default, Debug)] +// ColumnPacket doesn't have a packet header because +// it's nested inside a result set packet pub struct ColumnPacket { - pub length: u32, - pub seq_no: u8, pub columns: Option, } impl Deserialize for ColumnPacket { fn deserialize(ctx: &mut DeContext) -> Result { let decoder = &mut ctx.decoder; - let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); let columns = decoder.decode_int_lenenc(); - Ok(ColumnPacket { length, seq_no, columns }) + Ok(ColumnPacket { columns }) } } #[cfg(test)] mod test { - use bytes::Bytes; use super::*; + use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder}; + use bytes::Bytes; use mason_core::ConnectOptions; #[runtime::test] @@ -33,10 +32,15 @@ mod test { user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = Bytes::from(b"\x01\0\0\x01\xFB".to_vec()); - let message = ColumnPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?; + let buf = __bytes_builder!( + // int tag code: None + 0xFB_u8 + ); + + let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; assert_eq!(message.columns, None); @@ -51,10 +55,16 @@ mod test { user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = Bytes::from(b"\x04\0\0\x01\xFD\x01\x01\x01".to_vec()); - let message = ColumnPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?; + let buf = __bytes_builder!( + // int tag code: Some(3 bytes) + 0xFD_u8, // value: 3 bytes + 0x01_u8, 0x01_u8, 0x01_u8 + ); + + let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; assert_eq!(message.columns, Some(0x010101)); @@ -69,14 +79,21 @@ mod test { user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = Bytes::from(b"\x03\0\0\x01\xFC\x01\x01".to_vec()); - let message = ColumnPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?; + #[rustfmt::skip] + let buf = __bytes_builder!( + // int tag code: Some(3 bytes) + 0xFC_u8, + // value: 2 bytes + 0x01_u8, 0x01_u8 + ); + + let message = ColumnPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; assert_ne!(message.columns, Some(0x0100)); Ok(()) } } - diff --git a/mason-mariadb/src/protocol/packets/column_def.rs b/mason-mariadb/src/protocol/packets/column_def.rs index 8423d282..7d382e80 100644 --- a/mason-mariadb/src/protocol/packets/column_def.rs +++ b/mason-mariadb/src/protocol/packets/column_def.rs @@ -1,15 +1,15 @@ -use std::convert::TryFrom; -use bytes::Bytes; -use failure::Error; use super::super::{ - deserialize::{Deserialize, DeContext}, + deserialize::{DeContext, Deserialize}, types::{FieldDetailFlag, FieldType}, }; +use bytes::Bytes; +use failure::Error; +use std::convert::TryFrom; #[derive(Debug, Default)] +// ColumnDefPacket doesn't have a packet header because +// it's nested inside a result set packet pub struct ColumnDefPacket { - pub length: u32, - pub seq_no: u8, pub catalog: Bytes, pub schema: Bytes, pub table_alias: Bytes, @@ -27,8 +27,6 @@ pub struct ColumnDefPacket { impl Deserialize for ColumnDefPacket { fn deserialize(ctx: &mut DeContext) -> Result { let decoder = &mut ctx.decoder; - let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_1(); let catalog = decoder.decode_string_lenenc(); let schema = decoder.decode_string_lenenc(); @@ -47,8 +45,6 @@ impl Deserialize for ColumnDefPacket { decoder.skip_bytes(2); Ok(ColumnDefPacket { - length, - seq_no, catalog, schema, table_alias, @@ -67,8 +63,9 @@ impl Deserialize for ColumnDefPacket { #[cfg(test)] mod test { - use bytes::Bytes; use super::*; + use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder}; + use bytes::Bytes; use mason_core::ConnectOptions; #[runtime::test] @@ -79,26 +76,40 @@ mod test { user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = Bytes::from(b"\ - \0\0\0\ - \x01\ - \x01\0\0a\ - \x01\0\0b\ - \x01\0\0c\ - \x01\0\0d\ - \x01\0\0e\ - \x01\0\0f\ - \xfc\x01\x01\ - \x01\x01\ - \x01\x01\x01\x01\ - \x00\ - \x00\x00\ - \x01\ - \0\0 - ".to_vec()); - let message = ColumnDefPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?; + #[rustfmt::skip] + let buf = __bytes_builder!( + // string catalog (always 'def') + 1u8, 0u8, 0u8, b'a', + // string schema + 1u8, 0u8, 0u8, b'b', + // string table alias + 1u8, 0u8, 0u8, b'c', + // string table + 1u8, 0u8, 0u8, b'd', + // string column alias + 1u8, 0u8, 0u8, b'e', + // string column + 1u8, 0u8, 0u8, b'f', + // int length of fixed fields (=0xC) + 0xFC_u8, 1u8, 1u8, + // int<2> character set number + 1u8, 1u8, + // int<4> max. column size + 1u8, 1u8, 1u8, 1u8, + // int<1> Field types + 1u8, + // int<2> Field detail flag + 1u8, 0u8, + // int<1> decimals + 1u8, + // int<2> - unused - + 0u8, 0u8 + ); + + let message = ColumnDefPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; assert_eq!(&message.catalog[..], b"a"); assert_eq!(&message.schema[..], b"b"); diff --git a/mason-mariadb/src/protocol/packets/eof.rs b/mason-mariadb/src/protocol/packets/eof.rs new file mode 100644 index 00000000..49aa7c0c --- /dev/null +++ b/mason-mariadb/src/protocol/packets/eof.rs @@ -0,0 +1,76 @@ +use super::super::{ + decode::Decoder, + deserialize::{DeContext, Deserialize}, + error_codes::ErrorCode, + types::ServerStatusFlag, +}; +use bytes::Bytes; +use failure::Error; +use std::convert::TryFrom; + +#[derive(Default, Debug)] +pub struct EofPacket { + pub length: u32, + pub seq_no: u8, + pub warning_count: u16, + pub status: ServerStatusFlag, +} + +impl Deserialize for EofPacket { + fn deserialize(ctx: &mut DeContext) -> Result { + let decoder = &mut ctx.decoder; + + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_1(); + + let packet_header = decoder.decode_int_1(); + + if packet_header != 0xFE { + panic!("Packet header is not 0xFE for ErrPacket"); + } + + let warning_count = decoder.decode_int_2(); + let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2()); + + Ok(EofPacket { length, seq_no, warning_count, status }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{__bytes_builder, connection::Connection}; + use bytes::Bytes; + use mason_core::ConnectOptions; + + #[runtime::test] + async fn it_decodes_eof_packet() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }) + .await?; + + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 1u8, + // int<1> 0xfe : EOF header + 0xFE_u8, + // int<2> warning count + 0u8, 0u8, + // int<2> server status + 1u8, 1u8 + ); + + let buf = Bytes::from_static(b"\x01\0\0\x01\xFE\x00\x00\x01\x00"); + let _message = EofPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; + + Ok(()) + } +} diff --git a/mason-mariadb/src/protocol/packets/err.rs b/mason-mariadb/src/protocol/packets/err.rs index 200421e4..e9ccdfba 100644 --- a/mason-mariadb/src/protocol/packets/err.rs +++ b/mason-mariadb/src/protocol/packets/err.rs @@ -1,7 +1,10 @@ -use std::convert::TryFrom; +use super::super::{ + deserialize::{DeContext, Deserialize}, + error_codes::ErrorCode, +}; use bytes::Bytes; use failure::Error; -use super::super::{deserialize::Deserialize, deserialize::DeContext, error_codes::ErrorCode}; +use std::convert::TryFrom; #[derive(Default, Debug)] pub struct ErrPacket { @@ -72,8 +75,9 @@ impl Deserialize for ErrPacket { #[cfg(test)] mod test { - use bytes::Bytes; use super::*; + use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder}; + use bytes::Bytes; use mason_core::ConnectOptions; #[runtime::test] @@ -84,10 +88,39 @@ mod test { user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = Bytes::from(b"!\0\0\x01\xff\x84\x04#08S01Got packets out of order".to_vec()); - let _message = ErrPacket::deserialize(&mut conn, &mut Decoder::new(&buf))?; + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 1u8, + // int<1> 0xfe : EOF header + 0xFF_u8, + // int<2> error code + 0x84_u8, 0x04_u8, + // if (errorcode == 0xFFFF) /* progress reporting */ { + // int<1> stage + // int<1> max_stage + // int<3> progress + // string progress_info + // } else { + // if (next byte = '#') { + // string<1> sql state marker '#' + b"#", + // string<5>sql state + b"08S01", + // string error message + b"Got packets out of order" + // } else { + // string error message + // } + // } + ); + + let _message = ErrPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; Ok(()) } diff --git a/mason-mariadb/src/protocol/packets/handshake_response.rs b/mason-mariadb/src/protocol/packets/handshake_response.rs index b4998072..f6c10f21 100644 --- a/mason-mariadb/src/protocol/packets/handshake_response.rs +++ b/mason-mariadb/src/protocol/packets/handshake_response.rs @@ -28,7 +28,7 @@ impl Serialize for HandshakeResponsePacket { // Filler conn.encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19); - if !(conn.capabilities & Capabilities::CLIENT_MYSQL).is_empty() + if !(conn.context.capabilities & Capabilities::CLIENT_MYSQL).is_empty() && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { @@ -40,11 +40,11 @@ impl Serialize for HandshakeResponsePacket { conn.encoder.encode_string_null(&self.username); - if !(conn.capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() { + if !(conn.context.capabilities & Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA).is_empty() { if let Some(auth_data) = &self.auth_data { conn.encoder.encode_string_lenenc(&auth_data); } - } else if !(conn.capabilities & Capabilities::SECURE_CONNECTION).is_empty() { + } else if !(conn.context.capabilities & Capabilities::SECURE_CONNECTION).is_empty() { if let Some(auth_response) = &self.auth_response { conn.encoder.encode_int_1(self.auth_response_len.unwrap()); conn.encoder @@ -54,21 +54,21 @@ impl Serialize for HandshakeResponsePacket { conn.encoder.encode_int_1(0); } - if !(conn.capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { + if !(conn.context.capabilities & Capabilities::CONNECT_WITH_DB).is_empty() { if let Some(database) = &self.database { // string conn.encoder.encode_string_null(&database); } } - if !(conn.capabilities & Capabilities::PLUGIN_AUTH).is_empty() { + if !(conn.context.capabilities & Capabilities::PLUGIN_AUTH).is_empty() { if let Some(auth_plugin_name) = &self.auth_plugin_name { // string conn.encoder.encode_string_null(&auth_plugin_name); } } - if !(conn.capabilities & Capabilities::CONNECT_ATTRS).is_empty() { + if !(conn.context.capabilities & Capabilities::CONNECT_ATTRS).is_empty() { if let (Some(conn_attr_len), Some(conn_attr)) = (&self.conn_attr_len, &self.conn_attr) { // int conn.encoder.encode_int_lenenc(Some(conn_attr_len)); diff --git a/mason-mariadb/src/protocol/packets/initial.rs b/mason-mariadb/src/protocol/packets/initial.rs index c9272c12..d03ddea9 100644 --- a/mason-mariadb/src/protocol/packets/initial.rs +++ b/mason-mariadb/src/protocol/packets/initial.rs @@ -1,5 +1,5 @@ use super::super::{ - deserialize::{Deserialize, DeContext}, + deserialize::{DeContext, Deserialize}, types::{Capabilities, ServerStatusFlag}, }; use bytes::Bytes; @@ -102,41 +102,84 @@ impl Deserialize for InitialHandshakePacket { #[cfg(test)] mod test { use super::*; + use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder}; use bytes::BytesMut; use mason_core::ConnectOptions; #[runtime::test] async fn it_decodes_initial_handshake_packet() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { + let mut conn = crate::connection::Connection::establish(ConnectOptions { host: "127.0.0.1", port: 3306, user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = BytesMut::from(b"\ - n\0\0\ - \0\ - \n\ - 5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0\ - \x13\0\0\0\ - ?~~|vZAu\ - \0\ - \xfe\xf7\ - \x08\ - \x02\0\ - \xff\x81\ - \x15\ - \0\0\0\0\0\0\ - \x07\0\0\0\ - JQ8cihP4Q}Dx\ - \0\ - mysql_native_password\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" - .to_vec(), + let buf = __bytes_builder!( + // int<3> length + 1u8, + 0u8, + 0u8, + // int<1> seq_no + 0u8, + //int<1> protocol version + 10u8, + //string server version (MariaDB server version is by default prefixed by "5.5.5-") + b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0", + //int<4> connection id + 13u8, + 0u8, + 0u8, + 0u8, + //string<8> scramble 1st part (authentication seed) + b"?~~|vZAu", + //string<1> reserved byte + 0u8, + //int<2> server capabilities (1st part) + 0xFEu8, + 0xF7u8, + //int<1> server default collation + 8u8, + //int<2> status flags + 2u8, + 0u8, + //int<2> server capabilities (2nd part) + 0xFF_u8, + 0x81_u8, + //if (server_capabilities & PLUGIN_AUTH) + // int<1> plugin data length + //else + // int<1> 0x00 + 15u8, + //string<6> filler + 0u8, + 0u8, + 0u8, + 0u8, + 0u8, + 0u8, + //if (server_capabilities & CLIENT_MYSQL) + // string<4> filler + //else + // int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */ + 7u8, + 0u8, + 0u8, + 0u8, + //if (server_capabilities & CLIENT_SECURE_CONNECTION) + // string scramble 2nd part . Length = max(12, plugin data length - 9) + b"JQ8cihP4Q}Dx", + // string<1> reserved byte + 0u8, + //if (server_capabilities & PLUGIN_AUTH) + // string authentication plugin name + b"mysql_native_password\0" ); - let _message = InitialHandshakePacket::deserialize(&mut conn, &mut Decoder::new(&buf.freeze()))?; + let _message = + InitialHandshakePacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; Ok(()) } diff --git a/mason-mariadb/src/protocol/packets/mod.rs b/mason-mariadb/src/protocol/packets/mod.rs index 66950464..f20654b5 100644 --- a/mason-mariadb/src/protocol/packets/mod.rs +++ b/mason-mariadb/src/protocol/packets/mod.rs @@ -12,9 +12,11 @@ pub mod com_set_option; pub mod com_shutdown; pub mod com_sleep; pub mod com_statistics; +pub mod eof; pub mod err; pub mod handshake_response; pub mod initial; pub mod ok; +pub mod packet_header; pub mod result_set; pub mod ssl_request; diff --git a/mason-mariadb/src/protocol/packets/ok.rs b/mason-mariadb/src/protocol/packets/ok.rs index b6fbc08c..e8961ba9 100644 --- a/mason-mariadb/src/protocol/packets/ok.rs +++ b/mason-mariadb/src/protocol/packets/ok.rs @@ -1,7 +1,9 @@ -use super::super::{deserialize::Deserialize, deserialize::DeContext, types::ServerStatusFlag}; +use super::super::{ + deserialize::{DeContext, Deserialize}, + types::ServerStatusFlag, +}; use bytes::Bytes; -use failure::Error; -use failure::err_msg; +use failure::{err_msg, Error}; #[derive(Default, Debug)] pub struct OkPacket { @@ -57,7 +59,7 @@ impl Deserialize for OkPacket { #[cfg(test)] mod test { use super::*; - use bytes::BytesMut; + use crate::{__bytes_builder, connection::Connection, protocol::decode::Decoder}; use mason_core::ConnectOptions; #[runtime::test] @@ -68,22 +70,38 @@ mod test { user: Some("root"), database: None, password: None, - }).await?; + }) + .await?; - let buf = BytesMut::from(b"\ - \x0F\x00\x00\ - \x01\ - \x00\ - \xFB\ - \xFB\ - \x01\x01\ - \x00\x00\ - info\ - " - .to_vec(), + #[rustfmt::skip] + let buf = __bytes_builder!( + // length + 0x0F_u8, 0x0_u8, 0x0_u8, + // seq_no + 0x01_u8, + // 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set) + 0x00_u8, + // int affected rows + 0xFB_u8, + // int last insert id + 0xFB_u8, + // int<2> server status + 0x01_u8, 0x01_u8, + // int<2> warning count + 0x0_u8, 0x0_u8, + // if session_tracking_supported (see CLIENT_SESSION_TRACK) { + // string info + // if (status flags & SERVER_SESSION_STATE_CHANGED) { + // string session state info + // string value of variable + // } + // } else { + // string info + b"info" + // } ); - let message = OkPacket::deserialize(&mut conn, &mut Decoder::new(&buf.freeze()))?; + let message = OkPacket::deserialize(&mut DeContext::new(&mut conn.context, &buf))?; assert_eq!(message.affected_rows, None); assert_eq!(message.last_insert_id, None); diff --git a/mason-mariadb/src/protocol/packets/packet_header.rs b/mason-mariadb/src/protocol/packets/packet_header.rs new file mode 100644 index 00000000..976460c7 --- /dev/null +++ b/mason-mariadb/src/protocol/packets/packet_header.rs @@ -0,0 +1,4 @@ +pub struct PacketHeader { + pub length: u32, + pub seq_no: u8, +} diff --git a/mason-mariadb/src/protocol/packets/result_set.rs b/mason-mariadb/src/protocol/packets/result_set.rs index 0691a14e..6006d38f 100644 --- a/mason-mariadb/src/protocol/packets/result_set.rs +++ b/mason-mariadb/src/protocol/packets/result_set.rs @@ -1,9 +1,9 @@ -use bytes::Bytes; -use failure::Error; use super::super::{ - deserialize::{Deserialize, DeContext}, + deserialize::{DeContext, Deserialize}, packets::{column::ColumnPacket, column_def::ColumnDefPacket}, }; +use bytes::Bytes; +use failure::Error; #[derive(Debug, Default)] pub struct ResultSet { @@ -21,6 +21,18 @@ impl Deserialize for ResultSet { let column_packet = ColumnPacket::deserialize(ctx)?; + match ctx.decoder.decode_int_1() { + // 0x00 -> PACKET_OK + 0x00 => {} + + // 0xFF -> PACKET_ERR + 0xFF => {} + + _ => { + panic!("Didn't receive 0x00 nor 0xFF"); + } + } + let columns = if let Some(columns) = column_packet.columns { (0..columns) .map(|_| ColumnDefPacket::deserialize(ctx)) @@ -33,68 +45,164 @@ impl Deserialize for ResultSet { let mut rows = Vec::new(); - for _ in 0.. { + loop { // if end of buffer stop if ctx.decoder.eof() { break; } - // Decode each column as string - rows.push( - (0..column_packet.columns.unwrap_or(0)) - .map(|_| ctx.decoder.decode_string_lenenc()) - .collect::>(), - ) + let columns = if let Some(columns) = column_packet.columns { + (0..columns).map(|_| ctx.decoder.decode_string_lenenc()).collect::>() + } else { + Vec::new() + }; } - Ok(ResultSet { - length, - seq_no, - column_packet, - columns, - rows, - }) + Ok(ResultSet { length, seq_no, column_packet, columns, rows }) } } #[cfg(test)] mod test { - use bytes::Bytes; use super::*; + use crate::{__bytes_builder, connection::Connection}; + use bytes::{BufMut, Bytes}; #[runtime::test] async fn it_decodes_result_set_packet() -> Result<(), Error> { - let buf = Bytes::from(b"\ - \0\0\0\x01\ - \x02\0\0\x02\xff\x02 - \x01\0\0a\ - \x01\0\0b\ - \x01\0\0c\ - \x01\0\0d\ - \x01\0\0e\ - \x01\0\0f\ - \xfc\x01\x01\ - \x01\x01\ - \x01\x01\x01\x01\ - \x00\ - \x00\x00\ - \x01\ - \0\0\ - \x01\0\0g\ - \x01\0\0h\ - \x01\0\0i\ - \x01\0\0j\ - \x01\0\0k\ - \x01\0\0l\ - \xfc\x01\x01\ - \x01\x01\ - \x01\x01\x01\x01\ - \x00\ - \x00\x00\ - \x01\ - \0\0 - ".to_vec()); -// let message = ColumnDefPacket::deserialize(&mut Connection::mock().await, &mut Decoder::new(&buf))?; + let mut conn = Connection::establish(mason_core::ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }) + .await?; + + // conn.query("SELECT * FROM users"); + + #[rustfmt::skip] + let buf = __bytes_builder!( + // ------------------- // + // Column Count packet // + // ------------------- // + + // length + 0x02_u8, 0x0_u8, 0x0_u8, + // seq_no + 0x02_u8, + // int Column count packet + 0x02_u8, 0x00_u8, + + // ------------------------ // + // Column Definition packet // + // ------------------------ // + + // length + 0x02_u8, 0x0_u8, 0x0_u8, + // seq_no + 0x02_u8, + // string catalog (always 'def') + 0x03_u8, 0x0_u8, 0x0_u8, b"def", + // string schema + 0x01_u8, 0x0_u8, 0x0_u8, b'b', + // string table alias + 0x01_u8, 0x0_u8, 0x0_u8, b'c', + // string table + 0x01_u8, 0x0_u8, 0x0_u8, b'd', + // string column alias + 0x01_u8, 0x0_u8, 0x0_u8, b'e', + // string column + 0x01_u8, 0x0_u8, 0x0_u8, b'f', + // int length of fixed fields (=0xC) + 0xfc_u8, 0x01_u8, 0x01_u8, + // int<2> character set number + 0x01_u8, 0x01_u8, + // int<4> max. column size + 0x01_u8, 0x01_u8, 0x01_u8, 0x01_u8, + // int<1> Field types + 0x00_u8, + // int<2> Field detail flag + 0x00_u8, 0x00_u8, + // int<1> decimals + 0x01_u8, + // int<2> - unused - + 0x0_u8, 0x0_u8, + + // ------------------------ // + // Column Definition packet // + // ------------------------ // + + // length + 0x02_u8, 0x0_u8, 0x0_u8, + // seq_no + 0x02_u8, + // string catalog (always 'def') + 0x03_u8, 0x0_u8, 0x0_u8, b"def", + // string schema + 0x01_u8, 0x0_u8, 0x0_u8, b'b', + // string table alias + 0x01_u8, 0x0_u8, 0x0_u8, b'c', + // string table + 0x01_u8, 0x0_u8, 0x0_u8, b'd', + // string column alias + 0x01_u8, 0x0_u8, 0x0_u8, b'e', + // string column + 0x01_u8, 0x0_u8, 0x0_u8, b'f', + // int length of fixed fields (=0xC) + 0xfc_u8, 0x01_u8, 0x01_u8, + // int<2> character set number + 0x01_u8, 0x01_u8, + // int<4> max. column size + 0x01_u8, 0x01_u8, 0x01_u8, 0x01_u8, + // int<1> Field types + 0x00_u8, + // int<2> Field detail flag + 0x00_u8, 0x00_u8, + // int<1> decimals + 0x01_u8, + // int<2> - unused - + 0x0_u8, 0x00_u8, + + // ---------- // + // EOF Packet // + // ---------- // + + // length + 0x02_u8, 0x0_u8, 0x0_u8, + // seq_no + 0x02_u8, + // int<1> 0xfe : EOF header + 0xfe_u8, + // int<2> warning count + 0x0_u8, 0x0_u8, + // int<2> server status + 0x01_u8, 0x00_u8, + + // ------------------- // + // N Result Row Packet // + // ------------------- // + + // string column data + 0x01_u8, 0x0_u8, 0x0_u8, b'h', + // string column data + 0x01_u8, 0x0_u8, 0x0_u8, b'i', + + // ---------- // + // EOF Packet // + // ---------- // + + // length + 0x02_u8, 0x0_u8, 0x0_u8, + // seq_no + 0x02_u8, + // int<1> 0xfe : EOF header + 0xfe_u8, + // int<2> warning count + 0x0_u8, 0x0_u8, + // int<2> server status + 0x01_u8, 0x00_u8 + ); Ok(()) } diff --git a/mason-mariadb/src/protocol/packets/ssl_request.rs b/mason-mariadb/src/protocol/packets/ssl_request.rs index 165f5bdd..a64cae46 100644 --- a/mason-mariadb/src/protocol/packets/ssl_request.rs +++ b/mason-mariadb/src/protocol/packets/ssl_request.rs @@ -20,7 +20,7 @@ impl Serialize for SSLRequestPacket { // Filler conn.encoder.encode_byte_fix(&Bytes::from_static(&[0u8; 19]), 19); - if !(conn.capabilities & Capabilities::CLIENT_MYSQL).is_empty() + if !(conn.context.capabilities & Capabilities::CLIENT_MYSQL).is_empty() && !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index d70a5a34..c68d7512 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -18,14 +18,7 @@ pub enum Message { impl Message { pub fn deserialize(ctx: &mut DeContext) -> Result, Error> { let decoder = &mut ctx.decoder; - if decoder.buf.len() < 4 { - return Ok(None); - } - - let length = decoder.decode_length()?; - if decoder.buf.len() < (length + 4) as usize { - return Ok(None); - } + let _packet_header = decoder.peek_packet_header()?; let tag = decoder.buf[4];