WIP: Binary result set

This commit is contained in:
Daniel Akhterov 2019-07-30 21:33:41 -07:00
parent ffe25704fc
commit ac6006733c
11 changed files with 303 additions and 126 deletions

View File

@ -1,10 +1,5 @@
use super::Connection;
use crate::mariadb::protocol::{
deserialize::{DeContext, Deserialize},
packets::{handshake_response::HandshakeResponsePacket, initial::InitialHandshakePacket},
server::Message as ServerMessage,
types::Capabilities,
};
use crate::mariadb::{DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
use bytes::{BufMut, Bytes};
use failure::{err_msg, Error};
use crate::ConnectOptions;
@ -32,12 +27,12 @@ pub async fn establish<'a, 'b: 'a>(
conn.send(handshake).await?;
match conn.next().await? {
Some(ServerMessage::OkPacket(message)) => {
Some(Message::OkPacket(message)) => {
conn.context.seq_no = message.seq_no;
Ok(())
}
Some(ServerMessage::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))),
Some(Message::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))),
Some(message) => {
panic!("Did not receive OkPacket nor ErrPacket. Received: {:?}", message);
@ -51,100 +46,147 @@ pub async fn establish<'a, 'b: 'a>(
#[cfg(test)]
mod test {
use super::*;
use failure::Error;
#[runtime::test]
async fn it_can_connect() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
})
.await?;
Ok(())
}
#[runtime::test]
async fn it_can_ping() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.ping().await?;
Ok(())
}
#[runtime::test]
async fn it_can_select_db() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
Ok(())
}
#[runtime::test]
async fn it_can_query() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
conn.query("SELECT * FROM users").await?;
Ok(())
}
#[runtime::test]
async fn it_can_prepare() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
conn.prepare("SELECT * FROM users WHERE username = ?").await?;
Ok(())
}
#[runtime::test]
async fn it_does_not_connect() -> Result<(), Error> {
match Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("roote"),
database: None,
password: None,
})
.await
{
Ok(_) => Err(err_msg("Bad username still worked?")),
Err(_) => Ok(()),
}
}
// use super::*;
// use failure::Error;
// use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
//
// #[runtime::test]
// async fn it_can_connect() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// })
// .await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_ping() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.ping().await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_select_db() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_query() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// conn.query("SELECT * FROM users").await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_prepare() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// conn.prepare("SELECT * FROM users WHERE username = ?").await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_can_execute_prepared() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// }).await?;
//
// conn.select_db("test").await?;
//
// let mut prepared = conn.prepare("SELECT * FROM users WHERE username = ?;").await?;
//
// println!("{:?}", prepared);
//
// if let Some(param_defs) = &mut prepared.param_defs {
// param_defs[0].field_type = FieldType::MysqlTypeBlob;
// }
//
// let exec = ComStmtExec {
// stmt_id: prepared.ok.stmt_id,
// flags: StmtExecFlag::CursorForUpdate,
// params: Some(vec![Some(Bytes::from_static(b"'daniel'"))]),
// param_defs: prepared.param_defs,
// };
//
//
// conn.send(exec).await?;
//
// let buf = conn.stream.next_bytes().await?;
//
// println!("{:?}", buf);
//
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?.rows);
//
// let fetch = ComStmtFetch {
// stmt_id: prepared.ok.stmt_id,
// rows: 1,
// };
//
// conn.send(exec).await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_does_not_connect() -> Result<(), Error> {
// match Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("roote"),
// database: None,
// password: None,
// })
// .await
// {
// Ok(_) => Err(err_msg("Bad username still worked?")),
// Err(_) => Ok(()),
// }
// }
}

View File

