diff --git a/src/lib.rs b/src/lib.rs
index 4b929d90..22f394e1 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-#![feature(non_exhaustive, async_await)]
+#![feature(non_exhaustive, async_await, async_closure)]
#![cfg_attr(test, feature(test))]
#![allow(clippy::needless_lifetimes)]
// FIXME: Remove this once API has matured
diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs
index 1b560c3e..55f160c4 100644
--- a/src/mariadb/connection/establish.rs
+++ b/src/mariadb/connection/establish.rs
@@ -1,5 +1,5 @@
use super::Connection;
-use crate::mariadb::{DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
+use crate::mariadb::{ErrPacket, OkPacket, DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
use bytes::{BufMut, Bytes};
use failure::{err_msg, Error};
use crate::ConnectOptions;
@@ -9,8 +9,8 @@ pub async fn establish<'a, 'b: 'a>(
conn: &'a mut Connection,
options: ConnectOptions<'b>,
) -> Result<(), Error> {
- let buf = &conn.stream.next_bytes().await?;
- let mut de_ctx = DeContext::new(&mut conn.context, &buf);
+ let buf = conn.stream.next_packet().await?;
+ let mut de_ctx = DeContext::new(&mut conn.context, buf);
let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?;
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
@@ -26,167 +26,167 @@ pub async fn establish<'a, 'b: 'a>(
conn.send(handshake).await?;
- match conn.next().await? {
- Some(Message::OkPacket(message)) => {
- conn.context.seq_no = message.seq_no;
- Ok(())
- }
+ let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?);
- Some(Message::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))),
-
- Some(message) => {
- panic!("Did not receive OkPacket nor ErrPacket. Received: {:?}", message);
- }
-
- None => {
- panic!("Did not receive packet");
+ if let Some(tag) = ctx.decoder.peek_tag() {
+ match tag {
+ 0xFF => {
+ return Err(ErrPacket::deserialize(&mut ctx)?.into());
+ },
+ 0x00 => {
+ OkPacket::deserialize(&mut ctx)?;
+ },
+ _ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"),
}
+ } else {
+ failure::bail!("Did not receive an appropriately tagged packet when one is expected");
}
+
+ Ok(())
}
#[cfg(test)]
mod test {
-// use super::*;
-// use failure::Error;
-// use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
-//
-// #[runtime::test]
-// async fn it_can_connect() -> Result<(), Error> {
-// let mut conn = Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("root"),
-// database: None,
-// password: None,
-// })
-// .await?;
-//
-// Ok(())
-// }
-//
-// #[runtime::test]
-// async fn it_can_ping() -> Result<(), Error> {
-// let mut conn = Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("root"),
-// database: None,
-// password: None,
-// }).await?;
-//
-// conn.ping().await?;
-//
-// Ok(())
-// }
-//
-// #[runtime::test]
-// async fn it_can_select_db() -> Result<(), Error> {
-// let mut conn = Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("root"),
-// database: None,
-// password: None,
-// }).await?;
-//
-// conn.select_db("test").await?;
-//
-// Ok(())
-// }
-//
-// #[runtime::test]
-// async fn it_can_query() -> Result<(), Error> {
-// let mut conn = Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("root"),
-// database: None,
-// password: None,
-// }).await?;
-//
-// conn.select_db("test").await?;
-//
-// conn.query("SELECT * FROM users").await?;
-//
-// Ok(())
-// }
-//
-// #[runtime::test]
-// async fn it_can_prepare() -> Result<(), Error> {
-// let mut conn = Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("root"),
-// database: None,
-// password: None,
-// }).await?;
-//
-// conn.select_db("test").await?;
-//
-// conn.prepare("SELECT * FROM users WHERE username = ?").await?;
-//
-// Ok(())
-// }
-//
-// #[runtime::test]
-// async fn it_can_execute_prepared() -> Result<(), Error> {
-// let mut conn = Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("root"),
-// database: None,
-// password: None,
-// }).await?;
-//
-// conn.select_db("test").await?;
-//
-// let mut prepared = conn.prepare("SELECT * FROM users WHERE username = ?;").await?;
-//
-// println!("{:?}", prepared);
-//
-// if let Some(param_defs) = &mut prepared.param_defs {
-// param_defs[0].field_type = FieldType::MysqlTypeBlob;
-// }
-//
-// let exec = ComStmtExec {
-// stmt_id: prepared.ok.stmt_id,
-// flags: StmtExecFlag::CursorForUpdate,
-// params: Some(vec![Some(Bytes::from_static(b"'daniel'"))]),
-// param_defs: prepared.param_defs,
-// };
-//
-//
-// conn.send(exec).await?;
-//
-// let buf = conn.stream.next_bytes().await?;
-//
-// println!("{:?}", buf);
-//
-// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?.rows);
-//
-// let fetch = ComStmtFetch {
-// stmt_id: prepared.ok.stmt_id,
-// rows: 1,
-// };
-//
-// conn.send(exec).await?;
-//
-// Ok(())
-// }
-//
-// #[runtime::test]
-// async fn it_does_not_connect() -> Result<(), Error> {
-// match Connection::establish(ConnectOptions {
-// host: "127.0.0.1",
-// port: 3306,
-// user: Some("roote"),
-// database: None,
-// password: None,
-// })
-// .await
-// {
-// Ok(_) => Err(err_msg("Bad username still worked?")),
-// Err(_) => Ok(()),
-// }
-// }
+ use super::*;
+ use failure::Error;
+ use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
+
+ #[runtime::test]
+ async fn it_can_connect() -> Result<(), Error> {
+ let mut conn = Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("root"),
+ database: None,
+ password: None,
+ })
+ .await?;
+
+ Ok(())
+ }
+
+ #[runtime::test]
+ async fn it_can_ping() -> Result<(), Error> {
+ let mut conn = Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("root"),
+ database: None,
+ password: None,
+ }).await?;
+
+ conn.ping().await?;
+
+ Ok(())
+ }
+
+ #[runtime::test]
+ async fn it_can_select_db() -> Result<(), Error> {
+ let mut conn = Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("root"),
+ database: None,
+ password: None,
+ }).await?;
+
+ conn.select_db("test").await?;
+
+ Ok(())
+ }
+
+ #[runtime::test]
+ async fn it_can_query() -> Result<(), Error> {
+ let mut conn = Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("root"),
+ database: None,
+ password: None,
+ }).await?;
+
+ conn.select_db("test").await?;
+
+ conn.query("SELECT * FROM users").await?;
+
+ Ok(())
+ }
+
+ #[runtime::test]
+ async fn it_can_prepare() -> Result<(), Error> {
+ let mut conn = Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("root"),
+ database: None,
+ password: None,
+ }).await?;
+
+ conn.select_db("test").await?;
+
+ conn.prepare("SELECT * FROM users WHERE username = ?").await?;
+
+ Ok(())
+ }
+
+ #[runtime::test]
+ async fn it_can_execute_prepared() -> Result<(), Error> {
+ let mut conn = Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("root"),
+ database: None,
+ password: None,
+ }).await?;
+
+ conn.select_db("test").await?;
+
+ let mut prepared = conn.prepare("SELECT username FROM users WHERE username = ?;").await?;
+
+ println!("{:?}", prepared);
+
+ if let Some(param_defs) = &mut prepared.param_defs {
+ param_defs[0].field_type = FieldType::MysqlTypeBlob;
+ }
+
+ let exec = ComStmtExec {
+ stmt_id: -1,
+ flags: StmtExecFlag::ReadOnly,
+ params: Some(vec![Some(Bytes::from_static(b"daniel"))]),
+ param_defs: prepared.param_defs,
+ };
+
+ println!("{:?}", ResultSet::deserialize(DeContext::with_stream(&mut conn.context, &mut conn.stream)).await?);
+
+ let fetch = ComStmtFetch {
+ stmt_id: -1,
+ rows: 10,
+ };
+
+ conn.send(fetch).await?;
+
+ let buf = conn.stream.next_packet().await?;
+
+ println!("{:?}", buf);
+
+// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?);
+
+ Ok(())
+ }
+
+ #[runtime::test]
+ async fn it_does_not_connect() -> Result<(), Error> {
+ match Connection::establish(ConnectOptions {
+ host: "127.0.0.1",
+ port: 3306,
+ user: Some("roote"),
+ database: None,
+ password: None,
+ })
+ .await
+ {
+ Ok(_) => Err(err_msg("Bad username still worked?")),
+ Err(_) => Ok(()),
+ }
+ }
}
diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs
index dd5a9db4..c34c6d80 100644
--- a/src/mariadb/connection/mod.rs
+++ b/src/mariadb/connection/mod.rs
@@ -6,7 +6,8 @@ use futures::{
prelude::*,
};
use runtime::net::TcpStream;
-use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
+use core::convert::TryFrom;
+use crate::{ConnectOptions, mariadb::{protocol::encode, PacketHeader, Decoder, DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
mod establish;
@@ -103,8 +104,8 @@ impl Connection {
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result