From d401f63a3ebe85fecda24392c72e43b69074d005 Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Mon, 17 Jun 2019 16:04:51 -0700 Subject: [PATCH] Add okpacket test --- mason-mariadb/src/protocol/server.rs | 51 ++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index 1fabda734..d61dad2c1 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -133,6 +133,12 @@ impl Default for Capabilities { } } +impl Default for ServerStatusFlag { + fn default() -> Self { + ServerStatusFlag::SERVER_STATUS_IN_TRANS + } +} + #[derive(Default, Debug)] pub struct InitialHandshakePacket { pub length: u32, @@ -153,7 +159,7 @@ pub struct InitialHandshakePacket { pub struct OkPacket { pub affected_rows: Option, pub last_insert_id: Option, - pub server_status: u16, + pub server_status: ServerStatusFlag, pub warning_count: u16, pub info: Bytes, pub session_state_info: Option, @@ -267,10 +273,24 @@ impl Deserialize for InitialHandshakePacket { impl Deserialize for OkPacket { fn deserialize(buf: &mut Vec) -> Result { - let mut index = 1; + let mut index = 0; + + let length = deserialize_int_3(&buf, &mut index); + + if buf.len() != length as usize { + return Err(err_msg("Lengths to do not match")); + } + + let _sequence_number = deserialize_int_1(&buf, &mut index); + + let packet_header = deserialize_int_1(&buf, &mut index); + if packet_header != 0 { + panic!("Packet header is not 0 for OkPacket"); + } + let affected_rows = deserialize_int_lenenc(&buf, &mut index); let last_insert_id = deserialize_int_lenenc(&buf, &mut index); - let server_status = deserialize_int_2(&buf, &mut index); + let server_status = ServerStatusFlag::from_bits(deserialize_int_2(&buf, &mut index).into()).unwrap(); let warning_count = deserialize_int_2(&buf, &mut index); // Assuming CLIENT_SESSION_TRACK is unsupported @@ -383,4 +403,29 @@ mod test { Ok(()) } + + #[test] + fn it_decodes_okpacket() -> Result<(), Error> { + let mut buf = b"\ + \x0F\x00\x00\ + \x01\ + \x00\ + \xFB\ + \xFB\ + \x01\x01\ + \x00\x00\ + info\ + " + .to_vec(); + + let message = OkPacket::deserialize(&mut buf)?; + + assert_eq!(message.affected_rows, None); + assert_eq!(message.last_insert_id, None); + assert!(!(message.server_status & ServerStatusFlag::SERVER_STATUS_IN_TRANS).is_empty()); + assert_eq!(message.warning_count, 0); + assert_eq!(message.info, b"info".to_vec()); + + Ok(()) + } }