@ -151,16 +151,14 @@ impl Connection {
Ok(())
}
pub async fn prepare(&mut self, query: &str) -> Result<(), Error> {
pub async fn prepare(&mut self, query: &str) -> Result<ComStmtPrepareResp, Error> {
self.send(ComStmtPrepare {
statement: Bytes::from(query),
}).await?;
let buf = self.stream.next_bytes().await?;
ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))?;
Ok(())
ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))
}
pub async fn next(&mut self) -> Result<Option<Message>, Error> {

View File

@ -47,3 +47,4 @@ pub use protocol::FieldType;
pub use protocol::FieldDetailFlag;
pub use protocol::SessionChangeType;
pub use protocol::StmtExecFlag;
pub use protocol::ComStmtFetch;

View File

@ -76,30 +76,30 @@ impl<'a> Decoder<'a> {
// 0xFF then there was an error.
// If the first byte is not in the previous list then that byte is the int value.
#[inline]
pub fn decode_int_lenenc(&mut self) -> Option<i64> {
pub fn decode_int_lenenc(&mut self) -> Option<u64> {
match self.buf[self.index] {
0xFB => {
self.index += 1;
None
}
0xFC => {
let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as i64);
let value = Some(LittleEndian::read_i16(&self.buf[self.index + 1..]) as u64);
self.index += 3;
value
}
0xFD => {
let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as i64);
let value = Some(LittleEndian::read_i24(&self.buf[self.index + 1..]) as u64);
self.index += 4;
value
}
0xFE => {
let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as i64);
let value = Some(LittleEndian::read_i64(&self.buf[self.index + 1..]) as u64);
self.index += 9;
value
}
0xFF => panic!("int<lenenc> unprocessable first byte 0xFF"),
_ => {
let value = Some(self.buf[self.index] as i64);
let value = Some(self.buf[self.index] as u64);
self.index += 1;
value
}
@ -176,8 +176,8 @@ impl<'a> Decoder<'a> {
// Decode a string<fix> which is a string of fixed length.
#[inline]
pub fn decode_string_fix(&mut self, length: u32) -> Bytes {
let value = self.buf.slice(self.index, self.index + length as usize);
pub fn decode_string_fix(&mut self, length: usize) -> Bytes {
let value = self.buf.slice(self.index, self.index + length);
self.index = self.index + length as usize;
value
}
@ -212,8 +212,8 @@ impl<'a> Decoder<'a> {
// Same as the string counter part, but copied to maintain consistency with the spec.
#[inline]
pub fn decode_byte_fix(&mut self, length: u32) -> Bytes {
let value = self.buf.slice(self.index, self.index + length as usize);
pub fn decode_byte_fix(&mut self, length: usize) -> Bytes {
let value = self.buf.slice(self.index, self.index + length);
self.index = self.index + length as usize;
value
}
@ -242,6 +242,100 @@ impl<'a> Decoder<'a> {
self.index = self.buf.len();
value
}
#[inline]
pub fn decode_binary_decimal(&mut self) -> Bytes {
self.decode_string_lenenc()
}
#[inline]
pub fn decode_binary_double(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 8);
self.index += 8;
value
}
#[inline]
pub fn decode_binary_bigint(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 8);
self.index += 8;
value
}
#[inline]
pub fn decode_binary_int(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 4);
self.index += 4;
value
}
#[inline]
pub fn decode_binary_mediumint(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 4);
self.index += 4;
value
}
#[inline]
pub fn decode_binary_float(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 4);
self.index += 4;
value
}
#[inline]
pub fn decode_binary_smallint(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 2);
self.index += 2;
value
}
#[inline]
pub fn decode_binary_year(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 2);
self.index += 2;
value
}
#[inline]
pub fn decode_binary_tinyint(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 1);
self.index += 1;
value
}
#[inline]
pub fn decode_binary_date(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 5);
self.index += 5;
value
}
#[inline]
pub fn decode_binary_timestamp(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 12);
self.index += 12;
value
}
#[inline]
pub fn decode_binary_datetime(&mut self) -> Bytes {
let value = self.buf.slice(self.index, self.index + 12);
self.index += 12;
value
}
#[inline]
pub fn decode_binary_time(&mut self) -> Bytes {
let length = self.decode_int_u8();
if length != 8 && length != 12 {
panic!("Date length is not 8 or 12 (the only two possible values)");
}
let value = self.buf.slice(self.index, self.index + length as usize);
self.index += length as usize;
value
}
}
#[cfg(test)]

View File

@ -1,5 +1,4 @@
use super::decode::Decoder;
use crate::mariadb::connection::{ConnContext, Connection};
use crate::mariadb::{Decoder, ConnContext, Connection, ColumnDefPacket};
use bytes::Bytes;
use failure::Error;
@ -10,12 +9,13 @@ use failure::Error;
pub struct DeContext<'a> {
pub ctx: &'a mut ConnContext,
pub decoder: Decoder<'a>,
pub columns: Option<i64>,
pub columns: Option<u64>,
pub column_defs: Option<Vec<ColumnDefPacket>>,
}
impl<'a> DeContext<'a> {
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None }
DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None , column_defs: None }
}
}

View File

@ -5,6 +5,7 @@ pub mod com_stmt_close;
pub mod com_stmt_exec;
pub mod com_stmt_fetch;
pub mod com_stmt_reset;
pub mod result_row;
pub use com_stmt_prepare::ComStmtPrepare;
pub use com_stmt_prepare_ok::ComStmtPrepareOk;

View File

@ -0,0 +1,41 @@
use bytes::Bytes;
#[derive(Debug, Default)]
pub struct ResultRow {
pub columns: Vec<Option<Bytes>>
}
impl crate::mariadb::Deserialize for ResultRow {
fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result<Self, failure::Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_u8();
let header = decoder.decode_int_u8();
let bitmap = if let Some(columns) = ctx.columns {
let size = (columns + 9) / 8;
decoder.decode_byte_fix(size as usize)
} else {
Bytes::new()
};
let row = if let Some(columns) = ctx.columns {
(0..columns).map(|index| {
if (1 << index) & (bitmap[index/8] << (index % 8)) == 0 {
None
} else {
match ctx.column_defs[index] {
}
decoder.decode_binary_column(&ctx.column_defs)
}
}).collect::<Vec<Bytes>>()
} else {
Vec::new()
};
Ok(ResultRow::default())
}
}

View File

@ -9,7 +9,7 @@ use crate::mariadb::{DeContext, Deserialize};
pub struct ColumnPacket {
pub length: u32,
pub seq_no: u8,
pub columns: Option<i64>,
pub columns: Option<u64>,
}
impl Deserialize for ColumnPacket {

View File

@ -13,7 +13,7 @@ pub struct ColumnDefPacket {
pub table: Bytes,
pub column_alias: Bytes,
pub column: Bytes,
pub length_of_fixed_fields: Option<i64>,
pub length_of_fixed_fields: Option<u64>,
pub char_set: i16,
pub max_columns: i32,
pub field_type: FieldType,

View File

@ -67,7 +67,7 @@ impl Deserialize for InitialHandshakePacket {
let mut scramble: Option<Bytes> = None;
if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
let len = std::cmp::max(12, plugin_data_length as usize - 9);
scramble = Some(decoder.decode_string_fix(len as u32));
scramble = Some(decoder.decode_string_fix(len as usize));
// Skip reserve byte
decoder.skip_bytes(1);
}

View File

@ -8,8 +8,8 @@ use crate::mariadb::{DeContext, Deserialize, ServerStatusFlag,
pub struct OkPacket {
pub length: u32,
pub seq_no: u8,
pub affected_rows: Option<i64>,
pub last_insert_id: Option<i64>,
pub affected_rows: Option<u64>,
pub last_insert_id: Option<u64>,
pub server_status: ServerStatusFlag,
pub warning_count: i16,
pub info: Bytes,