Adjust mariadb for async-trait usage

This commit is contained in:
Ryan Leckey 2019-09-10 22:51:23 -07:00
parent 47b06edad1
commit fb877fee28
4 changed files with 19 additions and 446 deletions

View File

@ -1,147 +0,0 @@
use super::MariaDbRawConnection;
use crate::mariadb::protocol::{
Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket, HandshakeResponsePacket,
InitialHandshakePacket, OkPacket, ProtocolType, StmtExecFlag,
};
use bytes::Bytes;
use failure::{err_msg, Error};
use std::ops::BitAnd;
use url::Url;
pub async fn establish(conn: &mut MariaDbRawConnection, url: Url) -> Result<(), Error> {
let buf = conn.stream.next_packet().await?;
let mut de_ctx = DeContext::new(&mut conn.context, buf);
let initial = InitialHandshakePacket::decode(&mut de_ctx)?;
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
// Minimum client capabilities required to establish connection.bak
capabilities: de_ctx.ctx.capabilities,
max_packet_size: 1024,
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
username: url.username(),
..Default::default()
};
conn.send(handshake).await?;
let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?);
match ctx.decoder.peek_tag() {
0xFF => {
return Err(ErrPacket::decode(&mut ctx)?.into());
}
0x00 => {
OkPacket::decode(&mut ctx)?;
}
_ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"),
}
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
use crate::mariadb::{ComStmtFetch, ComStmtPrepareResp, FieldType, ResultSet};
use failure::Error;
#[tokio::test]
async fn it_can_connect() -> Result<(), Error> {
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
Ok(())
}
#[tokio::test]
async fn it_can_ping() -> Result<(), Error> {
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
conn.ping().await?;
Ok(())
}
#[tokio::test]
async fn it_can_select_db() -> Result<(), Error> {
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
conn.select_db("test").await?;
Ok(())
}
#[tokio::test]
async fn it_can_query() -> Result<(), Error> {
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
conn.select_db("test").await?;
conn.query("SELECT * FROM users").await?;
Ok(())
}
#[tokio::test]
async fn it_can_prepare() -> Result<(), Error> {
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
conn.select_db("test").await?;
conn.prepare("SELECT * FROM users WHERE username = ?")
.await?;
Ok(())
}
#[tokio::test]
async fn it_can_execute_prepared() -> Result<(), Error> {
let mut conn = MariaDbRawConnection::establish(&"mariadb://root@127.0.0.1:3306").await?;
conn.select_db("test").await?;
let mut prepared = conn
.prepare("SELECT id FROM users WHERE username=?")
.await?;
let exec = ComStmtExec {
stmt_id: prepared.ok.stmt_id,
flags: StmtExecFlag::NO_CURSOR,
params: Some(vec![Some(Bytes::from_static(b"josh"))]),
param_defs: prepared.param_defs,
};
conn.send(exec).await?;
let mut ctx = DeContext::with_stream(&mut conn.context, &mut conn.stream);
ctx.next_packet().await?;
ctx.columns = Some(prepared.ok.columns as u64);
ctx.column_defs = prepared.res_columns;
println!("{:?}", ctx.columns);
println!("{:?}", ctx.column_defs);
match ctx.decoder.peek_tag() {
0xFF => {
ErrPacket::decode(&mut ctx)?;
}
0x00 => {
OkPacket::decode(&mut ctx)?;
}
_ => {
ResultSet::deserialize(ctx, ProtocolType::Binary).await?;
}
}
Ok(())
}
#[tokio::test]
async fn it_does_not_connect() -> Result<(), Error> {
match MariaDbRawConnection::establish(&"mariadb//roote@127.0.0.1:3306").await {
Ok(_) => Err(err_msg("Bad username still worked?")),
Err(_) => Ok(()),
}
}
}

View File

@ -1,9 +0,0 @@
use crate::mariadb::MariaDbRawConnection;
use std::io;
pub async fn execute(conn: &mut MariaDbRawConnection) -> io::Result<u64> {
conn.flush().await?;
let mut rows: u64 = 0;
while let Some(message) = conn.receive().await? {}
}

View File

