WIP: Update next_packet to work more efficiently and correctly

This commit is contained in:
Daniel Akhterov 2019-07-31 20:40:42 -07:00
parent 31188aa919
commit b731dbe90f
20 changed files with 430 additions and 386 deletions

View File

@ -1,4 +1,4 @@
#![feature(non_exhaustive, async_await)]
#![feature(non_exhaustive, async_await, async_closure)]
#![cfg_attr(test, feature(test))]
#![allow(clippy::needless_lifetimes)]
// FIXME: Remove this once API has matured

View File

@ -1,5 +1,5 @@
use super::Connection;
use crate::mariadb::{DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
use crate::mariadb::{ErrPacket, OkPacket, DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
use bytes::{BufMut, Bytes};
use failure::{err_msg, Error};
use crate::ConnectOptions;
@ -9,8 +9,8 @@ pub async fn establish<'a, 'b: 'a>(
conn: &'a mut Connection,
options: ConnectOptions<'b>,
) -> Result<(), Error> {
let buf = &conn.stream.next_bytes().await?;
let mut de_ctx = DeContext::new(&mut conn.context, &buf);
let buf = conn.stream.next_packet().await?;
let mut de_ctx = DeContext::new(&mut conn.context, buf);
let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?;
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
@ -26,167 +26,167 @@ pub async fn establish<'a, 'b: 'a>(
conn.send(handshake).await?;
match conn.next().await? {
Some(Message::OkPacket(message)) => {
conn.context.seq_no = message.seq_no;
Ok(())
}
let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?);
Some(Message::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))),
Some(message) => {
panic!("Did not receive OkPacket nor ErrPacket. Received: {:?}", message);
}
None => {
panic!("Did not receive packet");
if let Some(tag) = ctx.decoder.peek_tag() {
match tag {
0xFF => {
return Err(ErrPacket::deserialize(&mut ctx)?.into());
},
0x00 => {
OkPacket::deserialize(&mut ctx)?;
},
_ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"),
}
} else {
failure::bail!("Did not receive an appropriately tagged packet when one is expected");
}
Ok(())
}
#[cfg(test)]
mod test {
// use super::*;
// use failure::Error;
// use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
//
// #[runtime::test]
// async fn it_can_connect() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// })
// .await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_ping() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.ping().await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_select_db() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_query() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// conn.query("SELECT * FROM users").await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_prepare() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// conn.prepare("SELECT * FROM users WHERE username = ?").await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_execute_prepared() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// let mut prepared = conn.prepare("SELECT * FROM users WHERE username = ?;").await?;
//
// println!("{:?}", prepared);
//
// if let Some(param_defs) = &mut prepared.param_defs {
// param_defs[0].field_type = FieldType::MysqlTypeBlob;
// }
//
// let exec = ComStmtExec {
// stmt_id: prepared.ok.stmt_id,
// flags: StmtExecFlag::CursorForUpdate,
// params: Some(vec![Some(Bytes::from_static(b"'daniel'"))]),
// param_defs: prepared.param_defs,
// };
//
//
// conn.send(exec).await?;
//
// let buf = conn.stream.next_bytes().await?;
//
// println!("{:?}", buf);
//
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?.rows);
//
// let fetch = ComStmtFetch {
// stmt_id: prepared.ok.stmt_id,
// rows: 1,
// };
//
// conn.send(exec).await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_does_not_connect() -> Result<(), Error> {
// match Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("roote"),
// database: None,
// password: None,
// })
// .await
// {
// Ok(_) => Err(err_msg("Bad username still worked?")),
// Err(_) => Ok(()),
// }
// }
use super::*;
use failure::Error;
use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
#[runtime::test]
async fn it_can_connect() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
})
.await?;
Ok(())
}
#[runtime::test]
async fn it_can_ping() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.ping().await?;
Ok(())
}
#[runtime::test]
async fn it_can_select_db() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
Ok(())
}
#[runtime::test]
async fn it_can_query() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
conn.query("SELECT * FROM users").await?;
Ok(())
}
#[runtime::test]
async fn it_can_prepare() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
conn.prepare("SELECT * FROM users WHERE username = ?").await?;
Ok(())
}
#[runtime::test]
async fn it_can_execute_prepared() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
let mut prepared = conn.prepare("SELECT username FROM users WHERE username = ?;").await?;
println!("{:?}", prepared);
if let Some(param_defs) = &mut prepared.param_defs {
param_defs[0].field_type = FieldType::MysqlTypeBlob;
}
let exec = ComStmtExec {
stmt_id: -1,
flags: StmtExecFlag::ReadOnly,
params: Some(vec![Some(Bytes::from_static(b"daniel"))]),
param_defs: prepared.param_defs,
};
println!("{:?}", ResultSet::deserialize(DeContext::with_stream(&mut conn.context, &mut conn.stream)).await?);
let fetch = ComStmtFetch {
stmt_id: -1,
rows: 10,
};
conn.send(fetch).await?;
let buf = conn.stream.next_packet().await?;
println!("{:?}", buf);
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?);
Ok(())
}
#[runtime::test]
async fn it_does_not_connect() -> Result<(), Error> {
match Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("roote"),
database: None,
password: None,
})
.await
{
Ok(_) => Err(err_msg("Bad username still worked?")),
Err(_) => Ok(()),
}
}
}

