Merge pull request #7 from izik1/zg-debyte

Remove bytes from row decoding
This commit is contained in:
Ryan Leckey 2019-08-19 20:56:26 -07:00 committed by GitHub
commit e5d2283eb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 85 additions and 114 deletions

View File

@ -12,7 +12,7 @@ pub async fn execute(conn: &mut PgRawConnection) -> io::Result<u64> {
Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = body.rows();
rows = body.rows;
}
Message::ReadyForQuery(_) => {

View File

@ -26,21 +26,21 @@ pub enum Authentication {
Sspi,
/// This message contains GSSAPI or SSPI data.
GssContinue { data: Bytes },
GssContinue { data: Vec<u8> },
/// SASL authentication is required.
// FIXME: authentication mechanisms
Sasl,
/// This message contains a SASL challenge.
SaslContinue { data: Bytes },
SaslContinue { data: Vec<u8> },
/// SASL authentication has completed.
SaslFinal { data: Bytes },
SaslFinal { data: Vec<u8> },
}
impl Decode for Authentication {
fn decode(src: Bytes) -> io::Result<Self> {
fn decode(src: &[u8]) -> io::Result<Self> {
Ok(match src[0] {
0 => Authentication::Ok,
2 => Authentication::KerberosV5,

View File

@ -23,15 +23,17 @@ impl BackendKeyData {
}
}
impl Decode for BackendKeyData {
fn decode(src: Bytes) -> io::Result<Self> {
let process_id = u32::from_be_bytes(src.as_ref()[0..4].try_into().unwrap());
let secret_key = u32::from_be_bytes(src.as_ref()[4..8].try_into().unwrap());
impl BackendKeyData {
pub fn decode2(src: &[u8]) -> Self {
// todo: error handling
assert_eq!(src.len(), 8);
let process_id = u32::from_be_bytes(src[0..4].try_into().unwrap());
let secret_key = u32::from_be_bytes(src[4..8].try_into().unwrap());
Ok(Self {
Self {
process_id,
secret_key,
})
}
}
}
@ -45,8 +47,8 @@ mod tests {
#[test]
fn it_decodes_backend_key_data() -> io::Result<()> {
let src = Bytes::from_static(BACKEND_KEY_DATA);
let message = BackendKeyData::decode(src)?;
let src = BACKEND_KEY_DATA;
let message = BackendKeyData::decode2(src);
assert_eq!(message.process_id(), 10182);
assert_eq!(message.secret_key(), 2303903019);

View File

@ -5,32 +5,23 @@ use std::{io, str};
#[derive(Debug)]
pub struct CommandComplete {
tag: Bytes,
pub rows: u64,
}
impl CommandComplete {
#[inline]
pub fn tag(&self) -> &str {
unsafe { str::from_utf8_unchecked(&self.tag.as_ref()[..self.tag.len() - 1]) }
}
pub fn rows(&self) -> u64 {
pub fn decode2(src: &[u8]) -> Self {
// Attempt to parse the last word in the command tag as an integer
// If it can't be parased, the tag is probably "CREATE TABLE" or something
// and we should return 0 rows
let rows_start = memrchr(b' ', &*self.tag).unwrap_or(0);
let rows_s = unsafe {
str::from_utf8_unchecked(&self.tag.as_ref()[(rows_start + 1)..(self.tag.len() - 1)])
let rows_start = memrchr(b' ',src).unwrap_or(0);
let rows = unsafe {
str::from_utf8_unchecked(&src[(rows_start + 1)..(src.len() - 1)])
};
rows_s.parse().unwrap_or(0)
}
}
impl Decode for CommandComplete {
fn decode(src: Bytes) -> io::Result<Self> {
Ok(Self { tag: src })
Self {
rows: rows.parse().unwrap_or(0)
}
}
}
@ -47,44 +38,36 @@ mod tests {
#[test]
fn it_decodes_command_complete_for_insert() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE_INSERT);
let message = CommandComplete::decode(src)?;
let message = CommandComplete::decode2(COMMAND_COMPLETE_INSERT);
assert_eq!(message.tag(), "INSERT 0 1");
assert_eq!(message.rows(), 1);
assert_eq!(message.rows, 1);
Ok(())
}
#[test]
fn it_decodes_command_complete_for_update() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE_UPDATE);
let message = CommandComplete::decode(src)?;
let message = CommandComplete::decode2(COMMAND_COMPLETE_UPDATE);
assert_eq!(message.tag(), "UPDATE 512");
assert_eq!(message.rows(), 512);
assert_eq!(message.rows, 512);
Ok(())
}
#[test]
fn it_decodes_command_complete_for_begin() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE_BEGIN);
let message = CommandComplete::decode(src)?;
let message = CommandComplete::decode2(COMMAND_COMPLETE_BEGIN);
assert_eq!(message.tag(), "BEGIN");
assert_eq!(message.rows(), 0);
assert_eq!(message.rows, 0);
Ok(())
}
#[test]
fn it_decodes_command_complete_for_create_table() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE_CREATE_TABLE);
let message = CommandComplete::decode(src)?;
let message = CommandComplete::decode2(COMMAND_COMPLETE_CREATE_TABLE);
assert_eq!(message.tag(), "CREATE TABLE");
assert_eq!(message.rows(), 0);
assert_eq!(message.rows, 0);
Ok(())
}

View File

@ -3,7 +3,7 @@ use memchr::memchr;
use std::{io, str};
pub trait Decode {
fn decode(src: Bytes) -> io::Result<Self>
fn decode(src: &[u8]) -> io::Result<Self>
where
Self: Sized;
}

View File

@ -54,38 +54,28 @@ impl Message {
return Ok(None);
}
let mut old = false;
let src_ = &src.as_ref()[5..(len + 1)];
let message = match token {
b'D' => Message::DataRow(DataRow::decode2(&src.as_ref()[5..(len + 1)])),
token => {
let src = src.split_to(len + 1).freeze().slice_from(5);
old = true;
match token {
b'N' | b'E' => Message::Response(Box::new(Response::decode(src)?)),
b'S' => Message::ParameterStatus(ParameterStatus::decode(src)?),
b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(src)?),
b'R' => Message::Authentication(Authentication::decode(src)?),
b'K' => Message::BackendKeyData(BackendKeyData::decode(src)?),
b'T' => Message::RowDescription(RowDescription::decode(src)?),
b'C' => Message::CommandComplete(CommandComplete::decode(src)?),
b'A' => Message::NotificationResponse(NotificationResponse::decode(src)?),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(ParameterDescription::decode(src)?),
_ => unimplemented!("decode not implemented for token: {}", token as char),
}
},
b'N' | b'E' => Message::Response(Box::new(Response::decode(src_)?)),
b'D' => Message::DataRow(DataRow::decode2(src_)),
b'S' => Message::ParameterStatus(ParameterStatus::decode(src_)?),
b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(src_)?),
b'R' => Message::Authentication(Authentication::decode(src_)?),
b'K' => Message::BackendKeyData(BackendKeyData::decode2(src_)),
b'T' => Message::RowDescription(RowDescription::decode(src_)?),
b'C' => Message::CommandComplete(CommandComplete::decode2(src_)),
b'A' => Message::NotificationResponse(NotificationResponse::decode(src_)?),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(ParameterDescription::decode2(src_)),
_ => unimplemented!("decode not implemented for token: {}", token as char),
};
if !old {
src.advance(len + 1);
}
src.advance(len + 1);
Ok(Some(message))
}

View File

@ -6,7 +6,7 @@ use std::{fmt, io, pin::Pin, ptr::NonNull};
pub struct NotificationResponse {
#[used]
storage: Pin<Bytes>,
storage: Pin<Vec<u8>>,
pid: u32,
channel_name: NonNull<str>,
message: NonNull<str>,
@ -46,15 +46,16 @@ impl fmt::Debug for NotificationResponse {
}
impl Decode for NotificationResponse {
fn decode(src: Bytes) -> io::Result<Self> {
let storage = Pin::new(src);
let pid = BigEndian::read_u32(&*storage);
fn decode(src: &[u8]) -> io::Result<Self> {
let pid = BigEndian::read_u32(&src);
// offset from pid=4
let channel_name = get_str(&storage[4..])?;
let storage = Pin::new(Vec::from(&src[4..]));
// offset = pid + channel_name.len() + \0
let message = get_str(&storage[(4 + channel_name.len() + 1)..])?;
let channel_name = get_str(&storage)?;
// offset = channel_name.len() + \0
let message = get_str(&storage[(channel_name.len() + 1)..])?;
let channel_name = NonNull::from(channel_name);
let message = NonNull::from(message);
@ -78,8 +79,7 @@ mod tests {
#[test]
fn it_decodes_notification_response() -> io::Result<()> {
let src = Bytes::from_static(NOTIFICATION_RESPONSE);
let message = NotificationResponse::decode(src)?;
let message = NotificationResponse::decode(NOTIFICATION_RESPONSE)?;
assert_eq!(message.pid(), 0x34201002);
assert_eq!(message.channel_name(), "TEST-CHANNEL");

View File

@ -11,8 +11,8 @@ pub struct ParameterDescription {
ids: Vec<ObjectId>,
}
impl Decode for ParameterDescription {
fn decode(src: Bytes) -> io::Result<Self> {
impl ParameterDescription {
pub fn decode2(src: &[u8]) -> Self {
let count = BigEndian::read_u16(&*src) as usize;
// todo: error handling
@ -24,7 +24,7 @@ impl Decode for ParameterDescription {
ids.push(BigEndian::read_u32(&src[offset..]));
}
Ok(ParameterDescription { ids })
ParameterDescription { ids }
}
}
@ -36,8 +36,8 @@ mod test {
#[test]
fn it_decodes_parameter_description() -> io::Result<()> {
let src = Bytes::from_static(b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00");
let desc = ParameterDescription::decode(src)?;
let src = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
let desc = ParameterDescription::decode2(src);
assert_eq!(desc.ids.len(), 2);
assert_eq!(desc.ids[0], 0x0000_0000);
@ -47,8 +47,8 @@ mod test {
#[test]
fn it_decodes_empty_parameter_description() -> io::Result<()> {
let src = Bytes::from_static(b"\x00\x00");
let desc = ParameterDescription::decode(src)?;
let src = b"\x00\x00";
let desc = ParameterDescription::decode2(src);
assert_eq!(desc.ids.len(), 0);
Ok(())
@ -57,7 +57,7 @@ mod test {
#[test]
#[should_panic]
fn parameter_description_wrong_length_fails() -> () {
let src = Bytes::from_static(b"\x00\x00\x00\x01\x02\x03");
ParameterDescription::decode(src).unwrap();
let src = b"\x00\x00\x00\x01\x02\x03";
ParameterDescription::decode2(src);
}
}

View File

@ -5,8 +5,8 @@ use std::{io, str};
// FIXME: Use &str functions for a custom Debug
#[derive(Debug)]
pub struct ParameterStatus {
name: Bytes,
value: Bytes,
name: Vec<u8>,
value: Vec<u8>,
}
impl ParameterStatus {
@ -22,12 +22,12 @@ impl ParameterStatus {
}
impl Decode for ParameterStatus {
fn decode(src: Bytes) -> io::Result<Self> {
fn decode(src: &[u8]) -> io::Result<Self> {
let name_end = memchr::memchr(0, &src).unwrap();
let value_end = memchr::memchr(0, &src[(name_end + 1)..]).unwrap();
let name = src.slice_to(name_end);
let value = src.slice(name_end + 1, name_end + 1 + value_end);
let name = src[..name_end].into();
let value = src[(name_end + 1)..(name_end + 1 + value_end)].into();
Ok(Self { name, value })
}
@ -43,8 +43,7 @@ mod tests {
#[test]
fn it_decodes_param_status() -> io::Result<()> {
let src = Bytes::from_static(PARAM_STATUS);
let message = ParameterStatus::decode(src)?;
let message = ParameterStatus::decode(PARAM_STATUS)?;
assert_eq!(message.name(), "session_authorization");
assert_eq!(message.value(), "postgres");

View File

@ -22,7 +22,7 @@ pub struct ReadyForQuery {
}
impl Decode for ReadyForQuery {
fn decode(src: Bytes) -> io::Result<Self> {
fn decode(src: &[u8]) -> io::Result<Self> {
if src.len() != 1 {
return Err(io::ErrorKind::InvalidInput)?;
}
@ -50,8 +50,7 @@ mod tests {
#[test]
fn it_decodes_ready_for_query() -> io::Result<()> {
let src = Bytes::from_static(READY_FOR_QUERY);
let message = ReadyForQuery::decode(src)?;
let message = ReadyForQuery::decode(READY_FOR_QUERY)?;
assert_eq!(message.status, TransactionStatus::Error);

View File

@ -77,7 +77,7 @@ impl FromStr for Severity {
#[derive(Clone)]
pub struct Response {
#[used]
storage: Pin<Bytes>,
storage: Pin<Vec<u8>>,
severity: Severity,
code: NonNull<str>,
message: NonNull<str>,
@ -226,8 +226,8 @@ impl fmt::Debug for Response {
}
impl Decode for Response {
fn decode(src: Bytes) -> io::Result<Self> {
let storage = Pin::new(src);
fn decode(src: &[u8]) -> io::Result<Self> {
let storage = Pin::new(Vec::from(src));
let mut code = None::<&str>;
let mut message = None::<&str>;
@ -407,8 +407,7 @@ mod tests {
#[test]
fn it_decodes_response() -> io::Result<()> {
let src = Bytes::from_static(RESPONSE);
let message = Response::decode(src)?;
let message = Response::decode(RESPONSE)?;
assert_eq!(message.severity(), Severity::Notice);
assert_eq!(

View File

@ -64,25 +64,25 @@ impl<'a> FieldDescription<'a> {
pub struct RowDescription {
// The number of fields in a row (can be zero).
len: u16,
data: Bytes,
data: Vec<u8>,
}
impl RowDescription {
pub fn fields(&self) -> FieldDescriptions<'_> {
FieldDescriptions {
rem: self.len,
buf: &*self.data,
buf: &self.data,
}
}
}
impl Decode for RowDescription {
fn decode(src: Bytes) -> io::Result<Self> {
fn decode(src: &[u8]) -> io::Result<Self> {
let len = BigEndian::read_u16(&src[..2]);
Ok(Self {
len,
data: src.slice_from(2),
data: src[2..].into(),
})
}
}
@ -153,8 +153,7 @@ mod tests {
#[test]
fn it_decodes_row_description() -> io::Result<()> {
let src = Bytes::from_static(ROW_DESC);
let message = RowDescription::decode(src)?;
let message = RowDescription::decode(ROW_DESC)?;
assert_eq!(message.fields().len(), 3);
for field in message.fields() {