diff --git a/mason-mariadb/src/connection/establish.rs b/mason-mariadb/src/connection/establish.rs index 8ab0cf14..502407bc 100644 --- a/mason-mariadb/src/connection/establish.rs +++ b/mason-mariadb/src/connection/establish.rs @@ -11,7 +11,7 @@ pub async fn establish<'a, 'b: 'a>( conn: &'a mut Connection, _options: ConnectOptions<'b>, ) -> Result<(), Error> { - let init_packet = InitialHandshakePacket::deserialize(&conn.stream.next_bytes().await?)?; + let init_packet = InitialHandshakePacket::deserialize(&conn.stream.next_bytes().await?, None)?; conn.capabilities = init_packet.capabilities; diff --git a/mason-mariadb/src/connection/mod.rs b/mason-mariadb/src/connection/mod.rs index 6695799c..a014e772 100644 --- a/mason-mariadb/src/connection/mod.rs +++ b/mason-mariadb/src/connection/mod.rs @@ -92,7 +92,7 @@ impl Connection { self.send(ComPing()).await?; // Ping response must be an OkPacket - OkPacket::deserialize(&self.stream.next_bytes().await?)?; + OkPacket::deserialize(&self.stream.next_bytes().await?, None)?; Ok(()) } diff --git a/mason-mariadb/src/protocol/decode.rs b/mason-mariadb/src/protocol/decode.rs index 36c802f2..da53c406 100644 --- a/mason-mariadb/src/protocol/decode.rs +++ b/mason-mariadb/src/protocol/decode.rs @@ -3,142 +3,157 @@ use byteorder::{ByteOrder, LittleEndian}; use bytes::Bytes; use failure::{err_msg, Error}; -//pub struct Decoder<'a> { -// pub buf: &'a Bytes, -// pub index: usize, -//} - -#[inline] -pub fn decode_length(buf: &Bytes, index: &mut usize) -> Result { - let length = decode_int_3(&buf, index); - - if buf.len() < length as usize { - return Err(err_msg("Lengths to do not match")); - } - - Ok(length) +pub struct Decoder<'a> { + pub buf: &'a Bytes, + pub index: usize, } -#[inline] -pub fn decode_int_lenenc(buf: &Bytes, index: &mut usize) -> Option { - match buf[*index] { - 0xFB => { - *index += 1; - None +impl<'a> Decoder<'a> { + pub fn new(buf: &'a Bytes) -> Self { + Decoder { + buf, + index: 0, } - 0xFC => { - let value = Some(LittleEndian::read_u16(&buf[*index + 1..]) as usize); - *index += 3; - value + } + + #[inline] + pub fn decode_length(&mut self) -> Result { + let length = self.decode_int_3(); + + if self.buf.len() < length as usize { + return Err(err_msg("Lengths to do not match")); } - 0xFD => { - let value = Some(LittleEndian::read_u24(&buf[*index + 1..]) as usize); - *index += 4; - value + + Ok(length) + } + + pub fn skip_bytes(&mut self, amount: usize) { + self.index += amount; + } + + #[inline] + pub fn decode_int_lenenc(&mut self) -> Option { + match self.buf[self.index] { + 0xFB => { + self.index += 1; + None + } + 0xFC => { + let value = Some(LittleEndian::read_u16(&self.buf[self.index + 1..]) as usize); + self.index += 3; + value + } + 0xFD => { + let value = Some(LittleEndian::read_u24(&self.buf[self.index + 1..]) as usize); + self.index += 4; + value + } + 0xFE => { + let value = Some(LittleEndian::read_u64(&self.buf[self.index + 1..]) as usize); + self.index += 9; + value + } + 0xFF => panic!("int unprocessable first byte 0xFF"), + _ => { + let value = Some(self.buf[self.index] as usize); + self.index += 1; + value + } } - 0xFE => { - let value = Some(LittleEndian::read_u64(&buf[*index + 1..]) as usize); - *index += 9; - value - } - 0xFF => panic!("int unprocessable first byte 0xFF"), - _ => { - let value = Some(buf[*index] as usize); - *index += 1; - value + } + + #[inline] + pub fn decode_int_8(&mut self) -> u64 { + let value = LittleEndian::read_u64(&self.buf[self.index..]); + self.index += 8; + value + } + + #[inline] + pub fn decode_int_4(&mut self) -> u32 { + let value = LittleEndian::read_u32(&self.buf[self.index..]); + self.index += 4; + value + } + + #[inline] + pub fn decode_int_3(&mut self) -> u32 { + let value = LittleEndian::read_u24(&self.buf[self.index..]); + self.index += 3; + value + } + + #[inline] + pub fn decode_int_2(&mut self) -> u16 { + let value = LittleEndian::read_u16(&self.buf[self.index..]); + self.index += 2; + value + } + + #[inline] + pub fn decode_int_1(&mut self) -> u8 { + let value = self.buf[self.index]; + self.index += 1; + value + } + + #[inline] + pub fn decode_string_lenenc(&mut self) -> Bytes { + let length = self.decode_int_3(); + let value = Bytes::from(&self.buf[self.index..self.index + length as usize]); + self.index = self.index + length as usize; + value + } + + #[inline] + pub fn decode_string_fix(&mut self, length: u32) -> Bytes { + let value = Bytes::from(&self.buf[self.index..self.index + length as usize]); + self.index = self.index + length as usize; + value + } + + #[inline] + pub fn decode_string_eof(&mut self) -> Bytes { + let value = Bytes::from(&self.buf[self.index..]); + self.index = self.buf.len(); + value + } + + #[inline] + pub fn decode_string_null(&mut self) -> Result { + if let Some(null_index) = memchr::memchr(0, &self.buf[self.index..]) { + let value = Bytes::from(&self.buf[self.index..self.index + null_index]); + self.index = self.index + null_index + 1; + Ok(value) + } else { + Err(err_msg("Null index no found")) } } + + #[inline] + pub fn decode_byte_fix(&mut self, length: u32) -> Bytes { + let value = Bytes::from(&self.buf[self.index..self.index + length as usize]); + self.index = self.index + length as usize; + value + } + + #[inline] + pub fn decode_byte_lenenc(&mut self) -> Bytes { + let length = self.decode_int_3(); + let value = Bytes::from(&self.buf[self.index..self.index + length as usize]); + self.index = self.index + length as usize; + value + } + + #[inline] + pub fn decode_byte_eof(&mut self) -> Bytes { + let value = Bytes::from(&self.buf[self.index..]); + self.index = self.buf.len(); + value + } } -#[inline] -pub fn decode_int_8(buf: &Bytes, index: &mut usize) -> u64 { - let value = LittleEndian::read_u64(&buf[*index..]); - *index += 8; - value -} -#[inline] -pub fn decode_int_4(buf: &Bytes, index: &mut usize) -> u32 { - let value = LittleEndian::read_u32(&buf[*index..]); - *index += 4; - value -} - -#[inline] -pub fn decode_int_3(buf: &Bytes, index: &mut usize) -> u32 { - let value = LittleEndian::read_u24(&buf[*index..]); - *index += 3; - value -} - -#[inline] -pub fn decode_int_2(buf: &Bytes, index: &mut usize) -> u16 { - let value = LittleEndian::read_u16(&buf[*index..]); - *index += 2; - value -} - -#[inline] -pub fn decode_int_1(buf: &Bytes, index: &mut usize) -> u8 { - let value = buf[*index]; - *index += 1; - value -} - -#[inline] -pub fn decode_string_lenenc(buf: &Bytes, index: &mut usize) -> Bytes { - let length = decode_int_3(&buf, &mut *index); - let value = Bytes::from(&buf[*index..*index + length as usize]); - *index = *index + length as usize; - value -} - -#[inline] -pub fn decode_string_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes { - let value = Bytes::from(&buf[*index..*index + length as usize]); - *index = *index + length as usize; - value -} - -#[inline] -pub fn decode_string_eof(buf: &Bytes, index: &mut usize) -> Bytes { - let value = Bytes::from(&buf[*index..]); - *index = buf.len(); - value -} - -#[inline] -pub fn decode_string_null(buf: &Bytes, index: &mut usize) -> Result { - if let Some(null_index) = memchr::memchr(0, &buf[*index..]) { - let value = Bytes::from(&buf[*index..*index + null_index]); - *index = *index + null_index + 1; - Ok(value) - } else { - Err(err_msg("Null index no found")) - } -} - -#[inline] -pub fn decode_byte_fix(buf: &Bytes, index: &mut usize, length: usize) -> Bytes { - let value = Bytes::from(&buf[*index..*index + length as usize]); - *index = *index + length as usize; - value -} - -#[inline] -pub fn decode_byte_lenenc(buf: &Bytes, index: &mut usize) -> Bytes { - let length = decode_int_3(&buf, &mut *index); - let value = Bytes::from(&buf[*index..*index + length as usize]); - *index = *index + length as usize; - value -} - -#[inline] -pub fn decode_byte_eof(buf: &Bytes, index: &mut usize) -> Bytes { - let value = Bytes::from(&buf[*index..]); - *index = buf.len(); - value -} #[cfg(test)] mod tests { @@ -146,24 +161,23 @@ mod tests { use bytes::{Bytes, BytesMut}; use failure::Error; - // [X] deserialize_int_lenenc - // [X] deserialize_int_8 - // [X] deserialize_int_4 - // [X] deserialize_int_3 - // [X] deserialize_int_2 - // [X] deserialize_int_1 - // [X] deserialize_string_lenenc - // [X] deserialize_string_fix - // [X] deserialize_string_eof - // [X] deserialize_string_null - // [X] deserialize_byte_lenenc - // [X] deserialize_byte_eof + // [X] it_decodes_int_lenenc + // [X] it_decodes_int_8 + // [X] it_decodes_int_4 + // [X] it_decodes_int_3 + // [X] it_decodes_int_2 + // [X] it_decodes_int_1 + // [X] it_decodes_string_lenenc + // [X] it_decodes_string_fix + // [X] it_decodes_string_eof + // [X] it_decodes_string_null + // [X] it_decodes_byte_lenenc + // [X] it_decodes_byte_eof #[test] fn it_decodes_int_lenenc_0x_fb() { - let buf: BytesMut = BytesMut::from(b"\xFB".to_vec()); - let mut index = 0; - let int: Option = decode_int_lenenc(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\xFB".to_vec()).freeze()); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, None); assert_eq!(index, 1); @@ -171,9 +185,8 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fc() { - let buf = BytesMut::from(b"\xFC\x01\x01".to_vec()); - let mut index = 0; - let int: Option = decode_int_lenenc(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\xFC\x01\x01".to_vec()).freeze()); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(257)); assert_eq!(index, 3); @@ -181,9 +194,8 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fd() { - let buf = BytesMut::from(b"\xFD\x01\x01\x01".to_vec()); - let mut index = 0; - let int: Option = decode_int_lenenc(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\xFD\x01\x01\x01".to_vec()).freeze()); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(65793)); assert_eq!(index, 4); @@ -191,9 +203,8 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fe() { - let buf = BytesMut::from(b"\xFE\x01\x01\x01\x01\x01\x01\x01\x01".to_vec()); - let mut index = 0; - let int: Option = decode_int_lenenc(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\xFE\x01\x01\x01\x01\x01\x01\x01\x01".to_vec()).freeze()); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(72340172838076673)); assert_eq!(index, 9); @@ -201,9 +212,8 @@ mod tests { #[test] fn it_decodes_int_lenenc_0x_fa() { - let buf = BytesMut::from(b"\xFA".to_vec()); - let mut index = 0; - let int: Option = decode_int_lenenc(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\xFA".to_vec()).freeze()); + let int: Option = decoder.decode_int_lenenc(); assert_eq!(int, Some(0xfA)); assert_eq!(index, 1); @@ -211,9 +221,8 @@ mod tests { #[test] fn it_decodes_int_8() { - let buf = BytesMut::from(b"\x01\x01\x01\x01\x01\x01\x01\x01".to_vec()); - let mut index = 0; - let int: u64 = decode_int_8(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01\x01\x01\x01\x01\x01\x01\x01".to_vec()).freeze()); + let int: u64 = decoder.decode_int_8(); assert_eq!(int, 72340172838076673); assert_eq!(index, 8); @@ -221,9 +230,8 @@ mod tests { #[test] fn it_decodes_int_4() { - let buf = BytesMut::from(b"\x01\x01\x01\x01".to_vec()); - let mut index = 0; - let int: u32 = decode_int_4(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01\x01\x01\x01".to_vec()).freeze()); + let int: u32 = decoder.decode_int_4(); assert_eq!(int, 16843009); assert_eq!(index, 4); @@ -231,9 +239,8 @@ mod tests { #[test] fn it_decodes_int_3() { - let buf = BytesMut::from(b"\x01\x01\x01".to_vec()); - let mut index = 0; - let int: u32 = decode_int_3(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01\x01\x01".to_vec()).freeze()); + let int: u32 = decoder.decode_int_3(); assert_eq!(int, 65793); assert_eq!(index, 3); @@ -241,9 +248,8 @@ mod tests { #[test] fn it_decodes_int_2() { - let buf = BytesMut::from(b"\x01\x01".to_vec()); - let mut index = 0; - let int: u16 = decode_int_2(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01\x01".to_vec()).freeze()); + let int: u16 = decoder.decode_int_2(); assert_eq!(int, 257); assert_eq!(index, 2); @@ -251,9 +257,8 @@ mod tests { #[test] fn it_decodes_int_1() { - let buf = BytesMut::from(b"\x01".to_vec()); - let mut index = 0; - let int: u8 = decode_int_1(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01".to_vec()).freeze()); + let int: u8 = decoder.decode_int_1(); assert_eq!(int, 1); assert_eq!(index, 1); @@ -261,9 +266,8 @@ mod tests { #[test] fn it_decodes_string_lenenc() { - let buf = BytesMut::from(b"\x01\x00\x00\x01".to_vec()); - let mut index = 0; - let string: Bytes = decode_string_lenenc(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01\x00\x00\x01".to_vec()).freeze()); + let string: Bytes = decoder.decode_string_lenenc(); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -272,9 +276,8 @@ mod tests { #[test] fn it_decodes_string_fix() { - let buf = BytesMut::from(b"\x01".to_vec()); - let mut index = 0; - let string: Bytes = decode_string_fix(&buf.freeze(), &mut index, 1); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01".to_vec()).freeze()); + let string: Bytes = decoder.decode_string_fix(1); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -283,9 +286,8 @@ mod tests { #[test] fn it_decodes_string_eof() { - let buf = BytesMut::from(b"\x01".to_vec()); - let mut index = 0; - let string: Bytes = decode_string_eof(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01".to_vec()).freeze()); + let string: Bytes = decoder.decode_string_eof(); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -294,9 +296,8 @@ mod tests { #[test] fn it_decodes_string_null() -> Result<(), Error> { - let buf = BytesMut::from(b"random\x00\x01".to_vec()); - let mut index = 0; - let string: Bytes = decode_string_null(&buf.freeze(), &mut index)?; + let mut decoder = Decoder::new(&BytesMut::from(b"random\x00\x01".to_vec()).freeze()); + let string: Bytes = decoder.decode_string_null()?; assert_eq!(&string[..], b"random"); @@ -309,9 +310,8 @@ mod tests { #[test] fn it_decodes_byte_fix() { - let buf = BytesMut::from(b"\x01".to_vec()); - let mut index = 0; - let string: Bytes = decode_byte_fix(&buf.freeze(), &mut index, 1); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01".to_vec()).freeze()); + let string: Bytes = decoder.decode_byte_fix(1); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); @@ -320,9 +320,8 @@ mod tests { #[test] fn it_decodes_byte_eof() { - let buf = BytesMut::from(b"\x01".to_vec()); - let mut index = 0; - let string: Bytes = decode_byte_eof(&buf.freeze(), &mut index); + let mut decoder = Decoder::new(&BytesMut::from(b"\x01".to_vec()).freeze()); + let string: Bytes = decoder.decode_byte_eof(); assert_eq!(string[0], b'\x01'); assert_eq!(string.len(), 1); diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index 64497c63..9c13af91 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -7,7 +7,7 @@ use failure::{err_msg, Error}; use std::convert::TryFrom; pub trait Deserialize: Sized { - fn deserialize(buf: &Bytes) -> Result; + fn deserialize<'a: 'b, 'b>(buf: &Bytes, decoder: Option<&mut Decoder>) -> Result; } #[derive(Debug)] @@ -141,6 +141,12 @@ impl Default for ServerStatusFlag { } } +impl Default for FieldDetailFlag { + fn default() -> Self { + FieldDetailFlag::NOT_NULL + } +} + impl Default for FieldType { fn default() -> Self { FieldType::MysqlTypeDecimal @@ -197,6 +203,7 @@ pub struct ColumnPacket { pub columns: Option, } +#[derive(Debug, Default)] pub struct ColumnDefPacket { pub length: u32, pub seq_no: u8, @@ -236,74 +243,73 @@ impl Message { let tag = buf[4]; Ok(Some(match tag { - 0xFF => Message::ErrPacket(ErrPacket::deserialize(&buf)?), - 0x00 | 0xFE => Message::OkPacket(OkPacket::deserialize(&buf)?), + 0xFF => Message::ErrPacket(ErrPacket::deserialize(&buf, None)?), + 0x00 | 0xFE => Message::OkPacket(OkPacket::deserialize(&buf, None)?), _ => unimplemented!(), })) } } impl Deserialize for InitialHandshakePacket { - fn deserialize(buf: &Bytes) -> Result { - let mut index = 0; - - let length = decode_length(&buf, &mut index)?; - let seq_no = decode_int_1(&buf, &mut index); + fn deserialize<'a: 'b, 'b>(buf: &'a Bytes, decoder: Option<&mut Decoder<'b>>) -> Result { + let mut decoder: &mut Decoder = decoder.unwrap_or(&mut Decoder::new(&buf)); + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_1(); if seq_no != 0 { - return Err(err_msg("Squence Number of Initial Handshake Packet is not 0")); + return Err(err_msg("Sequence Number of Initial Handshake Packet is not 0")); } - let protocol_version = decode_int_1(&buf, &mut index); - let server_version = decode_string_null(&buf, &mut index)?; - let connection_id = decode_int_4(&buf, &mut index); - let auth_seed = decode_string_fix(&buf, &mut index, 8); + let protocol_version = decoder.decode_int_1(); + let server_version = decoder.decode_string_null()?; + let connection_id = decoder.decode_int_4(); + let auth_seed = decoder.decode_string_fix(8); // Skip reserved byte - index += 1; + decoder.skip_bytes(1); let mut capabilities = - Capabilities::from_bits_truncate(decode_int_2(&buf, &mut index).into()); + Capabilities::from_bits_truncate(decoder.decode_int_2().into()); - let collation = decode_int_1(&buf, &mut index); + let collation = decoder.decode_int_1(); let status = - ServerStatusFlag::from_bits_truncate(decode_int_2(&buf, &mut index).into()); + ServerStatusFlag::from_bits_truncate(decoder.decode_int_2().into()); capabilities |= Capabilities::from_bits_truncate( - ((decode_int_2(&buf, &mut index) as u32) << 16).into(), + ((decoder.decode_int_2() as u32) << 16).into(), ); let mut plugin_data_length = 0; if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() { - plugin_data_length = decode_int_1(&buf, &mut index); + plugin_data_length = decoder.decode_int_1(); } else { // Skip reserve byte - index += 1; + decoder.skip_bytes(1); } // Skip filler - index += 6; + decoder.skip_bytes(6); if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() { capabilities |= Capabilities::from_bits_truncate( - ((decode_int_4(&buf, &mut index) as u128) << 32).into(), + ((decoder.decode_int_4() as u128) << 32).into(), ); } else { // Skip filler - index += 4; + decoder.skip_bytes(4); } let mut scramble: Option = None; if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() { let len = std::cmp::max(12, plugin_data_length as usize - 9); - scramble = Some(decode_string_fix(&buf, &mut index, len)); + scramble = Some(decoder.decode_string_fix(len as u32)); // Skip reserve byte - index += 1; + decoder.skip_bytes(1); } let mut auth_plugin_name: Option = None; if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() { - auth_plugin_name = Some(decode_string_null(&buf, &mut index)?); + auth_plugin_name = Some(decoder.decode_string_null()?); } Ok(InitialHandshakePacket { @@ -324,30 +330,30 @@ impl Deserialize for InitialHandshakePacket { } impl Deserialize for OkPacket { - fn deserialize(buf: &Bytes) -> Result { - let mut index = 0; + fn deserialize(buf: &Bytes, decoder: Option<&mut Decoder>) -> Result { + let mut decoder = decoder.unwrap_or(&mut Decoder::new(&buf)); // Packet header - let length = decode_length(&buf, &mut index)?; - let seq_no = decode_int_1(&buf, &mut index); + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_1(); // Packet body - let packet_header = decode_int_1(&buf, &mut index); + let packet_header = decoder.decode_int_1(); if packet_header != 0 && packet_header != 0xFE { panic!("Packet header is not 0 or 0xFE for OkPacket"); } - let affected_rows = decode_int_lenenc(&buf, &mut index); - let last_insert_id = decode_int_lenenc(&buf, &mut index); + let affected_rows = decoder.decode_int_lenenc(); + let last_insert_id = decoder.decode_int_lenenc(); let server_status = - ServerStatusFlag::from_bits_truncate(decode_int_2(&buf, &mut index).into()); - let warning_count = decode_int_2(&buf, &mut index); + ServerStatusFlag::from_bits_truncate(decoder.decode_int_2().into()); + let warning_count = decoder.decode_int_2(); // Assuming CLIENT_SESSION_TRACK is unsupported let session_state_info = None; let value = None; - let info = Bytes::from(&buf[index..]); + let info = decoder.decode_byte_eof(); Ok(OkPacket { length, @@ -364,18 +370,18 @@ impl Deserialize for OkPacket { } impl Deserialize for ErrPacket { - fn deserialize(buf: &Bytes) -> Result { - let mut index = 0; + fn deserialize(buf: &Bytes, decoder: Option<&mut Decoder>) -> Result { + let mut decoder = decoder.unwrap_or(&mut Decoder::new(&buf)); - let length = decode_length(&buf, &mut index)?; - let seq_no = decode_int_1(&buf, &mut index); + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_1(); - let packet_header = decode_int_1(&buf, &mut index); + let packet_header = decoder.decode_int_1(); if packet_header != 0xFF { panic!("Packet header is not 0xFF for ErrPacket"); } - let error_code = ErrorCode::try_from(decode_int_2(&buf, &mut index))?; + let error_code = ErrorCode::try_from(decoder.decode_int_2())?; let mut stage = None; let mut max_stage = None; @@ -388,17 +394,17 @@ impl Deserialize for ErrPacket { // Progress Reporting if error_code as u16 == 0xFFFF { - stage = Some(decode_int_1(buf, &mut index)); - max_stage = Some(decode_int_1(buf, &mut index)); - progress = Some(decode_int_3(buf, &mut index)); - progress_info = Some(decode_string_lenenc(&buf, &mut index)); + stage = Some(decoder.decode_int_1()); + max_stage = Some(decoder.decode_int_1()); + progress = Some(decoder.decode_int_3()); + progress_info = Some(decoder.decode_string_lenenc()); } else { - if buf[index] == b'#' { - sql_state_marker = Some(decode_string_fix(buf, &mut index, 1)); - sql_state = Some(decode_string_fix(buf, &mut index, 5)); - error_message = Some(decode_string_eof(buf, &mut index)); + if buf[decoder.index] == b'#' { + sql_state_marker = Some(decoder.decode_string_fix(1)); + sql_state = Some(decoder.decode_string_fix(5)); + error_message = Some(decoder.decode_string_eof()); } else { - error_message = Some(decode_string_eof(buf, &mut index)); + error_message = Some(decoder.decode_string_eof()); } } @@ -418,12 +424,12 @@ impl Deserialize for ErrPacket { } impl Deserialize for ColumnPacket { - fn deserialize(buf: &Bytes) -> Result { - let mut index = 0; + fn deserialize(buf: &Bytes, decoder: Option<&mut Decoder>) -> Result { + let mut decoder = decoder.unwrap_or(&mut Decoder::new(&buf)); - let length = decode_length(&buf, &mut index)?; - let seq_no = decode_int_1(&buf, &mut index); - let columns = decode_int_lenenc(&buf, &mut index); + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_1(); + let columns = decoder.decode_int_lenenc(); Ok(ColumnPacket { length, @@ -434,24 +440,24 @@ impl Deserialize for ColumnPacket { } impl Deserialize for ColumnDefPacket { - fn deserialize(buf: &Bytes) -> Result { - let mut index = 0; + fn deserialize(buf: &Bytes, decoder: Option<&mut Decoder>) -> Result { + let mut decoder = decoder.unwrap_or(&mut Decoder::new(&buf)); - let length = decode_length(&buf, &mut index)?; - let seq_no = decode_int_1(&buf, &mut index); + let length = decoder.decode_length()?; + let seq_no = decoder.decode_int_1(); - let catalog = decode_string_lenenc(&buf, &mut index); - let schema = decode_string_lenenc(&buf, &mut index); - let table_alias = decode_string_lenenc(&buf, &mut index); - let table = decode_string_lenenc(&buf, &mut index); - let column_alias = decode_string_lenenc(&buf, &mut index); - let column = decode_string_lenenc(&buf, &mut index); - let length_of_fixed_fields = decode_int_lenenc(&buf, &mut index); - let char_set = decode_int_2(&buf, &mut index); - let max_columns = decode_int_4(&buf, &mut index); - let field_type = FieldType::try_from(decode_int_1(&buf, &mut index))?; - let field_details = FieldDetailFlag::from_bits_truncate(decode_int_2(&buf, &mut index)); - let decimals = decode_int_1(&buf, &mut index); + let catalog = decoder.decode_string_lenenc(); + let schema = decoder.decode_string_lenenc(); + let table_alias = decoder.decode_string_lenenc(); + let table = decoder.decode_string_lenenc(); + let column_alias = decoder.decode_string_lenenc(); + let column = decoder.decode_string_lenenc(); + let length_of_fixed_fields = decoder.decode_int_lenenc(); + let char_set = decoder.decode_int_2(); + let max_columns = decoder.decode_int_4(); + let field_type = FieldType::try_from(decoder.decode_int_1())?; + let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_2()); + let decimals = decoder.decode_int_1(); // Skip last two unused bytes // index += 2;