View File

@ -6,7 +6,8 @@ use futures::{
prelude::*,
};
use runtime::net::TcpStream;
use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
use core::convert::TryFrom;
use crate::{ConnectOptions, mariadb::{protocol::encode, PacketHeader, Decoder, DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
mod establish;
@ -103,8 +104,8 @@ impl Connection {
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<Option<ResultSet>, Error> {
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
let buf = self.stream.next_bytes().await?;
let mut ctx = DeContext::new(&mut self.context, &buf);
let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
ctx.next_packet().await?;
if let Some(tag) = ctx.decoder.peek_tag() {
match tag {
0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()),
@ -113,7 +114,9 @@ impl Connection {
Ok(None)
},
0xFB => unimplemented!(),
_ => Ok(Some(ResultSet::deserialize(&mut ctx)?))
_ => {
Ok(Some(ResultSet::deserialize(ctx).await?))
}
}
} else {
panic!("Tag not found in result packet");
@ -125,17 +128,19 @@ impl Connection {
self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?;
match self.next().await? {
Some(Message::OkPacket(_)) => {},
Some(message @ Message::ErrPacket(_)) => {
failure::bail!("Received an ErrPacket packet: {:?}", message);
},
Some(message) => {
failure::bail!("Received an unexpected packet type: {:?}", message);
}
None => {
failure::bail!("Did not receive a packet when one was expected");
let mut ctx = DeContext::new(&mut self.context, self.stream.next_packet().await?);
if let Some(tag) = ctx.decoder.peek_tag() {
match tag {
0xFF => {
ErrPacket::deserialize(&mut ctx)?;
},
0x00 => {
OkPacket::deserialize(&mut ctx)?;
},
_ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one was expected"),
}
} else {
failure::bail!("No tag found");
}
Ok(())
@ -145,8 +150,7 @@ impl Connection {
self.send(ComPing()).await?;
// Ping response must be an OkPacket
let buf = self.stream.next_bytes().await?;
OkPacket::deserialize(&mut DeContext::new(&mut self.context, &buf))?;
OkPacket::deserialize(&mut DeContext::new(&mut self.context, self.stream.next_packet().await?))?;
Ok(())
}
@ -156,110 +160,95 @@ impl Connection {
statement: Bytes::from(query),
}).await?;
let buf = self.stream.next_bytes().await?;
ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))
}
pub async fn next(&mut self) -> Result<Option<Message>, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0;
loop {
if len == rbuf.len() {
rbuf.reserve(32);
unsafe {
// Set length to the capacity and efficiently
// zero-out the memory
rbuf.set_len(rbuf.capacity());
self.stream.inner.initializer().initialize(&mut rbuf[len..]);
}
}
let bytes_read = self.stream.inner.read(&mut rbuf[len..]).await?;
if bytes_read > 0 {
len += bytes_read;
} else {
// Read 0 bytes from the server; end-of-stream
break;
}
while len > 0 {
let size = rbuf.len();
let message = Message::deserialize(&mut DeContext::new(
&mut self.context,
&rbuf.as_ref().into(),
))?;
len -= size - rbuf.len();
match message {
message @ Some(_) => return Ok(message),
// Did not receive enough bytes to
// deserialize a complete message
None => break,
}
}
}
Ok(None)
// let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
// ctx.next_packet().await?;
// ComStmtPrepareResp::deserialize(&mut ctx)
Ok(ComStmtPrepareResp::default())
}
}
pub struct Framed {
inner: TcpStream,
readable: bool,
eof: bool,
buffer: BytesMut,
buf: BytesMut,
}
impl Framed {
fn new(stream: TcpStream) -> Self {
Self {
readable: false,
eof: false,
inner: stream,
buffer: BytesMut::with_capacity(8 * 1024),
buf: BytesMut::with_capacity(8 * 1024),
}
}
pub async fn next_bytes(&mut self) -> Result<Bytes, Error> {
pub async fn next_packet(&mut self) -> Result<Bytes, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0;
let mut packet_len: u32 = 0;
let mut len = 0usize;
let mut packet_headers: Vec<PacketHeader> = Vec::new();
loop {
if len == rbuf.len() {
rbuf.reserve(20000);
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > rbuf.len() {
let reserve = packet_header.combined_length() - rbuf.len();
rbuf.reserve(reserve);
unsafe {
rbuf.set_len(rbuf.capacity());
self.inner.initializer().initialize(&mut rbuf[len..]);
}
}
} else if rbuf.len() == len {
rbuf.reserve(32);
unsafe {
// Set length to the capacity and efficiently
// zero-out the memory
rbuf.set_len(rbuf.capacity());
self.inner.initializer().initialize(&mut rbuf[len..]);
}
}
let bytes_read = self.inner.read(&mut rbuf[len..]).await?;
// If we have a packet_header and the amount of currently read bytes (len) is less than
// the specified length inside packet_header, then we can continue reading to rbuf; but
// only up until packet_header.length.
// Else if the total number of bytes read is equal to packet_header then we will
// return rbuf as it should contain the entire packet.
// Else we read too many bytes -- which shouldn't happen -- and will return an error.
let bytes_read;
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > len {
bytes_read = self.inner.read(&mut rbuf[len..packet_header.combined_length()]).await?;
} else {
return Ok(rbuf.freeze());
}
} else {
// Only read header to make sure that we dont' read the next packets buffer.
bytes_read = self.inner.read(&mut rbuf[len..len + 4]).await?;
}
if bytes_read > 0 {
len += bytes_read;
// If we have read less than 4 bytes, and we don't already have a packet_header
// we must try to read again. The packet_header is always present and is 4 bytes long.
if bytes_read < 4 && packet_headers.len() == 0 {
continue;
}
} else {
// Read 0 bytes from the server; end-of-stream
return Ok(Bytes::new());
}
if len > 0 && packet_len == 0 {
packet_len = LittleEndian::read_u24(&rbuf[0..]);
}
// Loop until the length of the buffer is the length of the packet
if packet_len as usize > len {
continue;
} else {
return Ok(rbuf.freeze());
}
// If we don't have a packet header or the last packet header had a length of
// 0xFF_FF_FF (the max possible length); then we must continue receiving packets
// because the entire message hasn't been received.
// After this operation we know that packet_headers.last() *SHOULD* always return valid data,
// so the the use of packet_headers.last().unwrap() is allowed.
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
if let Some(packet_header) = packet_headers.last() {
if packet_header.length as usize == encode::U24_MAX {
packet_headers.push(PacketHeader::try_from(&rbuf[0..])?);
}
} else {
packet_headers.push(PacketHeader::try_from(&rbuf[0..])?);
}
}
}
}

