mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-30 05:11:13 +00:00
WIP: Update next_packet to work more efficiently and correctly
This commit is contained in:
parent
31188aa919
commit
b731dbe90f
@ -1,4 +1,4 @@
|
||||
#![feature(non_exhaustive, async_await)]
|
||||
#![feature(non_exhaustive, async_await, async_closure)]
|
||||
#![cfg_attr(test, feature(test))]
|
||||
#![allow(clippy::needless_lifetimes)]
|
||||
// FIXME: Remove this once API has matured
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use super::Connection;
|
||||
use crate::mariadb::{DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
|
||||
use crate::mariadb::{ErrPacket, OkPacket, DeContext, Deserialize, HandshakeResponsePacket, InitialHandshakePacket, Message, Capabilities, ComStmtExec, StmtExecFlag};
|
||||
use bytes::{BufMut, Bytes};
|
||||
use failure::{err_msg, Error};
|
||||
use crate::ConnectOptions;
|
||||
@ -9,8 +9,8 @@ pub async fn establish<'a, 'b: 'a>(
|
||||
conn: &'a mut Connection,
|
||||
options: ConnectOptions<'b>,
|
||||
) -> Result<(), Error> {
|
||||
let buf = &conn.stream.next_bytes().await?;
|
||||
let mut de_ctx = DeContext::new(&mut conn.context, &buf);
|
||||
let buf = conn.stream.next_packet().await?;
|
||||
let mut de_ctx = DeContext::new(&mut conn.context, buf);
|
||||
let initial = InitialHandshakePacket::deserialize(&mut de_ctx)?;
|
||||
|
||||
de_ctx.ctx.capabilities = de_ctx.ctx.capabilities.bitand(initial.capabilities);
|
||||
@ -26,167 +26,167 @@ pub async fn establish<'a, 'b: 'a>(
|
||||
|
||||
conn.send(handshake).await?;
|
||||
|
||||
match conn.next().await? {
|
||||
Some(Message::OkPacket(message)) => {
|
||||
conn.context.seq_no = message.seq_no;
|
||||
Ok(())
|
||||
}
|
||||
let mut ctx = DeContext::new(&mut conn.context, conn.stream.next_packet().await?);
|
||||
|
||||
Some(Message::ErrPacket(message)) => Err(err_msg(format!("{:?}", message))),
|
||||
|
||||
Some(message) => {
|
||||
panic!("Did not receive OkPacket nor ErrPacket. Received: {:?}", message);
|
||||
}
|
||||
|
||||
None => {
|
||||
panic!("Did not receive packet");
|
||||
if let Some(tag) = ctx.decoder.peek_tag() {
|
||||
match tag {
|
||||
0xFF => {
|
||||
return Err(ErrPacket::deserialize(&mut ctx)?.into());
|
||||
},
|
||||
0x00 => {
|
||||
OkPacket::deserialize(&mut ctx)?;
|
||||
},
|
||||
_ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one is expected"),
|
||||
}
|
||||
} else {
|
||||
failure::bail!("Did not receive an appropriately tagged packet when one is expected");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
// use super::*;
|
||||
// use failure::Error;
|
||||
// use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_can_connect() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// })
|
||||
// .await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_can_ping() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// }).await?;
|
||||
//
|
||||
// conn.ping().await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_can_select_db() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// }).await?;
|
||||
//
|
||||
// conn.select_db("test").await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_can_query() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// }).await?;
|
||||
//
|
||||
// conn.select_db("test").await?;
|
||||
//
|
||||
// conn.query("SELECT * FROM users").await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_can_prepare() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// }).await?;
|
||||
//
|
||||
// conn.select_db("test").await?;
|
||||
//
|
||||
// conn.prepare("SELECT * FROM users WHERE username = ?").await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_can_execute_prepared() -> Result<(), Error> {
|
||||
// let mut conn = Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("root"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// }).await?;
|
||||
//
|
||||
// conn.select_db("test").await?;
|
||||
//
|
||||
// let mut prepared = conn.prepare("SELECT * FROM users WHERE username = ?;").await?;
|
||||
//
|
||||
// println!("{:?}", prepared);
|
||||
//
|
||||
// if let Some(param_defs) = &mut prepared.param_defs {
|
||||
// param_defs[0].field_type = FieldType::MysqlTypeBlob;
|
||||
// }
|
||||
//
|
||||
// let exec = ComStmtExec {
|
||||
// stmt_id: prepared.ok.stmt_id,
|
||||
// flags: StmtExecFlag::CursorForUpdate,
|
||||
// params: Some(vec![Some(Bytes::from_static(b"'daniel'"))]),
|
||||
// param_defs: prepared.param_defs,
|
||||
// };
|
||||
//
|
||||
//
|
||||
// conn.send(exec).await?;
|
||||
//
|
||||
// let buf = conn.stream.next_bytes().await?;
|
||||
//
|
||||
// println!("{:?}", buf);
|
||||
//
|
||||
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?.rows);
|
||||
//
|
||||
// let fetch = ComStmtFetch {
|
||||
// stmt_id: prepared.ok.stmt_id,
|
||||
// rows: 1,
|
||||
// };
|
||||
//
|
||||
// conn.send(exec).await?;
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[runtime::test]
|
||||
// async fn it_does_not_connect() -> Result<(), Error> {
|
||||
// match Connection::establish(ConnectOptions {
|
||||
// host: "127.0.0.1",
|
||||
// port: 3306,
|
||||
// user: Some("roote"),
|
||||
// database: None,
|
||||
// password: None,
|
||||
// })
|
||||
// .await
|
||||
// {
|
||||
// Ok(_) => Err(err_msg("Bad username still worked?")),
|
||||
// Err(_) => Ok(()),
|
||||
// }
|
||||
// }
|
||||
use super::*;
|
||||
use failure::Error;
|
||||
use crate::mariadb::{ComStmtPrepareResp, FieldType, ResultSet, ComStmtFetch};
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_connect() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_ping() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.ping().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_select_db() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_query() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
conn.query("SELECT * FROM users").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_prepare() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
conn.prepare("SELECT * FROM users WHERE username = ?").await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_can_execute_prepared() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
|
||||
conn.select_db("test").await?;
|
||||
|
||||
let mut prepared = conn.prepare("SELECT username FROM users WHERE username = ?;").await?;
|
||||
|
||||
println!("{:?}", prepared);
|
||||
|
||||
if let Some(param_defs) = &mut prepared.param_defs {
|
||||
param_defs[0].field_type = FieldType::MysqlTypeBlob;
|
||||
}
|
||||
|
||||
let exec = ComStmtExec {
|
||||
stmt_id: -1,
|
||||
flags: StmtExecFlag::ReadOnly,
|
||||
params: Some(vec![Some(Bytes::from_static(b"daniel"))]),
|
||||
param_defs: prepared.param_defs,
|
||||
};
|
||||
|
||||
println!("{:?}", ResultSet::deserialize(DeContext::with_stream(&mut conn.context, &mut conn.stream)).await?);
|
||||
|
||||
let fetch = ComStmtFetch {
|
||||
stmt_id: -1,
|
||||
rows: 10,
|
||||
};
|
||||
|
||||
conn.send(fetch).await?;
|
||||
|
||||
let buf = conn.stream.next_packet().await?;
|
||||
|
||||
println!("{:?}", buf);
|
||||
|
||||
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[runtime::test]
|
||||
async fn it_does_not_connect() -> Result<(), Error> {
|
||||
match Connection::establish(ConnectOptions {
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("roote"),
|
||||
database: None,
|
||||
password: None,
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => Err(err_msg("Bad username still worked?")),
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,8 @@ use futures::{
|
||||
prelude::*,
|
||||
};
|
||||
use runtime::net::TcpStream;
|
||||
use crate::{ConnectOptions, mariadb::protocol::{DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
|
||||
use core::convert::TryFrom;
|
||||
use crate::{ConnectOptions, mariadb::{protocol::encode, PacketHeader, Decoder, DeContext, Deserialize, Encoder, ComInitDb, ComPing, ComQuery, ComQuit, OkPacket, Serialize, Message, Capabilities, ServerStatusFlag, ComStmtPrepare, ComStmtPrepareResp, ResultSet, ErrPacket}};
|
||||
|
||||
mod establish;
|
||||
|
||||
@ -103,8 +104,8 @@ impl Connection {
|
||||
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<Option<ResultSet>, Error> {
|
||||
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
|
||||
|
||||
let buf = self.stream.next_bytes().await?;
|
||||
let mut ctx = DeContext::new(&mut self.context, &buf);
|
||||
let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
|
||||
ctx.next_packet().await?;
|
||||
if let Some(tag) = ctx.decoder.peek_tag() {
|
||||
match tag {
|
||||
0xFF => Err(ErrPacket::deserialize(&mut ctx)?.into()),
|
||||
@ -113,7 +114,9 @@ impl Connection {
|
||||
Ok(None)
|
||||
},
|
||||
0xFB => unimplemented!(),
|
||||
_ => Ok(Some(ResultSet::deserialize(&mut ctx)?))
|
||||
_ => {
|
||||
Ok(Some(ResultSet::deserialize(ctx).await?))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Tag not found in result packet");
|
||||
@ -125,17 +128,19 @@ impl Connection {
|
||||
self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?;
|
||||
|
||||
|
||||
match self.next().await? {
|
||||
Some(Message::OkPacket(_)) => {},
|
||||
Some(message @ Message::ErrPacket(_)) => {
|
||||
failure::bail!("Received an ErrPacket packet: {:?}", message);
|
||||
},
|
||||
Some(message) => {
|
||||
failure::bail!("Received an unexpected packet type: {:?}", message);
|
||||
}
|
||||
None => {
|
||||
failure::bail!("Did not receive a packet when one was expected");
|
||||
let mut ctx = DeContext::new(&mut self.context, self.stream.next_packet().await?);
|
||||
if let Some(tag) = ctx.decoder.peek_tag() {
|
||||
match tag {
|
||||
0xFF => {
|
||||
ErrPacket::deserialize(&mut ctx)?;
|
||||
},
|
||||
0x00 => {
|
||||
OkPacket::deserialize(&mut ctx)?;
|
||||
},
|
||||
_ => failure::bail!("Did not receive an ErrPacket nor OkPacket when one was expected"),
|
||||
}
|
||||
} else {
|
||||
failure::bail!("No tag found");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@ -145,8 +150,7 @@ impl Connection {
|
||||
self.send(ComPing()).await?;
|
||||
|
||||
// Ping response must be an OkPacket
|
||||
let buf = self.stream.next_bytes().await?;
|
||||
OkPacket::deserialize(&mut DeContext::new(&mut self.context, &buf))?;
|
||||
OkPacket::deserialize(&mut DeContext::new(&mut self.context, self.stream.next_packet().await?))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -156,110 +160,95 @@ impl Connection {
|
||||
statement: Bytes::from(query),
|
||||
}).await?;
|
||||
|
||||
let buf = self.stream.next_bytes().await?;
|
||||
|
||||
ComStmtPrepareResp::deserialize(&mut DeContext::new(&mut self.context, &buf))
|
||||
}
|
||||
|
||||
pub async fn next(&mut self) -> Result<Option<Message>, Error> {
|
||||
let mut rbuf = BytesMut::new();
|
||||
let mut len = 0;
|
||||
|
||||
loop {
|
||||
if len == rbuf.len() {
|
||||
rbuf.reserve(32);
|
||||
|
||||
unsafe {
|
||||
// Set length to the capacity and efficiently
|
||||
// zero-out the memory
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
self.stream.inner.initializer().initialize(&mut rbuf[len..]);
|
||||
}
|
||||
}
|
||||
|
||||
let bytes_read = self.stream.inner.read(&mut rbuf[len..]).await?;
|
||||
|
||||
if bytes_read > 0 {
|
||||
len += bytes_read;
|
||||
} else {
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
break;
|
||||
}
|
||||
|
||||
while len > 0 {
|
||||
let size = rbuf.len();
|
||||
let message = Message::deserialize(&mut DeContext::new(
|
||||
&mut self.context,
|
||||
&rbuf.as_ref().into(),
|
||||
))?;
|
||||
len -= size - rbuf.len();
|
||||
|
||||
match message {
|
||||
message @ Some(_) => return Ok(message),
|
||||
// Did not receive enough bytes to
|
||||
// deserialize a complete message
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
// let mut ctx = DeContext::with_stream(&mut self.context, &mut self.stream);
|
||||
// ctx.next_packet().await?;
|
||||
// ComStmtPrepareResp::deserialize(&mut ctx)
|
||||
Ok(ComStmtPrepareResp::default())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Framed {
|
||||
inner: TcpStream,
|
||||
readable: bool,
|
||||
eof: bool,
|
||||
buffer: BytesMut,
|
||||
buf: BytesMut,
|
||||
}
|
||||
|
||||
impl Framed {
|
||||
fn new(stream: TcpStream) -> Self {
|
||||
Self {
|
||||
readable: false,
|
||||
eof: false,
|
||||
inner: stream,
|
||||
buffer: BytesMut::with_capacity(8 * 1024),
|
||||
buf: BytesMut::with_capacity(8 * 1024),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next_bytes(&mut self) -> Result<Bytes, Error> {
|
||||
pub async fn next_packet(&mut self) -> Result<Bytes, Error> {
|
||||
let mut rbuf = BytesMut::new();
|
||||
let mut len = 0;
|
||||
let mut packet_len: u32 = 0;
|
||||
let mut len = 0usize;
|
||||
let mut packet_headers: Vec<PacketHeader> = Vec::new();
|
||||
|
||||
loop {
|
||||
if len == rbuf.len() {
|
||||
rbuf.reserve(20000);
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.combined_length() > rbuf.len() {
|
||||
let reserve = packet_header.combined_length() - rbuf.len();
|
||||
rbuf.reserve(reserve);
|
||||
|
||||
unsafe {
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
self.inner.initializer().initialize(&mut rbuf[len..]);
|
||||
}
|
||||
}
|
||||
} else if rbuf.len() == len {
|
||||
rbuf.reserve(32);
|
||||
|
||||
unsafe {
|
||||
// Set length to the capacity and efficiently
|
||||
// zero-out the memory
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
self.inner.initializer().initialize(&mut rbuf[len..]);
|
||||
}
|
||||
}
|
||||
|
||||
let bytes_read = self.inner.read(&mut rbuf[len..]).await?;
|
||||
// If we have a packet_header and the amount of currently read bytes (len) is less than
|
||||
// the specified length inside packet_header, then we can continue reading to rbuf; but
|
||||
// only up until packet_header.length.
|
||||
// Else if the total number of bytes read is equal to packet_header then we will
|
||||
// return rbuf as it should contain the entire packet.
|
||||
// Else we read too many bytes -- which shouldn't happen -- and will return an error.
|
||||
let bytes_read;
|
||||
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.combined_length() > len {
|
||||
bytes_read = self.inner.read(&mut rbuf[len..packet_header.combined_length()]).await?;
|
||||
} else {
|
||||
return Ok(rbuf.freeze());
|
||||
}
|
||||
} else {
|
||||
// Only read header to make sure that we dont' read the next packets buffer.
|
||||
bytes_read = self.inner.read(&mut rbuf[len..len + 4]).await?;
|
||||
}
|
||||
|
||||
if bytes_read > 0 {
|
||||
len += bytes_read;
|
||||
// If we have read less than 4 bytes, and we don't already have a packet_header
|
||||
// we must try to read again. The packet_header is always present and is 4 bytes long.
|
||||
if bytes_read < 4 && packet_headers.len() == 0 {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
return Ok(Bytes::new());
|
||||
}
|
||||
|
||||
if len > 0 && packet_len == 0 {
|
||||
packet_len = LittleEndian::read_u24(&rbuf[0..]);
|
||||
}
|
||||
|
||||
// Loop until the length of the buffer is the length of the packet
|
||||
if packet_len as usize > len {
|
||||
continue;
|
||||
} else {
|
||||
return Ok(rbuf.freeze());
|
||||
}
|
||||
|
||||
// If we don't have a packet header or the last packet header had a length of
|
||||
// 0xFF_FF_FF (the max possible length); then we must continue receiving packets
|
||||
// because the entire message hasn't been received.
|
||||
// After this operation we know that packet_headers.last() *SHOULD* always return valid data,
|
||||
// so the the use of packet_headers.last().unwrap() is allowed.
|
||||
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
|
||||
if let Some(packet_header) = packet_headers.last() {
|
||||
if packet_header.length as usize == encode::U24_MAX {
|
||||
packet_headers.push(PacketHeader::try_from(&rbuf[0..])?);
|
||||
}
|
||||
} else {
|
||||
packet_headers.push(PacketHeader::try_from(&rbuf[0..])?);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ pub mod protocol;
|
||||
// Re-export all the things
|
||||
pub use connection::ConnContext;
|
||||
pub use connection::Connection;
|
||||
pub use connection::Framed;
|
||||
pub use protocol::AuthenticationSwitchRequestPacket;
|
||||
pub use protocol::ColumnPacket;
|
||||
pub use protocol::ColumnDefPacket;
|
||||
|
||||
@ -58,14 +58,14 @@ use super::packets::packet_header::PacketHeader;
|
||||
// - Byte 7 is the minutes (0 if DATE type) (0-59)
|
||||
// - Byte 8 is the seconds (0 if DATE type) (0-59)
|
||||
// - Bytes 10-13 are the micro-seconds on 4 bytes little-endian format (only if data-length is > 7)
|
||||
pub struct Decoder<'a> {
|
||||
pub buf: &'a Bytes,
|
||||
pub struct Decoder {
|
||||
pub buf: Bytes,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
impl<'a> Decoder<'a> {
|
||||
impl Decoder {
|
||||
// Create a new Decoder from an existing Bytes
|
||||
pub fn new(buf: &'a Bytes) -> Self {
|
||||
pub fn new(buf: Bytes) -> Self {
|
||||
Decoder { buf, index: 0 }
|
||||
}
|
||||
|
||||
@ -441,7 +441,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_lenenc_0x_fb() {
|
||||
let buf = __bytes_builder!(0xFB_u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int = decoder.decode_int_lenenc_unsigned();
|
||||
|
||||
assert_eq!(int, None);
|
||||
@ -451,7 +451,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_lenenc_0x_fc() {
|
||||
let buf =__bytes_builder!(0xFCu8, 1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int = decoder.decode_int_lenenc_unsigned();
|
||||
|
||||
assert_eq!(int, Some(0x0101));
|
||||
@ -461,7 +461,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_lenenc_0x_fd() {
|
||||
let buf = __bytes_builder!(0xFDu8, 1u8, 1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int = decoder.decode_int_lenenc_unsigned();
|
||||
|
||||
assert_eq!(int, Some(0x010101));
|
||||
@ -471,7 +471,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_lenenc_0x_fe() {
|
||||
let buf = __bytes_builder!(0xFE_u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int = decoder.decode_int_lenenc_unsigned();
|
||||
|
||||
assert_eq!(int, Some(0x0101010101010101));
|
||||
@ -481,7 +481,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_lenenc_0x_fa() {
|
||||
let buf = __bytes_builder!(0xFA_u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int = decoder.decode_int_lenenc_unsigned();
|
||||
|
||||
assert_eq!(int, Some(0xFA));
|
||||
@ -491,7 +491,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_8() {
|
||||
let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int: i64 = decoder.decode_int_i64();
|
||||
|
||||
assert_eq!(int, 0x0101010101010101);
|
||||
@ -501,7 +501,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_4() {
|
||||
let buf = __bytes_builder!(1u8, 1u8, 1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int: i32 = decoder.decode_int_i32();
|
||||
|
||||
assert_eq!(int, 0x01010101);
|
||||
@ -511,7 +511,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_3() {
|
||||
let buf = __bytes_builder!(1u8, 1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int: i32 = decoder.decode_int_i24();
|
||||
|
||||
assert_eq!(int, 0x010101);
|
||||
@ -521,7 +521,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_2() {
|
||||
let buf = __bytes_builder!(1u8, 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int: i16 = decoder.decode_int_i16();
|
||||
|
||||
assert_eq!(int, 0x0101);
|
||||
@ -531,7 +531,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_int_1() {
|
||||
let buf = __bytes_builder!(1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let int: u8 = decoder.decode_int_u8();
|
||||
|
||||
assert_eq!(int, 1u8);
|
||||
@ -541,7 +541,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_string_lenenc() {
|
||||
let buf = __bytes_builder!(3u8, b"sup");
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let string: Bytes = decoder.decode_string_lenenc();
|
||||
|
||||
assert_eq!(string[..], b"sup"[..]);
|
||||
@ -552,7 +552,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_string_fix() {
|
||||
let buf = __bytes_builder!(b"a");
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let string: Bytes = decoder.decode_string_fix(1);
|
||||
|
||||
assert_eq!(&string[..], b"a");
|
||||
@ -563,7 +563,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_string_eof() {
|
||||
let buf = __bytes_builder!(b"a");
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let string: Bytes = decoder.decode_string_eof(None);
|
||||
|
||||
assert_eq!(&string[..], b"a");
|
||||
@ -574,7 +574,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_string_null() -> Result<(), Error> {
|
||||
let buf = __bytes_builder!(b"random\0", 1u8);
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let string: Bytes = decoder.decode_string_null()?;
|
||||
|
||||
assert_eq!(&string[..], b"random");
|
||||
@ -589,7 +589,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_byte_fix() {
|
||||
let buf = __bytes_builder!(b"a");
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let string: Bytes = decoder.decode_byte_fix(1);
|
||||
|
||||
assert_eq!(&string[..], b"a");
|
||||
@ -600,7 +600,7 @@ mod tests {
|
||||
#[test]
|
||||
fn it_decodes_byte_eof() {
|
||||
let buf = __bytes_builder!(b"a");
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
let string: Bytes = decoder.decode_byte_eof(None);
|
||||
|
||||
assert_eq!(&string[..], b"a");
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use crate::mariadb::{Decoder, ConnContext, Connection, ColumnDefPacket};
|
||||
use crate::mariadb::{Framed, Decoder, ConnContext, Connection, ColumnDefPacket};
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
|
||||
@ -8,14 +8,36 @@ use failure::Error;
|
||||
// Mainly used to simply to simplify number of parameters for deserializing functions
|
||||
pub struct DeContext<'a> {
|
||||
pub ctx: &'a mut ConnContext,
|
||||
pub decoder: Decoder<'a>,
|
||||
pub stream: Option<&'a mut Framed>,
|
||||
pub decoder: Decoder,
|
||||
pub columns: Option<u64>,
|
||||
pub column_defs: Option<Vec<ColumnDefPacket>>,
|
||||
}
|
||||
|
||||
impl<'a> DeContext<'a> {
|
||||
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
|
||||
DeContext { ctx: conn, decoder: Decoder::new(&buf), columns: None , column_defs: None }
|
||||
pub fn new(conn: &'a mut ConnContext, buf: Bytes) -> Self {
|
||||
DeContext { ctx: conn, stream: None, decoder: Decoder::new(buf), columns: None , column_defs: None }
|
||||
}
|
||||
|
||||
pub fn with_stream(conn: &'a mut ConnContext, stream: &'a mut Framed) -> Self {
|
||||
DeContext {
|
||||
ctx: conn,
|
||||
stream: Some(stream),
|
||||
decoder: Decoder::new(Bytes::new()),
|
||||
columns: None ,
|
||||
column_defs: None
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next_packet(&mut self) -> Result<(), failure::Error> {
|
||||
if let Some(stream) = &mut self.stream {
|
||||
println!("Called next packet");
|
||||
self.decoder = Decoder::new(stream.next_packet().await?);
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
failure::bail!("Calling next_packet on DeContext with no stream provided")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ use byteorder::{ByteOrder, LittleEndian};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use crate::mariadb::FieldType;
|
||||
|
||||
const U24_MAX: usize = 0xFF_FF_FF;
|
||||
pub const U24_MAX: usize = 0xFF_FF_FF;
|
||||
|
||||
// A simple wrapper around a BytesMut to easily encode values
|
||||
pub struct Encoder {
|
||||
@ -253,17 +253,17 @@ impl Encoder {
|
||||
FieldType::MysqlTypeTimestamp2 => unimplemented!(),
|
||||
FieldType::MysqlTypeDatetime2 => unimplemented!(),
|
||||
FieldType::MysqlTypeTime2 =>unimplemented!(),
|
||||
FieldType::MysqlTypeJson => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeNewdecimal => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeEnum => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeSet => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeTinyBlob => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeMediumBlob => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeLongBlob => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeBlob => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeVarString => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeString => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeGeometry => self.encode_string_lenenc(bytes),
|
||||
FieldType::MysqlTypeJson => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeNewdecimal => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeEnum => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeSet => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeTinyBlob => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeMediumBlob => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeLongBlob => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeBlob => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeVarString => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeString => self.encode_byte_lenenc(bytes),
|
||||
FieldType::MysqlTypeGeometry => self.encode_byte_lenenc(bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,56 +17,48 @@ impl crate::mariadb::Serialize for ComStmtExec {
|
||||
encoder.encode_int_u8(super::BinaryProtocol::ComStmtExec.into());
|
||||
encoder.encode_int_i32(self.stmt_id);
|
||||
encoder.encode_int_u8(self.flags as u8);
|
||||
encoder.encode_int_u8(0);
|
||||
encoder.encode_int_u8(1);
|
||||
|
||||
if let Some(params) = &self.params {
|
||||
if let Some(param_defs) = &self.param_defs {
|
||||
if params.len() != param_defs.len() {
|
||||
failure::bail!("Unequal number of params and param definitions supplied");
|
||||
}
|
||||
}
|
||||
match (&self.params, &self.param_defs) {
|
||||
(Some(params), Some(param_defs)) if params.len() > 0 => {
|
||||
let null_bitmap_size = (params.len() + 7) / 8;
|
||||
let mut shift_amount = 0u8;
|
||||
let mut bitmap = vec![0u8];
|
||||
let send_type = 1u8;
|
||||
|
||||
let null_bitmap_size = (params.len() + 7) / 8;
|
||||
let mut shift_amount = 0u8;
|
||||
let mut bitmap = vec![0u8];
|
||||
// Generate NULL-bitmap from params
|
||||
for param in params {
|
||||
if param.is_some() {
|
||||
let last_byte = bitmap.pop().unwrap();
|
||||
bitmap.push(last_byte & (1 << shift_amount));
|
||||
}
|
||||
|
||||
// Generate NULL-bitmap from params
|
||||
for param in params {
|
||||
if param.is_none() {
|
||||
bitmap.push(bitmap.last().unwrap() & (1 << shift_amount));
|
||||
shift_amount = (shift_amount + 1) % 8;
|
||||
|
||||
if shift_amount % 8 == 0 {
|
||||
bitmap.push(0u8);
|
||||
}
|
||||
}
|
||||
|
||||
shift_amount = (shift_amount + 1) % 8;
|
||||
encoder.encode_byte_fix(&Bytes::from(bitmap), null_bitmap_size);
|
||||
encoder.encode_int_u8(send_type);
|
||||
|
||||
if shift_amount % 8 == 0 {
|
||||
bitmap.push(0u8);
|
||||
}
|
||||
}
|
||||
|
||||
// Do not send the param types
|
||||
encoder.encode_int_u8(if self.param_defs.is_some() {
|
||||
1u8
|
||||
} else {
|
||||
0u8
|
||||
});
|
||||
|
||||
if let Some(params_defs) = &self.param_defs {
|
||||
for param in params_defs {
|
||||
encoder.encode_int_u8(param.field_type as u8);
|
||||
encoder.encode_int_u8(if (param.field_details & FieldDetailFlag::UNSIGNED).is_empty() {
|
||||
1u8
|
||||
} else {
|
||||
0u8
|
||||
});
|
||||
if send_type > 0 {
|
||||
for param in param_defs {
|
||||
encoder.encode_int_u8(param.field_type as u8);
|
||||
encoder.encode_int_u8(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Encode params
|
||||
for index in 0..params.len() {
|
||||
if let Some(bytes) = ¶ms[index] {
|
||||
encoder.encode_param(&bytes, ¶ms_defs[index].field_type);
|
||||
encoder.encode_param(&bytes, ¶m_defs[index].field_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => {},
|
||||
|
||||
}
|
||||
|
||||
encoder.encode_length();
|
||||
|
||||
@ -64,7 +64,7 @@ mod tests {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ComStmtPrepareOk::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use crate::mariadb::{ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket};
|
||||
use crate::mariadb::{DeContext, Deserialize, ComStmtPrepareOk, ColumnDefPacket, Capabilities, EofPacket};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ComStmtPrepareResp {
|
||||
@ -7,18 +7,22 @@ pub struct ComStmtPrepareResp {
|
||||
pub res_columns: Option<Vec<ColumnDefPacket>>,
|
||||
}
|
||||
|
||||
impl crate::mariadb::Deserialize for ComStmtPrepareResp {
|
||||
fn deserialize(ctx: &mut crate::mariadb::DeContext) -> Result<Self, failure::Error> {
|
||||
let ok = ComStmtPrepareOk::deserialize(ctx)?;
|
||||
impl ComStmtPrepareResp {
|
||||
pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result<Self, failure::Error> {
|
||||
let ok = ComStmtPrepareOk::deserialize(&mut ctx)?;
|
||||
|
||||
let param_defs = if ok.params > 0 {
|
||||
let param_defs = (0..ok.params).map(|_| ColumnDefPacket::deserialize(ctx))
|
||||
.filter(Result::is_ok)
|
||||
.map(Result::unwrap)
|
||||
.collect::<Vec<ColumnDefPacket>>();
|
||||
let mut param_defs = Vec::new();
|
||||
|
||||
for _ in 0..ok.params {
|
||||
ctx.next_packet().await?;
|
||||
param_defs.push(ColumnDefPacket::deserialize(&mut ctx)?);
|
||||
}
|
||||
|
||||
ctx.next_packet().await?;
|
||||
|
||||
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
EofPacket::deserialize(ctx)?;
|
||||
EofPacket::deserialize(&mut ctx)?;
|
||||
}
|
||||
|
||||
Some(param_defs)
|
||||
@ -27,16 +31,20 @@ impl crate::mariadb::Deserialize for ComStmtPrepareResp {
|
||||
};
|
||||
|
||||
let res_columns = if ok.columns > 0 {
|
||||
let param_defs = (0..ok.columns).map(|_| ColumnDefPacket::deserialize(ctx))
|
||||
.filter(Result::is_ok)
|
||||
.map(Result::unwrap)
|
||||
.collect::<Vec<ColumnDefPacket>>();
|
||||
let mut res_columns = Vec::new();
|
||||
|
||||
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
EofPacket::deserialize(ctx)?;
|
||||
for _ in 0..ok.columns {
|
||||
ctx.next_packet().await?;
|
||||
res_columns.push(ColumnDefPacket::deserialize(&mut ctx)?);
|
||||
}
|
||||
|
||||
Some(param_defs)
|
||||
ctx.next_packet().await?;
|
||||
|
||||
if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
EofPacket::deserialize(&mut ctx)?;
|
||||
}
|
||||
|
||||
Some(res_columns)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -54,8 +62,8 @@ mod test {
|
||||
use super::*;
|
||||
use crate::{__bytes_builder, ConnectOptions, mariadb::{ConnContext, DeContext, Deserialize}};
|
||||
|
||||
#[test]
|
||||
fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> {
|
||||
#[runtime::test]
|
||||
async fn it_decodes_com_stmt_prepare_resp_eof() -> Result<(), failure::Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// ---------------------------- //
|
||||
@ -153,15 +161,15 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::with_eof();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ComStmtPrepareResp::deserialize(&mut ctx)?;
|
||||
let message = ComStmtPrepareResp::deserialize(ctx).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
|
||||
#[runtime::test]
|
||||
async fn it_decodes_com_stmt_prepare_resp() -> Result<(), failure::Error> {
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
// ---------------------------- //
|
||||
@ -287,9 +295,9 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ComStmtPrepareResp::deserialize(&mut ctx)?;
|
||||
let message = ComStmtPrepareResp::deserialize(ctx).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ColumnPacket::deserialize(&mut ctx)?;
|
||||
|
||||
@ -70,7 +70,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ColumnPacket::deserialize(&mut ctx)?;
|
||||
|
||||
@ -94,7 +94,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ColumnPacket::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -115,7 +115,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = ColumnDefPacket::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ mod test {
|
||||
let buf = Bytes::from_static(b"\x01\0\0\x01\xFE\x00\x00\x01\x00");
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let _message = EofPacket::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -124,7 +124,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let _message = ErrPacket::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -153,7 +153,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let _message = InitialHandshakePacket::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
let message = OkPacket::deserialize(&mut ctx)?;
|
||||
|
||||
|
||||
@ -1,5 +1,33 @@
|
||||
#[derive(Debug)]
|
||||
use byteorder::LittleEndian;
|
||||
use byteorder::ByteOrder;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PacketHeader {
|
||||
pub length: u32,
|
||||
pub seq_no: u8,
|
||||
}
|
||||
|
||||
impl PacketHeader {
|
||||
pub fn size() -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
pub fn combined_length(&self) -> usize {
|
||||
PacketHeader::size() + self.length as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl core::convert::TryFrom<&[u8]> for PacketHeader {
|
||||
type Error = failure::Error;
|
||||
|
||||
fn try_from(buffer: &[u8]) -> Result<Self, Self::Error> {
|
||||
if buffer.len() < 4 {
|
||||
failure::bail!("Buffer length is too short")
|
||||
} else {
|
||||
Ok(PacketHeader {
|
||||
length: LittleEndian::read_u24(&buffer),
|
||||
seq_no: buffer[3],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
ctx.columns = Some(1);
|
||||
|
||||
|
||||
@ -1,14 +1,7 @@
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
|
||||
use crate::mariadb::Decoder;
|
||||
use crate::mariadb::Message;
|
||||
use crate::mariadb::Capabilities;
|
||||
|
||||
use super::super::{
|
||||
deserialize::{DeContext, Deserialize},
|
||||
packets::{column::ColumnPacket, column_def::ColumnDefPacket, eof::EofPacket, err::ErrPacket, ok::OkPacket, result_row::ResultRow},
|
||||
};
|
||||
use crate::mariadb::{Deserialize, ConnContext, Framed, Decoder, Message, Capabilities, DeContext, ColumnPacket, ColumnDefPacket, EofPacket, ErrPacket, OkPacket, ResultRow};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ResultSet {
|
||||
@ -17,22 +10,28 @@ pub struct ResultSet {
|
||||
pub rows: Vec<ResultRow>,
|
||||
}
|
||||
|
||||
impl Deserialize for ResultSet {
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let column_packet = ColumnPacket::deserialize(ctx)?;
|
||||
impl ResultSet {
|
||||
pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result<Self, Error> {
|
||||
let column_packet = ColumnPacket::deserialize(&mut ctx)?;
|
||||
|
||||
let columns = if let Some(columns) = column_packet.columns {
|
||||
(0..columns)
|
||||
.map(|_| ColumnDefPacket::deserialize(ctx))
|
||||
.filter(Result::is_ok)
|
||||
.map(Result::unwrap)
|
||||
.collect::<Vec<ColumnDefPacket>>()
|
||||
let mut column_defs = Vec::new();
|
||||
for _ in 0..columns {
|
||||
ctx.next_packet().await?;
|
||||
column_defs.push(ColumnDefPacket::deserialize(&mut ctx)?);
|
||||
}
|
||||
column_defs
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
ctx.next_packet().await?;
|
||||
|
||||
let eof_packet = if !ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
Some(EofPacket::deserialize(ctx)?)
|
||||
// If we get an eof packet we must update ctx to hold a new buffer of the next packet.
|
||||
let eof_packet = Some(EofPacket::deserialize(&mut ctx)?);
|
||||
ctx.next_packet().await?;
|
||||
eof_packet
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -48,12 +47,15 @@ impl Deserialize for ResultSet {
|
||||
};
|
||||
|
||||
let tag = ctx.decoder.peek_tag();
|
||||
if tag == Some(&0xFE) && packet_header.length <= 0xFFFFFF {
|
||||
if tag == Some(&0xFE) && packet_header.length <= 0xFFFFFF || packet_header.length == 0 {
|
||||
break;
|
||||
} else {
|
||||
let index = ctx.decoder.index;
|
||||
match ResultRow::deserialize(ctx) {
|
||||
Ok(v) => rows.push(v),
|
||||
match ResultRow::deserialize(&mut ctx) {
|
||||
Ok(v) => {
|
||||
rows.push(v);
|
||||
ctx.next_packet().await?;
|
||||
},
|
||||
Err(_) => {
|
||||
ctx.decoder.index = index;
|
||||
break;
|
||||
@ -62,10 +64,12 @@ impl Deserialize for ResultSet {
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
OkPacket::deserialize(ctx)?;
|
||||
} else {
|
||||
EofPacket::deserialize(ctx)?;
|
||||
if ctx.decoder.peek_packet_header()?.length > 0 {
|
||||
if ctx.ctx.capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
|
||||
OkPacket::deserialize(&mut ctx)?;
|
||||
} else {
|
||||
EofPacket::deserialize(&mut ctx)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ResultSet {
|
||||
@ -83,8 +87,8 @@ mod test {
|
||||
use crate::{__bytes_builder, mariadb::{Connection, EofPacket, ErrPacket, OkPacket, ResultRow, ServerStatusFlag, Capabilities, ConnContext}};
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_decodes_result_set_packet() -> Result<(), Error> {
|
||||
#[runtime::test]
|
||||
async fn it_decodes_result_set_packet() -> Result<(), Error> {
|
||||
// TODO: Use byte string as input for test; this is a valid return from a mariadb.
|
||||
#[rustfmt::skip]
|
||||
let buf = __bytes_builder!(
|
||||
@ -296,9 +300,9 @@ mod test {
|
||||
);
|
||||
|
||||
let mut context = ConnContext::new();
|
||||
let mut ctx = DeContext::new(&mut context, &buf);
|
||||
let mut ctx = DeContext::new(&mut context, buf);
|
||||
|
||||
ResultSet::deserialize(&mut ctx)?;
|
||||
ResultSet::deserialize(ctx).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -170,7 +170,7 @@ mod test {
|
||||
#[test]
|
||||
fn it_decodes_capabilities() {
|
||||
let buf = Bytes::from(b"\xfe\xf7".to_vec());
|
||||
let mut decoder = Decoder::new(&buf);
|
||||
let mut decoder = Decoder::new(buf);
|
||||
Capabilities::from_bits_truncate(decoder.decode_int_u16().into());
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user