diff --git a/mason-mariadb/src/connection/establish.rs b/mason-mariadb/src/connection/establish.rs index f71ac108..19c52cd3 100644 --- a/mason-mariadb/src/connection/establish.rs +++ b/mason-mariadb/src/connection/establish.rs @@ -5,6 +5,8 @@ use crate::protocol::{ server::Deserialize, server::Capabilities, client::HandshakeResponsePacket, + client::ComQuit, + client::ComPing, client::Serialize }; use futures::StreamExt; @@ -18,6 +20,7 @@ pub async fn establish<'a, 'b: 'a>( options: ConnectOptions<'b>, ) -> Result<(), Error> { let init_packet = if let Some(message) = conn.incoming.next().await { + conn.sequence_number = message.sequence_number(); match message { ServerMessage::InitialHandshakePacket(message) => { Ok(message) @@ -50,6 +53,7 @@ pub async fn establish<'a, 'b: 'a>( if let Some(message) = conn.incoming.next().await { println!("{:?}", message); + conn.sequence_number = message.sequence_number(); Ok(()) } else { Err(failure::err_msg("Handshake Failed")) @@ -60,10 +64,11 @@ pub async fn establish<'a, 'b: 'a>( mod test { use super::*; use failure::Error; + use failure::err_msg; #[runtime::test] async fn it_connects() -> Result<(), Error> { - Connection::establish(ConnectOptions { + let mut conn = Connection::establish(ConnectOptions { host: "localhost", port: 3306, user: Some("root"), @@ -71,7 +76,22 @@ mod test { password: None, }).await?; - Ok(()) + conn.ping().await?; + + if let Some(message) = conn.incoming.next().await { + match message { + ServerMessage::OkPacket(packet) => { + conn.quit().await?; + Ok(()) + } + ServerMessage::ErrPacket(packet) => { + Err(err_msg(format!("{:?}", packet))) + } + _ => Err(err_msg("Server Failed")) + } + } else { + Err(err_msg("Server Failed")) + } } } diff --git a/mason-mariadb/src/connection/mod.rs b/mason-mariadb/src/connection/mod.rs index 78c19560..3e0a2edf 100644 --- a/mason-mariadb/src/connection/mod.rs +++ b/mason-mariadb/src/connection/mod.rs @@ -1,5 +1,7 @@ use crate::protocol::{ client::Serialize, + client::ComQuit, + client::ComPing, server::Message as ServerMessage, server::Capabilities, server::InitialHandshakePacket, @@ -79,17 +81,31 @@ impl Connection { */ // Reserve space for packet header; Packet Body Length (3 bytes) and sequence number (1 byte) self.wbuf.extend_from_slice(&[0; 4]); - self.wbuf[3] =self.sequence_number; - self.sequence_number += 1; + self.wbuf[3] = self.sequence_number; message.serialize(&mut self.wbuf, &self.server_capabilities)?; serialize_length(&mut self.wbuf); + println!("{:?}", self.wbuf); + self.writer.write_all(&self.wbuf).await?; self.writer.flush().await?; Ok(()) } + + async fn quit(&mut self) -> Result<(), Error> { + self.send(ComQuit()).await?; + + Ok(()) + } + + async fn ping(&mut self) -> Result<(), Error> { + self.sequence_number = 0; + self.send(ComPing()).await?; + + Ok(()) + } } async fn receiver( diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index 5ebef799..d0317e6d 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -18,6 +18,17 @@ pub enum Message { ErrPacket(ErrPacket), } +impl Message { + pub fn sequence_number(&self) -> u8 { + match self { + Message::InitialHandshakePacket(InitialHandshakePacket{ sequence_number, ..}) => sequence_number + 1, + Message::OkPacket(OkPacket{ sequence_number, ..}) => sequence_number + 1, + Message::ErrPacket(ErrPacket { sequence_number, .. }) => sequence_number + 1, + _ => 0 + } + } +} + bitflags! { pub struct Capabilities: u128 { const CLIENT_MYSQL = 1; @@ -123,6 +134,7 @@ pub struct InitialHandshakePacket { #[derive(Default, Debug)] pub struct OkPacket { + pub sequence_number: u8, pub affected_rows: Option, pub last_insert_id: Option, pub server_status: ServerStatusFlag, @@ -134,6 +146,7 @@ pub struct OkPacket { #[derive(Default, Debug)] pub struct ErrPacket { + pub sequence_number: u8, pub error_code: u16, pub stage: Option, pub max_stage: Option, @@ -269,7 +282,7 @@ impl Deserialize for OkPacket { // Packet header let length = deserialize_length(&buf, &mut index)?; - let _sequence_number = deserialize_int_1(&buf, &mut index); + let sequence_number = deserialize_int_1(&buf, &mut index); // Packet body let packet_header = deserialize_int_1(&buf, &mut index); @@ -289,6 +302,7 @@ impl Deserialize for OkPacket { let info = Bytes::from(&buf[index..]); Ok(OkPacket { + sequence_number, affected_rows, last_insert_id, server_status, @@ -305,7 +319,7 @@ impl Deserialize for ErrPacket { let mut index = 0; let length = deserialize_length(&buf, &mut index)?; - let _sequence_number = deserialize_int_1(&buf, &mut index); + let sequence_number = deserialize_int_1(&buf, &mut index); let packet_header = deserialize_int_1(&buf, &mut index); if packet_header != 0xFF { @@ -340,6 +354,7 @@ impl Deserialize for ErrPacket { } Ok(ErrPacket { + sequence_number, error_code, stage, max_stage, @@ -366,7 +381,7 @@ mod test { #[test] fn it_decodes_errpacket_real() -> Result<(), Error> { let buf = BytesMut::from(b"!\0\0\x01\xff\x84\x04#08S01Got packets out of order".to_vec()); - let _message = InitialHandshakePacket::deserialize(&buf.freeze())?; + let _message = ErrPacket::deserialize(&buf.freeze())?; Ok(()) }