Fix prepared statement execution

This commit is contained in:
Daniel Akhterov 2019-08-06 17:40:30 -07:00
parent 433ec628da
commit c019f91fc6
14 changed files with 158 additions and 120 deletions

View File

@ -2,7 +2,7 @@ use super::Connection;
use crate::{
mariadb::{
Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket,
HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag,
HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag, ProtocolType
},
ConnectOptions,
};
@ -109,10 +109,8 @@ mod test {
})
.await?;
println!("selecting db");
conn.select_db("test").await?;
println!("querying");
conn.query("SELECT * FROM users").await?;
Ok(())
@ -154,17 +152,10 @@ mod test {
.prepare("SELECT id FROM users WHERE username=?")
.await?;
println!("{:?}", prepared);
// if let Some(param_defs) = &mut prepared.param_defs {
// param_defs[0].field_type = FieldType::MysqlTypeBlob;
// }
let exec = ComStmtExec {
stmt_id: prepared.ok.stmt_id,
flags: StmtExecFlag::NoCursor,
// params: None,
params: Some(vec![Some(Bytes::from_static(b"daniel"))]),
params: Some(vec![Some(Bytes::from_static(b"josh"))]),
param_defs: prepared.param_defs,
};
@ -172,26 +163,21 @@ mod test {
let mut ctx = DeContext::with_stream(&mut conn.context, &mut conn.stream);
ctx.next_packet().await?;
ctx.columns = Some(prepared.ok.columns as u64);
ctx.column_defs = prepared.res_columns;
match ctx.decoder.peek_tag() {
0xFF => println!("{:?}", ErrPacket::decode(&mut ctx)?),
0x00 => println!("{:?}", OkPacket::decode(&mut ctx)?),
_ => println!("{:?}", ResultSet::deserialize(ctx).await?),
0xFF => {
ErrPacket::decode(&mut ctx)?;
},
0x00 => {
OkPacket::decode(&mut ctx)?;
},
_ => {
ResultSet::deserialize(ctx, ProtocolType::Binary).await?;
},
}
// let fetch = ComStmtFetch {
// stmt_id: -1,
// rows: 10,
// };
//
// conn.send(fetch).await?;
//
// let buf = conn.stream.next_packet().await?;
//
// println!("{:?}", buf);
// println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?);
Ok(())
}

View File

@ -2,7 +2,7 @@ use crate::{
mariadb::{
protocol::encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare,
ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader,
ResultSet, ServerStatusFlag,
ResultSet, ServerStatusFlag, ProtocolType
},
ConnectOptions,
};
@ -130,7 +130,7 @@ impl Connection {
Ok(None)
}
0xFB => unimplemented!(),
_ => Ok(Some(ResultSet::deserialize(ctx).await?)),
_ => Ok(Some(ResultSet::deserialize(ctx, ProtocolType::Text).await?)),
}
}
@ -203,7 +203,6 @@ impl Framed {
// After this operation we know that packet_headers.last() *SHOULD* always return valid data,
// so the the use of packet_headers.last().unwrap() is allowed.
// TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions.
println!("{:?}", self.buf);
if let Some(packet_header) = packet_headers.last() {
if packet_header.length as usize == encode::U24_MAX {
packet_headers.push(PacketHeader::try_from(&self.buf[self.index..])?);

View File

@ -9,6 +9,6 @@ pub use protocol::{
ComSetOption, ComShutdown, ComSleep, ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch,
ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp, DeContext, Decode, Decoder, Encode,
EofPacket, ErrPacket, ErrorCode, FieldDetailFlag, FieldType, HandshakeResponsePacket,
InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultSet, SSLRequestPacket,
ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag,
InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultRow, ResultSet, SSLRequestPacket,
ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag, ProtocolType
};

View File

@ -316,9 +316,7 @@ impl BufMut for Vec<u8> {
// Same as the string counterpart copied to maintain consistency with the spec.
#[inline]
fn put_byte_fix(&mut self, bytes: &Bytes, size: usize) {
if size != bytes.len() {
panic!("Sizes do not match");
}
assert_eq!(size, bytes.len());
self.extend_from_slice(bytes);
}

View File

@ -19,8 +19,8 @@ pub use packets::{
ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep,
ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk,
ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket,
InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultSet, SSLRequestPacket,
SetOptionOptions, ShutdownOptions,
InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultSet, SSLRequestPacket,
SetOptionOptions, ShutdownOptions, ResultRow
};
pub use decode::{DeContext, Decode, Decoder};
@ -30,5 +30,5 @@ pub use encode::{BufMut, Encode};
pub use error_codes::ErrorCode;
pub use types::{
Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag,
ProtocolType, Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag,
};

View File

@ -1,5 +1,5 @@
use crate::mariadb::{
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag,
BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag, FieldType
};
use bytes::Bytes;
use failure::Error;
@ -26,12 +26,12 @@ impl Encode for ComStmtExec {
(Some(params), Some(param_defs)) if params.len() > 0 => {
let null_bitmap_size = (params.len() + 7) / 8;
let mut shift_amount = 0u8;
let mut bitmap = vec![1u8];
let send_type = 0u8;
let mut bitmap = vec![0u8];
let send_type = 1u8;
// Generate NULL-bitmap from params
for param in params {
if param.is_some() {
if param.is_none() {
let last_byte = bitmap.pop().unwrap();
bitmap.push(last_byte & (1 << shift_amount));
}
@ -48,7 +48,7 @@ impl Encode for ComStmtExec {
if send_type > 0 {
for param in param_defs {
// buf.put_int_u8(param.field_type as u8);
buf.put_int_u8(param.field_type as u8);
buf.put_int_u8(0);
}
}

View File

@ -14,6 +14,7 @@ pub use com_stmt_prepare::ComStmtPrepare;
pub use com_stmt_prepare_ok::ComStmtPrepareOk;
pub use com_stmt_prepare_resp::ComStmtPrepareResp;
pub use com_stmt_reset::ComStmtReset;
pub use result_row::ResultRow;
pub enum BinaryProtocol {
ComStmtPrepare = 0x16,

View File

@ -3,6 +3,8 @@ use bytes::Bytes;
#[derive(Debug, Default)]
pub struct ResultRow {
pub length: u32,
pub seq_no: u8,
pub columns: Vec<Option<Bytes>>,
}
@ -24,11 +26,11 @@ impl crate::mariadb::Decode for ResultRow {
))
}?;
let row = match (&ctx.columns, &ctx.column_defs) {
let columns = match (&ctx.columns, &ctx.column_defs) {
(Some(columns), Some(column_defs)) => {
(0..*columns as usize)
.map(|index| {
if (1 << (index % 8)) & bitmap[index / 8] as usize == 0 {
if (1 << (index % 8)) & bitmap[index / 8] as usize == 1 {
None
} else {
match column_defs[index].field_type {
@ -105,6 +107,10 @@ impl crate::mariadb::Decode for ResultRow {
_ => Vec::new(),
};
Ok(ResultRow::default())
Ok(ResultRow {
length,
seq_no,
columns
})
}
}

View File

@ -29,9 +29,10 @@ pub use ssl_request::SSLRequestPacket;
pub use text::{
ComDebug, ComInitDb, ComPing, ComProcessKill, ComQuery, ComQuit, ComResetConnection,
ComSetOption, ComShutdown, ComSleep, ComStatistics, SetOptionOptions, ShutdownOptions,
ResultRow as ResultRowText
};
pub use binary::{
ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp,
ComStmtReset,
ComStmtReset, ResultRow as ResultRowBinary
};

View File

@ -1,67 +1,29 @@
use crate::mariadb::{DeContext, Decode, Decoder, ErrorCode, ServerStatusFlag};
use bytes::Bytes;
use failure::Error;
use std::convert::TryFrom;
use crate::mariadb::{ResultRowText, ResultRowBinary};
#[derive(Default, Debug)]
#[derive(Debug)]
pub struct ResultRow {
pub length: u32,
pub seq_no: u8,
pub row: Vec<Bytes>,
pub columns: Vec<Option<bytes::Bytes>>
}
impl Decode for ResultRow {
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_u8();
let row = if let Some(columns) = ctx.columns {
(0..columns)
.map(|_| decoder.decode_string_lenenc())
.collect::<Vec<Bytes>>()
} else {
Vec::new()
};
Ok(ResultRow {
length,
seq_no,
row,
})
impl From<ResultRowText> for ResultRow {
fn from(row: ResultRowText) -> Self {
ResultRow {
length: row.length,
seq_no: row.seq_no,
columns: row.columns,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
__bytes_builder,
mariadb::{ConnContext, Decoder},
ConnectOptions,
};
use bytes::Bytes;
#[test]
fn it_decodes_result_row_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// string<lenenc> column data
1u8, b"s"
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
ctx.columns = Some(1);
let _message = ResultRow::decode(&mut ctx)?;
Ok(())
impl From<ResultRowBinary> for ResultRow {
fn from(row: ResultRowBinary) -> Self {
ResultRow {
length: row.length,
seq_no: row.seq_no,
columns: row.columns,
}
}
}

View File

@ -3,7 +3,7 @@ use failure::Error;
use crate::mariadb::{
Capabilities, ColumnDefPacket, ColumnPacket, ConnContext, DeContext, Decode, Decoder,
EofPacket, ErrPacket, Framed, OkPacket, ResultRow,
EofPacket, ErrPacket, Framed, OkPacket, ResultRowText, ResultRowBinary, ProtocolType, ResultRow
};
#[derive(Debug, Default)]
@ -14,11 +14,9 @@ pub struct ResultSet {
}
impl ResultSet {
pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result<Self, Error> {
pub async fn deserialize(mut ctx: DeContext<'_>, protocol: ProtocolType) -> Result<Self, Error> {
let column_packet = ColumnPacket::decode(&mut ctx)?;
println!("{:?}", column_packet);
let columns = if let Some(columns) = column_packet.columns {
let mut column_defs = Vec::new();
for _ in 0..columns {
@ -30,8 +28,6 @@ impl ResultSet {
Vec::new()
};
println!("{:?}", columns);
ctx.next_packet().await?;
let eof_packet = if !ctx
@ -41,7 +37,6 @@ impl ResultSet {
{
// If we get an eof packet we must update ctx to hold a new buffer of the next packet.
let eof_packet = Some(EofPacket::decode(&mut ctx)?);
println!("{:?}", eof_packet);
ctx.next_packet().await?;
eof_packet
} else {
@ -59,19 +54,35 @@ impl ResultSet {
};
let tag = ctx.decoder.peek_tag();
if tag == &0xFE && packet_header.length <= 0xFFFFFF || packet_header.length == 0 {
if tag == &0xFE && packet_header.length <= 0xFFFFFF {
break;
} else {
let index = ctx.decoder.index;
match ResultRow::decode(&mut ctx) {
Ok(v) => {
rows.push(v);
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
match protocol {
ProtocolType::Text => {
match ResultRowText::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
}
},
ProtocolType::Binary => {
match ResultRowBinary::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
}
},
}
}
}
@ -324,7 +335,7 @@ mod test {
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
ResultSet::deserialize(ctx).await?;
ResultSet::deserialize(ctx, ProtocolType::Text).await?;
Ok(())
}

View File

@ -9,6 +9,7 @@ pub mod com_set_option;
pub mod com_shutdown;
pub mod com_sleep;
pub mod com_statistics;
pub mod result_row;
pub use com_debug::ComDebug;
pub use com_init_db::ComInitDb;
@ -21,6 +22,7 @@ pub use com_set_option::{ComSetOption, SetOptionOptions};
pub use com_shutdown::{ComShutdown, ShutdownOptions};
pub use com_sleep::ComSleep;
pub use com_statistics::ComStatistics;
pub use result_row::ResultRow;
// This is an enum of text protocol packet tags.
// Tags are the 5th byte of the packet (1st byte of packet body)

View File

@ -0,0 +1,67 @@
use crate::mariadb::{DeContext, Decode, Decoder, ErrorCode, ServerStatusFlag};
use bytes::Bytes;
use failure::Error;
use std::convert::TryFrom;
#[derive(Default, Debug)]
pub struct ResultRow {
pub length: u32,
pub seq_no: u8,
pub columns: Vec<Option<Bytes>>,
}
impl Decode for ResultRow {
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_u8();
let columns = if let Some(columns) = ctx.columns {
(0..columns)
.map(|_| Some(decoder.decode_string_lenenc()))
.collect::<Vec<Option<Bytes>>>()
} else {
Vec::new()
};
Ok(ResultRow {
length,
seq_no,
columns,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
__bytes_builder,
mariadb::{ConnContext, Decoder},
ConnectOptions,
};
use bytes::Bytes;
#[test]
fn it_decodes_result_row_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// string<lenenc> column data
1u8, b"s"
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
ctx.columns = Some(1);
let _message = ResultRow::decode(&mut ctx)?;
Ok(())
}
}

View File

@ -1,5 +1,10 @@
use std::convert::TryFrom;
pub enum ProtocolType {
Text,
Binary
}
bitflags! {
pub struct Capabilities: u128 {
const CLIENT_MYSQL = 1;