This commit is contained in:
Daniel Akhterov 2019-08-22 19:58:47 -07:00 committed by Ryan Leckey
parent 91fc27a5b8
commit 7077790452
30 changed files with 211 additions and 83 deletions

14
src/mariadb/backend.rs Normal file
View File

@ -0,0 +1,14 @@
use crate::backend::{Backend, BackendAssocRawQuery};
pub struct MariaDB;
impl<'q> BackendAssocRawQuery<'q, MariaDB> for MariaDB {
type RawQuery = super::MariaDbRawQuery<'q>;
}
impl Backend for MariaDB {
type RawConnection = super::MariaDbRawConnection;
type Row = super::MariaDbRow;
}
impl_from_sql_row_tuples_for_backend!(MariaDb);

View File

@ -1,4 +1,4 @@
use super::Connection;
use super::MariaDbRawConnection;
use crate::{
mariadb::{
Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket,
@ -11,7 +11,7 @@ use std::ops::BitAnd;
use url::Url;
pub async fn establish(
conn: &mut Connection,
conn: &mut MariaDbRawConnection,
url: Url
) -> Result<(), Error> {
let buf = conn.stream.next_packet().await?;
@ -54,7 +54,7 @@ mod test {
#[tokio::test]
async fn it_can_connect() -> Result<(), Error> {
let mut conn = Connection::establish(&"mariadb://root@127.0.0.1:3306")
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306")
.await?;
Ok(())
@ -62,7 +62,7 @@ mod test {
#[tokio::test]
async fn it_can_ping() -> Result<(), Error> {
let mut conn = Connection::establish(&"mariadb://root@127.0.0.1:3306")
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306")
.await?;
@ -73,7 +73,7 @@ mod test {
#[tokio::test]
async fn it_can_select_db() -> Result<(), Error> {
let mut conn = Connection::establish(&"mariadb://root@127.0.0.1:3306")
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306")
.await?;
conn.select_db("test").await?;
@ -83,7 +83,7 @@ mod test {
#[tokio::test]
async fn it_can_query() -> Result<(), Error> {
let mut conn = Connection::establish(&"mariadb://root@127.0.0.1:3306")
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306")
.await?;
conn.select_db("test").await?;
@ -95,7 +95,7 @@ mod test {
#[tokio::test]
async fn it_can_prepare() -> Result<(), Error> {
let mut conn = Connection::establish(&"mariadb://root@127.0.0.1:3306")
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306")
.await?;
conn.select_db("test").await?;
@ -108,7 +108,7 @@ mod test {
#[tokio::test]
async fn it_can_execute_prepared() -> Result<(), Error> {
let mut conn = Connection::establish(&"mariadb://root@127.0.0.1:3306")
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306")
.await?;
conn.select_db("test").await?;
@ -151,7 +151,7 @@ mod test {
#[tokio::test]
async fn it_does_not_connect() -> Result<(), Error> {
match Connection::establish(&"mariadb//roote@127.0.0.1:3306")
match MariaDbRawConnection::establish(&"mariadb//roote@127.0.0.1:3306")
.await
{
Ok(_) => Err(err_msg("Bad username still worked?")),

View File

@ -0,0 +1,11 @@
use crate::mariadb::MariaDbRawConnection;
use std::io;
pub async fn execute(conn: &mut MariaDbRawConnection) -> io::Result<u64> {
conn.flush().await?;
let mut rows: u64 = 0;
while let Some(message) = conn.receive().await? {
}
}

View File

@ -17,15 +17,22 @@ use std::net::{SocketAddr, IpAddr, Ipv4Addr};
use url::Url;
use bytes::BufMut;
use crate::error::ErrorKind;
use crate::connection::RawConnection;
use futures_core::future::BoxFuture;
use crate::query::RawQuery;
mod establish;
mod execute;
pub struct Connection {
pub stream: Framed,
pub struct MariaDbRawConnection {
pub stream: TcpStream,
// Buffer used when serializing outgoing messages
pub wbuf: Vec<u8>,
pub rbuf: BytesMut,
pub read_index: usize,
// Context for the connection
// Explicitly declared to easily send to deserializers
pub context: ConnContext,
@ -73,7 +80,7 @@ impl ConnContext {
}
}
impl Connection {
impl MariaDbRawConnection {
pub async fn establish(url: &str) -> Result<Self, Error> {
// TODO: Handle errors
let url = Url::parse(url).map_err(ErrorKind::UrlParse)?;
@ -85,10 +92,12 @@ impl Connection {
// FIXME: handle errors
let host: IpAddr = host.parse().unwrap();
let addr: SocketAddr = (host, port).into();
let stream: Framed = Framed::new(TcpStream::connect(&addr).await?);
let mut conn: Connection = Self {
let stream = TcpStream::connect(&addr).await?;
let mut conn: MariaDbRawConnection = Self {
stream,
wbuf: Vec::with_capacity(1024),
rbuf: BytesMut::with_capacity(8 * 1024),
read_index: 0,
context: ConnContext {
connection_id: -1,
seq_no: 1,
@ -103,21 +112,29 @@ impl Connection {
Ok(conn)
}
pub async fn send<S>(&mut self, message: S) -> Result<(), Error>
where
S: Encode,
{
self.wbuf.clear();
// pub async fn send<S>(&mut self, message: S) -> Result<(), Error>
// where
// S: Encode,
// {
// self.wbuf.clear();
// message.encode(&mut self.wbuf, &mut self.context)?;
// self.stream.inner.write_all(&self.wbuf).await?;
// Ok(())
// }
message.encode(&mut self.wbuf, &mut self.context)?;
pub fn write(&mut self, message: impl Encode) {
message.encode(&mut self.wbuf);
}
self.stream.inner.write_all(&self.wbuf).await?;
pub async fn flush(&mut self) -> Result<(), Error> {
self.stream.flush().await?;
self.stream.clear().clear();
Ok(())
}
pub async fn quit(&mut self) -> Result<(), Error> {
self.send(ComQuit()).await?;
self.write(ComQuit()).await?;
Ok(())
}
@ -126,7 +143,7 @@ impl Connection {
&'a mut self,
sql_statement: &'a str,
) -> Result<Option<ResultSet>, Error> {
self.send(ComQuery {
self.write(ComQuery {
sql_statement: bytes::Bytes::from(sql_statement),
})
.await?;
@ -146,7 +163,7 @@ impl Connection {
}
pub async fn select_db<'a>(&'a mut self, db: &'a str) -> Result<(), Error> {
self.send(ComInitDb {
self.write(ComInitDb {
schema_name: bytes::Bytes::from(db),
})
.await?;
@ -166,7 +183,7 @@ impl Connection {
}
pub async fn ping(&mut self) -> Result<(), Error> {
self.send(ComPing()).await?;
self.write(ComPing()).await?;
// Ping response must be an OkPacket
OkPacket::decode(&mut DeContext::new(
@ -178,7 +195,7 @@ impl Connection {
}
pub async fn prepare(&mut self, query: &str) -> Result<ComStmtPrepareResp, Error> {
self.send(ComStmtPrepare {
self.write(ComStmtPrepare {
statement: Bytes::from(query),
})
.await?;
@ -187,28 +204,12 @@ impl Connection {
ctx.next_packet().await?;
Ok(ComStmtPrepareResp::deserialize(ctx).await?)
}
}
pub struct Framed {
inner: TcpStream,
buf: BytesMut,
index: usize,
}
impl Framed {
fn new(stream: TcpStream) -> Self {
Self {
inner: stream,
buf: BytesMut::with_capacity(8 * 1024),
index: 0,
}
}
pub async fn next_packet(&mut self) -> Result<Bytes, Error> {
let mut packet_headers: Vec<PacketHeader> = Vec::new();
loop {
println!("BUF: {:?}: ", self.buf);
println!("BUF: {:?}: ", self.rbuf);
// If we don't have a packet header or the last packet header had a length of
// 0xFF_FF_FF (the max possible length); then we must continue receiving packets
// because the entire message hasn't been received.
@ -217,45 +218,45 @@ impl Framed {
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
if let Some(packet_header) = packet_headers.last() {
if packet_header.length as usize == encode::U24_MAX {
packet_headers.push(PacketHeader::try_from(&self.buf[self.index..])?);
packet_headers.push(PacketHeader::try_from(&self.rbuf[self.read_index..])?);
}
} else if self.buf.len() > 4 {
match PacketHeader::try_from(&self.buf[0..]) {
} else if self.rbuf.len() > 4 {
match PacketHeader::try_from(&self.rbuf[0..]) {
Ok(v) => packet_headers.push(v),
Err(_) => {}
}
}
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > self.buf.len() {
unsafe { self.buf.reserve(packet_header.combined_length() - self.buf.len()); }
if packet_header.combined_length() > self.rbuf.len() {
unsafe { self.rbuf.reserve(packet_header.combined_length() - self.rbuf.len()); }
}
} else if self.buf.len() == self.index {
unsafe { self.buf.reserve(32); }
} else if self.rbuf.len() == self.read_index {
unsafe { self.rbuf.reserve(32); }
}
unsafe { self.buf.set_len(self.buf.capacity()); }
unsafe { self.rbuf.set_len(self.rbuf.capacity()); }
// If we have a packet_header and the amount of currently read bytes (len) is less than
// the specified length inside packet_header, then we can continue reading to self.buf.
// the specified length inside packet_header, then we can continue reading to self.rbuf.
// Else if the total number of bytes read is equal to packet_header then we will
// return self.buf from 0 to self.index as it should contain the entire packet.
// return self.rbuf from 0 to self.read_index as it should contain the entire packet.
let bytes_read;
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > self.index {
bytes_read = self.inner.read(&mut self.buf[self.index..]).await?;
if packet_header.combined_length() > self.read_index {
bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?;
} else {
// Get the packet from the buffer, reset index, and return packet
let packet = self.buf.split_to(packet_header.combined_length()).freeze();
self.index -= packet.len();
// Get the packet from the rbuffer, reset read_index, and return packet
let packet = self.rbuf.split_to(packet_header.combined_length()).freeze();
self.read_index -= packet.len();
return Ok(packet);
}
} else {
bytes_read = self.inner.read(&mut self.buf[self.index..]).await?;
bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?;
}
if bytes_read > 0 {
self.index += bytes_read;
self.read_index += bytes_read;
// If we have read less than 4 bytes, and we don't already have a packet_header
// we must try to read again. The packet_header is always present and is 4 bytes long.
if bytes_read < 4 && packet_headers.len() == 0 {
@ -268,3 +269,25 @@ impl Framed {
}
}
}
impl RawConnection for MariaDbRawConnection {
type Backend = MariaDb;
#[inline]
fn establish(url: &str) -> BoxFuture<std::io::Result<Self>> {
Box::pin(MariaDbRawConnection::establish(url))
}
#[inline]
fn finalize<'c>(&'c mut self) -> BoxFuture<'c, std::io::Result<()>> {
Box::pin(self.finalize())
}
fn execute<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxFuture<'c, std::io::Result<()>>
where
Q: RawQuery<'q, Backend = Self::Backend>,
{
query.finish(self);
Box::pin(execute::execute(self))
}
}

View File

@ -1,8 +1,11 @@
pub mod connection;
pub mod protocol;
pub mod types;
pub mod backend;
pub mod query;
// Re-export all the things
pub use connection::{ConnContext, Connection, Framed};
pub use connection::{ConnContext, MariaDbRawConnection, Framed};
pub use protocol::{
AuthenticationSwitchRequestPacket, BufMut, Capabilities, ColumnDefPacket, ColumnPacket,
ComDebug, ComInitDb, ComPing, ComProcessKill, ComQuery, ComQuit, ComResetConnection,
@ -13,3 +16,5 @@ pub use protocol::{
ResultRowText, ResultSet, SSLRequestPacket, ServerStatusFlag, SessionChangeType,
SetOptionOptions, ShutdownOptions, StmtExecFlag,
};
pub use backend::MariaDB;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{ColumnDefPacket, ConnContext, Connection, Framed, PacketHeader};
use crate::mariadb::{ColumnDefPacket, ConnContext, MariaDbRawConnection, Framed, PacketHeader};
use byteorder::{ByteOrder, LittleEndian};
use bytes::Bytes;
use failure::{err_msg, Error};

View File

@ -1,4 +1,4 @@
use crate::mariadb::{ConnContext, Connection, FieldType};
use crate::mariadb::{ConnContext, MariaDbRawConnection, FieldType};
use byteorder::{ByteOrder, LittleEndian};
use bytes::Bytes;
use failure::Error;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use bytes::Bytes;
use failure::Error;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
use std::convert::TryInto;

View File

@ -1,5 +1,5 @@
use crate::mariadb::{
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, FieldType,
BufMut, ColumnDefPacket, ConnContext, MariaDbRawConnection, Encode, FieldDetailFlag, FieldType,
StmtExecFlag,
};
use bytes::Bytes;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
#[derive(Debug)]

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use bytes::Bytes;
use failure::Error;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
#[derive(Debug)]

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, Capabilities, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, Capabilities, ConnContext, MariaDbRawConnection, Encode};
use bytes::Bytes;
use failure::Error;

View File

@ -119,7 +119,7 @@ mod test {
use crate::{
__bytes_builder,
mariadb::{
Capabilities, ConnContext, Connection, EofPacket, ErrPacket, OkPacket, ResultRow,
Capabilities, ConnContext, MariaDbRawConnection, EofPacket, ErrPacket, OkPacket, ResultRow,
ServerStatusFlag,
},
};

View File

@ -1,7 +1,7 @@
use bytes::Bytes;
use failure::Error;
use crate::mariadb::{BufMut, Capabilities, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, Capabilities, ConnContext, MariaDbRawConnection, Encode};
#[derive(Default, Debug)]
pub struct SSLRequestPacket {

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComDebug();

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use bytes::Bytes;
use failure::Error;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComPing();

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComProcessKill {

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use bytes::Bytes;
use failure::Error;

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComQuit();

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComResetConnection();

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
#[derive(Clone, Copy)]

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
#[derive(Clone, Copy)]

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComSleep();

View File

@ -1,4 +1,4 @@
use crate::mariadb::{BufMut, ConnContext, Connection, Encode};
use crate::mariadb::{BufMut, ConnContext, MariaDbRawConnection, Encode};
use failure::Error;
pub struct ComStatistics();

54
src/mariadb/query.rs Normal file
View File

@ -0,0 +1,54 @@
use crate::mariadb::{FieldType, MariaDbRawConnection};
use crate::mariadb::protocol::types::ParamFlag;
use crate::query::RawQuery;
use crate::types::HasSqlType;
use crate::serialize::{ToSql, IsNull};
pub struct MariaDbRawQuery<'q> {
query: &'q str,
types: Vec<u8>,
null_bitmap: Vec<u8>,
flags: Vec<u8>,
buf: Vec<u8>,
index: u64,
}
impl<'q> RawQueryQuery<'q> for MariaDbRawQuery<'q> {
type Backend = MariaDb;
fn new(query: &'q str) -> Self {
Self {
query,
types: Vec::with_capacity(4),
null_bitmap: vec![0, 0, 0, 0],
flags: Vec::with_capacity(4),
buf: Vec::with_capacity(32),
index: 0,
}
}
fn bind<T>(mut self, value: T) -> Self
where
Self: Sized,
Self::Backend: HasSqlType<T>,
T: ToSql<Self::Backend>,
{
self.types.push(<MariaDb as HasSqlType<T>>::metadata().field_type.0);
self.flags.push(<MariaDb as HasSqlType<T>>::metadata().param_flag.0);
match value.to_sql(&mut self.buf) {
IsNull::Yes => {
self.null_bitmap[self.index / 8] =
self.null_bitmap[self.index / 8] & (1 << self.index % 8);
},
IsNull::No => {}
}
self
}
fn finish(self, conn: &mut MariaDbRawConnection) {
conn.prepare(self.query);
}
}

View File

21
src/mariadb/types/mod.rs Normal file
View File

@ -0,0 +1,21 @@
use super::MariaDB;
use crate::types::TypeMetadata;
use crate::mariadb::FieldType;
use crate::mariadb::protocol::types::ParamFlag;
mod boolean;
pub enum MariaDbTypeFormat {
Text = 0,
Binary = 1,
}
pub struct MariaDbTypeMetadata {
pub format: MariaDbTypeFormat,
pub field_type: FieldType,
pub param_flag: ParamFlag,
}
impl TypeMetadata for MariaDb {
type TypeMetadata = MariaDbTypeMetadata;
}