diff --git a/src/mariadb/connection/establish.rs b/src/mariadb/connection/establish.rs index dfad90bd..92d3f1b7 100644 --- a/src/mariadb/connection/establish.rs +++ b/src/mariadb/connection/establish.rs @@ -2,7 +2,7 @@ use super::Connection; use crate::{ mariadb::{ Capabilities, ComStmtExec, DeContext, Decode, EofPacket, ErrPacket, - HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag, + HandshakeResponsePacket, InitialHandshakePacket, OkPacket, StmtExecFlag, ProtocolType }, ConnectOptions, }; @@ -109,10 +109,8 @@ mod test { }) .await?; - println!("selecting db"); conn.select_db("test").await?; - println!("querying"); conn.query("SELECT * FROM users").await?; Ok(()) @@ -154,17 +152,10 @@ mod test { .prepare("SELECT id FROM users WHERE username=?") .await?; - println!("{:?}", prepared); - - // if let Some(param_defs) = &mut prepared.param_defs { - // param_defs[0].field_type = FieldType::MysqlTypeBlob; - // } - let exec = ComStmtExec { stmt_id: prepared.ok.stmt_id, flags: StmtExecFlag::NoCursor, - // params: None, - params: Some(vec![Some(Bytes::from_static(b"daniel"))]), + params: Some(vec![Some(Bytes::from_static(b"josh"))]), param_defs: prepared.param_defs, }; @@ -172,26 +163,21 @@ mod test { let mut ctx = DeContext::with_stream(&mut conn.context, &mut conn.stream); ctx.next_packet().await?; + ctx.columns = Some(prepared.ok.columns as u64); + ctx.column_defs = prepared.res_columns; match ctx.decoder.peek_tag() { - 0xFF => println!("{:?}", ErrPacket::decode(&mut ctx)?), - 0x00 => println!("{:?}", OkPacket::decode(&mut ctx)?), - _ => println!("{:?}", ResultSet::deserialize(ctx).await?), + 0xFF => { + ErrPacket::decode(&mut ctx)?; + }, + 0x00 => { + OkPacket::decode(&mut ctx)?; + }, + _ => { + ResultSet::deserialize(ctx, ProtocolType::Binary).await?; + }, } - // let fetch = ComStmtFetch { - // stmt_id: -1, - // rows: 10, - // }; - // - // conn.send(fetch).await?; - // - // let buf = conn.stream.next_packet().await?; - // - // println!("{:?}", buf); - - // println!("{:?}", ResultSet::deserialize(&mut DeContext::new(&mut conn.context, &buf))?); - Ok(()) } diff --git a/src/mariadb/connection/mod.rs b/src/mariadb/connection/mod.rs index 9c6a0c6d..834fd7a5 100644 --- a/src/mariadb/connection/mod.rs +++ b/src/mariadb/connection/mod.rs @@ -2,7 +2,7 @@ use crate::{ mariadb::{ protocol::encode, Capabilities, ComInitDb, ComPing, ComQuery, ComQuit, ComStmtPrepare, ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, ErrPacket, OkPacket, PacketHeader, - ResultSet, ServerStatusFlag, + ResultSet, ServerStatusFlag, ProtocolType }, ConnectOptions, }; @@ -130,7 +130,7 @@ impl Connection { Ok(None) } 0xFB => unimplemented!(), - _ => Ok(Some(ResultSet::deserialize(ctx).await?)), + _ => Ok(Some(ResultSet::deserialize(ctx, ProtocolType::Text).await?)), } } @@ -203,7 +203,6 @@ impl Framed { // After this operation we know that packet_headers.last() *SHOULD* always return valid data, // so the the use of packet_headers.last().unwrap() is allowed. // TODO: Stitch packets together by removing the length and seq_no from in-between packet definitions. - println!("{:?}", self.buf); if let Some(packet_header) = packet_headers.last() { if packet_header.length as usize == encode::U24_MAX { packet_headers.push(PacketHeader::try_from(&self.buf[self.index..])?); diff --git a/src/mariadb/mod.rs b/src/mariadb/mod.rs index 5d37130e..f2571936 100644 --- a/src/mariadb/mod.rs +++ b/src/mariadb/mod.rs @@ -9,6 +9,6 @@ pub use protocol::{ ComSetOption, ComShutdown, ComSleep, ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp, DeContext, Decode, Decoder, Encode, EofPacket, ErrPacket, ErrorCode, FieldDetailFlag, FieldType, HandshakeResponsePacket, - InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultSet, SSLRequestPacket, - ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag, + InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultRow, ResultSet, SSLRequestPacket, + ServerStatusFlag, SessionChangeType, SetOptionOptions, ShutdownOptions, StmtExecFlag, ProtocolType }; diff --git a/src/mariadb/protocol/encode.rs b/src/mariadb/protocol/encode.rs index e2882b65..6ff59a32 100644 --- a/src/mariadb/protocol/encode.rs +++ b/src/mariadb/protocol/encode.rs @@ -316,9 +316,7 @@ impl BufMut for Vec { // Same as the string counterpart copied to maintain consistency with the spec. #[inline] fn put_byte_fix(&mut self, bytes: &Bytes, size: usize) { - if size != bytes.len() { - panic!("Sizes do not match"); - } + assert_eq!(size, bytes.len()); self.extend_from_slice(bytes); } diff --git a/src/mariadb/protocol/mod.rs b/src/mariadb/protocol/mod.rs index fdc03473..943dd293 100644 --- a/src/mariadb/protocol/mod.rs +++ b/src/mariadb/protocol/mod.rs @@ -19,8 +19,8 @@ pub use packets::{ ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep, ComStatistics, ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp, ComStmtReset, EofPacket, ErrPacket, HandshakeResponsePacket, - InitialHandshakePacket, OkPacket, PacketHeader, ResultRow, ResultSet, SSLRequestPacket, - SetOptionOptions, ShutdownOptions, + InitialHandshakePacket, OkPacket, PacketHeader, ResultRowText, ResultRowBinary, ResultSet, SSLRequestPacket, + SetOptionOptions, ShutdownOptions, ResultRow }; pub use decode::{DeContext, Decode, Decoder}; @@ -30,5 +30,5 @@ pub use encode::{BufMut, Encode}; pub use error_codes::ErrorCode; pub use types::{ - Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag, + ProtocolType, Capabilities, FieldDetailFlag, FieldType, ServerStatusFlag, SessionChangeType, StmtExecFlag, }; diff --git a/src/mariadb/protocol/packets/binary/com_stmt_exec.rs b/src/mariadb/protocol/packets/binary/com_stmt_exec.rs index e774c7fc..ef4645fc 100644 --- a/src/mariadb/protocol/packets/binary/com_stmt_exec.rs +++ b/src/mariadb/protocol/packets/binary/com_stmt_exec.rs @@ -1,5 +1,5 @@ use crate::mariadb::{ - BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag, + BufMut, ColumnDefPacket, ConnContext, Connection, Encode, FieldDetailFlag, StmtExecFlag, FieldType }; use bytes::Bytes; use failure::Error; @@ -26,12 +26,12 @@ impl Encode for ComStmtExec { (Some(params), Some(param_defs)) if params.len() > 0 => { let null_bitmap_size = (params.len() + 7) / 8; let mut shift_amount = 0u8; - let mut bitmap = vec![1u8]; - let send_type = 0u8; + let mut bitmap = vec![0u8]; + let send_type = 1u8; // Generate NULL-bitmap from params for param in params { - if param.is_some() { + if param.is_none() { let last_byte = bitmap.pop().unwrap(); bitmap.push(last_byte & (1 << shift_amount)); } @@ -48,7 +48,7 @@ impl Encode for ComStmtExec { if send_type > 0 { for param in param_defs { - // buf.put_int_u8(param.field_type as u8); + buf.put_int_u8(param.field_type as u8); buf.put_int_u8(0); } } diff --git a/src/mariadb/protocol/packets/binary/mod.rs b/src/mariadb/protocol/packets/binary/mod.rs index 2db1ac51..7b91d4b9 100644 --- a/src/mariadb/protocol/packets/binary/mod.rs +++ b/src/mariadb/protocol/packets/binary/mod.rs @@ -14,6 +14,7 @@ pub use com_stmt_prepare::ComStmtPrepare; pub use com_stmt_prepare_ok::ComStmtPrepareOk; pub use com_stmt_prepare_resp::ComStmtPrepareResp; pub use com_stmt_reset::ComStmtReset; +pub use result_row::ResultRow; pub enum BinaryProtocol { ComStmtPrepare = 0x16, diff --git a/src/mariadb/protocol/packets/binary/result_row.rs b/src/mariadb/protocol/packets/binary/result_row.rs index d8fc3ab0..ce75e49c 100644 --- a/src/mariadb/protocol/packets/binary/result_row.rs +++ b/src/mariadb/protocol/packets/binary/result_row.rs @@ -3,6 +3,8 @@ use bytes::Bytes; #[derive(Debug, Default)] pub struct ResultRow { + pub length: u32, + pub seq_no: u8, pub columns: Vec>, } @@ -24,11 +26,11 @@ impl crate::mariadb::Decode for ResultRow { )) }?; - let row = match (&ctx.columns, &ctx.column_defs) { + let columns = match (&ctx.columns, &ctx.column_defs) { (Some(columns), Some(column_defs)) => { (0..*columns as usize) .map(|index| { - if (1 << (index % 8)) & bitmap[index / 8] as usize == 0 { + if (1 << (index % 8)) & bitmap[index / 8] as usize == 1 { None } else { match column_defs[index].field_type { @@ -105,6 +107,10 @@ impl crate::mariadb::Decode for ResultRow { _ => Vec::new(), }; - Ok(ResultRow::default()) + Ok(ResultRow { + length, + seq_no, + columns + }) } } diff --git a/src/mariadb/protocol/packets/mod.rs b/src/mariadb/protocol/packets/mod.rs index dfdf63e5..4759d368 100644 --- a/src/mariadb/protocol/packets/mod.rs +++ b/src/mariadb/protocol/packets/mod.rs @@ -29,9 +29,10 @@ pub use ssl_request::SSLRequestPacket; pub use text::{ ComDebug, ComInitDb, ComPing, ComProcessKill, ComQuery, ComQuit, ComResetConnection, ComSetOption, ComShutdown, ComSleep, ComStatistics, SetOptionOptions, ShutdownOptions, + ResultRow as ResultRowText }; pub use binary::{ ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp, - ComStmtReset, + ComStmtReset, ResultRow as ResultRowBinary }; diff --git a/src/mariadb/protocol/packets/result_row.rs b/src/mariadb/protocol/packets/result_row.rs index bb48a709..4d57e74b 100644 --- a/src/mariadb/protocol/packets/result_row.rs +++ b/src/mariadb/protocol/packets/result_row.rs @@ -1,67 +1,29 @@ -use crate::mariadb::{DeContext, Decode, Decoder, ErrorCode, ServerStatusFlag}; -use bytes::Bytes; -use failure::Error; -use std::convert::TryFrom; +use crate::mariadb::{ResultRowText, ResultRowBinary}; -#[derive(Default, Debug)] +#[derive(Debug)] pub struct ResultRow { pub length: u32, pub seq_no: u8, - pub row: Vec, + pub columns: Vec> } -impl Decode for ResultRow { - fn decode(ctx: &mut DeContext) -> Result { - let decoder = &mut ctx.decoder; - - let length = decoder.decode_length()?; - let seq_no = decoder.decode_int_u8(); - - let row = if let Some(columns) = ctx.columns { - (0..columns) - .map(|_| decoder.decode_string_lenenc()) - .collect::>() - } else { - Vec::new() - }; - - Ok(ResultRow { - length, - seq_no, - row, - }) +impl From for ResultRow { + fn from(row: ResultRowText) -> Self { + ResultRow { + length: row.length, + seq_no: row.seq_no, + columns: row.columns, + } } } -#[cfg(test)] -mod test { - use super::*; - use crate::{ - __bytes_builder, - mariadb::{ConnContext, Decoder}, - ConnectOptions, - }; - use bytes::Bytes; - #[test] - fn it_decodes_result_row_packet() -> Result<(), Error> { - #[rustfmt::skip] - let buf = __bytes_builder!( - // int<3> length - 1u8, 0u8, 0u8, - // int<1> seq_no - 1u8, - // string column data - 1u8, b"s" - ); - - let mut context = ConnContext::new(); - let mut ctx = DeContext::new(&mut context, buf); - - ctx.columns = Some(1); - - let _message = ResultRow::decode(&mut ctx)?; - - Ok(()) +impl From for ResultRow { + fn from(row: ResultRowBinary) -> Self { + ResultRow { + length: row.length, + seq_no: row.seq_no, + columns: row.columns, + } } } diff --git a/src/mariadb/protocol/packets/result_set.rs b/src/mariadb/protocol/packets/result_set.rs index 26ba4d8d..d94509d8 100644 --- a/src/mariadb/protocol/packets/result_set.rs +++ b/src/mariadb/protocol/packets/result_set.rs @@ -3,7 +3,7 @@ use failure::Error; use crate::mariadb::{ Capabilities, ColumnDefPacket, ColumnPacket, ConnContext, DeContext, Decode, Decoder, - EofPacket, ErrPacket, Framed, OkPacket, ResultRow, + EofPacket, ErrPacket, Framed, OkPacket, ResultRowText, ResultRowBinary, ProtocolType, ResultRow }; #[derive(Debug, Default)] @@ -14,11 +14,9 @@ pub struct ResultSet { } impl ResultSet { - pub async fn deserialize<'a>(mut ctx: DeContext<'a>) -> Result { + pub async fn deserialize(mut ctx: DeContext<'_>, protocol: ProtocolType) -> Result { let column_packet = ColumnPacket::decode(&mut ctx)?; - println!("{:?}", column_packet); - let columns = if let Some(columns) = column_packet.columns { let mut column_defs = Vec::new(); for _ in 0..columns { @@ -30,8 +28,6 @@ impl ResultSet { Vec::new() }; - println!("{:?}", columns); - ctx.next_packet().await?; let eof_packet = if !ctx @@ -41,7 +37,6 @@ impl ResultSet { { // If we get an eof packet we must update ctx to hold a new buffer of the next packet. let eof_packet = Some(EofPacket::decode(&mut ctx)?); - println!("{:?}", eof_packet); ctx.next_packet().await?; eof_packet } else { @@ -59,19 +54,35 @@ impl ResultSet { }; let tag = ctx.decoder.peek_tag(); - if tag == &0xFE && packet_header.length <= 0xFFFFFF || packet_header.length == 0 { + if tag == &0xFE && packet_header.length <= 0xFFFFFF { break; } else { let index = ctx.decoder.index; - match ResultRow::decode(&mut ctx) { - Ok(v) => { - rows.push(v); - ctx.next_packet().await?; - } - Err(_) => { - ctx.decoder.index = index; - break; - } + match protocol { + ProtocolType::Text => { + match ResultRowText::decode(&mut ctx) { + Ok(row) => { + rows.push(ResultRow::from(row)); + ctx.next_packet().await?; + } + Err(_) => { + ctx.decoder.index = index; + break; + } + } + }, + ProtocolType::Binary => { + match ResultRowBinary::decode(&mut ctx) { + Ok(row) => { + rows.push(ResultRow::from(row)); + ctx.next_packet().await?; + } + Err(_) => { + ctx.decoder.index = index; + break; + } + } + }, } } } @@ -324,7 +335,7 @@ mod test { let mut context = ConnContext::new(); let mut ctx = DeContext::new(&mut context, buf); - ResultSet::deserialize(ctx).await?; + ResultSet::deserialize(ctx, ProtocolType::Text).await?; Ok(()) } diff --git a/src/mariadb/protocol/packets/text/mod.rs b/src/mariadb/protocol/packets/text/mod.rs index aa1c41de..2afa6ca2 100644 --- a/src/mariadb/protocol/packets/text/mod.rs +++ b/src/mariadb/protocol/packets/text/mod.rs @@ -9,6 +9,7 @@ pub mod com_set_option; pub mod com_shutdown; pub mod com_sleep; pub mod com_statistics; +pub mod result_row; pub use com_debug::ComDebug; pub use com_init_db::ComInitDb; @@ -21,6 +22,7 @@ pub use com_set_option::{ComSetOption, SetOptionOptions}; pub use com_shutdown::{ComShutdown, ShutdownOptions}; pub use com_sleep::ComSleep; pub use com_statistics::ComStatistics; +pub use result_row::ResultRow; // This is an enum of text protocol packet tags. // Tags are the 5th byte of the packet (1st byte of packet body) diff --git a/src/mariadb/protocol/packets/text/result_row.rs b/src/mariadb/protocol/packets/text/result_row.rs new file mode 100644 index 00000000..de38cfe8 --- /dev/null +++ b/src/mariadb/protocol/packets/text/result_row.rs @@ -0,0 +1,67 @@ +use crate::mariadb::{DeContext, Decode, Decoder, ErrorCode, ServerStatusFlag}; +use bytes::Bytes; +use failure::Error; +use std::convert::TryFrom; + +#[derive(Default, Debug)] +pub struct ResultRow { + pub length: u32, + pub seq_no: u8, + pub columns: Vec>, +} + +impl Decode for ResultRow { + fn decode(ctx: &mut DeContext) -> Result { + let decoder = &mut ctx.decoder; + + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_u8(); + + let columns = if let Some(columns) = ctx.columns { + (0..columns) + .map(|_| Some(decoder.decode_string_lenenc())) + .collect::>>() + } else { + Vec::new() + }; + + Ok(ResultRow { + length, + seq_no, + columns, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + __bytes_builder, + mariadb::{ConnContext, Decoder}, + ConnectOptions, + }; + use bytes::Bytes; + + #[test] + fn it_decodes_result_row_packet() -> Result<(), Error> { + #[rustfmt::skip] + let buf = __bytes_builder!( + // int<3> length + 1u8, 0u8, 0u8, + // int<1> seq_no + 1u8, + // string column data + 1u8, b"s" + ); + + let mut context = ConnContext::new(); + let mut ctx = DeContext::new(&mut context, buf); + + ctx.columns = Some(1); + + let _message = ResultRow::decode(&mut ctx)?; + + Ok(()) + } +} diff --git a/src/mariadb/protocol/types.rs b/src/mariadb/protocol/types.rs index 7fb6cbd8..d29bdbf5 100644 --- a/src/mariadb/protocol/types.rs +++ b/src/mariadb/protocol/types.rs @@ -1,5 +1,10 @@ use std::convert::TryFrom; +pub enum ProtocolType { + Text, + Binary +} + bitflags! { pub struct Capabilities: u128 { const CLIENT_MYSQL = 1;