From 80143b945a9e01aed2b22873ec52ab2d5a6b618d Mon Sep 17 00:00:00 2001 From: Daniel Akhterov Date: Sun, 23 Jun 2019 21:38:01 -0700 Subject: [PATCH] WIP: Use bytes/bytesmut instead of vec --- mason-mariadb/src/protocol/deserialize.rs | 92 +++++++++++------------ mason-mariadb/src/protocol/server.rs | 80 +++++++++++++------- 2 files changed, 99 insertions(+), 73 deletions(-) diff --git a/mason-mariadb/src/protocol/deserialize.rs b/mason-mariadb/src/protocol/deserialize.rs index 16853d03..561bd006 100644 --- a/mason-mariadb/src/protocol/deserialize.rs +++ b/mason-mariadb/src/protocol/deserialize.rs @@ -5,7 +5,7 @@ use failure::Error; use failure::err_msg; #[inline] -pub fn deserialize_length(buf: &Vec, index: &mut usize) -> Result { +pub fn deserialize_length(buf: &Bytes, index: &mut usize) -> Result { let length = deserialize_int_3(&buf, index); if buf.len() < length as usize { @@ -16,7 +16,7 @@ pub fn deserialize_length(buf: &Vec, index: &mut usize) -> Result, index: &mut usize) -> Option { +pub fn deserialize_int_lenenc(buf: &Bytes, index: &mut usize) -> Option { match buf[*index] { 0xFB => { *index += 1; @@ -47,42 +47,42 @@ pub fn deserialize_int_lenenc(buf: &Vec, index: &mut usize) -> Option } #[inline] -pub fn deserialize_int_8(buf: &Vec, index: &mut usize) -> u64 { +pub fn deserialize_int_8(buf: &Bytes, index: &mut usize) -> u64 { let value = LittleEndian::read_u64(&buf[*index..]); *index += 8; value } #[inline] -pub fn deserialize_int_4(buf: &Vec, index: &mut usize) -> u32 { +pub fn deserialize_int_4(buf: &Bytes, index: &mut usize) -> u32 { let value = LittleEndian::read_u32(&buf[*index..]); *index += 4; value } #[inline] -pub fn deserialize_int_3(buf: &Vec, index: &mut usize) -> u32 { +pub fn deserialize_int_3(buf: &Bytes, index: &mut usize) -> u32 { let value = LittleEndian::read_u24(&buf[*index..]); *index += 3; value } #[inline] -pub fn deserialize_int_2(buf: &Vec, index: &mut usize) -> u16 { +pub fn deserialize_int_2(buf: &Bytes, index: &mut usize) -> u16 { let value = LittleEndian::read_u16(&buf[*index..]); *index += 2; value } #[inline] -pub fn deserialize_int_1(buf: &Vec, index: &mut usize) -> u8 { +pub fn deserialize_int_1(buf: &Bytes, index: &mut usize) -> u8 { let value = buf[*index]; *index += 1; value } #[inline] -pub fn deserialize_string_lenenc(buf: &Vec, index: &mut usize) -> Bytes { +pub fn deserialize_string_lenenc(buf: &Bytes, index: &mut usize) -> Bytes { let length = deserialize_int_3(&buf, &mut *index); let value = Bytes::from(&buf[*index..*index + length as usize]); *index = *index + length as usize; @@ -90,21 +90,21 @@ pub fn deserialize_string_lenenc(buf: &Vec, index: &mut usize) -> Bytes { } #[inline] -pub fn deserialize_string_fix(buf: &Vec, index: &mut usize, length: usize) -> Bytes { +pub fn deserialize_string_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes { let value = Bytes::from(&buf[*index..*index + length as usize]); *index = *index + length as usize; value } #[inline] -pub fn deserialize_string_eof(buf: &Vec, index: &mut usize) -> Bytes { +pub fn deserialize_string_eof(buf: &Bytes, index: &mut usize) -> Bytes { let value = Bytes::from(&buf[*index..]); *index = buf.len(); value } #[inline] -pub fn deserialize_string_null(buf: &Vec, index: &mut usize) -> Bytes { +pub fn deserialize_string_null(buf: &Bytes, index: &mut usize) -> Bytes { let null_index = memchr::memchr(0, &buf[*index..]).unwrap(); let value = Bytes::from(&buf[*index..*index + null_index]); *index = *index + null_index + 1; @@ -112,14 +112,14 @@ pub fn deserialize_string_null(buf: &Vec, index: &mut usize) -> Bytes { } #[inline] -pub fn deserialize_byte_fix(buf: &Vec, index: &mut usize, length: usize) -> Bytes { +pub fn deserialize_byte_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes { let value = Bytes::from(&buf[*index..*index + length as usize]); *index = *index + length as usize; value } #[inline] -pub fn deserialize_byte_lenenc(buf: &Vec, index: &mut usize) -> Bytes { +pub fn deserialize_byte_lenenc(buf: &Bytes, index: &mut usize) -> Bytes { let length = deserialize_int_3(&buf, &mut *index); let value = Bytes::from(&buf[*index..*index + length as usize]); *index = *index + length as usize; @@ -127,7 +127,7 @@ pub fn deserialize_byte_lenenc(buf: &Vec, index: &mut usize) -> Bytes { } #[inline] -pub fn deserialize_byte_eof(buf: &Vec, index: &mut usize) -> Bytes { +pub fn deserialize_byte_eof(buf: &Bytes, index: &mut usize) -> Bytes { let value = Bytes::from(&buf[*index..]); *index = buf.len(); value @@ -154,9 +154,9 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fb() { - let mut buf: Vec = b"\xFB".to_vec(); + let buf: BytesMut = BytesMut::from(b"\xFB".to_vec()); let mut index = 0; - let int: Option = deserialize_int_lenenc(&buf, &mut index); + let int: Option = deserialize_int_lenenc(&buf.freeze(), &mut index); assert_eq!(int, None); assert_eq!(index, 1); @@ -164,9 +164,9 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fc() { - let mut buf = b"\xFC\x01\x01".to_vec(); + let buf = BytesMut::from(b"\xFC\x01\x01".to_vec()); let mut index = 0; - let int: Option = deserialize_int_lenenc(&buf, &mut index); + let int: Option = deserialize_int_lenenc(&buf.freeze(), &mut index); assert_eq!(int, Some(257)); assert_eq!(index, 3); @@ -174,9 +174,9 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fd() { - let mut buf = b"\xFD\x01\x01\x01".to_vec(); + let buf = BytesMut::from(b"\xFD\x01\x01\x01".to_vec()); let mut index = 0; - let int: Option = deserialize_int_lenenc(&buf, &mut index); + let int: Option = deserialize_int_lenenc(&buf.freeze(), &mut index); assert_eq!(int, Some(65793)); assert_eq!(index, 4); @@ -184,9 +184,9 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fe() { - let mut buf = b"\xFE\x01\x01\x01\x01\x01\x01\x01\x01".to_vec(); + let buf = BytesMut::from(b"\xFE\x01\x01\x01\x01\x01\x01\x01\x01".to_vec()); let mut index = 0; - let int: Option = deserialize_int_lenenc(&buf, &mut index); + let int: Option = deserialize_int_lenenc(&buf.freeze(), &mut index); assert_eq!(int, Some(72340172838076673)); assert_eq!(index, 9); @@ -194,9 +194,9 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fa() { - let mut buf = b"\xFA\x01".to_vec(); + let buf = BytesMut::from(b"\xFA\x01".to_vec()); let mut index = 0; - let int: Option = deserialize_int_lenenc(&buf, &mut index); + let int: Option = deserialize_int_lenenc(&buf.freeze(), &mut index); assert_eq!(int, Some(1)); assert_eq!(index, 2); @@ -204,9 +204,9 @@ mod tests { #[test] fn it_decodes_int_8() { - let mut buf = b"\x01\x01\x01\x01\x01\x01\x01\x01".to_vec(); + let buf = BytesMut::from(b"\x01\x01\x01\x01\x01\x01\x01\x01".to_vec()); let mut index = 0; - let int: u64 = deserialize_int_8(&buf, &mut index); + let int: u64 = deserialize_int_8(&buf.freeze(), &mut index); assert_eq!(int, 72340172838076673); assert_eq!(index, 8); @@ -214,9 +214,9 @@ mod tests { #[test] fn it_decodes_int_4() { - let mut buf = b"\x01\x01\x01\x01".to_vec(); + let buf = BytesMut::from(b"\x01\x01\x01\x01".to_vec()); let mut index = 0; - let int: u32 = deserialize_int_4(&buf, &mut index); + let int: u32 = deserialize_int_4(&buf.freeze(), &mut index); assert_eq!(int, 16843009); assert_eq!(index, 4); @@ -224,9 +224,9 @@ mod tests { #[test] fn it_decodes_int_3() { - let mut buf = b"\x01\x01\x01".to_vec(); + let buf = BytesMut::from(b"\x01\x01\x01".to_vec()); let mut index = 0; - let int: u32 = deserialize_int_3(&buf, &mut index); + let int: u32 = deserialize_int_3(&buf.freeze(), &mut index); assert_eq!(int, 65793); assert_eq!(index, 3); @@ -234,9 +234,9 @@ mod tests { #[test] fn it_decodes_int_2() { - let mut buf = b"\x01\x01".to_vec(); + let buf = BytesMut::from(b"\x01\x01".to_vec()); let mut index = 0; - let int: u16 = deserialize_int_2(&buf, &mut index); + let int: u16 = deserialize_int_2(&buf.freeze(), &mut index); assert_eq!(int, 257); assert_eq!(index, 2); @@ -244,9 +244,9 @@ mod tests { #[test] fn it_decodes_int_1() { - let mut buf = &b"\x01".to_vec(); + let buf = BytesMut::from(b"\x01".to_vec()); let mut index = 0; - let int: u8 = deserialize_int_1(&buf, &mut index); + let int: u8 = deserialize_int_1(&buf.freeze(), &mut index); assert_eq!(int, 1); assert_eq!(index, 1); @@ -254,9 +254,9 @@ mod tests { #[test] fn it_decodes_string_lenenc() { - let mut buf = &b"\x01\x00\x00\x01".to_vec(); + let buf = BytesMut::from(b"\x01\x00\x00\x01".to_vec()); let mut index = 0; - let string: Bytes = deserialize_string_lenenc(&buf, &mut index); + let string: Bytes = deserialize_string_lenenc(&buf.freeze(), &mut index); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -265,9 +265,9 @@ mod tests { #[test] fn it_decodes_string_fix() { - let mut buf = &b"\x01".to_vec(); + let buf = BytesMut::from(b"\x01".to_vec()); let mut index = 0; - let string: Bytes = deserialize_string_fix(&buf, &mut index, 1); + let string: Bytes = deserialize_string_fix(&buf.freeze(), &mut index, 1); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -276,9 +276,9 @@ mod tests { #[test] fn it_decodes_string_eof() { - let mut buf = &b"\x01".to_vec(); + let buf = BytesMut::from(b"\x01".to_vec()); let mut index = 0; - let string: Bytes = deserialize_string_eof(&buf, &mut index); + let string: Bytes = deserialize_string_eof(&buf.freeze(), &mut index); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -287,9 +287,9 @@ mod tests { #[test] fn it_decodes_string_null() { - let mut buf = &b"random\x00\x01".to_vec(); + let buf = BytesMut::from(b"random\x00\x01".to_vec()); let mut index = 0; - let string: Bytes = deserialize_string_null(&buf, &mut index); + let string: Bytes = deserialize_string_null(&buf.freeze(), &mut index); assert_eq!(string[0], b'r'); assert_eq!(string[1], b'a'); @@ -304,9 +304,9 @@ mod tests { #[test] fn it_decodes_byte_fix() { - let mut buf = &b"\x01".to_vec(); + let buf = BytesMut::from(b"\x01".to_vec()); let mut index = 0; - let string: Bytes = deserialize_byte_fix(&buf, &mut index, 1); + let string: Bytes = deserialize_byte_fix(&buf.freeze(), &mut index, 1); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -315,9 +315,9 @@ mod tests { #[test] fn it_decodes_byte_eof() { - let mut buf = &b"\x01".to_vec(); + let buf = BytesMut::from(b"\x01".to_vec()); let mut index = 0; - let string: Bytes = deserialize_byte_eof(&buf, &mut index); + let string: Bytes = deserialize_byte_eof(&buf.freeze(), &mut index); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index 0009d3e8..be104f68 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -3,9 +3,10 @@ use crate::protocol::deserialize::*; use bytes::{Bytes, BytesMut}; use failure::{err_msg, Error}; +use byteorder::{LittleEndian, ByteOrder}; pub trait Deserialize: Sized { - fn deserialize(buf: &mut Vec) -> Result; + fn deserialize(buf: &Bytes) -> Result; } #[derive(Debug)] @@ -179,12 +180,38 @@ pub struct ErrPacket { impl Message { pub fn deserialize(buf: &mut BytesMut) -> Result, Error> { - // let length = deserialize_int_3(buf, & - // let sequence_number = buf[3]; - Ok(None) + if buf.len() < 4 { + return Ok(None); + } + + let mut index = 0_usize; + let length = LittleEndian::read_u24(&buf[0..]) as usize; + if buf.len() < length + 4 { + return Ok(None); + } + + let buf = buf.split_to(length + 1).freeze(); + let serial_number = deserialize_int_1(&buf, &mut index); + let tag = deserialize_int_1(&buf, &mut index); + + Ok(Some(match tag { + 0xFF => { + Message::ErrPacket(ErrPacket::deserialize(&buf)?) + } + 0x00 => { + Message::OkPacket(OkPacket::deserialize(&buf)?) + } + _ => { + unimplemented!() + } + })) } pub fn init(buf: &mut BytesMut) -> Result, Error> { - match InitialHandshakePacket::deserialize(&mut buf.to_vec()) { + let length = LittleEndian::read_u24(&buf[0..]) as usize; + if buf.len() < length + 4 { + return Ok(None); + } + match InitialHandshakePacket::deserialize(&buf.split_to(length + 1).freeze()) { Ok(v) => Ok(Some(Message::InitialHandshakePacket(v))), Err(_) => Ok(None), } @@ -192,7 +219,7 @@ impl Message { } impl Deserialize for InitialHandshakePacket { - fn deserialize(buf: &mut Vec) -> Result { + fn deserialize(buf: &Bytes) -> Result { let mut index = 0; let length = deserialize_length(&buf, &mut index)?; @@ -270,7 +297,7 @@ impl Deserialize for InitialHandshakePacket { } impl Deserialize for OkPacket { - fn deserialize(buf: &mut Vec) -> Result { + fn deserialize(buf: &Bytes) -> Result { let mut index = 0; let length = deserialize_length(&buf, &mut index)?; @@ -305,7 +332,7 @@ impl Deserialize for OkPacket { } impl Deserialize for ErrPacket { - fn deserialize(buf: &mut Vec) -> Result { + fn deserialize(buf: &Bytes) -> Result { let mut index = 0; let length = deserialize_length(&buf, &mut index)?; @@ -329,17 +356,17 @@ impl Deserialize for ErrPacket { // Progress Reporting if error_code == 0xFFFF { - stage = Some(deserialize_int_1(&buf, &mut index)); - max_stage = Some(deserialize_int_1(&buf, &mut index)); - progress = Some(deserialize_int_3(&buf, &mut index)); + stage = Some(deserialize_int_1(buf, &mut index)); + max_stage = Some(deserialize_int_1(buf, &mut index)); + progress = Some(deserialize_int_3(buf, &mut index)); progress_info = Some(deserialize_string_lenenc(&buf, &mut index)); } else { if buf[index] == b'#' { - sql_state_marker = Some(deserialize_string_fix(&buf, &mut index, 1)); - sql_state = Some(deserialize_string_fix(&buf, &mut index, 5)); - error_message = Some(deserialize_string_eof(&buf, &mut index)); + sql_state_marker = Some(deserialize_string_fix(buf, &mut index, 1)); + sql_state = Some(deserialize_string_fix(buf, &mut index, 5)); + error_message = Some(deserialize_string_eof(buf, &mut index)); } else { - error_message = Some(deserialize_string_eof(&buf, &mut index)); + error_message = Some(deserialize_string_eof(buf, &mut index)); } } @@ -362,14 +389,14 @@ mod test { #[test] fn it_decodes_capabilities() { - let buf = b"\xfe\xf7".to_vec(); + let buf = BytesMut::from(b"\xfe\xf7".to_vec()); let mut index = 0; - Capabilities::from_bits_truncate(deserialize_int_2(&buf, &mut index).into()); + Capabilities::from_bits_truncate(deserialize_int_2(&buf.freeze(), &mut index).into()); } #[test] fn it_decodes_initialhandshakepacket() -> Result<(), Error> { - let mut buf = b"\ + let buf = BytesMut::from(b"\ n\0\0\ \0\ \n\ @@ -386,16 +413,16 @@ mod test { \x07\0\0\0\ JQ8cihP4Q}Dx\ \0\ - mysql_native_password\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".to_vec(); + mysql_native_password\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0".to_vec()); - let _message = InitialHandshakePacket::deserialize(&mut buf)?; + let _message = InitialHandshakePacket::deserialize(&buf.freeze())?; Ok(()) } #[test] fn it_decodes_okpacket() -> Result<(), Error> { - let mut buf = b"\ + let buf = BytesMut::from(b"\ \x0F\x00\x00\ \x01\ \x00\ @@ -404,10 +431,9 @@ mod test { \x01\x01\ \x00\x00\ info\ - " - .to_vec(); + ".to_vec()); - let message = OkPacket::deserialize(&mut buf)?; + let message = OkPacket::deserialize(&buf.freeze())?; assert_eq!(message.affected_rows, None); assert_eq!(message.last_insert_id, None); @@ -420,7 +446,7 @@ mod test { #[test] fn it_decodes_errpacket() -> Result<(), Error> { - let mut buf = b"\ + let buf = BytesMut::from(b"\ \x0F\x00\x00\ \x01\ \xFF\ @@ -429,9 +455,9 @@ mod test { HY000\ NO\ " - .to_vec(); + .to_vec()); - let message = ErrPacket::deserialize(&mut buf)?; + let message = ErrPacket::deserialize(&buf.freeze())?; assert_eq!(message.error_code, 1002); assert_eq!(message.sql_state_marker, Some(Bytes::from(b"#".to_vec())));