From b731dbe90f7e801cff0ac8a8f4405e823522adbf Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Wed, 31 Jul 2019 20:40:42 -0700 Subject: [PATCH] WIP: Update next_packet to work more efficiently and correctly --- src/lib.rs | 2 +- src/mariadb/connection/establish.rs | 318 +++++++++--------- src/mariadb/connection/mod.rs | 167 +++++---- src/mariadb/mod.rs | 1 + src/mariadb/protocol/decode.rs | 40 +-- src/mariadb/protocol/deserialize.rs | 30 +- src/mariadb/protocol/encode.rs | 24 +- .../protocol/packets/binary/com_stmt_exec.rs | 66 ++-- .../packets/binary/com_stmt_prepare_ok.rs | 2 +- .../packets/binary/com_stmt_prepare_resp.rs | 56 +-- src/mariadb/protocol/packets/column.rs | 6 +- src/mariadb/protocol/packets/column_def.rs | 2 +- src/mariadb/protocol/packets/eof.rs | 2 +- src/mariadb/protocol/packets/err.rs | 2 +- src/mariadb/protocol/packets/initial.rs | 2 +- src/mariadb/protocol/packets/ok.rs | 2 +- src/mariadb/protocol/packets/packet_header.rs | 30 +- src/mariadb/protocol/packets/result_row.rs | 2 +- src/mariadb/protocol/packets/result_set.rs | 60 ++-- src/mariadb/protocol/types.rs | 2 +- 20 files changed, 430 insertions(+), 386 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4b929d90..22f394e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![feature(non_exhaustive, async_await)] +#![feature(non_exhaustive, async_await, async_closure)] #![cfg_attr(test, feature(test))] #![allow(clippy::needless_lifetimes)] // FIXME: Remove this once API has matured diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs index 1b560c3e..55f160c4 100644 --- a/src/mariadb/connection/establish.rs +++ b/src/mariadb/connection/establish.rs @@ -1,5 +1,5 @@ use super::Connection; -use crate::mariadb::{DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag}; +use crate::mariadb::{ErrPacket, OkPacket, DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag}; use bytes::{BufMut, Bytes}; use failure::{err_msg, Error}; use crate::ConnectOptions; @@ -9,8 +9,8 @@ pub async fn establish<'a, 'b: 'a>( conn: &'a mut Connection, options: ConnectOptions<'b>, ) -> Result<(), Error> { - let buf = &conn.stream.next_bytes().await?; - let mut de_ctx = DeContext::new(&mut conn.context, &buf); + let buf = conn.stream.next_packet().await?; + let mut de_ctx = DeContext::new(&mut conn.context, buf); let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?; de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities); @@ -26,167 +26,167 @@ pub async fn establish<'a, 'b: 'a>( conn.send(handshake).await?; - match conn.next().await? { - Some(Message::OkPacket(message)) => { - conn.context.seq_no = message.seq_no; - Ok(()) - } + let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?); - Some(Message::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))), - - Some(message) => { - panic!("Did not receive OkPacket nor ErrPacket. Received: {:?}", message); - } - - None => { - panic!("Did not receive packet"); + if let Some(tag) = ctx.decoder.peek_tag() { + match tag { + 0xFF => { + return Err(ErrPacket::deserialize(&mut ctx)?.into()); + }, + 0x00 => { + OkPacket::deserialize(&mut ctx)?; + }, + _ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"), } + } else { + failure::bail!("Did not receive an appropriately tagged packet when one is expected"); } + + Ok(()) } #[cfg(test)] mod test { -// use super::*; -// use failure::Error; -// use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch}; -// -// #[runtime::test] -// async fn it_can_connect() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }) -// .await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_can_ping() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }).await?; -// -// conn.ping().await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_can_select_db() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }).await?; -// -// conn.select_db("test").await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_can_query() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }).await?; -// -// conn.select_db("test").await?; -// -// conn.query("SELECT * FROM users").await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_can_prepare() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }).await?; -// -// conn.select_db("test").await?; -// -// conn.prepare("SELECT * FROM users WHERE username = ?").await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_can_execute_prepared() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }).await?; -// -// conn.select_db("test").await?; -// -// let mut prepared = conn.prepare("SELECT * FROM users WHERE username = ?;").await?; -// -// println!("{:?}", prepared); -// -// if let Some(param_defs) = &mut prepared.param_defs { -// param_defs[0].field_type = FieldType::MysqlTypeBlob; -// } -// -// let exec = ComStmtExec { -// stmt_id: prepared.ok.stmt_id, -// flags: StmtExecFlag::CursorForUpdate, -// params: Some(vec![Some(Bytes::from_static(b"'daniel'"))]), -// param_defs: prepared.param_defs, -// }; -// -// -// conn.send(exec).await?; -// -// let buf = conn.stream.next_bytes().await?; -// -// println!("{:?}", buf); -// -// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?.rows); -// -// let fetch = ComStmtFetch { -// stmt_id: prepared.ok.stmt_id, -// rows: 1, -// }; -// -// conn.send(exec).await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_does_not_connect() -> Result<(), Error> { -// match Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("roote"), -// database: None, -// password: None, -// }) -// .await -// { -// Ok(_) => Err(err_msg("Bad username still worked?")), -// Err(_) => Ok(()), -// } -// } + use super::*; + use failure::Error; + use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch}; + + #[runtime::test] + async fn it_can_connect() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }) + .await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_ping() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.ping().await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_select_db() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_query() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + conn.query("SELECT * FROM users").await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_prepare() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + conn.prepare("SELECT * FROM users WHERE username = ?").await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_execute_prepared() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + let mut prepared = conn.prepare("SELECT username FROM users WHERE username = ?;").await?; + + println!("{:?}", prepared); + + if let Some(param_defs) = &mut prepared.param_defs { + param_defs[0].field_type = FieldType::MysqlTypeBlob; + } + + let exec = ComStmtExec { + stmt_id: -1, + flags: StmtExecFlag::ReadOnly, + params: Some(vec![Some(Bytes::from_static(b"daniel"))]), + param_defs: prepared.param_defs, + }; + + println!("{:?}", ResultSet::deserialize(DeContext::with_stream(&mut conn.context, &mut conn.stream)).await?); + + let fetch = ComStmtFetch { + stmt_id: -1, + rows: 10, + }; + + conn.send(fetch).await?; + + let buf = conn.stream.next_packet().await?; + + println!("{:?}", buf); + +// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?); + + Ok(()) + } + + #[runtime::test] + async fn it_does_not_connect() -> Result<(), Error> { + match Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("roote"), + database: None, + password: None, + }) + .await + { + Ok(_) => Err(err_msg("Bad username still worked?")), + Err(_) => Ok(()), + } + } } diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index dd5a9db4..c34c6d80 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -6,7 +6,8 @@ use futures::{ prelude::*, }; use runtime::net::TcpStream; -use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}}; +use core::convert::TryFrom; +use crate::{ConnectOptions, mariadb::{protocol::encode, PacketHeader, Decoder, DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}}; mod establish; @@ -103,8 +104,8 @@ impl Connection { pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result, Error> { self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?; - let buf = self.stream.next_bytes().await?; - let mut ctx = DeContext::new(&mut self.context, &buf); + let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream); + ctx.next_packet().await?; if let Some(tag) = ctx.decoder.peek_tag() { match tag { 0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()), @@ -113,7 +114,9 @@ impl Connection { Ok(None) }, 0xFB => unimplemented!(), - _ => Ok(Some(ResultSet::deserialize(&mut ctx)?)) + _ => { + Ok(Some(ResultSet::deserialize(ctx).await?)) + } } } else { panic!("Tag not found in result packet"); @@ -125,17 +128,19 @@ impl Connection { self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?; - match self.next().await? { - Some(Message::OkPacket(_)) => {}, - Some(message @ Message::ErrPacket(_)) => { - failure::bail!("Received an ErrPacket packet: {:?}", message); - }, - Some(message) => { - failure::bail!("Received an unexpected packet type: {:?}", message); - } - None => { - failure::bail!("Did not receive a packet when one was expected"); + let mut ctx = DeContext::new(&mut self.context, self.stream.next_packet().await?); + if let Some(tag) = ctx.decoder.peek_tag() { + match tag { + 0xFF => { + ErrPacket::deserialize(&mut ctx)?; + }, + 0x00 => { + OkPacket::deserialize(&mut ctx)?; + }, + _ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one was expected"), } + } else { + failure::bail!("No tag found"); } Ok(()) @@ -145,8 +150,7 @@ impl Connection { self.send(ComPing()).await?; // Ping response must be an OkPacket - let buf = self.stream.next_bytes().await?; - OkPacket::deserialize(&mut DeContext::new(&mut self.context, &buf))?; + OkPacket::deserialize(&mut DeContext::new(&mut self.context, self.stream.next_packet().await?))?; Ok(()) } @@ -156,110 +160,95 @@ impl Connection { statement: Bytes::from(query), }).await?; - let buf = self.stream.next_bytes().await?; - - ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf)) - } - - pub async fn next(&mut self) -> Result, Error> { - let mut rbuf = BytesMut::new(); - let mut len = 0; - - loop { - if len == rbuf.len() { - rbuf.reserve(32); - - unsafe { - // Set length to the capacity and efficiently - // zero-out the memory - rbuf.set_len(rbuf.capacity()); - self.stream.inner.initializer().initialize(&mut rbuf[len..]); - } - } - - let bytes_read = self.stream.inner.read(&mut rbuf[len..]).await?; - - if bytes_read > 0 { - len += bytes_read; - } else { - // Read 0 bytes from the server; end-of-stream - break; - } - - while len > 0 { - let size = rbuf.len(); - let message = Message::deserialize(&mut DeContext::new( - &mut self.context, - &rbuf.as_ref().into(), - ))?; - len -= size - rbuf.len(); - - match message { - message @ Some(_) => return Ok(message), - // Did not receive enough bytes to - // deserialize a complete message - None => break, - } - } - } - - Ok(None) +// let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream); +// ctx.next_packet().await?; +// ComStmtPrepareResp::deserialize(&mut ctx) + Ok(ComStmtPrepareResp::default()) } } pub struct Framed { inner: TcpStream, - readable: bool, - eof: bool, - buffer: BytesMut, + buf: BytesMut, } impl Framed { fn new(stream: TcpStream) -> Self { Self { - readable: false, - eof: false, inner: stream, - buffer: BytesMut::with_capacity(8 * 1024), + buf: BytesMut::with_capacity(8 * 1024), } } - pub async fn next_bytes(&mut self) -> Result { + pub async fn next_packet(&mut self) -> Result { let mut rbuf = BytesMut::new(); - let mut len = 0; - let mut packet_len: u32 = 0; + let mut len = 0usize; + let mut packet_headers: Vec = Vec::new(); loop { - if len == rbuf.len() { - rbuf.reserve(20000); + if let Some(packet_header) = packet_headers.last() { + if packet_header.combined_length() > rbuf.len() { + let reserve = packet_header.combined_length() - rbuf.len(); + rbuf.reserve(reserve); + + unsafe { + rbuf.set_len(rbuf.capacity()); + self.inner.initializer().initialize(&mut rbuf[len..]); + } + } + } else if rbuf.len() == len { + rbuf.reserve(32); unsafe { - // Set length to the capacity and efficiently - // zero-out the memory rbuf.set_len(rbuf.capacity()); self.inner.initializer().initialize(&mut rbuf[len..]); } } - let bytes_read = self.inner.read(&mut rbuf[len..]).await?; + // If we have a packet_header and the amount of currently read bytes (len) is less than + // the specified length inside packet_header, then we can continue reading to rbuf; but + // only up until packet_header.length. + // Else if the total number of bytes read is equal to packet_header then we will + // return rbuf as it should contain the entire packet. + // Else we read too many bytes -- which shouldn't happen -- and will return an error. + let bytes_read; + + if let Some(packet_header) = packet_headers.last() { + if packet_header.combined_length() > len { + bytes_read = self.inner.read(&mut rbuf[len..packet_header.combined_length()]).await?; + } else { + return Ok(rbuf.freeze()); + } + } else { + // Only read header to make sure that we dont' read the next packets buffer. + bytes_read = self.inner.read(&mut rbuf[len..len + 4]).await?; + } if bytes_read > 0 { len += bytes_read; + // If we have read less than 4 bytes, and we don't already have a packet_header + // we must try to read again. The packet_header is always present and is 4 bytes long. + if bytes_read < 4 && packet_headers.len() == 0 { + continue; + } } else { // Read 0 bytes from the server; end-of-stream - return Ok(Bytes::new()); - } - - if len > 0 && packet_len == 0 { - packet_len = LittleEndian::read_u24(&rbuf[0..]); - } - - // Loop until the length of the buffer is the length of the packet - if packet_len as usize > len { - continue; - } else { return Ok(rbuf.freeze()); } + + // If we don't have a packet header or the last packet header had a length of + // 0xFF_FF_FF (the max possible length); then we must continue receiving packets + // because the entire message hasn't been received. + // After this operation we know that packet_headers.last() *SHOULD* always return valid data, + // so the the use of packet_headers.last().unwrap() is allowed. + // TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions. + if let Some(packet_header) = packet_headers.last() { + if packet_header.length as usize == encode::U24_MAX { + packet_headers.push(PacketHeader::try_from(&rbuf[0..])?); + } + } else { + packet_headers.push(PacketHeader::try_from(&rbuf[0..])?); + } } } } diff --git a/src/mariadb/mod.rs b/src/mariadb/mod.rs index e7e8702a..eca6c743 100644 --- a/src/mariadb/mod.rs +++ b/src/mariadb/mod.rs @@ -4,6 +4,7 @@ pub mod protocol; // Re-export all the things pub use connection::ConnContext; pub use connection::Connection; +pub use connection::Framed; pub use protocol::AuthenticationSwitchRequestPacket; pub use protocol::ColumnPacket; pub use protocol::ColumnDefPacket; diff --git a/src/mariadb/protocol/decode.rs b/src/mariadb/protocol/decode.rs index f82df65c..ab638d9b 100644 --- a/src/mariadb/protocol/decode.rs +++ b/src/mariadb/protocol/decode.rs @@ -58,14 +58,14 @@ use super::packets::packet_header::PacketHeader; // - Byte 7 is the minutes (0 if DATE type) (0-59) // - Byte 8 is the seconds (0 if DATE type) (0-59) // - Bytes 10-13 are the micro-seconds on 4 bytes little-endian format (only if data-length is > 7) -pub struct Decoder<'a> { - pub buf: &'a Bytes, +pub struct Decoder { + pub buf: Bytes, pub index: usize, } -impl<'a> Decoder<'a> { +impl Decoder { // Create a new Decoder from an existing Bytes - pub fn new(buf: &'a Bytes) -> Self { + pub fn new(buf: Bytes) -> Self { Decoder { buf, index: 0 } } @@ -441,7 +441,7 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fb() { let buf = __bytes_builder!(0xFB_u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, None); @@ -451,7 +451,7 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fc() { let buf =__bytes_builder!(0xFCu8, 1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0x0101)); @@ -461,7 +461,7 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fd() { let buf = __bytes_builder!(0xFDu8, 1u8, 1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0x010101)); @@ -471,7 +471,7 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fe() { let buf = __bytes_builder!(0xFE_u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0x0101010101010101)); @@ -481,7 +481,7 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fa() { let buf = __bytes_builder!(0xFA_u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int = decoder.decode_int_lenenc_unsigned(); assert_eq!(int, Some(0xFA)); @@ -491,7 +491,7 @@ mod tests { #[test] fn it_decodes_int_8() { let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int: i64 = decoder.decode_int_i64(); assert_eq!(int, 0x0101010101010101); @@ -501,7 +501,7 @@ mod tests { #[test] fn it_decodes_int_4() { let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int: i32 = decoder.decode_int_i32(); assert_eq!(int, 0x01010101); @@ -511,7 +511,7 @@ mod tests { #[test] fn it_decodes_int_3() { let buf = __bytes_builder!(1u8, 1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int: i32 = decoder.decode_int_i24(); assert_eq!(int, 0x010101); @@ -521,7 +521,7 @@ mod tests { #[test] fn it_decodes_int_2() { let buf = __bytes_builder!(1u8, 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int: i16 = decoder.decode_int_i16(); assert_eq!(int, 0x0101); @@ -531,7 +531,7 @@ mod tests { #[test] fn it_decodes_int_1() { let buf = __bytes_builder!(1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let int: u8 = decoder.decode_int_u8(); assert_eq!(int, 1u8); @@ -541,7 +541,7 @@ mod tests { #[test] fn it_decodes_string_lenenc() { let buf = __bytes_builder!(3u8, b"sup"); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let string: Bytes = decoder.decode_string_lenenc(); assert_eq!(string[..], b"sup"[..]); @@ -552,7 +552,7 @@ mod tests { #[test] fn it_decodes_string_fix() { let buf = __bytes_builder!(b"a"); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let string: Bytes = decoder.decode_string_fix(1); assert_eq!(&string[..], b"a"); @@ -563,7 +563,7 @@ mod tests { #[test] fn it_decodes_string_eof() { let buf = __bytes_builder!(b"a"); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let string: Bytes = decoder.decode_string_eof(None); assert_eq!(&string[..], b"a"); @@ -574,7 +574,7 @@ mod tests { #[test] fn it_decodes_string_null() -> Result<(), Error> { let buf = __bytes_builder!(b"random\0", 1u8); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let string: Bytes = decoder.decode_string_null()?; assert_eq!(&string[..], b"random"); @@ -589,7 +589,7 @@ mod tests { #[test] fn it_decodes_byte_fix() { let buf = __bytes_builder!(b"a"); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let string: Bytes = decoder.decode_byte_fix(1); assert_eq!(&string[..], b"a"); @@ -600,7 +600,7 @@ mod tests { #[test] fn it_decodes_byte_eof() { let buf = __bytes_builder!(b"a"); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); let string: Bytes = decoder.decode_byte_eof(None); assert_eq!(&string[..], b"a"); diff --git a/src/mariadb/protocol/deserialize.rs b/src/mariadb/protocol/deserialize.rs index c6a3d71a..d9c4da02 100644 --- a/src/mariadb/protocol/deserialize.rs +++ b/src/mariadb/protocol/deserialize.rs @@ -1,4 +1,4 @@ -use crate::mariadb::{Decoder, ConnContext, Connection, ColumnDefPacket}; +use crate::mariadb::{Framed, Decoder, ConnContext, Connection, ColumnDefPacket}; use bytes::Bytes; use failure::Error; @@ -8,14 +8,36 @@ use failure::Error; // Mainly used to simply to simplify number of parameters for deserializing functions pub struct DeContext<'a> { pub ctx: &'a mut ConnContext, - pub decoder: Decoder<'a>, + pub stream: Option<&'a mut Framed>, + pub decoder: Decoder, pub columns: Option, pub column_defs: Option>, } impl<'a> DeContext<'a> { - pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self { - DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None , column_defs: None } + pub fn new(conn: &'a mut ConnContext, buf: Bytes) -> Self { + DeContext { ctx: conn, stream: None, decoder: Decoder::new(buf), columns: None , column_defs: None } + } + + pub fn with_stream(conn: &'a mut ConnContext, stream: &'a mut Framed) -> Self { + DeContext { + ctx: conn, + stream: Some(stream), + decoder: Decoder::new(Bytes::new()), + columns: None , + column_defs: None + } + } + + pub async fn next_packet(&mut self) -> Result<(), failure::Error> { + if let Some(stream) = &mut self.stream { + println!("Called next packet"); + self.decoder = Decoder::new(stream.next_packet().await?); + + Ok(()) + } else { + failure::bail!("Calling next_packet on DeContext with no stream provided") + } } } diff --git a/src/mariadb/protocol/encode.rs b/src/mariadb/protocol/encode.rs index 410994cc..b30e2f39 100644 --- a/src/mariadb/protocol/encode.rs +++ b/src/mariadb/protocol/encode.rs @@ -2,7 +2,7 @@ use byteorder::{ByteOrder, LittleEndian}; use bytes::{BufMut, Bytes, BytesMut}; use crate::mariadb::FieldType; -const U24_MAX: usize = 0xFF_FF_FF; +pub const U24_MAX: usize = 0xFF_FF_FF; // A simple wrapper around a BytesMut to easily encode values pub struct Encoder { @@ -253,17 +253,17 @@ impl Encoder { FieldType::MysqlTypeTimestamp2 => unimplemented!(), FieldType::MysqlTypeDatetime2 => unimplemented!(), FieldType::MysqlTypeTime2 =>unimplemented!(), - FieldType::MysqlTypeJson => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeNewdecimal => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeEnum => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeSet => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeTinyBlob => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeMediumBlob => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeLongBlob => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeBlob => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeVarString => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeString => self.encode_string_lenenc(bytes), - FieldType::MysqlTypeGeometry => self.encode_string_lenenc(bytes), + FieldType::MysqlTypeJson => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeNewdecimal => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeEnum => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeSet => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeTinyBlob => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeMediumBlob => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeLongBlob => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeBlob => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeVarString => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeString => self.encode_byte_lenenc(bytes), + FieldType::MysqlTypeGeometry => self.encode_byte_lenenc(bytes), } } } diff --git a/src/mariadb/protocol/packets/binary/com_stmt_exec.rs b/src/mariadb/protocol/packets/binary/com_stmt_exec.rs index 98b721ad..23a833dd 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_exec.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_exec.rs @@ -17,56 +17,48 @@ impl crate::mariadb::Serialize for ComStmtExec { encoder.encode_int_u8(super::BinaryProtocol::ComStmtExec.into()); encoder.encode_int_i32(self.stmt_id); encoder.encode_int_u8(self.flags as u8); - encoder.encode_int_u8(0); + encoder.encode_int_u8(1); - if let Some(params) = &self.params { - if let Some(param_defs) = &self.param_defs { - if params.len() != param_defs.len() { - failure::bail!("Unequal number of params and param definitions supplied"); - } - } + match (&self.params, &self.param_defs) { + (Some(params), Some(param_defs)) if params.len() > 0 => { + let null_bitmap_size = (params.len() + 7) / 8; + let mut shift_amount = 0u8; + let mut bitmap = vec![0u8]; + let send_type = 1u8; - let null_bitmap_size = (params.len() + 7) / 8; - let mut shift_amount = 0u8; - let mut bitmap = vec![0u8]; + // Generate NULL-bitmap from params + for param in params { + if param.is_some() { + let last_byte = bitmap.pop().unwrap(); + bitmap.push(last_byte & (1 << shift_amount)); + } - // Generate NULL-bitmap from params - for param in params { - if param.is_none() { - bitmap.push(bitmap.last().unwrap() & (1 << shift_amount)); + shift_amount = (shift_amount + 1) % 8; + + if shift_amount % 8 == 0 { + bitmap.push(0u8); + } } - shift_amount = (shift_amount + 1) % 8; + encoder.encode_byte_fix(&Bytes::from(bitmap), null_bitmap_size); + encoder.encode_int_u8(send_type); - if shift_amount % 8 == 0 { - bitmap.push(0u8); - } - } - - // Do not send the param types - encoder.encode_int_u8(if self.param_defs.is_some() { - 1u8 - } else { - 0u8 - }); - - if let Some(params_defs) = &self.param_defs { - for param in params_defs { - encoder.encode_int_u8(param.field_type as u8); - encoder.encode_int_u8(if (param.field_details & FieldDetailFlag::UNSIGNED).is_empty() { - 1u8 - } else { - 0u8 - }); + if send_type > 0 { + for param in param_defs { + encoder.encode_int_u8(param.field_type as u8); + encoder.encode_int_u8(0); + } } // Encode params for index in 0..params.len() { if let Some(bytes) = ¶ms[index] { - encoder.encode_param(&bytes, ¶ms_defs[index].field_type); + encoder.encode_param(&bytes, ¶m_defs[index].field_type); } } - } + }, + _ => {}, + } encoder.encode_length(); diff --git a/src/mariadb/protocol/packets/binary/com_stmt_prepare_ok.rs b/src/mariadb/protocol/packets/binary/com_stmt_prepare_ok.rs index 47b24aef..11292258 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_prepare_ok.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_prepare_ok.rs @@ -64,7 +64,7 @@ mod tests { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let message = ComStmtPrepareOk::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs b/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs index 31537b39..d876318c 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs @@ -1,4 +1,4 @@ -use crate::mariadb::{ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket}; +use crate::mariadb::{DeContext, Deserialize, ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket}; #[derive(Debug, Default)] pub struct ComStmtPrepareResp { @@ -7,18 +7,22 @@ pub struct ComStmtPrepareResp { pub res_columns: Option>, } -impl crate::mariadb::Deserialize for ComStmtPrepareResp { - fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result { - let ok = ComStmtPrepareOk::deserialize(ctx)?; +impl ComStmtPrepareResp { + pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result { + let ok = ComStmtPrepareOk::deserialize(&mut ctx)?; let param_defs = if ok.params > 0 { - let param_defs = (0..ok.params).map(|_| ColumnDefPacket::deserialize(ctx)) - .filter(Result::is_ok) - .map(Result::unwrap) - .collect::>(); + let mut param_defs = Vec::new(); + + for _ in 0..ok.params { + ctx.next_packet().await?; + param_defs.push(ColumnDefPacket::deserialize(&mut ctx)?); + } + + ctx.next_packet().await?; if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { - EofPacket::deserialize(ctx)?; + EofPacket::deserialize(&mut ctx)?; } Some(param_defs) @@ -27,16 +31,20 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp { }; let res_columns = if ok.columns > 0 { - let param_defs = (0..ok.columns).map(|_| ColumnDefPacket::deserialize(ctx)) - .filter(Result::is_ok) - .map(Result::unwrap) - .collect::>(); + let mut res_columns = Vec::new(); - if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { - EofPacket::deserialize(ctx)?; + for _ in 0..ok.columns { + ctx.next_packet().await?; + res_columns.push(ColumnDefPacket::deserialize(&mut ctx)?); } - Some(param_defs) + ctx.next_packet().await?; + + if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + EofPacket::deserialize(&mut ctx)?; + } + + Some(res_columns) } else { None }; @@ -54,8 +62,8 @@ mod test { use super::*; use crate::{__bytes_builder, ConnectOptions, mariadb::{ConnContext, DeContext, Deserialize}}; - #[test] - fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> { + #[runtime::test] + async fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> { #[rustfmt::skip] let buf = __bytes_builder!( // ---------------------------- // @@ -153,15 +161,15 @@ mod test { ); let mut context = ConnContext::with_eof(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); - let message = ComStmtPrepareResp::deserialize(&mut ctx)?; + let message = ComStmtPrepareResp::deserialize(ctx).await?; Ok(()) } - #[test] - fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> { + #[runtime::test] + async fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> { #[rustfmt::skip] let buf = __bytes_builder!( // ---------------------------- // @@ -287,9 +295,9 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); - let message = ComStmtPrepareResp::deserialize(&mut ctx)?; + let message = ComStmtPrepareResp::deserialize(ctx).await?; Ok(()) } diff --git a/src/mariadb/protocol/packets/column.rs b/src/mariadb/protocol/packets/column.rs index fedd1e98..6f76d44c 100644 --- a/src/mariadb/protocol/packets/column.rs +++ b/src/mariadb/protocol/packets/column.rs @@ -46,7 +46,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let message = ColumnPacket::deserialize(&mut ctx)?; @@ -70,7 +70,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let message = ColumnPacket::deserialize(&mut ctx)?; @@ -94,7 +94,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let message = ColumnPacket::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/column_def.rs b/src/mariadb/protocol/packets/column_def.rs index 095bb300..15b79eea 100644 --- a/src/mariadb/protocol/packets/column_def.rs +++ b/src/mariadb/protocol/packets/column_def.rs @@ -115,7 +115,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let message = ColumnDefPacket::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/eof.rs b/src/mariadb/protocol/packets/eof.rs index c43ad6bf..8868bd63 100644 --- a/src/mariadb/protocol/packets/eof.rs +++ b/src/mariadb/protocol/packets/eof.rs @@ -58,7 +58,7 @@ mod test { let buf = Bytes::from_static(b"\x01\0\0\x01\xFE\x00\x00\x01\x00"); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let _message = EofPacket::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/err.rs b/src/mariadb/protocol/packets/err.rs index 138305eb..5d5065f8 100644 --- a/src/mariadb/protocol/packets/err.rs +++ b/src/mariadb/protocol/packets/err.rs @@ -124,7 +124,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let _message = ErrPacket::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/initial.rs b/src/mariadb/protocol/packets/initial.rs index 4a9e5d39..b024ba32 100644 --- a/src/mariadb/protocol/packets/initial.rs +++ b/src/mariadb/protocol/packets/initial.rs @@ -153,7 +153,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let _message = InitialHandshakePacket::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/ok.rs b/src/mariadb/protocol/packets/ok.rs index b9b35e2a..16f7b221 100644 --- a/src/mariadb/protocol/packets/ok.rs +++ b/src/mariadb/protocol/packets/ok.rs @@ -95,7 +95,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); let message = OkPacket::deserialize(&mut ctx)?; diff --git a/src/mariadb/protocol/packets/packet_header.rs b/src/mariadb/protocol/packets/packet_header.rs index 75adefca..8a742810 100644 --- a/src/mariadb/protocol/packets/packet_header.rs +++ b/src/mariadb/protocol/packets/packet_header.rs @@ -1,5 +1,33 @@ -#[derive(Debug)] +use byteorder::LittleEndian; +use byteorder::ByteOrder; + +#[derive(Debug, Default)] pub struct PacketHeader { pub length: u32, pub seq_no: u8, } + +impl PacketHeader { + pub fn size() -> usize { + 4 + } + + pub fn combined_length(&self) -> usize { + PacketHeader::size() + self.length as usize + } +} + +impl core::convert::TryFrom<&[u8]> for PacketHeader { + type Error = failure::Error; + + fn try_from(buffer: &[u8]) -> Result { + if buffer.len() < 4 { + failure::bail!("Buffer length is too short") + } else { + Ok(PacketHeader { + length: LittleEndian::read_u24(&buffer), + seq_no: buffer[3], + }) + } + } +} diff --git a/src/mariadb/protocol/packets/result_row.rs b/src/mariadb/protocol/packets/result_row.rs index cccfdce2..b3105c98 100644 --- a/src/mariadb/protocol/packets/result_row.rs +++ b/src/mariadb/protocol/packets/result_row.rs @@ -48,7 +48,7 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); ctx.columns = Some(1); diff --git a/src/mariadb/protocol/packets/result_set.rs b/src/mariadb/protocol/packets/result_set.rs index e78a8018..00d65d76 100644 --- a/src/mariadb/protocol/packets/result_set.rs +++ b/src/mariadb/protocol/packets/result_set.rs @@ -1,14 +1,7 @@ use bytes::Bytes; use failure::Error; -use crate::mariadb::Decoder; -use crate::mariadb::Message; -use crate::mariadb::Capabilities; - -use super::super::{ - deserialize::{DeContext, Deserialize}, - packets::{column::ColumnPacket, column_def::ColumnDefPacket, eof::EofPacket, err::ErrPacket, ok::OkPacket, result_row::ResultRow}, -}; +use crate::mariadb::{Deserialize, ConnContext, Framed, Decoder, Message, Capabilities, DeContext, ColumnPacket, ColumnDefPacket, EofPacket, ErrPacket, OkPacket, ResultRow}; #[derive(Debug, Default)] pub struct ResultSet { @@ -17,22 +10,28 @@ pub struct ResultSet { pub rows: Vec, } -impl Deserialize for ResultSet { - fn deserialize(ctx: &mut DeContext) -> Result { - let column_packet = ColumnPacket::deserialize(ctx)?; +impl ResultSet { + pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result { + let column_packet = ColumnPacket::deserialize(&mut ctx)?; let columns = if let Some(columns) = column_packet.columns { - (0..columns) - .map(|_| ColumnDefPacket::deserialize(ctx)) - .filter(Result::is_ok) - .map(Result::unwrap) - .collect::>() + let mut column_defs = Vec::new(); + for _ in 0..columns { + ctx.next_packet().await?; + column_defs.push(ColumnDefPacket::deserialize(&mut ctx)?); + } + column_defs } else { Vec::new() }; + ctx.next_packet().await?; + let eof_packet = if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { - Some(EofPacket::deserialize(ctx)?) + // If we get an eof packet we must update ctx to hold a new buffer of the next packet. + let eof_packet = Some(EofPacket::deserialize(&mut ctx)?); + ctx.next_packet().await?; + eof_packet } else { None }; @@ -48,12 +47,15 @@ impl Deserialize for ResultSet { }; let tag = ctx.decoder.peek_tag(); - if tag == Some(&0xFE) && packet_header.length <= 0xFFFFFF { + if tag == Some(&0xFE) && packet_header.length <= 0xFFFFFF || packet_header.length == 0 { break; } else { let index = ctx.decoder.index; - match ResultRow::deserialize(ctx) { - Ok(v) => rows.push(v), + match ResultRow::deserialize(&mut ctx) { + Ok(v) => { + rows.push(v); + ctx.next_packet().await?; + }, Err(_) => { ctx.decoder.index = index; break; @@ -62,10 +64,12 @@ impl Deserialize for ResultSet { } } - if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { - OkPacket::deserialize(ctx)?; - } else { - EofPacket::deserialize(ctx)?; + if ctx.decoder.peek_packet_header()?.length > 0 { + if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + OkPacket::deserialize(&mut ctx)?; + } else { + EofPacket::deserialize(&mut ctx)?; + } } Ok(ResultSet { @@ -83,8 +87,8 @@ mod test { use crate::{__bytes_builder, mariadb::{Connection, EofPacket, ErrPacket, OkPacket, ResultRow, ServerStatusFlag, Capabilities, ConnContext}}; use super::*; - #[test] - fn it_decodes_result_set_packet() -> Result<(), Error> { + #[runtime::test] + async fn it_decodes_result_set_packet() -> Result<(), Error> { // TODO: Use byte string as input for test; this is a valid return from a mariadb. #[rustfmt::skip] let buf = __bytes_builder!( @@ -296,9 +300,9 @@ mod test { ); let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, &buf); + let mut ctx = DeContext::new(&mut context, buf); - ResultSet::deserialize(&mut ctx)?; + ResultSet::deserialize(ctx).await?; Ok(()) } diff --git a/src/mariadb/protocol/types.rs b/src/mariadb/protocol/types.rs index 87f6418d..14c6e315 100644 --- a/src/mariadb/protocol/types.rs +++ b/src/mariadb/protocol/types.rs @@ -170,7 +170,7 @@ mod test { #[test] fn it_decodes_capabilities() { let buf = Bytes::from(b"\xfe\xf7".to_vec()); - let mut decoder = Decoder::new(&buf); + let mut decoder = Decoder::new(buf); Capabilities::from_bits_truncate(decoder.decode_int_u16().into()); } }