From 3f83fcb24de20bc40cc7014d7559c5e2ddb5f951 Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Tue, 20 Aug 2019 22:02:41 -0700 Subject: [PATCH] Use tokio --- src/lib.rs | 4 +- src/mariadb/connection/establish.rs | 66 ++++++----------------------- src/mariadb/connection/mod.rs | 53 +++++++++++++---------- 3 files changed, 45 insertions(+), 78 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index caacafaf..de25f8fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,8 +5,8 @@ // #[macro_use] // extern crate bitflags; -// #[macro_use] -// extern crate enum_tryfrom_derive; +#[macro_use] +extern crate bitflags; #[cfg(test)] extern crate test; diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs index c793586e..f52d5571 100644 --- a/src/mariadb/connection/establish.rs +++ b/src/mariadb/connection/establish.rs @@ -9,10 +9,11 @@ use crate::{ use bytes::Bytes; use failure::{err_msg, Error}; use std::ops::BitAnd; +use url::Url; -pub async fn establish<'a, 'b: 'a>( - conn: &'a mut Connection, - options: ConnectOptions<'b>, +pub async fn establish( + conn: &mut Connection, + url: Url ) -> Result<(), Error> { let buf = conn.stream.next_packet().await?; let mut de_ctx = DeContext::new(&mut conn.context, buf); @@ -25,7 +26,7 @@ pub async fn establish<'a, 'b: 'a>( capabilities: de_ctx.ctx.capabilities, max_packet_size: 1024, extended_capabilities: Some(Capabilities::from_bits_truncate(0)), - username: options.user.unwrap_or(""), + username: url.username(), ..Default::default() }; @@ -54,13 +55,7 @@ mod test { #[runtime::test] async fn it_can_connect() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) + let mut conn = Connection::establish(&"mariadb://root@localhost:3306") .await?; Ok(()) @@ -68,13 +63,8 @@ mod test { #[runtime::test] async fn it_can_ping() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) + let mut conn = Connection::establish(&"mariadb://root@localhost:3306") + .await?; conn.ping().await?; @@ -84,13 +74,7 @@ mod test { #[runtime::test] async fn it_can_select_db() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) + let mut conn = Connection::establish(&"mariadb://root@localhost:3306") .await?; conn.select_db("test").await?; @@ -100,13 +84,7 @@ mod test { #[runtime::test] async fn it_can_query() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) + let mut conn = Connection::establish(&"mariadb://root@localhost:3306") .await?; conn.select_db("test").await?; @@ -118,13 +96,7 @@ mod test { #[runtime::test] async fn it_can_prepare() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) + let mut conn = Connection::establish(&"mariadb://root@localhost:3306") .await?; conn.select_db("test").await?; @@ -137,13 +109,7 @@ mod test { #[runtime::test] async fn it_can_execute_prepared() -> Result<(), Error> { - let mut conn = Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("root"), - database: None, - password: None, - }) + let mut conn = Connection::establish(&"mariadb://root@localhost:3306") .await?; conn.select_db("test").await?; @@ -186,13 +152,7 @@ mod test { #[runtime::test] async fn it_does_not_connect() -> Result<(), Error> { - match Connection::establish(ConnectOptions { - host: "127.0.0.1", - port: 3306, - user: Some("roote"), - database: None, - password: None, - }) + match Connection::establish(&"mariadb//roote@localhost:3306") .await { Ok(_) => Err(err_msg("Bad username still worked?")), diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index 4a39f255..95c9f22b 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -11,10 +11,16 @@ use bytes::{Bytes, BytesMut}; use core::convert::TryFrom; use failure::Error; use futures::{ - io::{AsyncRead, AsyncWriteExt}, + io::{AsyncRead}, prelude::*, }; -use runtime::net::TcpStream; +use tokio::{ + io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, +}; +use std::net::{SocketAddr, IpAddr, Ipv4Addr}; +use url::Url; +use bytes::BufMut; mod establish; @@ -72,8 +78,17 @@ impl ConnContext { } impl Connection { - pub async fn establish(options: ConnectOptions<'static>) -> Result { - let stream: Framed = Framed::new(TcpStream::connect((options.host, options.port)).await?); + pub async fn establish(url: &str) -> Result { + // TODO: Handle errors + let url = Url::parse(url).unwrap(); + + let host = url.host_str().unwrap_or("localhost"); + let port = url.port().unwrap_or(3306); + + // FIXME: handle errors + let host: IpAddr = host.parse().unwrap(); + let addr: SocketAddr = (host, port).into(); + let stream: Framed = Framed::new(TcpStream::connect(&addr).await?); let mut conn: Connection = Self { stream, wbuf: Vec::with_capacity(1024), @@ -86,7 +101,7 @@ impl Connection { }, }; - establish::establish(&mut conn, options).await?; + establish::establish(&mut conn, url).await?; Ok(conn) } @@ -100,7 +115,6 @@ impl Connection { message.encode(&mut self.wbuf, &mut self.context)?; self.stream.inner.write_all(&self.wbuf).await?; - self.stream.inner.flush().await?; Ok(()) } @@ -193,6 +207,14 @@ impl Framed { } } + unsafe fn reserve(&mut self, size: usize) { + self.buf.reserve(size); + + unsafe { self.buf.set_len(self.buf.capacity()); } + + unsafe { self.buf.advance_mut(size) } + } + pub async fn next_packet(&mut self) -> Result { let mut packet_headers: Vec = Vec::new(); @@ -217,25 +239,10 @@ impl Framed { if let Some(packet_header) = packet_headers.last() { if packet_header.combined_length() > self.buf.len() { - self.buf - .reserve(packet_header.combined_length() - self.buf.len()); - - unsafe { - self.buf.set_len(self.buf.capacity()); - self.inner - .initializer() - .initialize(&mut self.buf[self.index..]); - } + unsafe { self.reserve(packet_header.combined_length() - self.buf.len()); } } } else if self.buf.len() == self.index { - self.buf.reserve(32); - - unsafe { - self.buf.set_len(self.buf.capacity()); - self.inner - .initializer() - .initialize(&mut self.buf[self.index..]); - } + unsafe { self.reserve(32); } } // If we have a packet_header and the amount of currently read bytes (len) is less than