Fix capabilities

This commit is contained in:
Daniel Akhterov 2019-07-30 20:30:33 -07:00
parent 667ee3b56f
commit ffe25704fc
10 changed files with 286 additions and 74 deletions

View File

@ -16,7 +16,7 @@ itoa = "0.4.4"
log = "0.4.7"
md-5 = "0.8.0"
memchr = "2.2.1"
runtime = { version = "=0.3.0-alpha.6", default-features = false }
runtime = { version = "=0.3.0-alpha.6", default-features = true }
bitflags = "1.1.0"
enum-tryfrom = "0.2.1"
enum-tryfrom-derive = "0.2.1"

View File

@ -8,6 +8,7 @@ use crate::mariadb::protocol::{
use bytes::{BufMut, Bytes};
use failure::{err_msg, Error};
use crate::ConnectOptions;
use std::ops::BitAnd;
pub async fn establish<'a, 'b: 'a>(
conn: &'a mut Connection,
@ -15,11 +16,13 @@ pub async fn establish<'a, 'b: 'a>(
) -> Result<(), Error> {
let buf = &conn.stream.next_bytes().await?;
let mut de_ctx = DeContext::new(&mut conn.context, &buf);
let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?;
let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?;
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
// Minimum client capabilities required to establish connection
capabilities: Capabilities::CLIENT_PROTOCOL_41,
capabilities: de_ctx.ctx.capabilities,
max_packet_size: 1024,
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
username: Bytes::from(options.user.unwrap_or("")),
@ -51,35 +54,97 @@ mod test {
use super::*;
use failure::Error;
// #[runtime::test]
// async fn it_connects() -> Result<(), Error> {
// let mut conn = Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("root"),
// database: None,
// password: None,
// })
// .await?;
//
// conn.ping().await?;
//
// Ok(())
// }
//
// #[runtime::test]
// async fn it_does_not_connect() -> Result<(), Error> {
// match Connection::establish(ConnectOptions {
// host: "127.0.0.1",
// port: 3306,
// user: Some("roote"),
// database: None,
// password: None,
// })
// .await
// {
// Ok(_) => Err(err_msg("Bad username still worked?")),
// Err(_) => Ok(()),
// }
// }
#[runtime::test]
async fn it_can_connect() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
})
.await?;
Ok(())
}
#[runtime::test]
async fn it_can_ping() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.ping().await?;
Ok(())
}
#[runtime::test]
async fn it_can_select_db() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
Ok(())
}
#[runtime::test]
async fn it_can_query() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
conn.query("SELECT * FROM users").await?;
Ok(())
}
#[runtime::test]
async fn it_can_prepare() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
conn.select_db("test").await?;
conn.prepare("SELECT * FROM users WHERE username = ?").await?;
Ok(())
}
#[runtime::test]
async fn it_does_not_connect() -> Result<(), Error> {
match Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("roote"),
database: None,
password: None,
})
.await
{
Ok(_) => Err(err_msg("Bad username still worked?")),
Err(_) => Ok(()),
}
}
}

View File

@ -6,7 +6,7 @@ use futures::{
prelude::*,
};
use runtime::net::TcpStream;
use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag}};
use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
mod establish;
@ -46,26 +46,18 @@ impl ConnContext {
connection_id: 0,
seq_no: 2,
last_seq_no: 0,
capabilities: Capabilities::FOUND_ROWS
| Capabilities::CONNECT_WITH_DB
| Capabilities::COMPRESS
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::CLIENT_PROTOCOL_41
| Capabilities::CLIENT_INTERACTIVE
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA
| Capabilities::CLIENT_SESSION_TRACK
| Capabilities::CLIENT_DEPRECATE_EOF
| Capabilities::MARIA_DB_CLIENT_PROGRESS
| Capabilities::MARIA_DB_CLIENT_COM_MULTI
| Capabilities::MARIA_CLIENT_STMT_BULK_OPERATIONS,
capabilities: Capabilities::CLIENT_PROTOCOL_41,
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS
}
}
#[cfg(test)]
pub fn with_eof() -> Self {
ConnContext {
connection_id: 0,
seq_no: 2,
last_seq_no: 0,
capabilities: Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CLIENT_DEPRECATE_EOF,
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS
}
}
@ -81,7 +73,7 @@ impl Connection {
connection_id: -1,
seq_no: 1,
last_seq_no: 0,
capabilities: Capabilities::default(),
capabilities: Capabilities::CLIENT_PROTOCOL_41,
status: ServerStatusFlag::default(),
},
};
@ -103,22 +95,33 @@ impl Connection {
}
pub async fn quit(&mut self) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComQuit()).await?;
Ok(())
}
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> {
self.context.seq_no = 0;
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<Option<ResultSet>, Error> {
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
Ok(())
let buf = self.stream.next_bytes().await?;
let mut ctx = DeContext::new(&mut self.context, &buf);
if let Some(tag) = ctx.decoder.peek_tag() {
match tag {
0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()),
0x00 => {
OkPacket::deserialize(&mut ctx)?;
Ok(None)
},
0xFB => unimplemented!(),
_ => Ok(Some(ResultSet::deserialize(&mut ctx)?))
}
} else {
panic!("Tag not found in result packet");
}
}
pub async fn select_db<'a>(&'a mut self, db: &'a str) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?;
@ -139,7 +142,6 @@ impl Connection {
}
pub async fn ping(&mut self) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComPing()).await?;
// Ping response must be an OkPacket
@ -149,6 +151,18 @@ impl Connection {
Ok(())
}
pub async fn prepare(&mut self, query: &str) -> Result<(), Error> {
self.send(ComStmtPrepare {
statement: Bytes::from(query),
}).await?;
let buf = self.stream.next_bytes().await?;
ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))?;
Ok(())
}
pub async fn next(&mut self) -> Result<Option<Message>, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0;

