diff --git a/Cargo.toml b/Cargo.toml index 1d57c3f0..a26da4dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ itoa = "0.4.4" log = "0.4.7" md-5 = "0.8.0" memchr = "2.2.1" -runtime = { version = "=0.3.0-alpha.6", default-features = false } +runtime = { version = "=0.3.0-alpha.6", default-features = true } bitflags = "1.1.0" enum-tryfrom = "0.2.1" enum-tryfrom-derive = "0.2.1" diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs index 873fc33e..6a4c262f 100644 --- a/src/mariadb/connection/establish.rs +++ b/src/mariadb/connection/establish.rs @@ -8,6 +8,7 @@ use crate::mariadb::protocol::{ use bytes::{BufMut, Bytes}; use failure::{err_msg, Error}; use crate::ConnectOptions; +use std::ops::BitAnd; pub async fn establish<'a, 'b: 'a>( conn: &'a mut Connection, @@ -15,11 +16,13 @@ pub async fn establish<'a, 'b: 'a>( ) -> Result<(), Error> { let buf = &conn.stream.next_bytes().await?; let mut de_ctx = DeContext::new(&mut conn.context, &buf); - let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?; + let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?; + + de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities); let handshake: HandshakeResponsePacket = HandshakeResponsePacket { // Minimum client capabilities required to establish connection - capabilities: Capabilities::CLIENT_PROTOCOL_41, + capabilities: de_ctx.ctx.capabilities, max_packet_size: 1024, extended_capabilities: Some(Capabilities::from_bits_truncate(0)), username: Bytes::from(options.user.unwrap_or("")), @@ -51,35 +54,97 @@ mod test { use super::*; use failure::Error; -// #[runtime::test] -// async fn it_connects() -> Result<(), Error> { -// let mut conn = Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("root"), -// database: None, -// password: None, -// }) -// .await?; -// -// conn.ping().await?; -// -// Ok(()) -// } -// -// #[runtime::test] -// async fn it_does_not_connect() -> Result<(), Error> { -// match Connection::establish(ConnectOptions { -// host: "127.0.0.1", -// port: 3306, -// user: Some("roote"), -// database: None, -// password: None, -// }) -// .await -// { -// Ok(_) => Err(err_msg("Bad username still worked?")), -// Err(_) => Ok(()), -// } -// } + #[runtime::test] + async fn it_can_connect() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }) + .await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_ping() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.ping().await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_select_db() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_query() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + conn.query("SELECT * FROM users").await?; + + Ok(()) + } + + #[runtime::test] + async fn it_can_prepare() -> Result<(), Error> { + let mut conn = Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("root"), + database: None, + password: None, + }).await?; + + conn.select_db("test").await?; + + conn.prepare("SELECT * FROM users WHERE username = ?").await?; + + Ok(()) + } + + #[runtime::test] + async fn it_does_not_connect() -> Result<(), Error> { + match Connection::establish(ConnectOptions { + host: "127.0.0.1", + port: 3306, + user: Some("roote"), + database: None, + password: None, + }) + .await + { + Ok(_) => Err(err_msg("Bad username still worked?")), + Err(_) => Ok(()), + } + } } diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index afcb7e5c..4d5531dc 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -6,7 +6,7 @@ use futures::{ prelude::*, }; use runtime::net::TcpStream; -use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag}}; +use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}}; mod establish; @@ -46,26 +46,18 @@ impl ConnContext { connection_id: 0, seq_no: 2, last_seq_no: 0, - capabilities: Capabilities::FOUND_ROWS - | Capabilities::CONNECT_WITH_DB - | Capabilities::COMPRESS - | Capabilities::LOCAL_FILES - | Capabilities::IGNORE_SPACE - | Capabilities::CLIENT_PROTOCOL_41 - | Capabilities::CLIENT_INTERACTIVE - | Capabilities::TRANSACTIONS - | Capabilities::SECURE_CONNECTION - | Capabilities::MULTI_STATEMENTS - | Capabilities::MULTI_RESULTS - | Capabilities::PS_MULTI_RESULTS - | Capabilities::PLUGIN_AUTH - | Capabilities::CONNECT_ATTRS - | Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA - | Capabilities::CLIENT_SESSION_TRACK - | Capabilities::CLIENT_DEPRECATE_EOF - | Capabilities::MARIA_DB_CLIENT_PROGRESS - | Capabilities::MARIA_DB_CLIENT_COM_MULTI - | Capabilities::MARIA_CLIENT_STMT_BULK_OPERATIONS, + capabilities: Capabilities::CLIENT_PROTOCOL_41, + status: ServerStatusFlag::SERVER_STATUS_IN_TRANS + } + } + + #[cfg(test)] + pub fn with_eof() -> Self { + ConnContext { + connection_id: 0, + seq_no: 2, + last_seq_no: 0, + capabilities: Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CLIENT_DEPRECATE_EOF, status: ServerStatusFlag::SERVER_STATUS_IN_TRANS } } @@ -81,7 +73,7 @@ impl Connection { connection_id: -1, seq_no: 1, last_seq_no: 0, - capabilities: Capabilities::default(), + capabilities: Capabilities::CLIENT_PROTOCOL_41, status: ServerStatusFlag::default(), }, }; @@ -103,22 +95,33 @@ impl Connection { } pub async fn quit(&mut self) -> Result<(), Error> { - self.context.seq_no = 0; self.send(ComQuit()).await?; Ok(()) } - pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> { - self.context.seq_no = 0; + pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result, Error> { self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?; - Ok(()) + let buf = self.stream.next_bytes().await?; + let mut ctx = DeContext::new(&mut self.context, &buf); + if let Some(tag) = ctx.decoder.peek_tag() { + match tag { + 0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()), + 0x00 => { + OkPacket::deserialize(&mut ctx)?; + Ok(None) + }, + 0xFB => unimplemented!(), + _ => Ok(Some(ResultSet::deserialize(&mut ctx)?)) + } + } else { + panic!("Tag not found in result packet"); + } } pub async fn select_db<'a>(&'a mut self, db: &'a str) -> Result<(), Error> { - self.context.seq_no = 0; self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?; @@ -139,7 +142,6 @@ impl Connection { } pub async fn ping(&mut self) -> Result<(), Error> { - self.context.seq_no = 0; self.send(ComPing()).await?; // Ping response must be an OkPacket @@ -149,6 +151,18 @@ impl Connection { Ok(()) } + pub async fn prepare(&mut self, query: &str) -> Result<(), Error> { + self.send(ComStmtPrepare { + statement: Bytes::from(query), + }).await?; + + let buf = self.stream.next_bytes().await?; + + ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))?; + + Ok(()) + } + pub async fn next(&mut self) -> Result, Error> { let mut rbuf = BytesMut::new(); let mut len = 0; diff --git a/src/mariadb/protocol/deserialize.rs b/src/mariadb/protocol/deserialize.rs index 1ddb0bd6..d9a50c9f 100644 --- a/src/mariadb/protocol/deserialize.rs +++ b/src/mariadb/protocol/deserialize.rs @@ -8,14 +8,14 @@ use failure::Error; // access to the connection context. // Mainly used to simply to simplify number of parameters for deserializing functions pub struct DeContext<'a> { - pub conn: &'a mut ConnContext, + pub ctx: &'a mut ConnContext, pub decoder: Decoder<'a>, pub columns: Option, } impl<'a> DeContext<'a> { pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self { - DeContext { conn, decoder: Decoder::new(&buf), columns: None } + DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None } } } diff --git a/src/mariadb/protocol/packets/binary/com_stmt_prepare.rs b/src/mariadb/protocol/packets/binary/com_stmt_prepare.rs index 3aa47ae5..ca041261 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_prepare.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_prepare.rs @@ -2,7 +2,7 @@ use bytes::Bytes; #[derive(Debug)] pub struct ComStmtPrepare { - statement: Bytes + pub statement: Bytes } impl crate::mariadb::Serialize for ComStmtPrepare { diff --git a/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs b/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs index 90231379..31537b39 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_prepare_resp.rs @@ -17,7 +17,7 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp { .map(Result::unwrap) .collect::>(); - if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { EofPacket::deserialize(ctx)?; } @@ -32,7 +32,7 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp { .map(Result::unwrap) .collect::>(); - if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { EofPacket::deserialize(ctx)?; } @@ -55,7 +55,7 @@ mod test { use crate::{__bytes_builder, ConnectOptions, mariadb::{ConnContext, DeContext, Deserialize}}; #[test] - fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> { + fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> { #[rustfmt::skip] let buf = __bytes_builder!( // ---------------------------- // @@ -152,6 +152,140 @@ mod test { 0u8, 0u8 ); + let mut context = ConnContext::with_eof(); + let mut ctx = DeContext::new(&mut context, &buf); + + let message = ComStmtPrepareResp::deserialize(&mut ctx)?; + + Ok(()) + } + + #[test] + fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // ---------------------------- // + // Statement Prepared Ok Packet // + // ---------------------------- // + + // int<3> length + 0u8, 0u8, 0u8, + // int<1> seq_no + 0u8, + // int<1> 0x00 COM_STMT_PREPARE_OK header + 0u8, + // int<4> statement id + 1u8, 0u8, 0u8, 0u8, + // int<2> number of columns in the returned result set (or 0 if statement does not return result set) + 1u8, 0u8, + // int<2> number of prepared statement parameters ('?' placeholders) + 1u8, 0u8, + // string<1> -not used- + 0u8, + // int<2> number of warnings + 0u8, 0u8, + + // Param column definition + + // ------------------------ // + // Column Definition packet // + // ------------------------ // + // int<3> length + 52u8, 0u8, 0u8, + // int<1> seq_no + 3u8, + // string catalog (always 'def') + 3u8, b"def", + // string schema + 4u8, b"test", + // string table alias + 5u8, b"users", + // string table + 5u8, b"users", + // string column alias + 8u8, b"username", + // string column + 8u8, b"username", + // int length of fixed fields (=0xC) + 0x0C_u8, + // int<2> character set number + 8u8, 0u8, + // int<4> max. column size + 0xFF_u8, 0xFF_u8, 0u8, 0u8, + // int<1> Field types + 0xFC_u8, + // int<2> Field detail flag + 0x11_u8, 0x10_u8, + // int<1> decimals + 0u8, + // int<2> - unused - + 0u8, 0u8, + + // ---------- // + // EOF Packet // + // ---------- // + // int<3> length + 5u8, 0u8, 0u8, + // int<1> seq_no + 6u8, + // int<1> 0xfe : EOF header + 0xFE_u8, + // int<2> warning count + 0u8, 0u8, + // int<2> server status + 34u8, 0u8, + + // Result column definitions + + // ------------------------ // + // Column Definition packet // + // ------------------------ // + // int<3> length + 52u8, 0u8, 0u8, + // int<1> seq_no + 3u8, + // string catalog (always 'def') + 3u8, b"def", + // string schema + 4u8, b"test", + // string table alias + 5u8, b"users", + // string table + 5u8, b"users", + // string column alias + 8u8, b"username", + // string column + 8u8, b"username", + // int length of fixed fields (=0xC) + 0x0C_u8, + // int<2> character set number + 8u8, 0u8, + // int<4> max. column size + 0xFF_u8, 0xFF_u8, 0u8, 0u8, + // int<1> Field types + 0xFC_u8, + // int<2> Field detail flag + 0x11_u8, 0x10_u8, + // int<1> decimals + 0u8, + // int<2> - unused - + 0u8, 0u8, + + // ---------- // + // EOF Packet // + // ---------- // + // int<3> length + 5u8, 0u8, 0u8, + // int<1> seq_no + 6u8, + // int<1> 0xfe : EOF header + 0xFE_u8, + // int<2> warning count + 0u8, 0u8, + // int<2> server status + 34u8, 0u8 + ); + let mut context = ConnContext::new(); let mut ctx = DeContext::new(&mut context, &buf); diff --git a/src/mariadb/protocol/packets/handshake_response.rs b/src/mariadb/protocol/packets/handshake_response.rs index 96db1851..29ff1f47 100644 --- a/src/mariadb/protocol/packets/handshake_response.rs +++ b/src/mariadb/protocol/packets/handshake_response.rs @@ -21,7 +21,7 @@ pub struct HandshakeResponsePacket { impl Serialize for HandshakeResponsePacket { fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> { encoder.alloc_packet_header(); - encoder.seq_no(0); + encoder.seq_no(1); encoder.encode_int_u32(self.capabilities.bits() as u32); encoder.encode_int_u32(self.max_packet_size); diff --git a/src/mariadb/protocol/packets/initial.rs b/src/mariadb/protocol/packets/initial.rs index 95aade5a..6868462f 100644 --- a/src/mariadb/protocol/packets/initial.rs +++ b/src/mariadb/protocol/packets/initial.rs @@ -77,8 +77,7 @@ impl Deserialize for InitialHandshakePacket { auth_plugin_name = Some(decoder.decode_string_null()?); } - ctx.conn.capabilities = capabilities; - ctx.conn.last_seq_no = seq_no; + ctx.ctx.last_seq_no = seq_no; Ok(InitialHandshakePacket { length, diff --git a/src/mariadb/protocol/packets/result_set.rs b/src/mariadb/protocol/packets/result_set.rs index 07565b91..e78a8018 100644 --- a/src/mariadb/protocol/packets/result_set.rs +++ b/src/mariadb/protocol/packets/result_set.rs @@ -31,7 +31,7 @@ impl Deserialize for ResultSet { Vec::new() }; - let eof_packet = if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + let eof_packet = if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { Some(EofPacket::deserialize(ctx)?) } else { None @@ -62,7 +62,7 @@ impl Deserialize for ResultSet { } } - if ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { + if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) { OkPacket::deserialize(ctx)?; } else { EofPacket::deserialize(ctx)?; diff --git a/src/mariadb/protocol/types.rs b/src/mariadb/protocol/types.rs index 896dd84e..87f6418d 100644 --- a/src/mariadb/protocol/types.rs +++ b/src/mariadb/protocol/types.rs @@ -128,7 +128,7 @@ pub enum ParamFlag { impl Default for Capabilities { fn default() -> Self { - Capabilities::CLIENT_MYSQL + Capabilities::CLIENT_PROTOCOL_41 } }