mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
Fix capabilities
This commit is contained in:
parent
667ee3b56f
commit
ffe25704fc
@ -16,7 +16,7 @@ itoa = "0.4.4"
|
||||
log = "0.4.7"
|
||||
md-5 = "0.8.0"
|
||||
memchr = "2.2.1"
|
||||
runtime = { version = "=0.3.0-alpha.6", default-features = false }
|
||||
runtime = { version = "=0.3.0-alpha.6", default-features = true }
|
||||
bitflags = "1.1.0"
|
||||
enum-tryfrom = "0.2.1"
|
||||
enum-tryfrom-derive = "0.2.1"
|
||||
|
||||
@ -8,6 +8,7 @@ use crate::mariadb::protocol::{
|
||||
use bytes::{BufMut, Bytes};
|
||||
use failure::{err_msg, Error};
|
||||
use crate::ConnectOptions;
|
||||
use std::ops::BitAnd;
|
||||
|
||||
pub async fn establish<'a, 'b: 'a>(
|
||||
conn: &'a mut Connection,
|
||||
@ -15,11 +16,13 @@ pub async fn establish<'a, 'b: 'a>(
|
||||
) -> Result<(), Error> {
|
||||
let buf = &conn.stream.next_bytes().await?;
|
||||
let mut de_ctx = DeContext::new(&mut conn.context, &buf);
|
||||
let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?;
|
||||
let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?;
|
||||
|
||||
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
|
||||
|
||||
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
|
||||
// Minimum client capabilities required to establish connection
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41,
|
||||
capabilities: de_ctx.ctx.capabilities,
|
||||
max_packet_size: 1024,
|
||||
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
|
||||
username: Bytes::from(options.user.unwrap_or("")),
|
||||
@ -51,35 +54,97 @@ mod test {
|
||||
use super::*;
|
||||
use failure::Error;
|
||||
|
||||
// #[runtime::test]
|
||||
// async fn it_connects() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// })
|
||||
// .await?;
|
||||
//
|
||||
// conn.ping().await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_does_not_connect() -> Result<(), Error> {
|
||||
// match Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("roote"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// })
|
||||
// .await
|
||||
// {
|
||||
// Ok(_) => Err(err_msg("Bad username still worked?")),
|
||||
// Err(_) => Ok(()),
|
||||
// }
|
||||
// }
|
||||
#[runtime::test]
|
||||
async fn it_can_connect() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_ping() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.ping().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_select_db() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_query() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
conn.query("SELECT * FROM users").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_prepare() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
conn.prepare("SELECT * FROM users WHERE username = ?").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_does_not_connect() -> Result<(), Error> {
|
||||
match Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("roote"),
|
||||
database: None,
|
||||
password: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => Err(err_msg("Bad username still worked?")),
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ use futures::{
|
||||
prelude::*,
|
||||
};
|
||||
use runtime::net::TcpStream;
|
||||
use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag}};
|
||||
use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
|
||||
|
||||
mod establish;
|
||||
|
||||
@ -46,26 +46,18 @@ impl ConnContext {
|
||||
connection_id: 0,
|
||||
seq_no: 2,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::FOUND_ROWS
|
||||
| Capabilities::CONNECT_WITH_DB
|
||||
| Capabilities::COMPRESS
|
||||
| Capabilities::LOCAL_FILES
|
||||
| Capabilities::IGNORE_SPACE
|
||||
| Capabilities::CLIENT_PROTOCOL_41
|
||||
| Capabilities::CLIENT_INTERACTIVE
|
||||
| Capabilities::TRANSACTIONS
|
||||
| Capabilities::SECURE_CONNECTION
|
||||
| Capabilities::MULTI_STATEMENTS
|
||||
| Capabilities::MULTI_RESULTS
|
||||
| Capabilities::PS_MULTI_RESULTS
|
||||
| Capabilities::PLUGIN_AUTH
|
||||
| Capabilities::CONNECT_ATTRS
|
||||
| Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA
|
||||
| Capabilities::CLIENT_SESSION_TRACK
|
||||
| Capabilities::CLIENT_DEPRECATE_EOF
|
||||
| Capabilities::MARIA_DB_CLIENT_PROGRESS
|
||||
| Capabilities::MARIA_DB_CLIENT_COM_MULTI
|
||||
| Capabilities::MARIA_CLIENT_STMT_BULK_OPERATIONS,
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41,
|
||||
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_eof() -> Self {
|
||||
ConnContext {
|
||||
connection_id: 0,
|
||||
seq_no: 2,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41 | Capabilities::CLIENT_DEPRECATE_EOF,
|
||||
status: ServerStatusFlag::SERVER_STATUS_IN_TRANS
|
||||
}
|
||||
}
|
||||
@ -81,7 +73,7 @@ impl Connection {
|
||||
connection_id: -1,
|
||||
seq_no: 1,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::default(),
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41,
|
||||
status: ServerStatusFlag::default(),
|
||||
},
|
||||
};
|
||||
@ -103,22 +95,33 @@ impl Connection {
|
||||
}
|
||||
|
||||
pub async fn quit(&mut self) -> Result<(), Error> {
|
||||
self.context.seq_no = 0;
|
||||
self.send(ComQuit()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> {
|
||||
self.context.seq_no = 0;
|
||||
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<Option<ResultSet>, Error> {
|
||||
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
|
||||
|
||||
Ok(())
|
||||
let buf = self.stream.next_bytes().await?;
|
||||
let mut ctx = DeContext::new(&mut self.context, &buf);
|
||||
if let Some(tag) = ctx.decoder.peek_tag() {
|
||||
match tag {
|
||||
0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()),
|
||||
0x00 => {
|
||||
OkPacket::deserialize(&mut ctx)?;
|
||||
Ok(None)
|
||||
},
|
||||
0xFB => unimplemented!(),
|
||||
_ => Ok(Some(ResultSet::deserialize(&mut ctx)?))
|
||||
}
|
||||
} else {
|
||||
panic!("Tag not found in result packet");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn select_db<'a>(&'a mut self, db: &'a str) -> Result<(), Error> {
|
||||
self.context.seq_no = 0;
|
||||
self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?;
|
||||
|
||||
|
||||
@ -139,7 +142,6 @@ impl Connection {
|
||||
}
|
||||
|
||||
pub async fn ping(&mut self) -> Result<(), Error> {
|
||||
self.context.seq_no = 0;
|
||||
self.send(ComPing()).await?;
|
||||
|
||||
// Ping response must be an OkPacket
|
||||
@ -149,6 +151,18 @@ impl Connection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn prepare(&mut self, query: &str) -> Result<(), Error> {
|
||||
self.send(ComStmtPrepare {
|
||||
statement: Bytes::from(query),
|
||||
}).await?;
|
||||
|
||||
let buf = self.stream.next_bytes().await?;
|
||||
|
||||
ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn next(&mut self) -> Result<Option<Message>, Error> {
|
||||
let mut rbuf = BytesMut::new();
|
||||
let mut len = 0;
|
||||
|
||||
@ -8,14 +8,14 @@ use failure::Error;
|
||||
// access to the connection context.
|
||||
// Mainly used to simply to simplify number of parameters for deserializing functions
|
||||
pub struct DeContext<'a> {
|
||||
pub conn: &'a mut ConnContext,
|
||||
pub ctx: &'a mut ConnContext,
|
||||
pub decoder: Decoder<'a>,
|
||||
pub columns: Option<i64>,
|
||||
}
|
||||
|
||||
impl<'a> DeContext<'a> {
|
||||
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
|
||||
DeContext { conn, decoder: Decoder::new(&buf), columns: None }
|
||||
DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ use bytes::Bytes;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ComStmtPrepare {
|
||||
statement: Bytes
|
||||
pub statement: Bytes
|
||||
}
|
||||
|
||||
impl crate::mariadb::Serialize for ComStmtPrepare {
|
||||
|
||||
@ -17,7 +17,7 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp {
|
||||
.map(Result::unwrap)
|
||||
.collect::<Vec<ColumnDefPacket>>();
|
||||
|
||||
if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
EofPacket::deserialize(ctx)?;
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp {
|
||||
.map(Result::unwrap)
|
||||
.collect::<Vec<ColumnDefPacket>>();
|
||||
|
||||
if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
EofPacket::deserialize(ctx)?;
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ mod test {
|
||||
use crate::{__bytes_builder, ConnectOptions, mariadb::{ConnContext, DeContext, Deserialize}};
|
||||
|
||||
#[test]
|
||||
fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
|
||||
fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// ---------------------------- //
|
||||
@ -152,6 +152,140 @@ mod test {
|
||||
0u8, 0u8
|
||||
);
|
||||
|
||||
let mut context = ConnContext::with_eof();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
|
||||
let message = ComStmtPrepareResp::deserialize(&mut ctx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// ---------------------------- //
|
||||
// Statement Prepared Ok Packet //
|
||||
// ---------------------------- //
|
||||
|
||||
// int<3> length
|
||||
0u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
0u8,
|
||||
// int<1> 0x00 COM_STMT_PREPARE_OK header
|
||||
0u8,
|
||||
// int<4> statement id
|
||||
1u8, 0u8, 0u8, 0u8,
|
||||
// int<2> number of columns in the returned result set (or 0 if statement does not return result set)
|
||||
1u8, 0u8,
|
||||
// int<2> number of prepared statement parameters ('?' placeholders)
|
||||
1u8, 0u8,
|
||||
// string<1> -not used-
|
||||
0u8,
|
||||
// int<2> number of warnings
|
||||
0u8, 0u8,
|
||||
|
||||
// Param column definition
|
||||
|
||||
// ------------------------ //
|
||||
// Column Definition packet //
|
||||
// ------------------------ //
|
||||
// int<3> length
|
||||
52u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
3u8,
|
||||
// string<lenenc> catalog (always 'def')
|
||||
3u8, b"def",
|
||||
// string<lenenc> schema
|
||||
4u8, b"test",
|
||||
// string<lenenc> table alias
|
||||
5u8, b"users",
|
||||
// string<lenenc> table
|
||||
5u8, b"users",
|
||||
// string<lenenc> column alias
|
||||
8u8, b"username",
|
||||
// string<lenenc> column
|
||||
8u8, b"username",
|
||||
// int<lenenc> length of fixed fields (=0xC)
|
||||
0x0C_u8,
|
||||
// int<2> character set number
|
||||
8u8, 0u8,
|
||||
// int<4> max. column size
|
||||
0xFF_u8, 0xFF_u8, 0u8, 0u8,
|
||||
// int<1> Field types
|
||||
0xFC_u8,
|
||||
// int<2> Field detail flag
|
||||
0x11_u8, 0x10_u8,
|
||||
// int<1> decimals
|
||||
0u8,
|
||||
// int<2> - unused -
|
||||
0u8, 0u8,
|
||||
|
||||
// ---------- //
|
||||
// EOF Packet //
|
||||
// ---------- //
|
||||
// int<3> length
|
||||
5u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
6u8,
|
||||
// int<1> 0xfe : EOF header
|
||||
0xFE_u8,
|
||||
// int<2> warning count
|
||||
0u8, 0u8,
|
||||
// int<2> server status
|
||||
34u8, 0u8,
|
||||
|
||||
// Result column definitions
|
||||
|
||||
// ------------------------ //
|
||||
// Column Definition packet //
|
||||
// ------------------------ //
|
||||
// int<3> length
|
||||
52u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
3u8,
|
||||
// string<lenenc> catalog (always 'def')
|
||||
3u8, b"def",
|
||||
// string<lenenc> schema
|
||||
4u8, b"test",
|
||||
// string<lenenc> table alias
|
||||
5u8, b"users",
|
||||
// string<lenenc> table
|
||||
5u8, b"users",
|
||||
// string<lenenc> column alias
|
||||
8u8, b"username",
|
||||
// string<lenenc> column
|
||||
8u8, b"username",
|
||||
// int<lenenc> length of fixed fields (=0xC)
|
||||
0x0C_u8,
|
||||
// int<2> character set number
|
||||
8u8, 0u8,
|
||||
// int<4> max. column size
|
||||
0xFF_u8, 0xFF_u8, 0u8, 0u8,
|
||||
// int<1> Field types
|
||||
0xFC_u8,
|
||||
// int<2> Field detail flag
|
||||
0x11_u8, 0x10_u8,
|
||||
// int<1> decimals
|
||||
0u8,
|
||||
// int<2> - unused -
|
||||
0u8, 0u8,
|
||||
|
||||
// ---------- //
|
||||
// EOF Packet //
|
||||
// ---------- //
|
||||
// int<3> length
|
||||
5u8, 0u8, 0u8,
|
||||
// int<1> seq_no
|
||||
6u8,
|
||||
// int<1> 0xfe : EOF header
|
||||
0xFE_u8,
|
||||
// int<2> warning count
|
||||
0u8, 0u8,
|
||||
// int<2> server status
|
||||
34u8, 0u8
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ pub struct HandshakeResponsePacket {
|
||||
impl Serialize for HandshakeResponsePacket {
|
||||
fn serialize<'a, 'b>(&self, ctx: &mut crate::mariadb::connection::ConnContext, encoder: &mut crate::mariadb::protocol::encode::Encoder) -> Result<(), Error> {
|
||||
encoder.alloc_packet_header();
|
||||
encoder.seq_no(0);
|
||||
encoder.seq_no(1);
|
||||
|
||||
encoder.encode_int_u32(self.capabilities.bits() as u32);
|
||||
encoder.encode_int_u32(self.max_packet_size);
|
||||
|
||||
@ -77,8 +77,7 @@ impl Deserialize for InitialHandshakePacket {
|
||||
auth_plugin_name = Some(decoder.decode_string_null()?);
|
||||
}
|
||||
|
||||
ctx.conn.capabilities = capabilities;
|
||||
ctx.conn.last_seq_no = seq_no;
|
||||
ctx.ctx.last_seq_no = seq_no;
|
||||
|
||||
Ok(InitialHandshakePacket {
|
||||
length,
|
||||
|
||||
@ -31,7 +31,7 @@ impl Deserialize for ResultSet {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let eof_packet = if !ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
let eof_packet = if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
Some(EofPacket::deserialize(ctx)?)
|
||||
} else {
|
||||
None
|
||||
@ -62,7 +62,7 @@ impl Deserialize for ResultSet {
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.conn.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
OkPacket::deserialize(ctx)?;
|
||||
} else {
|
||||
EofPacket::deserialize(ctx)?;
|
||||
|
||||
@ -128,7 +128,7 @@ pub enum ParamFlag {
|
||||
|
||||
impl Default for Capabilities {
|
||||
fn default() -> Self {
|
||||
Capabilities::CLIENT_MYSQL
|
||||
Capabilities::CLIENT_PROTOCOL_41
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user