2019-07-26 15:49:00 -07:00

167 lines
5.3 KiB
Rust

use super::super::{
deserialize::{DeContext, Deserialize},
types::{Capabilities, ServerStatusFlag},
};
use bytes::Bytes;
use failure::{err_msg, Error};
#[derive(Default, Debug)]
pub struct InitialHandshakePacket {
pub length: u32,
pub seq_no: u8,
pub protocol_version: u8,
pub server_version: Bytes,
pub connection_id: u32,
pub auth_seed: Bytes,
pub capabilities: Capabilities,
pub collation: u8,
pub status: ServerStatusFlag,
pub plugin_data_length: u8,
pub scramble: Option<Bytes>,
pub auth_plugin_name: Option<Bytes>,
}
impl Deserialize for InitialHandshakePacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
if seq_no != 0 {
return Err(err_msg("Sequence Number of Initial Handshake Packet is not 0"));
}
let protocol_version = decoder.decode_int_1();
let server_version = decoder.decode_string_null()?;
let connection_id = decoder.decode_int_4();
let auth_seed = decoder.decode_string_fix(8);
// Skip reserved byte
decoder.skip_bytes(1);
let mut capabilities = Capabilities::from_bits_truncate(decoder.decode_int_2().into());
let collation = decoder.decode_int_1();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2().into());
capabilities |=
Capabilities::from_bits_truncate(((decoder.decode_int_2() as u32) << 16).into());
let mut plugin_data_length = 0;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
plugin_data_length = decoder.decode_int_1();
} else {
// Skip reserve byte
decoder.skip_bytes(1);
}
// Skip filler
decoder.skip_bytes(6);
if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
capabilities |=
Capabilities::from_bits_truncate(((decoder.decode_int_4() as u128) << 32).into());
} else {
// Skip filler
decoder.skip_bytes(4);
}
let mut scramble: Option<Bytes> = None;
if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
let len = std::cmp::max(12, plugin_data_length as usize - 9);
scramble = Some(decoder.decode_string_fix(len as u32));
// Skip reserve byte
decoder.skip_bytes(1);
}
let mut auth_plugin_name: Option<Bytes> = None;
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
auth_plugin_name = Some(decoder.decode_string_null()?);
}
ctx.conn.capabilities = capabilities;
ctx.conn.last_seq_no = seq_no;
Ok(InitialHandshakePacket {
length,
seq_no,
protocol_version,
server_version,
connection_id,
auth_seed,
capabilities,
collation,
status,
plugin_data_length,
scramble,
auth_plugin_name,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{__bytes_builder, mariadb::connection::ConnContext, mariadb::protocol::decode::Decoder};
use bytes::BytesMut;
use crate::ConnectOptions;
#[test]
fn it_decodes_initial_handshake_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
0u8,
//int<1> protocol version
10u8,
//string<NUL> server version (MariaDB server version is by default prefixed by "5.5.5-")
b"5.5.5-10.4.6-MariaDB-1:10.4.6+maria~bionic\0",
//int<4> connection id
13u8, 0u8, 0u8, 0u8,
//string<8> scramble 1st part (authentication seed)
b"?~~|vZAu",
//string<1> reserved byte
0u8,
//int<2> server capabilities (1st part)
0xFEu8, 0xF7u8,
//int<1> server default collation
8u8,
//int<2> status flags
2u8, 0u8,
//int<2> server capabilities (2nd part)
0xFF_u8, 0x81_u8,
//if (server_capabilities & PLUGIN_AUTH)
// int<1> plugin data length
15u8,
//else
// int<1> 0x00
//string<6> filler
0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
//if (server_capabilities & CLIENT_MYSQL)
// string<4> filler
//else
// int<4> server capabilities 3rd part . MariaDB specific flags /* MariaDB 10.2 or later */
7u8, 0u8, 0u8, 0u8,
//if (server_capabilities & CLIENT_SECURE_CONNECTION)
// string<n> scramble 2nd part . Length = max(12, plugin data length - 9)
b"JQ8cihP4Q}Dx",
// string<1> reserved byte
0u8,
//if (server_capabilities & PLUGIN_AUTH)
// string<NUL> authentication plugin name
b"mysql_native_password\0"
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, &buf);
let _message = InitialHandshakePacket::deserialize(&mut ctx)?;
Ok(())
}
}