From 69cc962d7bb62190cccdcb7d11c45013b380bc95 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 19 Aug 2019 22:10:28 -0700 Subject: [PATCH] Clean up Decode some more, Box some fields in Message --- benches/postgres_protocol_decode.rs | 36 ++++++++-- src/pg/connection/establish.rs | 41 ++++++----- src/pg/protocol/authentication.rs | 12 ++-- src/pg/protocol/backend_key_data.rs | 18 ++--- src/pg/protocol/command_complete.rs | 37 ++++------ src/pg/protocol/data_row.rs | 61 ++++------------- src/pg/protocol/decode.rs | 9 ++- src/pg/protocol/message.rs | 87 ++++++++++++++---------- src/pg/protocol/notification_response.rs | 15 ++-- src/pg/protocol/parameter_description.rs | 33 +++------ src/pg/protocol/parameter_status.rs | 25 ++++--- src/pg/protocol/ready_for_query.rs | 28 ++++---- src/pg/protocol/response.rs | 47 +++++-------- src/pg/protocol/row_description.rs | 15 ++-- src/pg/row.rs | 2 +- 15 files changed, 214 insertions(+), 252 deletions(-) diff --git a/benches/postgres_protocol_decode.rs b/benches/postgres_protocol_decode.rs index fad0d4d3..5432f83d 100644 --- a/benches/postgres_protocol_decode.rs +++ b/benches/postgres_protocol_decode.rs @@ -34,12 +34,36 @@ fn bench(c: &mut Criterion, name: &'static str, input: &'static [u8]) { } fn criterion_benchmark(c: &mut Criterion) { - bench(c, "postgres - decode - Message - DataRow (x 1000)", MESSAGE_DATA_ROW); - bench(c, "postgres - decode - Message - ReadyForQuery (x 1000)", MESSAGE_READY_FOR_QUERY); - bench(c, "postgres - decode - Message - CommandComplete (x 1000)", MESSAGE_COMMAND_COMPLETE); - bench(c, "postgres - decode - Message - Response (x 1000)", MESSAGE_RESPONSE); - bench(c, "postgres - decode - Message - BackendKeyData (x 1000)", MESSAGE_BACKEND_KEY_DATA); - bench(c, "postgres - decode - Message - ParameterStatus (x 1000)", MESSAGE_PARAMETER_STATUS); + bench( + c, + "postgres - decode - Message - DataRow (x 1000)", + MESSAGE_DATA_ROW, + ); + bench( + c, + "postgres - decode - Message - ReadyForQuery (x 1000)", + MESSAGE_READY_FOR_QUERY, + ); + bench( + c, + "postgres - decode - Message - CommandComplete (x 1000)", + MESSAGE_COMMAND_COMPLETE, + ); + bench( + c, + "postgres - decode - Message - Response (x 1000)", + MESSAGE_RESPONSE, + ); + bench( + c, + "postgres - decode - Message - BackendKeyData (x 1000)", + MESSAGE_BACKEND_KEY_DATA, + ); + bench( + c, + "postgres - decode - Message - ParameterStatus (x 1000)", + MESSAGE_PARAMETER_STATUS, + ); } criterion_group!(benches, criterion_benchmark); diff --git a/src/pg/connection/establish.rs b/src/pg/connection/establish.rs index b26e29f4..56ef9490 100644 --- a/src/pg/connection/establish.rs +++ b/src/pg/connection/establish.rs @@ -36,27 +36,34 @@ pub async fn establish<'a, 'b: 'a>(conn: &'a mut PgRawConnection, url: &'b Url) while let Some(message) = conn.receive().await? { match message { - Message::Authentication(Authentication::Ok) => { - // Do nothing; server is just telling us that - // there is no password needed - } + Message::Authentication(auth) => { + match *auth { + Authentication::Ok => { + // Do nothing. No password is needed to continue. + } - Message::Authentication(Authentication::CleartextPassword) => { - // FIXME: Should error early (before send) if the user did not supply a password - conn.write(PasswordMessage::Cleartext(password)); + Authentication::CleartextPassword => { + // FIXME: Should error early (before send) if the user did not supply a password + conn.write(PasswordMessage::Cleartext(password)); - conn.flush().await?; - } + conn.flush().await?; + } - Message::Authentication(Authentication::Md5Password { salt }) => { - // FIXME: Should error early (before send) if the user did not supply a password - conn.write(PasswordMessage::Md5 { - password, - user, - salt, - }); + Authentication::Md5Password { salt } => { + // FIXME: Should error early (before send) if the user did not supply a password + conn.write(PasswordMessage::Md5 { + password, + user, + salt, + }); - conn.flush().await?; + conn.flush().await?; + } + + auth => { + unimplemented!("received {:?} unimplemented authentication message", auth); + } + } } Message::BackendKeyData(body) => { diff --git a/src/pg/protocol/authentication.rs b/src/pg/protocol/authentication.rs index ca3dcb64..84b69ba6 100644 --- a/src/pg/protocol/authentication.rs +++ b/src/pg/protocol/authentication.rs @@ -26,22 +26,22 @@ pub enum Authentication { Sspi, /// This message contains GSSAPI or SSPI data. - GssContinue { data: Vec }, + GssContinue { data: Box<[u8]> }, /// SASL authentication is required. // FIXME: authentication mechanisms Sasl, /// This message contains a SASL challenge. - SaslContinue { data: Vec }, + SaslContinue { data: Box<[u8]> }, /// SASL authentication has completed. - SaslFinal { data: Vec }, + SaslFinal { data: Box<[u8]> }, } impl Decode for Authentication { - fn decode(src: &[u8]) -> io::Result { - Ok(match src[0] { + fn decode(src: &[u8]) -> Self { + match src[0] { 0 => Authentication::Ok, 2 => Authentication::KerberosV5, 3 => Authentication::CleartextPassword, @@ -58,6 +58,6 @@ impl Decode for Authentication { 9 => Authentication::Sspi, token => unimplemented!("decode not implemented for token: {}", token), - }) + } } } diff --git a/src/pg/protocol/backend_key_data.rs b/src/pg/protocol/backend_key_data.rs index f9fbf768..0e074eb3 100644 --- a/src/pg/protocol/backend_key_data.rs +++ b/src/pg/protocol/backend_key_data.rs @@ -23,12 +23,10 @@ impl BackendKeyData { } } -impl BackendKeyData { - pub fn decode2(src: &[u8]) -> Self { - // todo: error handling - assert_eq!(src.len(), 8); - let process_id = u32::from_be_bytes(src[0..4].try_into().unwrap()); - let secret_key = u32::from_be_bytes(src[4..8].try_into().unwrap()); +impl Decode for BackendKeyData { + fn decode(src: &[u8]) -> Self { + let process_id = u32::from_be_bytes(src[..4].try_into().unwrap()); + let secret_key = u32::from_be_bytes(src[4..].try_into().unwrap()); Self { process_id, @@ -41,18 +39,14 @@ impl BackendKeyData { mod tests { use super::{BackendKeyData, Decode}; use bytes::Bytes; - use std::io; const BACKEND_KEY_DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; #[test] - fn it_decodes_backend_key_data() -> io::Result<()> { - let src = BACKEND_KEY_DATA; - let message = BackendKeyData::decode2(src); + fn it_decodes_backend_key_data() { + let message = BackendKeyData::decode(BACKEND_KEY_DATA); assert_eq!(message.process_id(), 10182); assert_eq!(message.secret_key(), 2303903019); - - Ok(()) } } diff --git a/src/pg/protocol/command_complete.rs b/src/pg/protocol/command_complete.rs index a90bf21b..39a8d475 100644 --- a/src/pg/protocol/command_complete.rs +++ b/src/pg/protocol/command_complete.rs @@ -8,19 +8,17 @@ pub struct CommandComplete { pub rows: u64, } -impl CommandComplete { - pub fn decode2(src: &[u8]) -> Self { +impl Decode for CommandComplete { + fn decode(src: &[u8]) -> Self { // Attempt to parse the last word in the command tag as an integer // If it can't be parased, the tag is probably "CREATE TABLE" or something // and we should return 0 rows - let rows_start = memrchr(b' ',src).unwrap_or(0); - let rows = unsafe { - str::from_utf8_unchecked(&src[(rows_start + 1)..(src.len() - 1)]) - }; + let rows_start = memrchr(b' ', src).unwrap_or(0); + let rows = unsafe { str::from_utf8_unchecked(&src[(rows_start + 1)..(src.len() - 1)]) }; Self { - rows: rows.parse().unwrap_or(0) + rows: rows.parse().unwrap_or(0), } } } @@ -29,7 +27,6 @@ impl CommandComplete { mod tests { use super::{CommandComplete, Decode}; use bytes::Bytes; - use std::io; const COMMAND_COMPLETE_INSERT: &[u8] = b"INSERT 0 1\0"; const COMMAND_COMPLETE_UPDATE: &[u8] = b"UPDATE 512\0"; @@ -37,38 +34,30 @@ mod tests { const COMMAND_COMPLETE_BEGIN: &[u8] = b"BEGIN\0"; #[test] - fn it_decodes_command_complete_for_insert() -> io::Result<()> { - let message = CommandComplete::decode2(COMMAND_COMPLETE_INSERT); + fn it_decodes_command_complete_for_insert() { + let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT); assert_eq!(message.rows, 1); - - Ok(()) } #[test] - fn it_decodes_command_complete_for_update() -> io::Result<()> { - let message = CommandComplete::decode2(COMMAND_COMPLETE_UPDATE); + fn it_decodes_command_complete_for_update() { + let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE); assert_eq!(message.rows, 512); - - Ok(()) } #[test] - fn it_decodes_command_complete_for_begin() -> io::Result<()> { - let message = CommandComplete::decode2(COMMAND_COMPLETE_BEGIN); + fn it_decodes_command_complete_for_begin() { + let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN); assert_eq!(message.rows, 0); - - Ok(()) } #[test] - fn it_decodes_command_complete_for_create_table() -> io::Result<()> { - let message = CommandComplete::decode2(COMMAND_COMPLETE_CREATE_TABLE); + fn it_decodes_command_complete_for_create_table() { + let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE); assert_eq!(message.rows, 0); - - Ok(()) } } diff --git a/src/pg/protocol/data_row.rs b/src/pg/protocol/data_row.rs index 7317900a..6c28cac6 100644 --- a/src/pg/protocol/data_row.rs +++ b/src/pg/protocol/data_row.rs @@ -11,21 +11,19 @@ use std::{ pub struct DataRow { #[used] - buffer: Pin>, - values: Vec>>, + buffer: Pin>, + values: Box<[Option>]>, } // SAFE: Raw pointers point to pinned memory inside the struct unsafe impl Send for DataRow {} unsafe impl Sync for DataRow {} -impl DataRow { - pub fn decode2(buf: &[u8]) -> Self { - let buffer = Pin::new(Vec::from(buf)); - let buf: &[u8] = &*buffer; +impl Decode for DataRow { + fn decode(buf: &[u8]) -> Self { + let buffer: Pin> = Pin::new(buf.into()); - // TODO: Handle unwrap - let len_b: [u8; 2] = buf[..2].try_into().unwrap(); + let len_b: [u8; 2] = buffer[..2].try_into().unwrap(); let len = u16::from_be_bytes(len_b) as usize; let mut values = Vec::with_capacity(len); @@ -36,7 +34,7 @@ impl DataRow { // Can be zero. As a special case, -1 indicates a NULL column value. // No value bytes follow in the NULL case. // TODO: Handle unwrap - let value_len_b: [u8; 4] = buf[index..(index + 4)].try_into().unwrap(); + let value_len_b: [u8; 4] = buffer[index..(index + 4)].try_into().unwrap(); let value_len = i32::from_be_bytes(value_len_b); index += 4; @@ -44,53 +42,20 @@ impl DataRow { values.push(None); } else { let value_len = value_len as usize; - let value = &buf[index..(index + value_len)]; + let value = &buffer[index..(index + value_len)]; index += value_len as usize; values.push(Some(value.into())); } } - Self { values, buffer } + Self { + values: values.into_boxed_slice(), + buffer, + } } } -// impl Decode for DataRow { -// fn decode(src: Bytes) -> io::Result { -// let buffer = Pin::new(src); -// let buf: &[u8] = &*buffer.as_ref(); - -// // TODO: Handle unwrap -// let len_b: [u8; 2] = buf[..2].try_into().unwrap(); -// let len = u16::from_be_bytes(len_b) as usize; - -// let mut values = Vec::with_capacity(len); -// let mut index = 2; - -// while values.len() < len { -// // The length of the column value, in bytes (this count does not include itself). -// // Can be zero. As a special case, -1 indicates a NULL column value. -// // No value bytes follow in the NULL case. -// // TODO: Handle unwrap -// let value_len_b: [u8; 4] = buf[index..(index + 4)].try_into().unwrap(); -// let value_len = i32::from_be_bytes(value_len_b); -// index += 4; - -// if value_len == -1 { -// values.push(None); -// } else { -// let value_len = value_len as usize; -// let value = &buf[index..(index + value_len)]; -// index += value_len as usize; - -// values.push(Some(value.into())); -// } -// } - -// Ok(Self { values, buffer }) -// } -// } - impl DataRow { #[inline] pub fn is_empty(&self) -> bool { @@ -126,7 +91,7 @@ mod tests { #[test] fn it_decodes_data_row() { - let message = DataRow::decode2(DATA_ROW); + let message = DataRow::decode(DATA_ROW); assert_eq!(message.len(), 3); diff --git a/src/pg/protocol/decode.rs b/src/pg/protocol/decode.rs index 124851a7..51d2005b 100644 --- a/src/pg/protocol/decode.rs +++ b/src/pg/protocol/decode.rs @@ -3,16 +3,15 @@ use memchr::memchr; use std::{io, str}; pub trait Decode { - fn decode(src: &[u8]) -> io::Result + fn decode(src: &[u8]) -> Self where Self: Sized; } #[inline] -pub(crate) fn get_str(src: &[u8]) -> io::Result<&str> { - let end = memchr(b'\0', &src).ok_or(io::ErrorKind::UnexpectedEof)?; +pub(crate) fn get_str(src: &[u8]) -> &str { + let end = memchr(b'\0', &src).expect("expected null terminator in UTF-8 string"); let buf = &src[..end]; - let s = unsafe { str::from_utf8_unchecked(buf) }; - Ok(s) + unsafe { str::from_utf8_unchecked(buf) } } diff --git a/src/pg/protocol/message.rs b/src/pg/protocol/message.rs index f0f1a79c..e3e02338 100644 --- a/src/pg/protocol/message.rs +++ b/src/pg/protocol/message.rs @@ -9,23 +9,37 @@ use std::io; #[derive(Debug)] #[repr(u8)] pub enum Message { - Authentication(Authentication), - ParameterStatus(ParameterStatus), + Authentication(Box), + ParameterStatus(Box), BackendKeyData(BackendKeyData), ReadyForQuery(ReadyForQuery), CommandComplete(CommandComplete), - RowDescription(RowDescription), - DataRow(DataRow), + RowDescription(Box), + DataRow(Box), Response(Box), - NotificationResponse(NotificationResponse), + NotificationResponse(Box), ParseComplete, BindComplete, CloseComplete, NoData, PortalSuspended, - ParameterDescription(ParameterDescription), + ParameterDescription(Box), } +/* + +size:Authentication = 32 +size:ParameterStatus = 56 +size:BackendKeyData = 8 +size:CommandComplete = 8 +size:ReadyForQuery = 1 +size:DataRow = 48 +size:NotificationResponse = 64 +size:ParameterDescription = 24 +size:Message = 72 + + */ + impl Message { // FIXME: `Message::decode` shares the name of the remaining message type `::decode` despite being very // different @@ -38,8 +52,6 @@ impl Message { return Ok(None); } - log::trace!("[postgres] [decode] {:?}", bytes::Bytes::from(src.as_ref())); - let token = src[0]; if token == 0 { // FIXME: Handle end-of-stream @@ -49,34 +61,39 @@ impl Message { // FIXME: What happens if len(u32) < len(usize) ? let len = BigEndian::read_u32(&src[1..5]) as usize; - if src.len() < (len + 1) { + if src.len() >= (len + 1) { + let window = &src[5..(len + 1)]; + + let message = match token { + b'N' | b'E' => Message::Response(Box::new(Response::decode(window))), + b'D' => Message::DataRow(Box::new(DataRow::decode(window))), + b'S' => Message::ParameterStatus(Box::new(ParameterStatus::decode(window))), + b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(window)), + b'R' => Message::Authentication(Box::new(Authentication::decode(window))), + b'K' => Message::BackendKeyData(BackendKeyData::decode(window)), + b'T' => Message::RowDescription(Box::new(RowDescription::decode(window))), + b'C' => Message::CommandComplete(CommandComplete::decode(window)), + b'A' => { + Message::NotificationResponse(Box::new(NotificationResponse::decode(window))) + } + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, + b'3' => Message::CloseComplete, + b'n' => Message::NoData, + b's' => Message::PortalSuspended, + b't' => { + Message::ParameterDescription(Box::new(ParameterDescription::decode(window))) + } + + _ => unimplemented!("decode not implemented for token: {}", token as char), + }; + + src.advance(len + 1); + + Ok(Some(message)) + } else { // We don't have enough in the stream yet - return Ok(None); + Ok(None) } - - let src_ = &src.as_ref()[5..(len + 1)]; - - let message = match token { - b'N' | b'E' => Message::Response(Box::new(Response::decode(src_)?)), - b'D' => Message::DataRow(DataRow::decode2(src_)), - b'S' => Message::ParameterStatus(ParameterStatus::decode(src_)?), - b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(src_)?), - b'R' => Message::Authentication(Authentication::decode(src_)?), - b'K' => Message::BackendKeyData(BackendKeyData::decode2(src_)), - b'T' => Message::RowDescription(RowDescription::decode(src_)?), - b'C' => Message::CommandComplete(CommandComplete::decode2(src_)), - b'A' => Message::NotificationResponse(NotificationResponse::decode(src_)?), - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'n' => Message::NoData, - b's' => Message::PortalSuspended, - b't' => Message::ParameterDescription(ParameterDescription::decode2(src_)), - _ => unimplemented!("decode not implemented for token: {}", token as char), - }; - - src.advance(len + 1); - - Ok(Some(message)) } } diff --git a/src/pg/protocol/notification_response.rs b/src/pg/protocol/notification_response.rs index 96645486..4582855d 100644 --- a/src/pg/protocol/notification_response.rs +++ b/src/pg/protocol/notification_response.rs @@ -46,26 +46,26 @@ impl fmt::Debug for NotificationResponse { } impl Decode for NotificationResponse { - fn decode(src: &[u8]) -> io::Result { + fn decode(src: &[u8]) -> Self { let pid = BigEndian::read_u32(&src); // offset from pid=4 let storage = Pin::new(Vec::from(&src[4..])); - let channel_name = get_str(&storage)?; + let channel_name = get_str(&storage); // offset = channel_name.len() + \0 - let message = get_str(&storage[(channel_name.len() + 1)..])?; + let message = get_str(&storage[(channel_name.len() + 1)..]); let channel_name = NonNull::from(channel_name); let message = NonNull::from(message); - Ok(Self { + Self { storage, pid, channel_name, message, - }) + } } } @@ -78,12 +78,11 @@ mod tests { const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; #[test] - fn it_decodes_notification_response() -> io::Result<()> { - let message = NotificationResponse::decode(NOTIFICATION_RESPONSE)?; + fn it_decodes_notification_response() { + let message = NotificationResponse::decode(NOTIFICATION_RESPONSE); assert_eq!(message.pid(), 0x34201002); assert_eq!(message.channel_name(), "TEST-CHANNEL"); assert_eq!(message.message(), "THIS IS A TEST"); - Ok(()) } } diff --git a/src/pg/protocol/parameter_description.rs b/src/pg/protocol/parameter_description.rs index 478316f8..f5d0a0ad 100644 --- a/src/pg/protocol/parameter_description.rs +++ b/src/pg/protocol/parameter_description.rs @@ -1,30 +1,26 @@ use super::Decode; use byteorder::{BigEndian, ByteOrder}; use bytes::Bytes; - -use std::io; +use std::mem::size_of; type ObjectId = u32; #[derive(Debug)] pub struct ParameterDescription { - ids: Vec, + ids: Box<[ObjectId]>, } -impl ParameterDescription { - pub fn decode2(src: &[u8]) -> Self { +impl Decode for ParameterDescription { + fn decode(src: &[u8]) -> Self { let count = BigEndian::read_u16(&*src) as usize; - // todo: error handling - assert_eq!(src.len(), count * 4 + 2); - let mut ids = Vec::with_capacity(count); for i in 0..count { - let offset = i * 4 + 2; // 4==size_of(u32), 2==size_of(u16) + let offset = i * size_of::() + size_of::(); ids.push(BigEndian::read_u32(&src[offset..])); } - ParameterDescription { ids } + ParameterDescription { ids: ids.into_boxed_slice() } } } @@ -35,29 +31,20 @@ mod test { use std::io; #[test] - fn it_decodes_parameter_description() -> io::Result<()> { + fn it_decodes_parameter_description() { let src = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; - let desc = ParameterDescription::decode2(src); + let desc = ParameterDescription::decode(src); assert_eq!(desc.ids.len(), 2); assert_eq!(desc.ids[0], 0x0000_0000); assert_eq!(desc.ids[1], 0x0000_0500); - Ok(()) } #[test] - fn it_decodes_empty_parameter_description() -> io::Result<()> { + fn it_decodes_empty_parameter_description() { let src = b"\x00\x00"; - let desc = ParameterDescription::decode2(src); + let desc = ParameterDescription::decode(src); assert_eq!(desc.ids.len(), 0); - Ok(()) - } - - #[test] - #[should_panic] - fn parameter_description_wrong_length_fails() -> () { - let src = b"\x00\x00\x00\x01\x02\x03"; - ParameterDescription::decode2(src); } } diff --git a/src/pg/protocol/parameter_status.rs b/src/pg/protocol/parameter_status.rs index a033316d..4ae69dff 100644 --- a/src/pg/protocol/parameter_status.rs +++ b/src/pg/protocol/parameter_status.rs @@ -1,7 +1,5 @@ -use super::decode::{Decode, get_str}; -use std::pin::Pin; -use std::ptr::NonNull; -use std::{io, str}; +use super::decode::{get_str, Decode}; +use std::{io, pin::Pin, ptr::NonNull, str}; // FIXME: Use &str functions for a custom Debug #[derive(Debug)] @@ -31,33 +29,34 @@ impl ParameterStatus { } impl Decode for ParameterStatus { - fn decode(src: &[u8]) -> io::Result { + fn decode(src: &[u8]) -> Self { let storage = Pin::new(Vec::from(src)); - let name = get_str(&storage).unwrap(); - let value = get_str(&storage[name.len() + 1..]).unwrap(); + let name = get_str(&storage); + let value = get_str(&storage[name.len() + 1..]); let name = NonNull::from(name); let value = NonNull::from(value); - Ok(Self { storage, name, value }) + Self { + storage, + name, + value, + } } } #[cfg(test)] mod tests { use super::{Decode, ParameterStatus}; - use std::io; const PARAM_STATUS: &[u8] = b"session_authorization\0postgres\0"; #[test] - fn it_decodes_param_status() -> io::Result<()> { - let message = ParameterStatus::decode(PARAM_STATUS)?; + fn it_decodes_param_status() { + let message = ParameterStatus::decode(PARAM_STATUS); assert_eq!(message.name(), "session_authorization"); assert_eq!(message.value(), "postgres"); - - Ok(()) } } diff --git a/src/pg/protocol/ready_for_query.rs b/src/pg/protocol/ready_for_query.rs index bfdd58e1..d12046be 100644 --- a/src/pg/protocol/ready_for_query.rs +++ b/src/pg/protocol/ready_for_query.rs @@ -18,25 +18,28 @@ pub enum TransactionStatus { /// `ReadyForQuery` is sent whenever the backend is ready for a new query cycle. #[derive(Debug)] pub struct ReadyForQuery { - pub status: TransactionStatus, + status: TransactionStatus, +} + +impl ReadyForQuery { + #[inline] + pub fn status(&self) -> TransactionStatus { + self.status + } } impl Decode for ReadyForQuery { - fn decode(src: &[u8]) -> io::Result { - if src.len() != 1 { - return Err(io::ErrorKind::InvalidInput)?; - } - - Ok(Self { + fn decode(src: &[u8]) -> Self { + Self { status: match src[0] { // FIXME: Variant value are duplicated with declaration b'I' => TransactionStatus::Idle, b'T' => TransactionStatus::Transaction, b'E' => TransactionStatus::Error, - _ => unreachable!(), + status => panic!("received {:?} for TransactionStatus", status), }, - }) + } } } @@ -44,16 +47,13 @@ impl Decode for ReadyForQuery { mod tests { use super::{Decode, ReadyForQuery, TransactionStatus}; use bytes::Bytes; - use std::io; const READY_FOR_QUERY: &[u8] = b"E"; #[test] - fn it_decodes_ready_for_query() -> io::Result<()> { - let message = ReadyForQuery::decode(READY_FOR_QUERY)?; + fn it_decodes_ready_for_query() { + let message = ReadyForQuery::decode(READY_FOR_QUERY); assert_eq!(message.status, TransactionStatus::Error); - - Ok(()) } } diff --git a/src/pg/protocol/response.rs b/src/pg/protocol/response.rs index 8bdc19ad..6e7d30a6 100644 --- a/src/pg/protocol/response.rs +++ b/src/pg/protocol/response.rs @@ -77,7 +77,7 @@ impl FromStr for Severity { #[derive(Clone)] pub struct Response { #[used] - storage: Pin>, + storage: Pin>, severity: Severity, code: NonNull, message: NonNull, @@ -226,8 +226,8 @@ impl fmt::Debug for Response { } impl Decode for Response { - fn decode(src: &[u8]) -> io::Result { - let storage = Pin::new(Vec::from(src)); + fn decode(src: &[u8]) -> Self { + let storage: Pin> = Pin::new(src.into()); let mut code = None::<&str>; let mut message = None::<&str>; @@ -258,7 +258,7 @@ impl Decode for Response { break; } - let field_value = get_str(&storage[idx..])?; + let field_value = get_str(&storage[idx..]); idx += field_value.len() + 1; match field_type { @@ -267,7 +267,7 @@ impl Decode for Response { } b'V' => { - severity_non_local = Some(field_value.parse()?); + severity_non_local = Some(field_value.parse().unwrap()); } b'C' => { @@ -287,19 +287,11 @@ impl Decode for Response { } b'P' => { - position = Some( - field_value - .parse() - .map_err(|_| io::ErrorKind::InvalidData)?, - ); + position = Some(field_value.parse().unwrap()); } b'p' => { - internal_position = Some( - field_value - .parse() - .map_err(|_| io::ErrorKind::InvalidData)?, - ); + internal_position = Some(field_value.parse().unwrap()); } b'q' => { @@ -335,11 +327,7 @@ impl Decode for Response { } b'L' => { - line = Some( - field_value - .parse() - .map_err(|_| io::ErrorKind::InvalidData)?, - ); + line = Some(field_value.parse().unwrap()); } b'R' => { @@ -373,7 +361,7 @@ impl Decode for Response { let file = file.map(NonNull::from); let routine = routine.map(NonNull::from); - Ok(Self { + Self { storage, severity, code, @@ -392,7 +380,7 @@ impl Decode for Response { line, position, internal_position, - }) + } } } @@ -400,25 +388,22 @@ impl Decode for Response { mod tests { use super::{Decode, Response, Severity}; use bytes::Bytes; - use std::io; const RESPONSE: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, \ skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; #[test] - fn it_decodes_response() -> io::Result<()> { - let message = Response::decode(RESPONSE)?; + fn it_decodes_response() { + let message = Response::decode(RESPONSE); assert_eq!(message.severity(), Severity::Notice); - assert_eq!( - message.message(), - "extension \"uuid-ossp\" already exists, skipping" - ); assert_eq!(message.code(), "42710"); assert_eq!(message.file(), Some("extension.c")); assert_eq!(message.line(), Some(1656)); assert_eq!(message.routine(), Some("CreateExtension")); - - Ok(()) + assert_eq!( + message.message(), + "extension \"uuid-ossp\" already exists, skipping" + ); } } diff --git a/src/pg/protocol/row_description.rs b/src/pg/protocol/row_description.rs index 19eb86fa..5cfa9772 100644 --- a/src/pg/protocol/row_description.rs +++ b/src/pg/protocol/row_description.rs @@ -64,7 +64,7 @@ impl<'a> FieldDescription<'a> { pub struct RowDescription { // The number of fields in a row (can be zero). len: u16, - data: Vec, + data: Box<[u8]>, } impl RowDescription { @@ -77,13 +77,13 @@ impl RowDescription { } impl Decode for RowDescription { - fn decode(src: &[u8]) -> io::Result { + fn decode(src: &[u8]) -> Self { let len = BigEndian::read_u16(&src[..2]); - Ok(Self { + Self { len, data: src[2..].into(), - }) + } } } @@ -147,13 +147,12 @@ impl<'a> ExactSizeIterator for FieldDescriptions<'a> {} mod tests { use super::{Decode, RowDescription}; use bytes::Bytes; - use std::io; const ROW_DESC: &[u8] = b"\0\x03?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0"; #[test] - fn it_decodes_row_description() -> io::Result<()> { - let message = RowDescription::decode(ROW_DESC)?; + fn it_decodes_row_description() { + let message = RowDescription::decode(ROW_DESC); assert_eq!(message.fields().len(), 3); for field in message.fields() { @@ -165,7 +164,5 @@ mod tests { assert_eq!(field.type_modifier(), -1); assert_eq!(field.format(), 0); } - - Ok(()) } } diff --git a/src/pg/row.rs b/src/pg/row.rs index 53bf517c..e2d7027f 100644 --- a/src/pg/row.rs +++ b/src/pg/row.rs @@ -1,7 +1,7 @@ use super::{protocol::DataRow, Pg}; use crate::row::Row; -pub struct PgRow(pub(crate) DataRow); +pub struct PgRow(pub(crate) Box); impl Row for PgRow { type Backend = Pg;