@ -1,267 +0,0 @@
use crate::{
error::ErrorKind,
mariadb::protocol::{
encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare,
ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader,
ProtocolType, ResultSet, ServerStatusFlag,
},
};
use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, Bytes, BytesMut};
use core::convert::TryFrom;
use failure::Error;
use futures_core::future::BoxFuture;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tokio::{
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};
use url::Url;
mod establish;
mod execute;
pub struct MariaDbRawConnection {
pub stream: TcpStream,
// Buffer used when serializing outgoing messages
pub wbuf: Vec<u8>,
pub rbuf: BytesMut,
pub read_index: usize,
// Context for the connection.bak
// Explicitly declared to easily send to deserializers
pub context: ConnContext,
}
#[derive(Debug)]
pub struct ConnContext {
// MariaDB Connection ID
pub connection_id: i32,
// Sequence Number
pub seq_no: u8,
// Last sequence number return by MariaDB
pub last_seq_no: u8,
// Server Capabilities
pub capabilities: Capabilities,
// Server status
pub status: ServerStatusFlag,
}
impl ConnContext {
#[cfg(test)]
pub fn new() -> Self {
ConnContext {
connection_id: 0,
seq_no: 2,
last_seq_no: 0,
capabilities: Capabilities::CLIENT_PROTOCOL_41,
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS,
}
}
#[cfg(test)]
pub fn with_eof() -> Self {
ConnContext {
connection_id: 0,
seq_no: 2,
last_seq_no: 0,
capabilities: Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CLIENT_DEPRECATE_EOF,
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS,
}
}
}
impl MariaDbRawConnection {
pub async fn establish(url: &str) -> Result<Self, Error> {
// TODO: Handle errors
let url = Url::parse(url).map_err(ErrorKind::UrlParse)?;
println!("{:?}", url);
let host = url.host_str().unwrap_or("localhost");
let port = url.port().unwrap_or(3306);
// FIXME: handle errors
let host: IpAddr = host.parse().unwrap();
let addr: SocketAddr = (host, port).into();
let stream = TcpStream::connect(&addr).await?;
let mut conn: MariaDbRawConnection = Self {
stream,
wbuf: Vec::with_capacity(1024),
rbuf: BytesMut::with_capacity(8 * 1024),
read_index: 0,
context: ConnContext {
connection_id: -1,
seq_no: 1,
last_seq_no: 0,
capabilities: Capabilities::CLIENT_PROTOCOL_41,
status: ServerStatusFlag::default(),
},
};
establish::establish(&mut conn, url).await?;
Ok(conn)
}
// pub async fn send<S>(&mut self, message: S) -> Result<(), Error>
// where
// S: Encode,
// {
// self.wbuf.clear();
// message.encode(&mut self.wbuf, &mut self.context)?;
// self.stream.inner.write_all(&self.wbuf).await?;
// Ok(())
// }
pub fn write(&mut self, message: impl Encode) {
message.encode(&mut self.wbuf);
}
pub async fn flush(&mut self) -> Result<(), Error> {
self.stream.flush().await?;
self.stream.clear().clear();
Ok(())
}
pub async fn quit(&mut self) -> Result<(), Error> {
self.write(ComQuit()).await?;
Ok(())
}
pub async fn ping(&mut self) -> Result<(), Error> {
self.write(ComPing()).await?;
// Ping response must be an OkPacket
OkPacket::decode(&mut DeContext::new(
&mut self.context,
self.stream.next_packet().await?,
))?;
Ok(())
}
pub async fn ping(&mut self) -> Result<(), Error> {
// Send the ping command and wait for (and drop) an OK packet
// SEND ================
self.last_seq_no = None;
self.write(ComPing);
self.stream.flush().await?;
// =====================
let _ = decode_ok_or_err(self.receive().await?)?;
Ok(())
}
pub async fn prepare(&mut self, query: &str) -> Result<ComStmtPrepareResp, Error> {
self.write(ComStmtPrepare {
statement: Bytes::from(query),
})
.await?;
let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
ctx.next_packet().await?;
Ok(ComStmtPrepareResp::deserialize(ctx).await?)
}
pub async fn next_packet(&mut self) -> Result<Bytes, Error> {
let mut packet_headers: Vec<PacketHeader> = Vec::new();
loop {
println!("BUF: {:?}: ", self.rbuf);
// If we don't have a packet header or the last packet header had a length of
// 0xFF_FF_FF (the max possible length); then we must continue receiving packets
// because the entire message hasn't been received.
// After this operation we know that packet_headers.last() *SHOULD* always return valid data,
// so the the use of packet_headers.last().unwrap() is allowed.
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
if let Some(packet_header) = packet_headers.last() {
if packet_header.length as usize == encode::U24_MAX {
packet_headers.push(PacketHeader::try_from(&self.rbuf[self.read_index..])?);
}
} else if self.rbuf.len() > 4 {
match PacketHeader::try_from(&self.rbuf[0..]) {
Ok(v) => packet_headers.push(v),
Err(_) => {}
}
}
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > self.rbuf.len() {
unsafe {
self.rbuf
.reserve(packet_header.combined_length() - self.rbuf.len());
}
}
} else if self.rbuf.len() == self.read_index {
unsafe {
self.rbuf.reserve(32);
}
}
unsafe {
self.rbuf.set_len(self.rbuf.capacity());
}
// If we have a packet_header and the amount of currently read bytes (len) is less than
// the specified length inside packet_header, then we can continue reading to self.rbuf.
// Else if the total number of bytes read is equal to packet_header then we will
// return self.rbuf from 0 to self.read_index as it should contain the entire packet.
let bytes_read;
if let Some(packet_header) = packet_headers.last() {
if packet_header.combined_length() > self.read_index {
bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?;
} else {
// Get the packet from the rbuffer, reset read_index, and return packet
let packet = self.rbuf.split_to(packet_header.combined_length()).freeze();
self.read_index -= packet.len();
return Ok(packet);
}
} else {
bytes_read = self.stream.read(&mut self.rbuf[self.read_index..]).await?;
}
if bytes_read > 0 {
self.read_index += bytes_read;
// If we have read less than 4 bytes, and we don't already have a packet_header
// we must try to read again. The packet_header is always present and is 4 bytes long.
if bytes_read < 4 && packet_headers.len() == 0 {
continue;
}
} else {
// Read 0 bytes from the server; end-of-stream
panic!("Cannot read 0 bytes from stream");
}
}
}
}
// impl RawConnection for MariaDbRawConnection {
// type Backend = MariaDb;
// #[inline]
// fn establish(url: &str) -> BoxFuture<std::io::Result<Self>> {
// Box::pin(MariaDbRawConnection::establish(url))
// }
// #[inline]
// fn finalize<'c>(&'c mut self) -> BoxFuture<'c, std::io::Result<()>> {
// Box::pin(self.finalize())
// }
// fn execute<'c, 'q, Q: 'q>(&'c mut self, query: Q) -> BoxFuture<'c, std::io::Result<()>>
// where
// Q: RawQuery<'q, Backend = Self::Backend>,
// {
// query.finish(self);
// Box::pin(execute::execute(self))
// }
// }

