WIP: Handshake

This commit is contained in:
Daniel Akhterov 2019-06-22 03:01:20 -07:00 committed by Daniel Akhterov
parent 3a07e393e8
commit 8ea28f36e5
4 changed files with 61 additions and 38 deletions

View File

@ -2,31 +2,59 @@ use super::Connection;
use crate::protocol::{
server::Message as ServerMessage,
server::InitialHandshakePacket,
server::Deserialize
server::Deserialize,
server::Capabilities,
client::HandshakeResponsePacket,
client::Serialize
};
use futures::StreamExt;
use mason_core::ConnectOptions;
use std::io;
use failure::Error;
use bytes::Bytes;
pub async fn establish<'a, 'b: 'a>(
conn: &'a mut Connection,
options: ConnectOptions<'b>,
) -> Result<(), Error> {
// The actual connection establishing
//
if let Some(message) = conn.incoming.next().await {
// return
// match message {
// ServerMessage::InitialHandshakePacket(message) => {
//
// },
// _ => unimplemented!("received {:?} unimplemented message", message),
// }
Ok(())
let init_packet = if let Some(message) = conn.incoming.next().await {
match message {
ServerMessage::InitialHandshakePacket(message) => {
Ok(message)
},
_ => Err(failure::err_msg("Incorrect First Packet")),
}
} else {
Err(failure::err_msg("Failed to connect"))
}
}?;
// println!("{:?}", init_packet);
let handshake = HandshakeResponsePacket {
server_capabilities: init_packet.capabilities,
sequence_number: 1,
capabilities: Capabilities::from_bits_truncate(0),
max_packet_size: 1024,
collation: 0,
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
username: Bytes::from("username"),
auth_data: None,
auth_response_len: None,
auth_response: None,
database: None,
auth_plugin_name: None,
conn_attr_len: None,
conn_attr: None,
};
conn.send(handshake).await?;
if let Some(message) = conn.incoming.next().await {
Ok(())
} else {
Err(failure::err_msg("Handshake Failed"))
}
// Ok(())
}
#[cfg(test)]

View File

@ -105,17 +105,15 @@ async fn receiver(
break;
}
println!("{:?}", rbuf);
while len > 0 {
let size = rbuf.len();
println!("Buffer: {:?}", rbuf);
let message = if first_packet {
println!("init");
ServerMessage::init(&mut rbuf)?
} else {
println!("deser");
ServerMessage::deserialize(&mut rbuf)?
};
println!("Message: {:?}", message);
len -= size - rbuf.len();
if let Some(message) = message {

View File

@ -1,6 +1,19 @@
// Deserializing bytes and string do the same thing. Except that string also has a null terminated deserialzer
use byteorder::{ByteOrder, LittleEndian};
use bytes::Bytes;
use failure::Error;
use failure::err_msg;
#[inline]
pub fn deserialize_length(buf: &Vec<u8>, index: &mut usize) -> Result<u32, Error> {
let length = deserialize_int_3(&buf, index);
if buf.len() < length as usize {
return Err(err_msg("Lengths to do not match"));
}
Ok(length)
}
#[inline]
pub fn deserialize_int_lenenc(buf: &Vec<u8>, index: &mut usize) -> Option<usize> {

View File

@ -17,7 +17,6 @@ pub enum Message {
}
bitflags! {
// 1111011111111110
pub struct Capabilities: u128 {
const CLIENT_MYSQL = 1;
const FOUND_ROWS = 2;
@ -149,7 +148,7 @@ pub struct InitialHandshakePacket {
pub auth_seed: Bytes,
pub capabilities: Capabilities,
pub collation: u8,
pub status: u16,
pub status: ServerStatusFlag,
pub plugin_data_length: u8,
pub scramble: Option<Bytes>,
pub auth_plugin_name: Option<Bytes>,
@ -196,12 +195,7 @@ impl Deserialize for InitialHandshakePacket {
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
let mut index = 0;
let length = deserialize_int_3(&buf, &mut index);
if buf.len() < length as usize {
return Err(err_msg("Lengths to do not match"));
}
let length = deserialize_length(&buf, &mut index)?;
let sequence_number = deserialize_int_1(&buf, &mut index);
if sequence_number != 0 {
@ -220,7 +214,7 @@ impl Deserialize for InitialHandshakePacket {
Capabilities::from_bits_truncate(deserialize_int_2(&buf, &mut index).into());
let collation = deserialize_int_1(&buf, &mut index);
let status = deserialize_int_2(&buf, &mut index);
let status = ServerStatusFlag::from_bits_truncate(deserialize_int_2(&buf, &mut index).into());
capabilities |=
Capabilities::from_bits_truncate(((deserialize_int_2(&buf, &mut index) as u32) << 16).into());
@ -279,12 +273,7 @@ impl Deserialize for OkPacket {
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
let mut index = 0;
let length = deserialize_int_3(&buf, &mut index);
if buf.len() != length as usize {
return Err(err_msg("Lengths to do not match"));
}
let length = deserialize_length(&buf, &mut index)?;
let _sequence_number = deserialize_int_1(&buf, &mut index);
let packet_header = deserialize_int_1(&buf, &mut index);
@ -319,12 +308,7 @@ impl Deserialize for ErrPacket {
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
let mut index = 0;
let length = deserialize_int_3(&buf, &mut index);
if buf.len() != length as usize {
return Err(err_msg("Lengths to do not match"));
}
let length = deserialize_length(&buf, &mut index)?;
let _sequence_number = deserialize_int_1(&buf, &mut index);
let packet_header = deserialize_int_1(&buf, &mut index);