diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs index 6a4c262f..1b560c3e 100644 --- a/src/mariadb/connection/establish.rs +++ b/src/mariadb/connection/establish.rs @@ -1,10 +1,5 @@ use super::Connection; -use crate::mariadb::protocol::{ - deserialize::{DeContext, Deserialize}, - packets::{handshake_response::HandshakeResponsePacket, initial::InitialHandshakePacket}, - server::Message as ServerMessage, - types::Capabilities, -}; +use crate::mariadb::{DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag}; use bytes::{BufMut, Bytes}; use failure::{err_msg, Error}; use crate::ConnectOptions; @@ -32,12 +27,12 @@ pub async fn establish<'a, 'b: 'a>( conn.send(handshake).await?; match conn.next().await? { - Some(ServerMessage::OkPacket(message)) => { + Some(Message::OkPacket(message)) => { conn.context.seq_no = message.seq_no; Ok(()) } - Some(ServerMessage::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))), + Some(Message::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))), Some(message) => { panic!("Did not receive OkPacket nor ErrPacket. Received: {:?}", message); @@ -51,100 +46,147 @@ pub async fn establish<'a, 'b: 'a>( #[cfg(test)] mod test { - use super::*; - use failure::Error; - - #[runtime::test] - async fn it_can_connect() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) - .await?; - - Ok(()) - } - - #[runtime::test] - async fn it_can_ping() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }).await?; - - conn.ping().await?; - - Ok(()) - } - - #[runtime::test] - async fn it_can_select_db() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }).await?; - - conn.select_db("test").await?; - - Ok(()) - } - - #[runtime::test] - async fn it_can_query() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }).await?; - - conn.select_db("test").await?; - - conn.query("SELECT * FROM users").await?; - - Ok(()) - } - - #[runtime::test] - async fn it_can_prepare() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }).await?; - - conn.select_db("test").await?; - - conn.prepare("SELECT * FROM users WHERE username = ?").await?; - - Ok(()) - } - - #[runtime::test] - async fn it_does_not_connect() -> Result<(), Error> { - match Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("roote"), - database: None, - password: None, - }) - .await - { - Ok(_) => Err(err_msg("Bad username still worked?")), - Err(_) => Ok(()), - } - } +// use super::*; +// use failure::Error; +// use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch}; +// +// #[runtime::test] +// async fn it_can_connect() -> Result<(), Error> { +// let mut conn = Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("root"), +// database: None, +// password: None, +// }) +// .await?; +// +// Ok(()) +// } +// +// #[runtime::test] +// async fn it_can_ping() -> Result<(), Error> { +// let mut conn = Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("root"), +// database: None, +// password: None, +// }).await?; +// +// conn.ping().await?; +// +// Ok(()) +// } +// +// #[runtime::test] +// async fn it_can_select_db() -> Result<(), Error> { +// let mut conn = Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("root"), +// database: None, +// password: None, +// }).await?; +// +// conn.select_db("test").await?; +// +// Ok(()) +// } +// +// #[runtime::test] +// async fn it_can_query() -> Result<(), Error> { +// let mut conn = Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("root"), +// database: None, +// password: None, +// }).await?; +// +// conn.select_db("test").await?; +// +// conn.query("SELECT * FROM users").await?; +// +// Ok(()) +// } +// +// #[runtime::test] +// async fn it_can_prepare() -> Result<(), Error> { +// let mut conn = Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("root"), +// database: None, +// password: None, +// }).await?; +// +// conn.select_db("test").await?; +// +// conn.prepare("SELECT * FROM users WHERE username = ?").await?; +// +// Ok(()) +// } +// +// #[runtime::test] +// async fn it_can_execute_prepared() -> Result<(), Error> { +// let mut conn = Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("root"), +// database: None, +// password: None, +// }).await?; +// +// conn.select_db("test").await?; +// +// let mut prepared = conn.prepare("SELECT * FROM users WHERE username = ?;").await?; +// +// println!("{:?}", prepared); +// +// if let Some(param_defs) = &mut prepared.param_defs { +// param_defs[0].field_type = FieldType::MysqlTypeBlob; +// } +// +// let exec = ComStmtExec { +// stmt_id: prepared.ok.stmt_id, +// flags: StmtExecFlag::CursorForUpdate, +// params: Some(vec![Some(Bytes::from_static(b"'daniel'"))]), +// param_defs: prepared.param_defs, +// }; +// +// +// conn.send(exec).await?; +// +// let buf = conn.stream.next_bytes().await?; +// +// println!("{:?}", buf); +// +// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?.rows); +// +// let fetch = ComStmtFetch { +// stmt_id: prepared.ok.stmt_id, +// rows: 1, +// }; +// +// conn.send(exec).await?; +// +// Ok(()) +// } +// +// #[runtime::test] +// async fn it_does_not_connect() -> Result<(), Error> { +// match Connection::establish(ConnectOptions { +// host: "127.0.0.1", +// port: 3306, +// user: Some("roote"), +// database: None, +// password: None, +// }) +// .await +// { +// Ok(_) => Err(err_msg("Bad username still worked?")), +// Err(_) => Ok(()), +// } +// } } diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index 4d5531dc..dd5a9db4 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -151,16 +151,14 @@ impl Connection { Ok(()) } - pub async fn prepare(&mut self, query: &str) -> Result<(), Error> { + pub async fn prepare(&mut self, query: &str) -> Result { self.send(ComStmtPrepare { statement: Bytes::from(query), }).await?; let buf = self.stream.next_bytes().await?; - ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))?; - - Ok(()) + ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf)) } pub async fn next(&mut self) -> Result, Error> { diff --git a/src/mariadb/mod.rs b/src/mariadb/mod.rs index 913f6489..e7e8702a 100644 --- a/src/mariadb/mod.rs +++ b/src/mariadb/mod.rs @@ -47,3 +47,4 @@ pub use protocol::FieldType; pub use protocol::FieldDetailFlag; pub use protocol::SessionChangeType; pub use protocol::StmtExecFlag; +pub use protocol::ComStmtFetch; diff --git a/src/mariadb/protocol/decode.rs b/src/mariadb/protocol/decode.rs index 5a6709cc..7ea7c1d9 100644 --- a/src/mariadb/protocol/decode.rs +++ b/src/mariadb/protocol/decode.rs @@ -76,30 +76,30 @@ impl<'a> Decoder<'a> { // 0xFF then there was an error. // If the first byte is not in the previous list then that byte is the int value. #[inline] - pub fn decode_int_lenenc(&mut self) -> Option { + pub fn decode_int_lenenc(&mut self) -> Option { match self.buf[self.index] { 0xFB => { self.index += 1; None } 0xFC => { - let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as i64); + let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as u64); self.index += 3; value } 0xFD => { - let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as i64); + let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as u64); self.index += 4; value } 0xFE => { - let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as i64); + let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as u64); self.index += 9; value } 0xFF => panic!("int unprocessable first byte 0xFF"), _ => { - let value = Some(self.buf[self.index] as i64); + let value = Some(self.buf[self.index] as u64); self.index += 1; value } @@ -176,8 +176,8 @@ impl<'a> Decoder<'a> { // Decode a string which is a string of fixed length. #[inline] - pub fn decode_string_fix(&mut self, length: u32) -> Bytes { - let value = self.buf.slice(self.index, self.index + length as usize); + pub fn decode_string_fix(&mut self, length: usize) -> Bytes { + let value = self.buf.slice(self.index, self.index + length); self.index = self.index + length as usize; value } @@ -212,8 +212,8 @@ impl<'a> Decoder<'a> { // Same as the string counter part, but copied to maintain consistency with the spec. #[inline] - pub fn decode_byte_fix(&mut self, length: u32) -> Bytes { - let value = self.buf.slice(self.index, self.index + length as usize); + pub fn decode_byte_fix(&mut self, length: usize) -> Bytes { + let value = self.buf.slice(self.index, self.index + length); self.index = self.index + length as usize; value } @@ -242,6 +242,100 @@ impl<'a> Decoder<'a> { self.index = self.buf.len(); value } + + #[inline] + pub fn decode_binary_decimal(&mut self) -> Bytes { + self.decode_string_lenenc() + } + + #[inline] + pub fn decode_binary_double(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 8); + self.index += 8; + value + } + + #[inline] + pub fn decode_binary_bigint(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 8); + self.index += 8; + value + } + + #[inline] + pub fn decode_binary_int(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 4); + self.index += 4; + value + } + + #[inline] + pub fn decode_binary_mediumint(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 4); + self.index += 4; + value + } + + #[inline] + pub fn decode_binary_float(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 4); + self.index += 4; + value + } + + #[inline] + pub fn decode_binary_smallint(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 2); + self.index += 2; + value + } + + #[inline] + pub fn decode_binary_year(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 2); + self.index += 2; + value + } + + #[inline] + pub fn decode_binary_tinyint(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 1); + self.index += 1; + value + } + + #[inline] + pub fn decode_binary_date(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 5); + self.index += 5; + value + } + + #[inline] + pub fn decode_binary_timestamp(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 12); + self.index += 12; + value + } + + #[inline] + pub fn decode_binary_datetime(&mut self) -> Bytes { + let value = self.buf.slice(self.index, self.index + 12); + self.index += 12; + value + } + + #[inline] + pub fn decode_binary_time(&mut self) -> Bytes { + let length = self.decode_int_u8(); + if length != 8 && length != 12 { + panic!("Date length is not 8 or 12 (the only two possible values)"); + } + let value = self.buf.slice(self.index, self.index + length as usize); + self.index += length as usize; + value + + } } #[cfg(test)] diff --git a/src/mariadb/protocol/deserialize.rs b/src/mariadb/protocol/deserialize.rs index d9a50c9f..c6a3d71a 100644 --- a/src/mariadb/protocol/deserialize.rs +++ b/src/mariadb/protocol/deserialize.rs @@ -1,5 +1,4 @@ -use super::decode::Decoder; -use crate::mariadb::connection::{ConnContext, Connection}; +use crate::mariadb::{Decoder, ConnContext, Connection, ColumnDefPacket}; use bytes::Bytes; use failure::Error; @@ -10,12 +9,13 @@ use failure::Error; pub struct DeContext<'a> { pub ctx: &'a mut ConnContext, pub decoder: Decoder<'a>, - pub columns: Option, + pub columns: Option, + pub column_defs: Option>, } impl<'a> DeContext<'a> { pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self { - DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None } + DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None , column_defs: None } } } diff --git a/src/mariadb/protocol/packets/binary/mod.rs b/src/mariadb/protocol/packets/binary/mod.rs index 7fb6736f..6702a663 100644 --- a/src/mariadb/protocol/packets/binary/mod.rs +++ b/src/mariadb/protocol/packets/binary/mod.rs @@ -5,6 +5,7 @@ pub mod com_stmt_close; pub mod com_stmt_exec; pub mod com_stmt_fetch; pub mod com_stmt_reset; +pub mod result_row; pub use com_stmt_prepare::ComStmtPrepare; pub use com_stmt_prepare_ok::ComStmtPrepareOk; diff --git a/src/mariadb/protocol/packets/binary/result_row.rs b/src/mariadb/protocol/packets/binary/result_row.rs new file mode 100644 index 00000000..5cad8b85 --- /dev/null +++ b/src/mariadb/protocol/packets/binary/result_row.rs @@ -0,0 +1,41 @@ +use bytes::Bytes; + +#[derive(Debug, Default)] +pub struct ResultRow { + pub columns: Vec> +} + +impl crate::mariadb::Deserialize for ResultRow { + fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result { + let decoder = &mut ctx.decoder; + + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_u8(); + + let header = decoder.decode_int_u8(); + + let bitmap = if let Some(columns) = ctx.columns { + let size = (columns + 9) / 8; + decoder.decode_byte_fix(size as usize) + } else { + Bytes::new() + }; + + let row = if let Some(columns) = ctx.columns { + (0..columns).map(|index| { + if (1 << index) & (bitmap[index/8] << (index % 8)) == 0 { + None + } else { + match ctx.column_defs[index] { + + } + decoder.decode_binary_column(&ctx.column_defs) + } + }).collect::>() + } else { + Vec::new() + }; + + Ok(ResultRow::default()) + } +} diff --git a/src/mariadb/protocol/packets/column.rs b/src/mariadb/protocol/packets/column.rs index e728b3e9..763b9617 100644 --- a/src/mariadb/protocol/packets/column.rs +++ b/src/mariadb/protocol/packets/column.rs @@ -9,7 +9,7 @@ use crate::mariadb::{DeContext, Deserialize}; pub struct ColumnPacket { pub length: u32, pub seq_no: u8, - pub columns: Option, + pub columns: Option, } impl Deserialize for ColumnPacket { diff --git a/src/mariadb/protocol/packets/column_def.rs b/src/mariadb/protocol/packets/column_def.rs index 68890ed6..695f287f 100644 --- a/src/mariadb/protocol/packets/column_def.rs +++ b/src/mariadb/protocol/packets/column_def.rs @@ -13,7 +13,7 @@ pub struct ColumnDefPacket { pub table: Bytes, pub column_alias: Bytes, pub column: Bytes, - pub length_of_fixed_fields: Option, + pub length_of_fixed_fields: Option, pub char_set: i16, pub max_columns: i32, pub field_type: FieldType, diff --git a/src/mariadb/protocol/packets/initial.rs b/src/mariadb/protocol/packets/initial.rs index 6868462f..4a9e5d39 100644 --- a/src/mariadb/protocol/packets/initial.rs +++ b/src/mariadb/protocol/packets/initial.rs @@ -67,7 +67,7 @@ impl Deserialize for InitialHandshakePacket { let mut scramble: Option = None; if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() { let len = std::cmp::max(12, plugin_data_length as usize - 9); - scramble = Some(decoder.decode_string_fix(len as u32)); + scramble = Some(decoder.decode_string_fix(len as usize)); // Skip reserve byte decoder.skip_bytes(1); } diff --git a/src/mariadb/protocol/packets/ok.rs b/src/mariadb/protocol/packets/ok.rs index cdd47a66..f804fa0e 100644 --- a/src/mariadb/protocol/packets/ok.rs +++ b/src/mariadb/protocol/packets/ok.rs @@ -8,8 +8,8 @@ use crate::mariadb::{DeContext, Deserialize, ServerStatusFlag, pub struct OkPacket { pub length: u32, pub seq_no: u8, - pub affected_rows: Option, - pub last_insert_id: Option, + pub affected_rows: Option, + pub last_insert_id: Option, pub server_status: ServerStatusFlag, pub warning_count: i16, pub info: Bytes,