View File

@ -8,14 +8,14 @@ use failure::Error;
// access to the connection context.
// Mainly used to simply to simplify number of parameters for deserializing functions
pub struct DeContext<'a> {
pub conn: &'a mut ConnContext,
pub ctx: &'a mut ConnContext,
pub decoder: Decoder<'a>,
pub columns: Option<i64>,
}
impl<'a> DeContext<'a> {
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
DeContext { conn, decoder: Decoder::new(&buf), columns: None }
DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None }
}
}

View File

@ -2,7 +2,7 @@ use bytes::Bytes;
#[derive(Debug)]
pub struct ComStmtPrepare {
statement: Bytes
pub statement: Bytes
}
impl crate::mariadb::Serialize for ComStmtPrepare {

View File

@ -17,7 +17,7 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp {
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>();
if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
EofPacket::deserialize(ctx)?;
}
@ -32,7 +32,7 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp {
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>();
if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
EofPacket::deserialize(ctx)?;
}
@ -55,7 +55,7 @@ mod test {
use crate::{__bytes_builder, ConnectOptions, mariadb::{ConnContext, DeContext, Deserialize}};
#[test]
fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// ---------------------------- //
@ -152,6 +152,140 @@ mod test {
0u8, 0u8
);
let mut context = ConnContext::with_eof();
let mut ctx = DeContext::new(&mut context, &buf);
let message = ComStmtPrepareResp::deserialize(&mut ctx)?;
Ok(())
}
#[test]
fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// ---------------------------- //
// Statement Prepared Ok Packet //
// ---------------------------- //
// int<3> length
0u8, 0u8, 0u8,
// int<1> seq_no
0u8,
// int<1> 0x00 COM_STMT_PREPARE_OK header
0u8,
// int<4> statement id
1u8, 0u8, 0u8, 0u8,
// int<2> number of columns in the returned result set (or 0 if statement does not return result set)
1u8, 0u8,
// int<2> number of prepared statement parameters ('?' placeholders)
1u8, 0u8,
// string<1> -not used-
0u8,
// int<2> number of warnings
0u8, 0u8,
// Param column definition
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
3u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
8u8, b"username",
// string<lenenc> column
8u8, b"username",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0xFF_u8, 0xFF_u8, 0u8, 0u8,
// int<1> Field types
0xFC_u8,
// int<2> Field detail flag
0x11_u8, 0x10_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ---------- //
// EOF Packet //
// ---------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
6u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8,
// Result column definitions
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
3u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
8u8, b"username",
// string<lenenc> column
8u8, b"username",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0xFF_u8, 0xFF_u8, 0u8, 0u8,
// int<1> Field types
0xFC_u8,
// int<2> Field detail flag
0x11_u8, 0x10_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ---------- //
// EOF Packet //
// ---------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
6u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);

View File

@ -21,7 +21,7 @@ pub struct HandshakeResponsePacket {
impl Serialize for HandshakeResponsePacket {
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
encoder.alloc_packet_header();
encoder.seq_no(0);
encoder.seq_no(1);
encoder.encode_int_u32(self.capabilities.bits() as u32);
encoder.encode_int_u32(self.max_packet_size);

View File

@ -77,8 +77,7 @@ impl Deserialize for InitialHandshakePacket {
auth_plugin_name = Some(decoder.decode_string_null()?);
}
ctx.conn.capabilities = capabilities;
ctx.conn.last_seq_no = seq_no;
ctx.ctx.last_seq_no = seq_no;
Ok(InitialHandshakePacket {
length,

View File

@ -31,7 +31,7 @@ impl Deserialize for ResultSet {
Vec::new()
};
let eof_packet = if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
let eof_packet = if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
Some(EofPacket::deserialize(ctx)?)
} else {
None
@ -62,7 +62,7 @@ impl Deserialize for ResultSet {
}
}
if ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
OkPacket::deserialize(ctx)?;
} else {
EofPacket::deserialize(ctx)?;

View File

@ -128,7 +128,7 @@ pub enum ParamFlag {
impl Default for Capabilities {
fn default() -> Self {
Capabilities::CLIENT_MYSQL
Capabilities::CLIENT_PROTOCOL_41
}
}