diff --git a/src/pg/connection/execute.rs b/src/pg/connection/execute.rs index fff88791..9c24c65a 100644 --- a/src/pg/connection/execute.rs +++ b/src/pg/connection/execute.rs @@ -12,7 +12,7 @@ pub async fn execute(conn: &mut PgRawConnection) -> io::Result { Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {} Message::CommandComplete(body) => { - rows = body.rows(); + rows = body.rows; } Message::ReadyForQuery(_) => { diff --git a/src/pg/protocol/authentication.rs b/src/pg/protocol/authentication.rs index cdc51b7d..ca3dcb64 100644 --- a/src/pg/protocol/authentication.rs +++ b/src/pg/protocol/authentication.rs @@ -26,21 +26,21 @@ pub enum Authentication { Sspi, /// This message contains GSSAPI or SSPI data. - GssContinue { data: Bytes }, + GssContinue { data: Vec }, /// SASL authentication is required. // FIXME: authentication mechanisms Sasl, /// This message contains a SASL challenge. - SaslContinue { data: Bytes }, + SaslContinue { data: Vec }, /// SASL authentication has completed. - SaslFinal { data: Bytes }, + SaslFinal { data: Vec }, } impl Decode for Authentication { - fn decode(src: Bytes) -> io::Result { + fn decode(src: &[u8]) -> io::Result { Ok(match src[0] { 0 => Authentication::Ok, 2 => Authentication::KerberosV5, diff --git a/src/pg/protocol/backend_key_data.rs b/src/pg/protocol/backend_key_data.rs index 4a1389a2..f9fbf768 100644 --- a/src/pg/protocol/backend_key_data.rs +++ b/src/pg/protocol/backend_key_data.rs @@ -23,15 +23,17 @@ impl BackendKeyData { } } -impl Decode for BackendKeyData { - fn decode(src: Bytes) -> io::Result { - let process_id = u32::from_be_bytes(src.as_ref()[0..4].try_into().unwrap()); - let secret_key = u32::from_be_bytes(src.as_ref()[4..8].try_into().unwrap()); +impl BackendKeyData { + pub fn decode2(src: &[u8]) -> Self { + // todo: error handling + assert_eq!(src.len(), 8); + let process_id = u32::from_be_bytes(src[0..4].try_into().unwrap()); + let secret_key = u32::from_be_bytes(src[4..8].try_into().unwrap()); - Ok(Self { + Self { process_id, secret_key, - }) + } } } @@ -45,8 +47,8 @@ mod tests { #[test] fn it_decodes_backend_key_data() -> io::Result<()> { - let src = Bytes::from_static(BACKEND_KEY_DATA); - let message = BackendKeyData::decode(src)?; + let src = BACKEND_KEY_DATA; + let message = BackendKeyData::decode2(src); assert_eq!(message.process_id(), 10182); assert_eq!(message.secret_key(), 2303903019); diff --git a/src/pg/protocol/command_complete.rs b/src/pg/protocol/command_complete.rs index 4b1e034f..a90bf21b 100644 --- a/src/pg/protocol/command_complete.rs +++ b/src/pg/protocol/command_complete.rs @@ -5,32 +5,23 @@ use std::{io, str}; #[derive(Debug)] pub struct CommandComplete { - tag: Bytes, + pub rows: u64, } impl CommandComplete { - #[inline] - pub fn tag(&self) -> &str { - unsafe { str::from_utf8_unchecked(&self.tag.as_ref()[..self.tag.len() - 1]) } - } - - pub fn rows(&self) -> u64 { + pub fn decode2(src: &[u8]) -> Self { // Attempt to parse the last word in the command tag as an integer // If it can't be parased, the tag is probably "CREATE TABLE" or something // and we should return 0 rows - let rows_start = memrchr(b' ', &*self.tag).unwrap_or(0); - let rows_s = unsafe { - str::from_utf8_unchecked(&self.tag.as_ref()[(rows_start + 1)..(self.tag.len() - 1)]) + let rows_start = memrchr(b' ',src).unwrap_or(0); + let rows = unsafe { + str::from_utf8_unchecked(&src[(rows_start + 1)..(src.len() - 1)]) }; - rows_s.parse().unwrap_or(0) - } -} - -impl Decode for CommandComplete { - fn decode(src: Bytes) -> io::Result { - Ok(Self { tag: src }) + Self { + rows: rows.parse().unwrap_or(0) + } } } @@ -47,44 +38,36 @@ mod tests { #[test] fn it_decodes_command_complete_for_insert() -> io::Result<()> { - let src = Bytes::from_static(COMMAND_COMPLETE_INSERT); - let message = CommandComplete::decode(src)?; + let message = CommandComplete::decode2(COMMAND_COMPLETE_INSERT); - assert_eq!(message.tag(), "INSERT 0 1"); - assert_eq!(message.rows(), 1); + assert_eq!(message.rows, 1); Ok(()) } #[test] fn it_decodes_command_complete_for_update() -> io::Result<()> { - let src = Bytes::from_static(COMMAND_COMPLETE_UPDATE); - let message = CommandComplete::decode(src)?; + let message = CommandComplete::decode2(COMMAND_COMPLETE_UPDATE); - assert_eq!(message.tag(), "UPDATE 512"); - assert_eq!(message.rows(), 512); + assert_eq!(message.rows, 512); Ok(()) } #[test] fn it_decodes_command_complete_for_begin() -> io::Result<()> { - let src = Bytes::from_static(COMMAND_COMPLETE_BEGIN); - let message = CommandComplete::decode(src)?; + let message = CommandComplete::decode2(COMMAND_COMPLETE_BEGIN); - assert_eq!(message.tag(), "BEGIN"); - assert_eq!(message.rows(), 0); + assert_eq!(message.rows, 0); Ok(()) } #[test] fn it_decodes_command_complete_for_create_table() -> io::Result<()> { - let src = Bytes::from_static(COMMAND_COMPLETE_CREATE_TABLE); - let message = CommandComplete::decode(src)?; + let message = CommandComplete::decode2(COMMAND_COMPLETE_CREATE_TABLE); - assert_eq!(message.tag(), "CREATE TABLE"); - assert_eq!(message.rows(), 0); + assert_eq!(message.rows, 0); Ok(()) } diff --git a/src/pg/protocol/decode.rs b/src/pg/protocol/decode.rs index 94c697b9..124851a7 100644 --- a/src/pg/protocol/decode.rs +++ b/src/pg/protocol/decode.rs @@ -3,7 +3,7 @@ use memchr::memchr; use std::{io, str}; pub trait Decode { - fn decode(src: Bytes) -> io::Result + fn decode(src: &[u8]) -> io::Result where Self: Sized; } diff --git a/src/pg/protocol/message.rs b/src/pg/protocol/message.rs index 1c3f7db4..f0f1a79c 100644 --- a/src/pg/protocol/message.rs +++ b/src/pg/protocol/message.rs @@ -54,38 +54,28 @@ impl Message { return Ok(None); } - let mut old = false; + let src_ = &src.as_ref()[5..(len + 1)]; let message = match token { - b'D' => Message::DataRow(DataRow::decode2(&src.as_ref()[5..(len + 1)])), - - token => { - let src = src.split_to(len + 1).freeze().slice_from(5); - old = true; - - match token { - b'N' | b'E' => Message::Response(Box::new(Response::decode(src)?)), - b'S' => Message::ParameterStatus(ParameterStatus::decode(src)?), - b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(src)?), - b'R' => Message::Authentication(Authentication::decode(src)?), - b'K' => Message::BackendKeyData(BackendKeyData::decode(src)?), - b'T' => Message::RowDescription(RowDescription::decode(src)?), - b'C' => Message::CommandComplete(CommandComplete::decode(src)?), - b'A' => Message::NotificationResponse(NotificationResponse::decode(src)?), - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'n' => Message::NoData, - b's' => Message::PortalSuspended, - b't' => Message::ParameterDescription(ParameterDescription::decode(src)?), - _ => unimplemented!("decode not implemented for token: {}", token as char), - } - }, + b'N' | b'E' => Message::Response(Box::new(Response::decode(src_)?)), + b'D' => Message::DataRow(DataRow::decode2(src_)), + b'S' => Message::ParameterStatus(ParameterStatus::decode(src_)?), + b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(src_)?), + b'R' => Message::Authentication(Authentication::decode(src_)?), + b'K' => Message::BackendKeyData(BackendKeyData::decode2(src_)), + b'T' => Message::RowDescription(RowDescription::decode(src_)?), + b'C' => Message::CommandComplete(CommandComplete::decode2(src_)), + b'A' => Message::NotificationResponse(NotificationResponse::decode(src_)?), + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, + b'3' => Message::CloseComplete, + b'n' => Message::NoData, + b's' => Message::PortalSuspended, + b't' => Message::ParameterDescription(ParameterDescription::decode2(src_)), + _ => unimplemented!("decode not implemented for token: {}", token as char), }; - if !old { - src.advance(len + 1); - } + src.advance(len + 1); Ok(Some(message)) } diff --git a/src/pg/protocol/notification_response.rs b/src/pg/protocol/notification_response.rs index 9027ff11..96645486 100644 --- a/src/pg/protocol/notification_response.rs +++ b/src/pg/protocol/notification_response.rs @@ -6,7 +6,7 @@ use std::{fmt, io, pin::Pin, ptr::NonNull}; pub struct NotificationResponse { #[used] - storage: Pin, + storage: Pin>, pid: u32, channel_name: NonNull, message: NonNull, @@ -46,15 +46,16 @@ impl fmt::Debug for NotificationResponse { } impl Decode for NotificationResponse { - fn decode(src: Bytes) -> io::Result { - let storage = Pin::new(src); - let pid = BigEndian::read_u32(&*storage); + fn decode(src: &[u8]) -> io::Result { + let pid = BigEndian::read_u32(&src); // offset from pid=4 - let channel_name = get_str(&storage[4..])?; + let storage = Pin::new(Vec::from(&src[4..])); - // offset = pid + channel_name.len() + \0 - let message = get_str(&storage[(4 + channel_name.len() + 1)..])?; + let channel_name = get_str(&storage)?; + + // offset = channel_name.len() + \0 + let message = get_str(&storage[(channel_name.len() + 1)..])?; let channel_name = NonNull::from(channel_name); let message = NonNull::from(message); @@ -78,8 +79,7 @@ mod tests { #[test] fn it_decodes_notification_response() -> io::Result<()> { - let src = Bytes::from_static(NOTIFICATION_RESPONSE); - let message = NotificationResponse::decode(src)?; + let message = NotificationResponse::decode(NOTIFICATION_RESPONSE)?; assert_eq!(message.pid(), 0x34201002); assert_eq!(message.channel_name(), "TEST-CHANNEL"); diff --git a/src/pg/protocol/parameter_description.rs b/src/pg/protocol/parameter_description.rs index 810bf7ef..478316f8 100644 --- a/src/pg/protocol/parameter_description.rs +++ b/src/pg/protocol/parameter_description.rs @@ -11,8 +11,8 @@ pub struct ParameterDescription { ids: Vec, } -impl Decode for ParameterDescription { - fn decode(src: Bytes) -> io::Result { +impl ParameterDescription { + pub fn decode2(src: &[u8]) -> Self { let count = BigEndian::read_u16(&*src) as usize; // todo: error handling @@ -24,7 +24,7 @@ impl Decode for ParameterDescription { ids.push(BigEndian::read_u32(&src[offset..])); } - Ok(ParameterDescription { ids }) + ParameterDescription { ids } } } @@ -36,8 +36,8 @@ mod test { #[test] fn it_decodes_parameter_description() -> io::Result<()> { - let src = Bytes::from_static(b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"); - let desc = ParameterDescription::decode(src)?; + let src = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; + let desc = ParameterDescription::decode2(src); assert_eq!(desc.ids.len(), 2); assert_eq!(desc.ids[0], 0x0000_0000); @@ -47,8 +47,8 @@ mod test { #[test] fn it_decodes_empty_parameter_description() -> io::Result<()> { - let src = Bytes::from_static(b"\x00\x00"); - let desc = ParameterDescription::decode(src)?; + let src = b"\x00\x00"; + let desc = ParameterDescription::decode2(src); assert_eq!(desc.ids.len(), 0); Ok(()) @@ -57,7 +57,7 @@ mod test { #[test] #[should_panic] fn parameter_description_wrong_length_fails() -> () { - let src = Bytes::from_static(b"\x00\x00\x00\x01\x02\x03"); - ParameterDescription::decode(src).unwrap(); + let src = b"\x00\x00\x00\x01\x02\x03"; + ParameterDescription::decode2(src); } } diff --git a/src/pg/protocol/parameter_status.rs b/src/pg/protocol/parameter_status.rs index 65c2556a..c35c1fc1 100644 --- a/src/pg/protocol/parameter_status.rs +++ b/src/pg/protocol/parameter_status.rs @@ -5,8 +5,8 @@ use std::{io, str}; // FIXME: Use &str functions for a custom Debug #[derive(Debug)] pub struct ParameterStatus { - name: Bytes, - value: Bytes, + name: Vec, + value: Vec, } impl ParameterStatus { @@ -22,12 +22,12 @@ impl ParameterStatus { } impl Decode for ParameterStatus { - fn decode(src: Bytes) -> io::Result { + fn decode(src: &[u8]) -> io::Result { let name_end = memchr::memchr(0, &src).unwrap(); let value_end = memchr::memchr(0, &src[(name_end + 1)..]).unwrap(); - let name = src.slice_to(name_end); - let value = src.slice(name_end + 1, name_end + 1 + value_end); + let name = src[..name_end].into(); + let value = src[(name_end + 1)..(name_end + 1 + value_end)].into(); Ok(Self { name, value }) } @@ -43,8 +43,7 @@ mod tests { #[test] fn it_decodes_param_status() -> io::Result<()> { - let src = Bytes::from_static(PARAM_STATUS); - let message = ParameterStatus::decode(src)?; + let message = ParameterStatus::decode(PARAM_STATUS)?; assert_eq!(message.name(), "session_authorization"); assert_eq!(message.value(), "postgres"); diff --git a/src/pg/protocol/ready_for_query.rs b/src/pg/protocol/ready_for_query.rs index 2e759868..bfdd58e1 100644 --- a/src/pg/protocol/ready_for_query.rs +++ b/src/pg/protocol/ready_for_query.rs @@ -22,7 +22,7 @@ pub struct ReadyForQuery { } impl Decode for ReadyForQuery { - fn decode(src: Bytes) -> io::Result { + fn decode(src: &[u8]) -> io::Result { if src.len() != 1 { return Err(io::ErrorKind::InvalidInput)?; } @@ -50,8 +50,7 @@ mod tests { #[test] fn it_decodes_ready_for_query() -> io::Result<()> { - let src = Bytes::from_static(READY_FOR_QUERY); - let message = ReadyForQuery::decode(src)?; + let message = ReadyForQuery::decode(READY_FOR_QUERY)?; assert_eq!(message.status, TransactionStatus::Error); diff --git a/src/pg/protocol/response.rs b/src/pg/protocol/response.rs index 41dbe002..8bdc19ad 100644 --- a/src/pg/protocol/response.rs +++ b/src/pg/protocol/response.rs @@ -77,7 +77,7 @@ impl FromStr for Severity { #[derive(Clone)] pub struct Response { #[used] - storage: Pin, + storage: Pin>, severity: Severity, code: NonNull, message: NonNull, @@ -226,8 +226,8 @@ impl fmt::Debug for Response { } impl Decode for Response { - fn decode(src: Bytes) -> io::Result { - let storage = Pin::new(src); + fn decode(src: &[u8]) -> io::Result { + let storage = Pin::new(Vec::from(src)); let mut code = None::<&str>; let mut message = None::<&str>; @@ -407,8 +407,7 @@ mod tests { #[test] fn it_decodes_response() -> io::Result<()> { - let src = Bytes::from_static(RESPONSE); - let message = Response::decode(src)?; + let message = Response::decode(RESPONSE)?; assert_eq!(message.severity(), Severity::Notice); assert_eq!( diff --git a/src/pg/protocol/row_description.rs b/src/pg/protocol/row_description.rs index 28030702..19eb86fa 100644 --- a/src/pg/protocol/row_description.rs +++ b/src/pg/protocol/row_description.rs @@ -64,25 +64,25 @@ impl<'a> FieldDescription<'a> { pub struct RowDescription { // The number of fields in a row (can be zero). len: u16, - data: Bytes, + data: Vec, } impl RowDescription { pub fn fields(&self) -> FieldDescriptions<'_> { FieldDescriptions { rem: self.len, - buf: &*self.data, + buf: &self.data, } } } impl Decode for RowDescription { - fn decode(src: Bytes) -> io::Result { + fn decode(src: &[u8]) -> io::Result { let len = BigEndian::read_u16(&src[..2]); Ok(Self { len, - data: src.slice_from(2), + data: src[2..].into(), }) } } @@ -153,8 +153,7 @@ mod tests { #[test] fn it_decodes_row_description() -> io::Result<()> { - let src = Bytes::from_static(ROW_DESC); - let message = RowDescription::decode(src)?; + let message = RowDescription::decode(ROW_DESC)?; assert_eq!(message.fields().len(), 3); for field in message.fields() {