View File

@ -4,6 +4,7 @@ pub mod protocol;
// Re-export all the things
pub use connection::ConnContext;
pub use connection::Connection;
pub use connection::Framed;
pub use protocol::AuthenticationSwitchRequestPacket;
pub use protocol::ColumnPacket;
pub use protocol::ColumnDefPacket;

View File

@ -58,14 +58,14 @@ use super::packets::packet_header::PacketHeader;
// - Byte 7 is the minutes (0 if DATE type) (0-59)
// - Byte 8 is the seconds (0 if DATE type) (0-59)
// - Bytes 10-13 are the micro-seconds on 4 bytes little-endian format (only if data-length is > 7)
pub struct Decoder<'a> {
pub buf: &'a Bytes,
pub struct Decoder {
pub buf: Bytes,
pub index: usize,
}
impl<'a> Decoder<'a> {
impl Decoder {
// Create a new Decoder from an existing Bytes
pub fn new(buf: &'a Bytes) -> Self {
pub fn new(buf: Bytes) -> Self {
Decoder { buf, index: 0 }
}
@ -441,7 +441,7 @@ mod tests {
#[test]
fn it_decodes_int_lenenc_0x_fb() {
let buf = __bytes_builder!(0xFB_u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int = decoder.decode_int_lenenc_unsigned();
assert_eq!(int, None);
@ -451,7 +451,7 @@ mod tests {
#[test]
fn it_decodes_int_lenenc_0x_fc() {
let buf =__bytes_builder!(0xFCu8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int = decoder.decode_int_lenenc_unsigned();
assert_eq!(int, Some(0x0101));
@ -461,7 +461,7 @@ mod tests {
#[test]
fn it_decodes_int_lenenc_0x_fd() {
let buf = __bytes_builder!(0xFDu8, 1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int = decoder.decode_int_lenenc_unsigned();
assert_eq!(int, Some(0x010101));
@ -471,7 +471,7 @@ mod tests {
#[test]
fn it_decodes_int_lenenc_0x_fe() {
let buf = __bytes_builder!(0xFE_u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int = decoder.decode_int_lenenc_unsigned();
assert_eq!(int, Some(0x0101010101010101));
@ -481,7 +481,7 @@ mod tests {
#[test]
fn it_decodes_int_lenenc_0x_fa() {
let buf = __bytes_builder!(0xFA_u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int = decoder.decode_int_lenenc_unsigned();
assert_eq!(int, Some(0xFA));
@ -491,7 +491,7 @@ mod tests {
#[test]
fn it_decodes_int_8() {
let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int: i64 = decoder.decode_int_i64();
assert_eq!(int, 0x0101010101010101);
@ -501,7 +501,7 @@ mod tests {
#[test]
fn it_decodes_int_4() {
let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int: i32 = decoder.decode_int_i32();
assert_eq!(int, 0x01010101);
@ -511,7 +511,7 @@ mod tests {
#[test]
fn it_decodes_int_3() {
let buf = __bytes_builder!(1u8, 1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int: i32 = decoder.decode_int_i24();
assert_eq!(int, 0x010101);
@ -521,7 +521,7 @@ mod tests {
#[test]
fn it_decodes_int_2() {
let buf = __bytes_builder!(1u8, 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int: i16 = decoder.decode_int_i16();
assert_eq!(int, 0x0101);
@ -531,7 +531,7 @@ mod tests {
#[test]
fn it_decodes_int_1() {
let buf = __bytes_builder!(1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let int: u8 = decoder.decode_int_u8();
assert_eq!(int, 1u8);
@ -541,7 +541,7 @@ mod tests {
#[test]
fn it_decodes_string_lenenc() {
let buf = __bytes_builder!(3u8, b"sup");
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let string: Bytes = decoder.decode_string_lenenc();
assert_eq!(string[..], b"sup"[..]);
@ -552,7 +552,7 @@ mod tests {
#[test]
fn it_decodes_string_fix() {
let buf = __bytes_builder!(b"a");
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let string: Bytes = decoder.decode_string_fix(1);
assert_eq!(&string[..], b"a");
@ -563,7 +563,7 @@ mod tests {
#[test]
fn it_decodes_string_eof() {
let buf = __bytes_builder!(b"a");
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let string: Bytes = decoder.decode_string_eof(None);
assert_eq!(&string[..], b"a");
@ -574,7 +574,7 @@ mod tests {
#[test]
fn it_decodes_string_null() -> Result<(), Error> {
let buf = __bytes_builder!(b"random\0", 1u8);
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let string: Bytes = decoder.decode_string_null()?;
assert_eq!(&string[..], b"random");
@ -589,7 +589,7 @@ mod tests {
#[test]
fn it_decodes_byte_fix() {
let buf = __bytes_builder!(b"a");
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let string: Bytes = decoder.decode_byte_fix(1);
assert_eq!(&string[..], b"a");
@ -600,7 +600,7 @@ mod tests {
#[test]
fn it_decodes_byte_eof() {
let buf = __bytes_builder!(b"a");
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
let string: Bytes = decoder.decode_byte_eof(None);
assert_eq!(&string[..], b"a");

View File

@ -1,4 +1,4 @@
use crate::mariadb::{Decoder, ConnContext, Connection, ColumnDefPacket};
use crate::mariadb::{Framed, Decoder, ConnContext, Connection, ColumnDefPacket};
use bytes::Bytes;
use failure::Error;
@ -8,14 +8,36 @@ use failure::Error;
// Mainly used to simply to simplify number of parameters for deserializing functions
pub struct DeContext<'a> {
pub ctx: &'a mut ConnContext,
pub decoder: Decoder<'a>,
pub stream: Option<&'a mut Framed>,
pub decoder: Decoder,
pub columns: Option<u64>,
pub column_defs: Option<Vec<ColumnDefPacket>>,
}
impl<'a> DeContext<'a> {
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None , column_defs: None }
pub fn new(conn: &'a mut ConnContext, buf: Bytes) -> Self {
DeContext { ctx: conn, stream: None, decoder: Decoder::new(buf), columns: None , column_defs: None }
}
pub fn with_stream(conn: &'a mut ConnContext, stream: &'a mut Framed) -> Self {
DeContext {
ctx: conn,
stream: Some(stream),
decoder: Decoder::new(Bytes::new()),
columns: None ,
column_defs: None
}
}
pub async fn next_packet(&mut self) -> Result<(), failure::Error> {
if let Some(stream) = &mut self.stream {
println!("Called next packet");
self.decoder = Decoder::new(stream.next_packet().await?);
Ok(())
} else {
failure::bail!("Calling next_packet on DeContext with no stream provided")
}
}
}

View File

@ -2,7 +2,7 @@ use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, Bytes, BytesMut};
use crate::mariadb::FieldType;
const U24_MAX: usize = 0xFF_FF_FF;
pub const U24_MAX: usize = 0xFF_FF_FF;
// A simple wrapper around a BytesMut to easily encode values
pub struct Encoder {
@ -253,17 +253,17 @@ impl Encoder {
FieldType::MysqlTypeTimestamp2 => unimplemented!(),
FieldType::MysqlTypeDatetime2 => unimplemented!(),
FieldType::MysqlTypeTime2 =>unimplemented!(),
FieldType::MysqlTypeJson => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeNewdecimal => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeEnum => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeSet => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeTinyBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeMediumBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeLongBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeBlob => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeVarString => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeString => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeGeometry => self.encode_string_lenenc(bytes),
FieldType::MysqlTypeJson => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeNewdecimal => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeEnum => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeSet => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeTinyBlob => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeMediumBlob => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeLongBlob => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeBlob => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeVarString => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeString => self.encode_byte_lenenc(bytes),
FieldType::MysqlTypeGeometry => self.encode_byte_lenenc(bytes),
}
}
}

View File

@ -17,56 +17,48 @@ impl crate::mariadb::Serialize for ComStmtExec {
encoder.encode_int_u8(super::BinaryProtocol::ComStmtExec.into());
encoder.encode_int_i32(self.stmt_id);
encoder.encode_int_u8(self.flags as u8);
encoder.encode_int_u8(0);
encoder.encode_int_u8(1);
if let Some(params) = &self.params {
if let Some(param_defs) = &self.param_defs {
if params.len() != param_defs.len() {
failure::bail!("Unequal number of params and param definitions supplied");
}
}
match (&self.params, &self.param_defs) {
(Some(params), Some(param_defs)) if params.len() > 0 => {
let null_bitmap_size = (params.len() + 7) / 8;
let mut shift_amount = 0u8;
let mut bitmap = vec![0u8];
let send_type = 1u8;
let null_bitmap_size = (params.len() + 7) / 8;
let mut shift_amount = 0u8;
let mut bitmap = vec![0u8];
// Generate NULL-bitmap from params
for param in params {
if param.is_some() {
let last_byte = bitmap.pop().unwrap();
bitmap.push(last_byte & (1 << shift_amount));
}
// Generate NULL-bitmap from params
for param in params {
if param.is_none() {
bitmap.push(bitmap.last().unwrap() & (1 << shift_amount));
shift_amount = (shift_amount + 1) % 8;
if shift_amount % 8 == 0 {
bitmap.push(0u8);
}
}
shift_amount = (shift_amount + 1) % 8;
encoder.encode_byte_fix(&Bytes::from(bitmap), null_bitmap_size);
encoder.encode_int_u8(send_type);
if shift_amount % 8 == 0 {
bitmap.push(0u8);
}
}
// Do not send the param types
encoder.encode_int_u8(if self.param_defs.is_some() {
1u8
} else {
0u8
});
if let Some(params_defs) = &self.param_defs {
for param in params_defs {
encoder.encode_int_u8(param.field_type as u8);
encoder.encode_int_u8(if (param.field_details & FieldDetailFlag::UNSIGNED).is_empty() {
1u8
} else {
0u8
});
if send_type > 0 {
for param in param_defs {
encoder.encode_int_u8(param.field_type as u8);
encoder.encode_int_u8(0);
}
}
// Encode params
for index in 0..params.len() {
if let Some(bytes) = &params[index] {
encoder.encode_param(&bytes, &params_defs[index].field_type);
encoder.encode_param(&bytes, &param_defs[index].field_type);
}
}
}
},
_ => {},
}
encoder.encode_length();

View File

@ -64,7 +64,7 @@ mod tests {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ComStmtPrepareOk::deserialize(&mut ctx)?;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket};
use crate::mariadb::{DeContext, Deserialize, ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket};
#[derive(Debug, Default)]
pub struct ComStmtPrepareResp {
@ -7,18 +7,22 @@ pub struct ComStmtPrepareResp {
pub res_columns: Option<Vec<ColumnDefPacket>>,
}
impl crate::mariadb::Deserialize for ComStmtPrepareResp {
fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result<Self, failure::Error> {
let ok = ComStmtPrepareOk::deserialize(ctx)?;
impl ComStmtPrepareResp {
pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result<Self, failure::Error> {
let ok = ComStmtPrepareOk::deserialize(&mut ctx)?;
let param_defs = if ok.params > 0 {
let param_defs = (0..ok.params).map(|_| ColumnDefPacket::deserialize(ctx))
.filter(Result::is_ok)
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>();
let mut param_defs = Vec::new();
for _ in 0..ok.params {
ctx.next_packet().await?;
param_defs.push(ColumnDefPacket::deserialize(&mut ctx)?);
}
ctx.next_packet().await?;
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
EofPacket::deserialize(ctx)?;
EofPacket::deserialize(&mut ctx)?;
}
Some(param_defs)
@ -27,16 +31,20 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp {
};
let res_columns = if ok.columns > 0 {
let param_defs = (0..ok.columns).map(|_| ColumnDefPacket::deserialize(ctx))
.filter(Result::is_ok)
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>();
let mut res_columns = Vec::new();
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
EofPacket::deserialize(ctx)?;
for _ in 0..ok.columns {
ctx.next_packet().await?;
res_columns.push(ColumnDefPacket::deserialize(&mut ctx)?);
}
Some(param_defs)
ctx.next_packet().await?;
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
EofPacket::deserialize(&mut ctx)?;
}
Some(res_columns)
} else {
None
};
@ -54,8 +62,8 @@ mod test {
use super::*;
use crate::{__bytes_builder, ConnectOptions, mariadb::{ConnContext, DeContext, Deserialize}};
#[test]
fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> {
#[runtime::test]
async fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// ---------------------------- //
@ -153,15 +161,15 @@ mod test {
);
let mut context = ConnContext::with_eof();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ComStmtPrepareResp::deserialize(&mut ctx)?;
let message = ComStmtPrepareResp::deserialize(ctx).await?;
Ok(())
}
#[test]
fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
#[runtime::test]
async fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// ---------------------------- //
@ -287,9 +295,9 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ComStmtPrepareResp::deserialize(&mut ctx)?;
let message = ComStmtPrepareResp::deserialize(ctx).await?;
Ok(())
}

View File

@ -46,7 +46,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ColumnPacket::deserialize(&mut ctx)?;
@ -70,7 +70,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ColumnPacket::deserialize(&mut ctx)?;
@ -94,7 +94,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ColumnPacket::deserialize(&mut ctx)?;

View File

@ -115,7 +115,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = ColumnDefPacket::deserialize(&mut ctx)?;

View File

@ -58,7 +58,7 @@ mod test {
let buf = Bytes::from_static(b"\x01\0\0\x01\xFE\x00\x00\x01\x00");
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let _message = EofPacket::deserialize(&mut ctx)?;

View File

@ -124,7 +124,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let _message = ErrPacket::deserialize(&mut ctx)?;

View File

@ -153,7 +153,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let _message = InitialHandshakePacket::deserialize(&mut ctx)?;

View File

@ -95,7 +95,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
let message = OkPacket::deserialize(&mut ctx)?;

View File

@ -1,5 +1,33 @@
#[derive(Debug)]
use byteorder::LittleEndian;
use byteorder::ByteOrder;
#[derive(Debug, Default)]
pub struct PacketHeader {
pub length: u32,
pub seq_no: u8,
}
impl PacketHeader {
pub fn size() -> usize {
4
}
pub fn combined_length(&self) -> usize {
PacketHeader::size() + self.length as usize
}
}
impl core::convert::TryFrom<&[u8]> for PacketHeader {
type Error = failure::Error;
fn try_from(buffer: &[u8]) -> Result<Self, Self::Error> {
if buffer.len() < 4 {
failure::bail!("Buffer length is too short")
} else {
Ok(PacketHeader {
length: LittleEndian::read_u24(&buffer),
seq_no: buffer[3],
})
}
}
}

View File

@ -48,7 +48,7 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
ctx.columns = Some(1);

View File

@ -1,14 +1,7 @@
use bytes::Bytes;
use failure::Error;
use crate::mariadb::Decoder;
use crate::mariadb::Message;
use crate::mariadb::Capabilities;
use super::super::{
deserialize::{DeContext, Deserialize},
packets::{column::ColumnPacket, column_def::ColumnDefPacket, eof::EofPacket, err::ErrPacket, ok::OkPacket, result_row::ResultRow},
};
use crate::mariadb::{Deserialize, ConnContext, Framed, Decoder, Message, Capabilities, DeContext, ColumnPacket, ColumnDefPacket, EofPacket, ErrPacket, OkPacket, ResultRow};
#[derive(Debug, Default)]
pub struct ResultSet {
@ -17,22 +10,28 @@ pub struct ResultSet {
pub rows: Vec<ResultRow>,
}
impl Deserialize for ResultSet {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let column_packet = ColumnPacket::deserialize(ctx)?;
impl ResultSet {
pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result<Self, Error> {
let column_packet = ColumnPacket::deserialize(&mut ctx)?;
let columns = if let Some(columns) = column_packet.columns {
(0..columns)
.map(|_| ColumnDefPacket::deserialize(ctx))
.filter(Result::is_ok)
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>()
let mut column_defs = Vec::new();
for _ in 0..columns {
ctx.next_packet().await?;
column_defs.push(ColumnDefPacket::deserialize(&mut ctx)?);
}
column_defs
} else {
Vec::new()
};
ctx.next_packet().await?;
let eof_packet = if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
Some(EofPacket::deserialize(ctx)?)
// If we get an eof packet we must update ctx to hold a new buffer of the next packet.
let eof_packet = Some(EofPacket::deserialize(&mut ctx)?);
ctx.next_packet().await?;
eof_packet
} else {
None
};
@ -48,12 +47,15 @@ impl Deserialize for ResultSet {
};
let tag = ctx.decoder.peek_tag();
if tag == Some(&0xFE) && packet_header.length <= 0xFFFFFF {
if tag == Some(&0xFE) && packet_header.length <= 0xFFFFFF || packet_header.length == 0 {
break;
} else {
let index = ctx.decoder.index;
match ResultRow::deserialize(ctx) {
Ok(v) => rows.push(v),
match ResultRow::deserialize(&mut ctx) {
Ok(v) => {
rows.push(v);
ctx.next_packet().await?;
},
Err(_) => {
ctx.decoder.index = index;
break;
@ -62,10 +64,12 @@ impl Deserialize for ResultSet {
}
}
if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
OkPacket::deserialize(ctx)?;
} else {
EofPacket::deserialize(ctx)?;
if ctx.decoder.peek_packet_header()?.length > 0 {
if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
OkPacket::deserialize(&mut ctx)?;
} else {
EofPacket::deserialize(&mut ctx)?;
}
}
Ok(ResultSet {
@ -83,8 +87,8 @@ mod test {
use crate::{__bytes_builder, mariadb::{Connection, EofPacket, ErrPacket, OkPacket, ResultRow, ServerStatusFlag, Capabilities, ConnContext}};
use super::*;
#[test]
fn it_decodes_result_set_packet() -> Result<(), Error> {
#[runtime::test]
async fn it_decodes_result_set_packet() -> Result<(), Error> {
// TODO: Use byte string as input for test; this is a valid return from a mariadb.
#[rustfmt::skip]
let buf = __bytes_builder!(
@ -296,9 +300,9 @@ mod test {
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let mut ctx = DeContext::new(&mut context, buf);
ResultSet::deserialize(&mut ctx)?;
ResultSet::deserialize(ctx).await?;
Ok(())
}

View File

@ -170,7 +170,7 @@ mod test {
#[test]
fn it_decodes_capabilities() {
let buf = Bytes::from(b"\xfe\xf7".to_vec());
let mut decoder = Decoder::new(&buf);
let mut decoder = Decoder::new(buf);
Capabilities::from_bits_truncate(decoder.decode_int_u16().into());
}
}