diff --git a/src/mariadb/connection.rs b/src/mariadb/connection.rs index 4cea0c56..10c2aeae 100644 --- a/src/mariadb/connection.rs +++ b/src/mariadb/connection.rs @@ -1,20 +1,63 @@ -use crate::{ - io::{Buf, BufMut, BufStream}, - mariadb::protocol::{ComPing, Encode}, -}; use byteorder::{ByteOrder, LittleEndian}; use std::io; use tokio::net::TcpStream; -use crate::mariadb::protocol::{OkPacket, ErrPacket, Capabilities}; +use crate::mariadb::protocol::{OkPacket, ErrPacket, Capabilities, ComPing, ColumnCountPacket, ColumnDefinitionPacket, EofPacket, ComStmtExecute, StmtExecFlag, ResultRow, ComQuit, ComStmtPrepare, ComStmtPrepareOk}; +use url::Url; +use std::net::{IpAddr, SocketAddr}; +use std::future::Future; +use crate::error::DatabaseError; +use crate::{connection::RawConnection, io::{Buf, BufMut, BufStream}, mariadb::protocol::Encode, Error, Backend}; +use super::establish; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use crate::mariadb::{MariaDb, MariaDbRow, MariaDbQueryParameters}; +use crate::Result; -pub struct Connection { - stream: BufStream, - capabilities: Capabilities, +pub struct MariaDbRawConnection { + pub(crate) stream: BufStream, + pub(crate) rbuf: Vec, + pub(crate) capabilities: Capabilities, next_seq_no: u8, } -impl Connection { - pub async fn ping(&mut self) -> crate::Result<()> { +impl MariaDbRawConnection { + async fn establish(url: &str) -> Result { + // TODO: Handle errors + let url = Url::parse(url).unwrap(); + + let host = url.host_str().unwrap_or("127.0.0.1"); + let port = url.port().unwrap_or(3306); + + // TODO: handle errors + let host: IpAddr = host.parse().unwrap(); + let addr: SocketAddr = (host, port).into(); + + let stream = TcpStream::connect(&addr).await?; + + let mut conn = Self { + stream: BufStream::new(stream), + rbuf: Vec::with_capacity(8 * 1024), + capabilities: Capabilities::empty(), + next_seq_no: 0, + }; + + establish::establish(&mut conn, &url).await?; + + Ok(conn) + } + + pub async fn close(&mut self) -> Result<()> { + // Send the quit command + + self.start_sequence(); + self.write(ComQuit); + + self.stream.flush().await?; + + Ok(()) + } + + pub async fn ping(&mut self) -> Result<()> { // Send the ping command and wait for (and drop) an OK packet self.start_sequence(); @@ -27,14 +70,14 @@ impl Connection { Ok(()) } - async fn receive(&mut self) -> crate::Result<&[u8]> { + pub(crate) async fn receive(&mut self) -> Result<&[u8]> { Ok(self .try_receive() .await? - .ok_or(io::ErrorKind::UnexpectedEof)?) + .ok_or(Error::Io(io::ErrorKind::UnexpectedEof.into()))?) } - async fn try_receive(&mut self) -> crate::Result> { + async fn try_receive(&mut self) -> Result> { // Read the packet header which contains the length and the sequence number // https://mariadb.com/kb/en/library/0-packet/#standard-packet let mut header = ret_if_none!(self.stream.peek(4).await?); @@ -62,7 +105,7 @@ impl Connection { self.next_seq_no = 0; } - fn write(&mut self, packet: T) { + pub(crate) fn write(&mut self, packet: T) { let buf = self.stream.buffer_mut(); // Allocate room for the header that we write after the packet; @@ -89,16 +132,24 @@ impl Connection { // Decode an OK packet or bubble an ERR packet as an error // to terminate immediately - async fn receive_ok_or_err(&mut self) -> crate::Result { + pub(crate) async fn receive_ok_or_err(&mut self) -> Result { + let capabilities = self.capabilities; let mut buf = self.receive().await?; Ok(match buf[0] { - 0xfe | 0x00 => OkPacket::decode(buf, self.capabilities)?, + 0xfe | 0x00 => OkPacket::decode(buf, capabilities)?, 0xff => { let err = ErrPacket::decode(buf)?; // TODO: Bubble as Error::Database - panic!("received db err = {:?}", err); +// panic!("received db err = {:?}", err); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("{:?}", + err + ), + ) + .into()); } id => { @@ -113,4 +164,203 @@ impl Connection { } }) } -} \ No newline at end of file + + // This should not be used by the user. It's mean for `RawConnection` impl + // This assumes the buffer has been set and all it needs is a flush + async fn exec_prepare(&mut self) -> Result { + self.stream.flush().await?; + + // COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF) + let mut packet = self.receive().await?; + let ok = match packet[0] { + 0xFF => { + let err = ErrPacket::decode(packet)?; + + // TODO: Bubble as Error::Database + panic!("received db err = {:?}", err); + } + + _ => ComStmtPrepareOk::decode(packet)?, + }; + + // Skip decoding Column Definition packets for the result from a prepare statement + for _ in 0..ok.columns { + let _ = self.receive().await?; + } + + if ok.columns > 0 + && !self + .capabilities + .contains(Capabilities::CLIENT_DEPRECATE_EOF) + { + // TODO: Should we do something with the warning indicators here? + let _eof = EofPacket::decode(self.receive().await?)?; + } + + Ok(ok.statement_id) + } + + async fn prepare<'c>(&'c mut self, statement: &'c str) -> Result { + self.stream.flush().await?; + + self.start_sequence(); + self.write(ComStmtPrepare { statement }); + + self.exec_prepare().await + } + + async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result { + // TODO: EXECUTE(READ_ONLY) => FETCH instead of EXECUTE(NO) + + // SEND ================ + self.start_sequence(); + self.write(ComStmtExecute { + statement_id, + params: &[], + null: &[], + flags: StmtExecFlag::NO_CURSOR, + param_types: &[] + }); + self.stream.flush().await?; + // ===================== + + // Row Counter, used later + let mut rows = 0u64; + let capabilities = self.capabilities; + let has_eof = capabilities + .contains(Capabilities::CLIENT_DEPRECATE_EOF); + + let packet = self.receive().await?; + if packet[0] == 0x00 { + let _ok = OkPacket::decode(packet, capabilities)?; + } else if packet[0] == 0xFF { + let err = ErrPacket::decode(packet)?; + panic!("received db err = {:?}", err); + } else { + // A Resultset starts with a [ColumnCountPacket] which is a single field that encodes + // how many columns we can expect when fetching rows from this statement + let column_count: u64 = ColumnCountPacket::decode(packet)?.columns; + + // Next we have a [ColumnDefinitionPacket] which verbosely explains each minute + // detail about the column in question including table, aliasing, and type + // TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation + let mut columns = vec![]; + for _ in 0..column_count { + columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?); + } + + // When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by + // an EOF packet + if !has_eof { + let _eof = EofPacket::decode(self.receive().await?)?; + } + + // For each row in the result set we will receive a ResultRow packet. + // We may receive an [OkPacket], [EofPacket], or [ErrPacket] (depending on if EOFs are enabled) to finalize the iteration. + loop { + let packet = self.receive().await?; + if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF { + // NOTE: It's possible for a ResultRow to start with 0xFE (which would normally signify end-of-rows) + // but it's not possible for an Ok/Eof to be larger than 0xFF_FF_FF. + if !has_eof { + let _eof = EofPacket::decode(packet)?; + } else { + let _ok = OkPacket::decode(packet, capabilities)?; + } + + break; + } else if packet[0] == 0xFF { + let err = ErrPacket::decode(packet)?; + panic!("received db err = {:?}", err); + } else { + // Ignore result rows; exec only returns number of affected rows; + let _ = ResultRow::decode(packet, &columns)?; + + // For every row we decode we increment counter + rows = rows + 1; + } + } + } + + Ok(rows) + } +} + +enum ExecResult { + NoRows(OkPacket), + Rows(Vec), +} + +impl RawConnection for MariaDbRawConnection { + type Backend = MariaDb; + + fn establish(url: &str) -> BoxFuture> + where + Self: Sized { + Box::pin(MariaDbRawConnection::establish(url)) + } + + fn close(&mut self) -> BoxFuture<'_, Result<()>> { + Box::pin(self.close()) + } + + fn ping(&mut self) -> BoxFuture<'_, Result<()>> { + Box::pin(self.ping()) + } + + fn execute<'c>( + &'c mut self, + query: &str, + params: MariaDbQueryParameters, + ) -> BoxFuture<'c, Result> { + // Write prepare statement to buffer + self.start_sequence(); + self.write(ComStmtPrepare { + statement: query + }); + + Box::pin(async move { + let statement_id = self.exec_prepare().await?; + + let affected = self.execute(statement_id, params).await?; + + Ok(affected) + }) + } + + fn fetch<'c>( + &'c mut self, + query: &str, + params: MariaDbQueryParameters, + ) -> BoxStream<'c, Result> { + unimplemented!(); + } + + fn fetch_optional<'c>( + &'c mut self, + query: &str, + params: MariaDbQueryParameters, + ) -> BoxFuture<'c, Result::Row>>> { + unimplemented!(); + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::Error; + + #[tokio::test] + async fn it_can_connect() -> Result<()> { + MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?; + Ok(()) + } + + #[tokio::test] + async fn it_fails_to_connect_with_bad_username() -> Result<()> { + match MariaDbRawConnection::establish("mariadb://roote@127.0.0.1:3306/test").await { + Ok(_) => panic!("Somehow connected to database with incorrect username"), + Err(_) => Ok(()) + } + } +} diff --git a/src/mariadb/establish.rs b/src/mariadb/establish.rs new file mode 100644 index 00000000..51510dc3 --- /dev/null +++ b/src/mariadb/establish.rs @@ -0,0 +1,42 @@ +use crate::Result; +use url::Url; +use crate::mariadb::protocol::{HandshakeResponsePacket, InitialHandshakePacket, Encode, Capabilities}; +use crate::mariadb::connection::MariaDbRawConnection; + +pub(crate) async fn establish(conn: &mut MariaDbRawConnection, url: &Url) -> Result<()> { + let initial = InitialHandshakePacket::decode(conn.receive().await?)?; + + // TODO: Capabilities::SECURE_CONNECTION + // TODO: Capabilities::CONNECT_ATTRS + // TODO: Capabilities::PLUGIN_AUTH + // TODO: Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA + // TODO: Capabilities::TRANSACTIONS + // TODO: Capabilities::CLIENT_DEPRECATE_EOF + // TODO?: Capabilities::CLIENT_SESSION_TRACK + let mut capabilities = Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CONNECT_WITH_DB; + + let response = HandshakeResponsePacket { + // TODO: Find a good value for [max_packet_size] + capabilities, + max_packet_size: 1024, + client_collation: 192, // utf8_unicode_ci + username: url.username(), + database: &url.path()[1..], + auth_data: None, + auth_plugin_name: None, + connection_attrs: &[] + }; + + // The AND between our supported capabilities and the servers' is + // what we can use so remember it on the connection + conn.capabilities = capabilities & initial.capabilities; + + conn.write(response); + conn.stream.flush().await?; + + let _ = conn.receive_ok_or_err().await?; + + // TODO: If CONNECT_WITH_DB is not supported we need to send an InitDb command just after establish + + Ok(()) +} diff --git a/src/mariadb/mod.rs b/src/mariadb/mod.rs index 63fa5702..75bb16d1 100644 --- a/src/mariadb/mod.rs +++ b/src/mariadb/mod.rs @@ -1,19 +1,21 @@ // TODO: Remove after acitve development #![allow(unused)] -// mod backend; +mod row; +mod backend; mod connection; +mod establish; mod io; mod protocol; -// mod query; +mod query; pub mod types; -//pub use self::{ -// backend::MariaDb, -// connection.bak::MariaDbRawConnection, -// query::MariaDbQueryParameters, -// row::MariaDbRow, -//}; +pub use self::{ + backend::MariaDb, + connection::MariaDbRawConnection, + query::MariaDbQueryParameters, + row::MariaDbRow, +}; // pub use io::{BufExt, BufMutExt}; // pub use protocol::{ diff --git a/src/mariadb/protocol/binary/com_stmt_exec.rs b/src/mariadb/protocol/binary/com_stmt_exec.rs index 24537482..8bb43576 100644 --- a/src/mariadb/protocol/binary/com_stmt_exec.rs +++ b/src/mariadb/protocol/binary/com_stmt_exec.rs @@ -21,7 +21,7 @@ bitflags::bitflags! { // https://mariadb.com/kb/en/library/com_stmt_execute /// Executes a previously prepared statement. #[derive(Debug)] -pub struct ComStmtExec<'a> { +pub struct ComStmtExecute<'a> { pub statement_id: u32, pub flags: StmtExecFlag, pub params: &'a [u8], @@ -29,7 +29,7 @@ pub struct ComStmtExec<'a> { pub param_types: &'a [MariaDbTypeMetadata], } -impl Encode for ComStmtExec<'_> { +impl Encode for ComStmtExecute<'_> { fn encode(&self, buf: &mut Vec, _: Capabilities) { // COM_STMT_EXECUTE : int<1> buf.put_u8(BinaryProtocol::ComStmtExec as u8); @@ -75,7 +75,7 @@ mod tests { fn it_encodes_com_stmt_exec() { let mut buf = Vec::new(); - ComStmtExec { + ComStmtExecute { statement_id: 1, flags: StmtExecFlag::NO_CURSOR, null: &vec![], diff --git a/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs b/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs index 7d86f55b..5a07dd71 100644 --- a/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs +++ b/src/mariadb/protocol/binary/com_stmt_prepare_ok.rs @@ -5,7 +5,7 @@ use std::io; // https://mariadb.com/kb/en/library/com_stmt_prepare/#com_stmt_prepare_ok #[derive(Debug)] pub struct ComStmtPrepareOk { - pub statement_id: i32, + pub statement_id: u32, /// Number of columns in the returned result set (or 0 if statement does not return result set). pub columns: u16, @@ -18,7 +18,7 @@ pub struct ComStmtPrepareOk { } impl ComStmtPrepareOk { - fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> io::Result { let header = buf.get_u8()?; if header != 0x00 { @@ -28,7 +28,7 @@ impl ComStmtPrepareOk { )); } - let statement_id = buf.get_i32::()?; + let statement_id = buf.get_u32::()?; let columns = buf.get_u16::()?; let params = buf.get_u16::()?; diff --git a/src/mariadb/protocol/binary/mod.rs b/src/mariadb/protocol/binary/mod.rs index 56210f98..6c3e6bec 100644 --- a/src/mariadb/protocol/binary/mod.rs +++ b/src/mariadb/protocol/binary/mod.rs @@ -6,7 +6,7 @@ pub mod com_stmt_prepare_ok; pub mod com_stmt_reset; pub use com_stmt_close::ComStmtClose; -pub use com_stmt_exec::ComStmtExec; +pub use com_stmt_exec::{ComStmtExecute, StmtExecFlag}; pub use com_stmt_fetch::ComStmtFetch; pub use com_stmt_prepare::ComStmtPrepare; pub use com_stmt_prepare_ok::ComStmtPrepareOk; diff --git a/src/mariadb/protocol/connect/initial.rs b/src/mariadb/protocol/connect/initial.rs index 71a36277..b93426d3 100644 --- a/src/mariadb/protocol/connect/initial.rs +++ b/src/mariadb/protocol/connect/initial.rs @@ -21,7 +21,7 @@ pub struct InitialHandshakePacket { } impl InitialHandshakePacket { - fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> io::Result { let protocol_version = buf.get_u8()?; let server_version = buf.get_str_nul()?.to_owned(); let connection_id = buf.get_u32::()?; diff --git a/src/mariadb/protocol/mod.rs b/src/mariadb/protocol/mod.rs index bebd502f..35ff24e2 100644 --- a/src/mariadb/protocol/mod.rs +++ b/src/mariadb/protocol/mod.rs @@ -12,7 +12,7 @@ mod server_status; mod text; pub use binary::{ - ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtReset, + ComStmtClose, ComStmtExecute, StmtExecFlag, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtReset, }; pub use capabilities::Capabilities; pub use connect::{ diff --git a/src/mariadb/protocol/response/column_count.rs b/src/mariadb/protocol/response/column_count.rs index 43112873..c4cdd1b2 100644 --- a/src/mariadb/protocol/response/column_count.rs +++ b/src/mariadb/protocol/response/column_count.rs @@ -12,7 +12,7 @@ pub struct ColumnCountPacket { } impl ColumnCountPacket { - fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> io::Result { let columns = buf.get_uint_lenenc::()?.unwrap_or(0); Ok(Self { columns }) diff --git a/src/mariadb/protocol/response/column_def.rs b/src/mariadb/protocol/response/column_def.rs index 30f3c9ed..9a6ac64a 100644 --- a/src/mariadb/protocol/response/column_def.rs +++ b/src/mariadb/protocol/response/column_def.rs @@ -25,7 +25,7 @@ pub struct ColumnDefinitionPacket { } impl ColumnDefinitionPacket { - fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> io::Result { // string catalog (always 'def') let _catalog = buf.get_str_lenenc::()?; // TODO: Assert that this is always DEF diff --git a/src/mariadb/protocol/response/eof.rs b/src/mariadb/protocol/response/eof.rs index 2f515e97..34b90dc2 100644 --- a/src/mariadb/protocol/response/eof.rs +++ b/src/mariadb/protocol/response/eof.rs @@ -15,7 +15,7 @@ pub struct EofPacket { } impl EofPacket { - fn decode(mut buf: &[u8]) -> io::Result { + pub(crate) fn decode(mut buf: &[u8]) -> io::Result { let header = buf.get_u8()?; if header != 0xFE { return Err(io::Error::new( diff --git a/src/mariadb/query.rs b/src/mariadb/query.rs index 76946f0b..d82065be 100644 --- a/src/mariadb/query.rs +++ b/src/mariadb/query.rs @@ -4,6 +4,7 @@ use crate::{ serialize::{IsNull, ToSql}, types::HasSqlType, }; +use crate::mariadb::types::MariaDbTypeMetadata; pub struct MariaDbQueryParameters { param_types: Vec, diff --git a/src/mariadb/row.rs b/src/mariadb/row.rs new file mode 100644 index 00000000..660876ab --- /dev/null +++ b/src/mariadb/row.rs @@ -0,0 +1,26 @@ +use crate::row::Row; +use crate::mariadb::protocol::ResultRow; +use crate::mariadb::MariaDb; + +pub struct MariaDbRow(pub(super) ResultRow); + +impl Row for MariaDbRow { + type Backend = MariaDb; + + #[inline] + fn is_empty(&self) -> bool { + self.0.values.is_empty() + } + + #[inline] + fn len(&self) -> usize { + self.0.values.len() + } + + #[inline] + fn get_raw(&self, index: usize) -> Option<&[u8]> { + self.0.values[index] + .as_ref() + .map(|value| unsafe { value.as_ref() }) + } +} diff --git a/src/mariadb/types/mod.rs b/src/mariadb/types/mod.rs index 3e17fb54..05073696 100644 --- a/src/mariadb/types/mod.rs +++ b/src/mariadb/types/mod.rs @@ -1,5 +1,6 @@ use super::protocol::{FieldType, ParameterFlag}; use crate::types::TypeMetadata; +use crate::mariadb::MariaDb; #[derive(Debug)] pub struct MariaDbTypeMetadata { @@ -7,6 +8,6 @@ pub struct MariaDbTypeMetadata { pub param_flag: ParameterFlag, } -//impl TypeMetadata for MariaDb { -// type TypeMetadata = MariaDbTypeMetadata; -//} +impl TypeMetadata for MariaDb { + type TypeMetadata = MariaDbTypeMetadata; +}