Update mariadb connection to be able to prepare and execute

This commit is contained in:
Daniel Akhterov 2019-09-10 21:48:46 -07:00
parent 188982e12c
commit 351affc617
14 changed files with 363 additions and 41 deletions

View File

@ -1,20 +1,63 @@
use crate::{
io::{Buf, BufMut, BufStream},
mariadb::protocol::{ComPing, Encode},
};
use byteorder::{ByteOrder, LittleEndian};
use std::io;
use tokio::net::TcpStream;
use crate::mariadb::protocol::{OkPacket, ErrPacket, Capabilities};
use crate::mariadb::protocol::{OkPacket, ErrPacket, Capabilities, ComPing, ColumnCountPacket, ColumnDefinitionPacket, EofPacket, ComStmtExecute, StmtExecFlag, ResultRow, ComQuit, ComStmtPrepare, ComStmtPrepareOk};
use url::Url;
use std::net::{IpAddr, SocketAddr};
use std::future::Future;
use crate::error::DatabaseError;
use crate::{connection::RawConnection, io::{Buf, BufMut, BufStream}, mariadb::protocol::Encode, Error, Backend};
use super::establish;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use crate::mariadb::{MariaDb, MariaDbRow, MariaDbQueryParameters};
use crate::Result;
pub struct Connection {
stream: BufStream<TcpStream>,
capabilities: Capabilities,
pub struct MariaDbRawConnection {
pub(crate) stream: BufStream<TcpStream>,
pub(crate) rbuf: Vec<u8>,
pub(crate) capabilities: Capabilities,
next_seq_no: u8,
}
impl Connection {
pub async fn ping(&mut self) -> crate::Result<()> {
impl MariaDbRawConnection {
async fn establish(url: &str) -> Result<Self> {
// TODO: Handle errors
let url = Url::parse(url).unwrap();
let host = url.host_str().unwrap_or("127.0.0.1");
let port = url.port().unwrap_or(3306);
// TODO: handle errors
let host: IpAddr = host.parse().unwrap();
let addr: SocketAddr = (host, port).into();
let stream = TcpStream::connect(&addr).await?;
let mut conn = Self {
stream: BufStream::new(stream),
rbuf: Vec::with_capacity(8 * 1024),
capabilities: Capabilities::empty(),
next_seq_no: 0,
};
establish::establish(&mut conn, &url).await?;
Ok(conn)
}
pub async fn close(&mut self) -> Result<()> {
// Send the quit command
self.start_sequence();
self.write(ComQuit);
self.stream.flush().await?;
Ok(())
}
pub async fn ping(&mut self) -> Result<()> {
// Send the ping command and wait for (and drop) an OK packet
self.start_sequence();
@ -27,14 +70,14 @@ impl Connection {
Ok(())
}
async fn receive(&mut self) -> crate::Result<&[u8]> {
pub(crate) async fn receive(&mut self) -> Result<&[u8]> {
Ok(self
.try_receive()
.await?
.ok_or(io::ErrorKind::UnexpectedEof)?)
.ok_or(Error::Io(io::ErrorKind::UnexpectedEof.into()))?)
}
async fn try_receive(&mut self) -> crate::Result<Option<&[u8]>> {
async fn try_receive(&mut self) -> Result<Option<&[u8]>> {
// Read the packet header which contains the length and the sequence number
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
let mut header = ret_if_none!(self.stream.peek(4).await?);
@ -62,7 +105,7 @@ impl Connection {
self.next_seq_no = 0;
}
fn write<T: Encode>(&mut self, packet: T) {
pub(crate) fn write<T: Encode>(&mut self, packet: T) {
let buf = self.stream.buffer_mut();
// Allocate room for the header that we write after the packet;
@ -89,16 +132,24 @@ impl Connection {
// Decode an OK packet or bubble an ERR packet as an error
// to terminate immediately
async fn receive_ok_or_err(&mut self) -> crate::Result<OkPacket> {
pub(crate) async fn receive_ok_or_err(&mut self) -> Result<OkPacket> {
let capabilities = self.capabilities;
let mut buf = self.receive().await?;
Ok(match buf[0] {
0xfe | 0x00 => OkPacket::decode(buf, self.capabilities)?,
0xfe | 0x00 => OkPacket::decode(buf, capabilities)?,
0xff => {
let err = ErrPacket::decode(buf)?;
// TODO: Bubble as Error::Database
panic!("received db err = {:?}", err);
// panic!("received db err = {:?}", err);
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("{:?}",
err
),
)
.into());
}
id => {
@ -113,4 +164,203 @@ impl Connection {
}
})
}
// This should not be used by the user. It's mean for `RawConnection` impl
// This assumes the buffer has been set and all it needs is a flush
async fn exec_prepare(&mut self) -> Result<u32> {
self.stream.flush().await?;
// COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF)
let mut packet = self.receive().await?;
let ok = match packet[0] {
0xFF => {
let err = ErrPacket::decode(packet)?;
// TODO: Bubble as Error::Database
panic!("received db err = {:?}", err);
}
_ => ComStmtPrepareOk::decode(packet)?,
};
// Skip decoding Column Definition packets for the result from a prepare statement
for _ in 0..ok.columns {
let _ = self.receive().await?;
}
if ok.columns > 0
&& !self
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
{
// TODO: Should we do something with the warning indicators here?
let _eof = EofPacket::decode(self.receive().await?)?;
}
Ok(ok.statement_id)
}
async fn prepare<'c>(&'c mut self, statement: &'c str) -> Result<u32> {
self.stream.flush().await?;
self.start_sequence();
self.write(ComStmtPrepare { statement });
self.exec_prepare().await
}
async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result<u64> {
// TODO: EXECUTE(READ_ONLY) => FETCH instead of EXECUTE(NO)
// SEND ================
self.start_sequence();
self.write(ComStmtExecute {
statement_id,
params: &[],
null: &[],
flags: StmtExecFlag::NO_CURSOR,
param_types: &[]
});
self.stream.flush().await?;
// =====================
// Row Counter, used later
let mut rows = 0u64;
let capabilities = self.capabilities;
let has_eof = capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF);
let packet = self.receive().await?;
if packet[0] == 0x00 {
let _ok = OkPacket::decode(packet, capabilities)?;
} else if packet[0] == 0xFF {
let err = ErrPacket::decode(packet)?;
panic!("received db err = {:?}", err);
} else {
// A Resultset starts with a [ColumnCountPacket] which is a single field that encodes
// how many columns we can expect when fetching rows from this statement
let column_count: u64 = ColumnCountPacket::decode(packet)?.columns;
// Next we have a [ColumnDefinitionPacket] which verbosely explains each minute
// detail about the column in question including table, aliasing, and type
// TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation
let mut columns = vec![];
for _ in 0..column_count {
columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?);
}
// When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by
// an EOF packet
if !has_eof {
let _eof = EofPacket::decode(self.receive().await?)?;
}
// For each row in the result set we will receive a ResultRow packet.
// We may receive an [OkPacket], [EofPacket], or [ErrPacket] (depending on if EOFs are enabled) to finalize the iteration.
loop {
let packet = self.receive().await?;
if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF {
// NOTE: It's possible for a ResultRow to start with 0xFE (which would normally signify end-of-rows)
// but it's not possible for an Ok/Eof to be larger than 0xFF_FF_FF.
if !has_eof {
let _eof = EofPacket::decode(packet)?;
} else {
let _ok = OkPacket::decode(packet, capabilities)?;
}
break;
} else if packet[0] == 0xFF {
let err = ErrPacket::decode(packet)?;
panic!("received db err = {:?}", err);
} else {
// Ignore result rows; exec only returns number of affected rows;
let _ = ResultRow::decode(packet, &columns)?;
// For every row we decode we increment counter
rows = rows + 1;
}
}
}
Ok(rows)
}
}
enum ExecResult {
NoRows(OkPacket),
Rows(Vec<ColumnDefinitionPacket>),
}
impl RawConnection for MariaDbRawConnection {
type Backend = MariaDb;
fn establish(url: &str) -> BoxFuture<Result<Self>>
where
Self: Sized {
Box::pin(MariaDbRawConnection::establish(url))
}
fn close(&mut self) -> BoxFuture<'_, Result<()>> {
Box::pin(self.close())
}
fn ping(&mut self) -> BoxFuture<'_, Result<()>> {
Box::pin(self.ping())
}
fn execute<'c>(
&'c mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxFuture<'c, Result<u64>> {
// Write prepare statement to buffer
self.start_sequence();
self.write(ComStmtPrepare {
statement: query
});
Box::pin(async move {
let statement_id = self.exec_prepare().await?;
let affected = self.execute(statement_id, params).await?;
Ok(affected)
})
}
fn fetch<'c>(
&'c mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxStream<'c, Result<MariaDbRow>> {
unimplemented!();
}
fn fetch_optional<'c>(
&'c mut self,
query: &str,
params: MariaDbQueryParameters,
) -> BoxFuture<'c, Result<Option<<Self::Backend as Backend>::Row>>> {
unimplemented!();
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::Error;
#[tokio::test]
async fn it_can_connect() -> Result<()> {
MariaDbRawConnection::establish("mariadb://root@127.0.0.1:3306/test").await?;
Ok(())
}
#[tokio::test]
async fn it_fails_to_connect_with_bad_username() -> Result<()> {
match MariaDbRawConnection::establish("mariadb://roote@127.0.0.1:3306/test").await {
Ok(_) => panic!("Somehow connected to database with incorrect username"),
Err(_) => Ok(())
}
}
}

42
src/mariadb/establish.rs Normal file
View File

@ -0,0 +1,42 @@
use crate::Result;
use url::Url;
use crate::mariadb::protocol::{HandshakeResponsePacket, InitialHandshakePacket, Encode, Capabilities};
use crate::mariadb::connection::MariaDbRawConnection;
pub(crate) async fn establish(conn: &mut MariaDbRawConnection, url: &Url) -> Result<()> {
let initial = InitialHandshakePacket::decode(conn.receive().await?)?;
// TODO: Capabilities::SECURE_CONNECTION
// TODO: Capabilities::CONNECT_ATTRS
// TODO: Capabilities::PLUGIN_AUTH
// TODO: Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA
// TODO: Capabilities::TRANSACTIONS
// TODO: Capabilities::CLIENT_DEPRECATE_EOF
// TODO?: Capabilities::CLIENT_SESSION_TRACK
let mut capabilities = Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CONNECT_WITH_DB;
let response = HandshakeResponsePacket {
// TODO: Find a good value for [max_packet_size]
capabilities,
max_packet_size: 1024,
client_collation: 192, // utf8_unicode_ci
username: url.username(),
database: &url.path()[1..],
auth_data: None,
auth_plugin_name: None,
connection_attrs: &[]
};
// The AND between our supported capabilities and the servers' is
// what we can use so remember it on the connection
conn.capabilities = capabilities & initial.capabilities;
conn.write(response);
conn.stream.flush().await?;
let _ = conn.receive_ok_or_err().await?;
// TODO: If CONNECT_WITH_DB is not supported we need to send an InitDb command just after establish
Ok(())
}

View File

@ -1,19 +1,21 @@
// TODO: Remove after acitve development
#![allow(unused)]
// mod backend;
mod row;
mod backend;
mod connection;
mod establish;
mod io;
mod protocol;
// mod query;
mod query;
pub mod types;
//pub use self::{
// backend::MariaDb,
// connection.bak::MariaDbRawConnection,
// query::MariaDbQueryParameters,
// row::MariaDbRow,
//};
pub use self::{
backend::MariaDb,
connection::MariaDbRawConnection,
query::MariaDbQueryParameters,
row::MariaDbRow,
};
// pub use io::{BufExt, BufMutExt};
// pub use protocol::{

View File

@ -21,7 +21,7 @@ bitflags::bitflags! {
// https://mariadb.com/kb/en/library/com_stmt_execute
/// Executes a previously prepared statement.
#[derive(Debug)]
pub struct ComStmtExec<'a> {
pub struct ComStmtExecute<'a> {
pub statement_id: u32,
pub flags: StmtExecFlag,
pub params: &'a [u8],
@ -29,7 +29,7 @@ pub struct ComStmtExec<'a> {
pub param_types: &'a [MariaDbTypeMetadata],
}
impl Encode for ComStmtExec<'_> {
impl Encode for ComStmtExecute<'_> {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_STMT_EXECUTE : int<1>
buf.put_u8(BinaryProtocol::ComStmtExec as u8);
@ -75,7 +75,7 @@ mod tests {
fn it_encodes_com_stmt_exec() {
let mut buf = Vec::new();
ComStmtExec {
ComStmtExecute {
statement_id: 1,
flags: StmtExecFlag::NO_CURSOR,
null: &vec![],

View File

@ -5,7 +5,7 @@ use std::io;
// https://mariadb.com/kb/en/library/com_stmt_prepare/#com_stmt_prepare_ok
#[derive(Debug)]
pub struct ComStmtPrepareOk {
pub statement_id: i32,
pub statement_id: u32,
/// Number of columns in the returned result set (or 0 if statement does not return result set).
pub columns: u16,
@ -18,7 +18,7 @@ pub struct ComStmtPrepareOk {
}
impl ComStmtPrepareOk {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
let header = buf.get_u8()?;
if header != 0x00 {
@ -28,7 +28,7 @@ impl ComStmtPrepareOk {
));
}
let statement_id = buf.get_i32::<LittleEndian>()?;
let statement_id = buf.get_u32::<LittleEndian>()?;
let columns = buf.get_u16::<LittleEndian>()?;
let params = buf.get_u16::<LittleEndian>()?;

View File

@ -6,7 +6,7 @@ pub mod com_stmt_prepare_ok;
pub mod com_stmt_reset;
pub use com_stmt_close::ComStmtClose;
pub use com_stmt_exec::ComStmtExec;
pub use com_stmt_exec::{ComStmtExecute, StmtExecFlag};
pub use com_stmt_fetch::ComStmtFetch;
pub use com_stmt_prepare::ComStmtPrepare;
pub use com_stmt_prepare_ok::ComStmtPrepareOk;

View File

@ -21,7 +21,7 @@ pub struct InitialHandshakePacket {
}
impl InitialHandshakePacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
let protocol_version = buf.get_u8()?;
let server_version = buf.get_str_nul()?.to_owned();
let connection_id = buf.get_u32::<LittleEndian>()?;

View File

@ -12,7 +12,7 @@ mod server_status;
mod text;
pub use binary::{
ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtReset,
ComStmtClose, ComStmtExecute, StmtExecFlag, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtReset,
};
pub use capabilities::Capabilities;
pub use connect::{

View File

@ -12,7 +12,7 @@ pub struct ColumnCountPacket {
}
impl ColumnCountPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
let columns = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
Ok(Self { columns })

View File

@ -25,7 +25,7 @@ pub struct ColumnDefinitionPacket {
}
impl ColumnDefinitionPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
// string<lenenc> catalog (always 'def')
let _catalog = buf.get_str_lenenc::<LittleEndian>()?;
// TODO: Assert that this is always DEF

View File

@ -15,7 +15,7 @@ pub struct EofPacket {
}
impl EofPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
pub(crate) fn decode(mut buf: &[u8]) -> io::Result<Self> {
let header = buf.get_u8()?;
if header != 0xFE {
return Err(io::Error::new(

View File

@ -4,6 +4,7 @@ use crate::{
serialize::{IsNull, ToSql},
types::HasSqlType,
};
use crate::mariadb::types::MariaDbTypeMetadata;
pub struct MariaDbQueryParameters {
param_types: Vec<MariaDbTypeMetadata>,

26
src/mariadb/row.rs Normal file
View File

@ -0,0 +1,26 @@
use crate::row::Row;
use crate::mariadb::protocol::ResultRow;
use crate::mariadb::MariaDb;
pub struct MariaDbRow(pub(super) ResultRow);
impl Row for MariaDbRow {
type Backend = MariaDb;
#[inline]
fn is_empty(&self) -> bool {
self.0.values.is_empty()
}
#[inline]
fn len(&self) -> usize {
self.0.values.len()
}
#[inline]
fn get_raw(&self, index: usize) -> Option<&[u8]> {
self.0.values[index]
.as_ref()
.map(|value| unsafe { value.as_ref() })
}
}

View File

@ -1,5 +1,6 @@
use super::protocol::{FieldType, ParameterFlag};
use crate::types::TypeMetadata;
use crate::mariadb::MariaDb;
#[derive(Debug)]
pub struct MariaDbTypeMetadata {
@ -7,6 +8,6 @@ pub struct MariaDbTypeMetadata {
pub param_flag: ParameterFlag,
}
//impl TypeMetadata for MariaDb {
// type TypeMetadata = MariaDbTypeMetadata;
//}
impl TypeMetadata for MariaDb {
type TypeMetadata = MariaDbTypeMetadata;
}