diff --git a/src/mariadb/connection.bak/establish.rs b/src/mariadb/connection.bak/establish.rs deleted file mode 100644 index 032ff491..00000000 --- a/src/mariadb/connection.bak/establish.rs +++ /dev/null @@ -1,147 +0,0 @@ -use super::MariaDbRawConnection; -use crate::mariadb::protocol::{ - Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket, HandshakeResponsePacket, - InitialHandshakePacket, OkPacket, ProtocolType, StmtExecFlag, -}; -use bytes::Bytes; -use failure::{err_msg, Error}; -use std::ops::BitAnd; -use url::Url; - -pub async fn establish(conn: &mut MariaDbRawConnection, url: Url) -> Result<(), Error> { - let buf = conn.stream.next_packet().await?; - let mut de_ctx = DeContext::new(&mut conn.context, buf); - let initial = InitialHandshakePacket::decode(&mut de_ctx)?; - - de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities); - - let handshake: HandshakeResponsePacket = HandshakeResponsePacket { - // Minimum client capabilities required to establish connection.bak - capabilities: de_ctx.ctx.capabilities, - max_packet_size: 1024, - extended_capabilities: Some(Capabilities::from_bits_truncate(0)), - username: url.username(), - ..Default::default() - }; - - conn.send(handshake).await?; - - let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?); - - match ctx.decoder.peek_tag() { - 0xFF => { - return Err(ErrPacket::decode(&mut ctx)?.into()); - } - 0x00 => { - OkPacket::decode(&mut ctx)?; - } - _ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"), - } - - Ok(()) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::mariadb::{ComStmtFetch, ComStmtPrepareResp, FieldType, ResultSet}; - use failure::Error; - - #[tokio::test] - async fn it_can_connect() -> Result<(), Error> { - let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?; - - Ok(()) - } - - #[tokio::test] - async fn it_can_ping() -> Result<(), Error> { - let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?; - - conn.ping().await?; - - Ok(()) - } - - #[tokio::test] - async fn it_can_select_db() -> Result<(), Error> { - let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?; - - conn.select_db("test").await?; - - Ok(()) - } - - #[tokio::test] - async fn it_can_query() -> Result<(), Error> { - let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?; - - conn.select_db("test").await?; - - conn.query("SELECT * FROM users").await?; - - Ok(()) - } - - #[tokio::test] - async fn it_can_prepare() -> Result<(), Error> { - let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?; - - conn.select_db("test").await?; - - conn.prepare("SELECT * FROM users WHERE username = ?") - .await?; - - Ok(()) - } - - #[tokio::test] - async fn it_can_execute_prepared() -> Result<(), Error> { - let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?; - - conn.select_db("test").await?; - - let mut prepared = conn - .prepare("SELECT id FROM users WHERE username=?") - .await?; - - let exec = ComStmtExec { - stmt_id: prepared.ok.stmt_id, - flags: StmtExecFlag::NO_CURSOR, - params: Some(vec![Some(Bytes::from_static(b"josh"))]), - param_defs: prepared.param_defs, - }; - - conn.send(exec).await?; - - let mut ctx = DeContext::with_stream(&mut conn.context, &mut conn.stream); - ctx.next_packet().await?; - ctx.columns = Some(prepared.ok.columns as u64); - ctx.column_defs = prepared.res_columns; - - println!("{:?}", ctx.columns); - println!("{:?}", ctx.column_defs); - - match ctx.decoder.peek_tag() { - 0xFF => { - ErrPacket::decode(&mut ctx)?; - } - 0x00 => { - OkPacket::decode(&mut ctx)?; - } - _ => { - ResultSet::deserialize(ctx, ProtocolType::Binary).await?; - } - } - - Ok(()) - } - - #[tokio::test] - async fn it_does_not_connect() -> Result<(), Error> { - match MariaDbRawConnection::establish(&"mariadb//roote@127.0.0.1:3306").await { - Ok(_) => Err(err_msg("Bad username still worked?")), - Err(_) => Ok(()), - } - } -} diff --git a/src/mariadb/connection.bak/execute.rs b/src/mariadb/connection.bak/execute.rs deleted file mode 100644 index 17dccffa..00000000 --- a/src/mariadb/connection.bak/execute.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::mariadb::MariaDbRawConnection; -use std::io; - -pub async fn execute(conn: &mut MariaDbRawConnection) -> io::Result { - conn.flush().await?; - - let mut rows: u64 = 0; - while let Some(message) = conn.receive().await? {} -} diff --git a/src/mariadb/connection.bak/mod.rs b/src/mariadb/connection.bak/mod.rs deleted file mode 100644 index d327ca4d..00000000 --- a/src/mariadb/connection.bak/mod.rs +++ /dev/null @@ -1,267 +0,0 @@ -use crate::{ - error::ErrorKind, - mariadb::protocol::{ - encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare, - ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader, - ProtocolType, ResultSet, ServerStatusFlag, - }, -}; -use byteorder::{ByteOrder, LittleEndian}; -use bytes::{BufMut, Bytes, BytesMut}; -use core::convert::TryFrom; -use failure::Error; -use futures_core::future::BoxFuture; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use tokio::{ - io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, - net::TcpStream, -}; -use url::Url; - -mod establish; -mod execute; - -pub struct MariaDbRawConnection { - pub stream: TcpStream, - - // Buffer used when serializing outgoing messages - pub wbuf: Vec, - - pub rbuf: BytesMut, - pub read_index: usize, - - // Context for the connection.bak - // Explicitly declared to easily send to deserializers - pub context: ConnContext, -} - -#[derive(Debug)] -pub struct ConnContext { - // MariaDB Connection ID - pub connection_id: i32, - - // Sequence Number - pub seq_no: u8, - - // Last sequence number return by MariaDB - pub last_seq_no: u8, - - // Server Capabilities - pub capabilities: Capabilities, - - // Server status - pub status: ServerStatusFlag, -} - -impl ConnContext { - #[cfg(test)] - pub fn new() -> Self { - ConnContext { - connection_id: 0, - seq_no: 2, - last_seq_no: 0, - capabilities: Capabilities::CLIENT_PROTOCOL_41, - status: ServerStatusFlag::SERVER_STATUS_IN_TRANS, - } - } - - #[cfg(test)] - pub fn with_eof() -> Self { - ConnContext { - connection_id: 0, - seq_no: 2, - last_seq_no: 0, - capabilities: Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CLIENT_DEPRECATE_EOF, - status: ServerStatusFlag::SERVER_STATUS_IN_TRANS, - } - } -} - -impl MariaDbRawConnection { - pub async fn establish(url: &str) -> Result { - // TODO: Handle errors - let url = Url::parse(url).map_err(ErrorKind::UrlParse)?; - println!("{:?}", url); - - let host = url.host_str().unwrap_or("localhost"); - let port = url.port().unwrap_or(3306); - - // FIXME: handle errors - let host: IpAddr = host.parse().unwrap(); - let addr: SocketAddr = (host, port).into(); - let stream = TcpStream::connect(&addr).await?; - let mut conn: MariaDbRawConnection = Self { - stream, - wbuf: Vec::with_capacity(1024), - rbuf: BytesMut::with_capacity(8 * 1024), - read_index: 0, - context: ConnContext { - connection_id: -1, - seq_no: 1, - last_seq_no: 0, - capabilities: Capabilities::CLIENT_PROTOCOL_41, - status: ServerStatusFlag::default(), - }, - }; - - establish::establish(&mut conn, url).await?; - - Ok(conn) - } - - // pub async fn send(&mut self, message: S) -> Result<(), Error> - // where - // S: Encode, - // { - // self.wbuf.clear(); - // message.encode(&mut self.wbuf, &mut self.context)?; - // self.stream.inner.write_all(&self.wbuf).await?; - // Ok(()) - // } - - pub fn write(&mut self, message: impl Encode) { - message.encode(&mut self.wbuf); - } - - pub async fn flush(&mut self) -> Result<(), Error> { - self.stream.flush().await?; - self.stream.clear().clear(); - - Ok(()) - } - - pub async fn quit(&mut self) -> Result<(), Error> { - self.write(ComQuit()).await?; - - Ok(()) - } - - pub async fn ping(&mut self) -> Result<(), Error> { - self.write(ComPing()).await?; - - // Ping response must be an OkPacket - OkPacket::decode(&mut DeContext::new( - &mut self.context, - self.stream.next_packet().await?, - ))?; - - Ok(()) - } - - pub async fn ping(&mut self) -> Result<(), Error> { - // Send the ping command and wait for (and drop) an OK packet - // SEND ================ - self.last_seq_no = None; - self.write(ComPing); - self.stream.flush().await?; - // ===================== - - let _ = decode_ok_or_err(self.receive().await?)?; - - Ok(()) - } - - pub async fn prepare(&mut self, query: &str) -> Result { - self.write(ComStmtPrepare { - statement: Bytes::from(query), - }) - .await?; - - let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream); - ctx.next_packet().await?; - Ok(ComStmtPrepareResp::deserialize(ctx).await?) - } - - pub async fn next_packet(&mut self) -> Result { - let mut packet_headers: Vec = Vec::new(); - - loop { - println!("BUF: {:?}: ", self.rbuf); - // If we don't have a packet header or the last packet header had a length of - // 0xFF_FF_FF (the max possible length); then we must continue receiving packets - // because the entire message hasn't been received. - // After this operation we know that packet_headers.last() *SHOULD* always return valid data, - // so the the use of packet_headers.last().unwrap() is allowed. - // TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions. - if let Some(packet_header) = packet_headers.last() { - if packet_header.length as usize == encode::U24_MAX { - packet_headers.push(PacketHeader::try_from(&self.rbuf[self.read_index..])?); - } - } else if self.rbuf.len() > 4 { - match PacketHeader::try_from(&self.rbuf[0..]) { - Ok(v) => packet_headers.push(v), - Err(_) => {} - } - } - - if let Some(packet_header) = packet_headers.last() { - if packet_header.combined_length() > self.rbuf.len() { - unsafe { - self.rbuf - .reserve(packet_header.combined_length() - self.rbuf.len()); - } - } - } else if self.rbuf.len() == self.read_index { - unsafe { - self.rbuf.reserve(32); - } - } - unsafe { - self.rbuf.set_len(self.rbuf.capacity()); - } - - // If we have a packet_header and the amount of currently read bytes (len) is less than - // the specified length inside packet_header, then we can continue reading to self.rbuf. - // Else if the total number of bytes read is equal to packet_header then we will - // return self.rbuf from 0 to self.read_index as it should contain the entire packet. - let bytes_read; - - if let Some(packet_header) = packet_headers.last() { - if packet_header.combined_length() > self.read_index { - bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?; - } else { - // Get the packet from the rbuffer, reset read_index, and return packet - let packet = self.rbuf.split_to(packet_header.combined_length()).freeze(); - self.read_index -= packet.len(); - return Ok(packet); - } - } else { - bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?; - } - - if bytes_read > 0 { - self.read_index += bytes_read; - // If we have read less than 4 bytes, and we don't already have a packet_header - // we must try to read again. The packet_header is always present and is 4 bytes long. - if bytes_read < 4 && packet_headers.len() == 0 { - continue; - } - } else { - // Read 0 bytes from the server; end-of-stream - panic!("Cannot read 0 bytes from stream"); - } - } - } -} - -// impl RawConnection for MariaDbRawConnection { -// type Backend = MariaDb; - -// #[inline] -// fn establish(url: &str) -> BoxFuture> { -// Box::pin(MariaDbRawConnection::establish(url)) -// } - -// #[inline] -// fn finalize<'c>(&'c mut self) -> BoxFuture<'c, std::io::Result<()>> { -// Box::pin(self.finalize()) -// } - -// fn execute<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxFuture<'c, std::io::Result<()>> -// where -// Q: RawQuery<'q, Backend = Self::Backend>, -// { -// query.finish(self); -// Box::pin(execute::execute(self)) -// } -// } diff --git a/src/mariadb/connection.rs b/src/mariadb/connection.rs index 0deb4d61..f372c661 100644 --- a/src/mariadb/connection.rs +++ b/src/mariadb/connection.rs @@ -13,6 +13,7 @@ use crate::{ }, Backend, Error, Result, }; +use async_trait::async_trait; use byteorder::{ByteOrder, LittleEndian}; use futures_core::{future::BoxFuture, stream::BoxStream}; use std::{ @@ -56,7 +57,7 @@ impl MariaDbRawConnection { Ok(conn) } - pub async fn close(&mut self) -> Result<()> { + pub async fn close(mut self) -> Result<()> { // Send the quit command self.start_sequence(); @@ -297,55 +298,50 @@ enum ExecResult { Rows(Vec), } +#[async_trait] impl RawConnection for MariaDbRawConnection { type Backend = MariaDb; - fn establish(url: &str) -> BoxFuture> + async fn establish(url: &str) -> crate::Result where Self: Sized, { - Box::pin(MariaDbRawConnection::establish(url)) + MariaDbRawConnection::establish(url).await } - fn close(&mut self) -> BoxFuture<'_, Result<()>> { - Box::pin(self.close()) + async fn close(mut self) -> crate::Result<()> { + self.close().await } - fn ping(&mut self) -> BoxFuture<'_, Result<()>> { - Box::pin(self.ping()) + async fn ping(&mut self) -> crate::Result<()> { + self.ping().await } - fn execute<'c>( - &'c mut self, - query: &str, - params: MariaDbQueryParameters, - ) -> BoxFuture<'c, Result> { + async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result { // Write prepare statement to buffer self.start_sequence(); self.write(ComStmtPrepare { statement: query }); - Box::pin(async move { - let statement_id = self.exec_prepare().await?; + let statement_id = self.exec_prepare().await?; - let affected = self.execute(statement_id, params).await?; + let affected = self.execute(statement_id, params).await?; - Ok(affected) - }) + Ok(affected) } - fn fetch<'c>( - &'c mut self, + fn fetch( + &mut self, query: &str, params: MariaDbQueryParameters, - ) -> BoxStream<'c, Result> { + ) -> BoxStream<'_, Result> { unimplemented!(); } - fn fetch_optional<'c>( - &'c mut self, + async fn fetch_optional( + &mut self, query: &str, params: MariaDbQueryParameters, - ) -> BoxFuture<'c, Result::Row>>> { + ) -> crate::Result::Row>> { unimplemented!(); } }