View File

@ -13,6 +13,7 @@ use crate::{
},
Backend, Error, Result,
};
use async_trait::async_trait;
use byteorder::{ByteOrder, LittleEndian};
use futures_core::{future::BoxFuture, stream::BoxStream};
use std::{
@ -56,7 +57,7 @@ impl MariaDbRawConnection {
Ok(conn)
}
pub async fn close(&mut self) -> Result<()> {
pub async fn close(mut self) -> Result<()> {
// Send the quit command
self.start_sequence();
@ -297,55 +298,50 @@ enum ExecResult {
Rows(Vec<ColumnDefinitionPacket>),
}
#[async_trait]
impl RawConnection for MariaDbRawConnection {
type Backend = MariaDb;
fn establish(url: &str) -> BoxFuture<Result<Self>>
async fn establish(url: &str) -> crate::Result<Self>
where
Self: Sized,
{
Box::pin(MariaDbRawConnection::establish(url))
MariaDbRawConnection::establish(url).await
}
fn close(&mut self) -> BoxFuture<'_, Result<()>> {
Box::pin(self.close())
async fn close(mut self) -> crate::Result<()> {
self.close().await
}
fn ping(&mut self) -> BoxFuture<'_, Result<()>> {
Box::pin(self.ping())
async fn ping(&mut self) -> crate::Result<()> {
self.ping().await
}
fn execute<'c>(
&'c mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxFuture<'c, Result<u64>> {
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
// Write prepare statement to buffer
self.start_sequence();
self.write(ComStmtPrepare { statement: query });
Box::pin(async move {
let statement_id = self.exec_prepare().await?;
let statement_id = self.exec_prepare().await?;
let affected = self.execute(statement_id, params).await?;
let affected = self.execute(statement_id, params).await?;
Ok(affected)
})
Ok(affected)
}
fn fetch<'c>(
&'c mut self,
fn fetch(
&mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxStream<'c, Result<MariaDbRow>> {
) -> BoxStream<'_, Result<MariaDbRow>> {
unimplemented!();
}
fn fetch_optional<'c>(
&'c mut self,
async fn fetch_optional(
&mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxFuture<'c, Result<Option<<Self::Backend as Backend>::Row>>> {
) -> crate::Result<Option<<Self::Backend as Backend>::Row>> {
unimplemented!();
}
}