diff --git a/mason-mariadb/src/connection/establish.rs b/mason-mariadb/src/connection/establish.rs index 74210226..c38967bb 100644 --- a/mason-mariadb/src/connection/establish.rs +++ b/mason-mariadb/src/connection/establish.rs @@ -2,31 +2,59 @@ use super::Connection; use crate::protocol::{ server::Message as ServerMessage, server::InitialHandshakePacket, - server::Deserialize + server::Deserialize, + server::Capabilities, + client::HandshakeResponsePacket, + client::Serialize }; use futures::StreamExt; use mason_core::ConnectOptions; use std::io; use failure::Error; +use bytes::Bytes; pub async fn establish<'a, 'b: 'a>( conn: &'a mut Connection, options: ConnectOptions<'b>, ) -> Result<(), Error> { - // The actual connection establishing -// - if let Some(message) = conn.incoming.next().await { -// return -// match message { -// ServerMessage::InitialHandshakePacket(message) => { -// -// }, -// _ => unimplemented!("received {:?} unimplemented message", message), -// } - Ok(()) + let init_packet = if let Some(message) = conn.incoming.next().await { + match message { + ServerMessage::InitialHandshakePacket(message) => { + Ok(message) + }, + _ => Err(failure::err_msg("Incorrect First Packet")), + } } else { Err(failure::err_msg("Failed to connect")) - } + }?; + +// println!("{:?}", init_packet); + + let handshake = HandshakeResponsePacket { + server_capabilities: init_packet.capabilities, + sequence_number: 1, + capabilities: Capabilities::from_bits_truncate(0), + max_packet_size: 1024, + collation: 0, + extended_capabilities: Some(Capabilities::from_bits_truncate(0)), + username: Bytes::from("username"), + auth_data: None, + auth_response_len: None, + auth_response: None, + database: None, + auth_plugin_name: None, + conn_attr_len: None, + conn_attr: None, + }; + conn.send(handshake).await?; + + if let Some(message) = conn.incoming.next().await { + Ok(()) + } else { + Err(failure::err_msg("Handshake Failed")) + } + +// Ok(()) } #[cfg(test)] diff --git a/mason-mariadb/src/connection/mod.rs b/mason-mariadb/src/connection/mod.rs index 2a77a8b3..7702771f 100644 --- a/mason-mariadb/src/connection/mod.rs +++ b/mason-mariadb/src/connection/mod.rs @@ -105,17 +105,15 @@ async fn receiver( break; } + println!("{:?}", rbuf); + while len > 0 { let size = rbuf.len(); - println!("Buffer: {:?}", rbuf); let message = if first_packet { - println!("init"); ServerMessage::init(&mut rbuf)? } else { - println!("deser"); ServerMessage::deserialize(&mut rbuf)? }; - println!("Message: {:?}", message); len -= size - rbuf.len(); if let Some(message) = message { diff --git a/mason-mariadb/src/protocol/deserialize.rs b/mason-mariadb/src/protocol/deserialize.rs index 1a28e757..7ca23215 100644 --- a/mason-mariadb/src/protocol/deserialize.rs +++ b/mason-mariadb/src/protocol/deserialize.rs @@ -1,6 +1,19 @@ // Deserializing bytes and string do the same thing. Except that string also has a null terminated deserialzer use byteorder::{ByteOrder, LittleEndian}; use bytes::Bytes; +use failure::Error; +use failure::err_msg; + +#[inline] +pub fn deserialize_length(buf: &Vec, index: &mut usize) -> Result { + let length = deserialize_int_3(&buf, index); + + if buf.len() < length as usize { + return Err(err_msg("Lengths to do not match")); + } + + Ok(length) +} #[inline] pub fn deserialize_int_lenenc(buf: &Vec, index: &mut usize) -> Option { diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index dcffaf41..9b0c3e73 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -17,7 +17,6 @@ pub enum Message { } bitflags! { -// 1111011111111110 pub struct Capabilities: u128 { const CLIENT_MYSQL = 1; const FOUND_ROWS = 2; @@ -149,7 +148,7 @@ pub struct InitialHandshakePacket { pub auth_seed: Bytes, pub capabilities: Capabilities, pub collation: u8, - pub status: u16, + pub status: ServerStatusFlag, pub plugin_data_length: u8, pub scramble: Option, pub auth_plugin_name: Option, @@ -196,12 +195,7 @@ impl Deserialize for InitialHandshakePacket { fn deserialize(buf: &mut Vec) -> Result { let mut index = 0; - let length = deserialize_int_3(&buf, &mut index); - - if buf.len() < length as usize { - return Err(err_msg("Lengths to do not match")); - } - + let length = deserialize_length(&buf, &mut index)?; let sequence_number = deserialize_int_1(&buf, &mut index); if sequence_number != 0 { @@ -220,7 +214,7 @@ impl Deserialize for InitialHandshakePacket { Capabilities::from_bits_truncate(deserialize_int_2(&buf, &mut index).into()); let collation = deserialize_int_1(&buf, &mut index); - let status = deserialize_int_2(&buf, &mut index); + let status = ServerStatusFlag::from_bits_truncate(deserialize_int_2(&buf, &mut index).into()); capabilities |= Capabilities::from_bits_truncate(((deserialize_int_2(&buf, &mut index) as u32) << 16).into()); @@ -279,12 +273,7 @@ impl Deserialize for OkPacket { fn deserialize(buf: &mut Vec) -> Result { let mut index = 0; - let length = deserialize_int_3(&buf, &mut index); - - if buf.len() != length as usize { - return Err(err_msg("Lengths to do not match")); - } - + let length = deserialize_length(&buf, &mut index)?; let _sequence_number = deserialize_int_1(&buf, &mut index); let packet_header = deserialize_int_1(&buf, &mut index); @@ -319,12 +308,7 @@ impl Deserialize for ErrPacket { fn deserialize(buf: &mut Vec) -> Result { let mut index = 0; - let length = deserialize_int_3(&buf, &mut index); - - if buf.len() != length as usize { - return Err(err_msg("Lengths to do not match")); - } - + let length = deserialize_length(&buf, &mut index)?; let _sequence_number = deserialize_int_1(&buf, &mut index); let packet_header = deserialize_int_1(&buf, &mut index);