diff --git a/sqlx-postgres-protocol/src/command_complete.rs b/sqlx-postgres-protocol/src/command_complete.rs new file mode 100644 index 000000000..41e798e69 --- /dev/null +++ b/sqlx-postgres-protocol/src/command_complete.rs @@ -0,0 +1,51 @@ +use crate::Decode; +use bytes::Bytes; +use memchr::{memchr, memrchr}; +use std::{io, str}; + +#[derive(Debug)] +pub struct CommandComplete { + tag: Bytes, +} + +impl CommandComplete { + pub fn tag(&self) -> &str { + let tag_end = memchr(b' ', &*self.tag).unwrap(); + unsafe { str::from_utf8_unchecked(&self.tag[..tag_end]) } + } + + pub fn rows(&self) -> u64 { + let rows_start = memrchr(b' ', &*self.tag).unwrap(); + let rows_s = + unsafe { str::from_utf8_unchecked(&self.tag[(rows_start + 1)..(self.tag.len() - 1)]) }; + + rows_s.parse().unwrap() + } +} + +impl Decode for CommandComplete { + fn decode(src: Bytes) -> io::Result { + Ok(Self { tag: src }) + } +} + +#[cfg(test)] +mod tests { + use super::CommandComplete; + use crate::Decode; + use bytes::Bytes; + use std::io; + + const COMMAND_COMPLETE: &[u8] = b"INSERT 0 512\0"; + + #[test] + fn it_decodes_command_complete() -> io::Result<()> { + let src = Bytes::from_static(COMMAND_COMPLETE); + let message = CommandComplete::decode(src)?; + + assert_eq!(message.tag(), "INSERT"); + assert_eq!(message.rows(), 512); + + Ok(()) + } +} diff --git a/sqlx-postgres-protocol/src/data_row.rs b/sqlx-postgres-protocol/src/data_row.rs new file mode 100644 index 000000000..9470d72b0 --- /dev/null +++ b/sqlx-postgres-protocol/src/data_row.rs @@ -0,0 +1,91 @@ +use crate::Decode; +use byteorder::{BigEndian, ByteOrder}; +use bytes::Bytes; +use std::io; + +// TODO: Custom Debug for DataRow + +#[derive(Debug)] +pub struct DataRow { + len: u16, + data: Bytes, +} + +impl DataRow { + pub fn values(&self) -> DataValues<'_> { + DataValues { + rem: self.len, + buf: &*self.data, + } + } +} + +impl Decode for DataRow { + fn decode(src: Bytes) -> io::Result { + let len = BigEndian::read_u16(&src[..2]); + + Ok(Self { + len, + data: src.slice_from(2), + }) + } +} + +pub struct DataValues<'a> { + rem: u16, + buf: &'a [u8], +} + +impl<'a> Iterator for DataValues<'a> { + type Item = Option<&'a [u8]>; + + fn size_hint(&self) -> (usize, Option) { + (self.rem as usize, Some(self.rem as usize)) + } + + fn next(&mut self) -> Option { + if self.rem == 0 { + return None; + } + + let len = BigEndian::read_i32(self.buf); + let size = (if len < 0 { 0 } else { len }) as usize; + + let value = if len == -1 { + None + } else { + Some(&self.buf[4..(4 + len) as usize]) + }; + + self.rem -= 1; + self.buf = &self.buf[(size + 4)..]; + + Some(value) + } +} + +impl<'a> ExactSizeIterator for DataValues<'a> {} + +#[cfg(test)] +mod tests { + use super::DataRow; + use crate::Decode; + use bytes::Bytes; + use std::io; + + const DATA_ROW: &[u8] = b"\0\x03\0\0\0\x011\0\0\0\x012\0\0\0\x013"; + + #[test] + fn it_decodes_data_row() -> io::Result<()> { + let src = Bytes::from_static(DATA_ROW); + let message = DataRow::decode(src)?; + assert_eq!(message.values().len(), 3); + + for (index, value) in message.values().enumerate() { + // "1", "2", "3" + assert_eq!(value, Some(&[(index + 1 + 48) as u8][..])); + } + + Ok(()) + } +} diff --git a/sqlx-postgres-protocol/src/lib.rs b/sqlx-postgres-protocol/src/lib.rs index 7422959e0..0c771b9f4 100644 --- a/sqlx-postgres-protocol/src/lib.rs +++ b/sqlx-postgres-protocol/src/lib.rs @@ -2,6 +2,8 @@ mod authentication; mod backend_key_data; +mod command_complete; +mod data_row; mod decode; mod encode; mod message; @@ -17,6 +19,8 @@ mod terminate; pub use self::{ authentication::Authentication, backend_key_data::BackendKeyData, + command_complete::CommandComplete, + data_row::{DataRow, DataValues}, decode::Decode, encode::Encode, message::Message, diff --git a/sqlx-postgres-protocol/src/message.rs b/sqlx-postgres-protocol/src/message.rs index 53a7f7623..f34dfac5d 100644 --- a/sqlx-postgres-protocol/src/message.rs +++ b/sqlx-postgres-protocol/src/message.rs @@ -1,6 +1,6 @@ use crate::{ - Authentication, BackendKeyData, Decode, ParameterStatus, ReadyForQuery, Response, - RowDescription, + Authentication, BackendKeyData, CommandComplete, DataRow, Decode, ParameterStatus, + ReadyForQuery, Response, RowDescription, }; use byteorder::{BigEndian, ByteOrder}; use bytes::BytesMut; @@ -12,7 +12,9 @@ pub enum Message { ParameterStatus(ParameterStatus), BackendKeyData(BackendKeyData), ReadyForQuery(ReadyForQuery), + CommandComplete(CommandComplete), RowDescription(RowDescription), + DataRow(DataRow), Response(Box), } @@ -51,6 +53,8 @@ impl Message { b'R' => Message::Authentication(Authentication::decode(src)?), b'K' => Message::BackendKeyData(BackendKeyData::decode(src)?), b'T' => Message::RowDescription(RowDescription::decode(src)?), + b'D' => Message::DataRow(DataRow::decode(src)?), + b'C' => Message::CommandComplete(CommandComplete::decode(src)?), _ => unimplemented!("decode not implemented for token: {}", token as char), })) diff --git a/sqlx-postgres-protocol/src/row_description.rs b/sqlx-postgres-protocol/src/row_description.rs index fe806012f..428855cbc 100644 --- a/sqlx-postgres-protocol/src/row_description.rs +++ b/sqlx-postgres-protocol/src/row_description.rs @@ -104,8 +104,6 @@ impl<'a> Iterator for FieldDescriptions<'a> { return None; } - self.rem -= 1; - let name_end = memchr(0, &self.buf).unwrap(); let mut idx = name_end + 1; let name = unsafe { str::from_utf8_unchecked(&self.buf[..name_end]) }; @@ -126,6 +124,10 @@ impl<'a> Iterator for FieldDescriptions<'a> { idx += size_of_val(&type_modifier); let format = BigEndian::read_i16(&self.buf[idx..]); + idx += size_of_val(&format); + + self.rem -= 1; + self.buf = &self.buf[idx..]; Some(FieldDescription { name, @@ -148,24 +150,23 @@ mod tests { use bytes::Bytes; use std::io; - const ROW_DESC_1: &[u8] = b"\0\x01?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0D\0\0\0\x0b\0\x01\0\0\0\x011"; + const ROW_DESC: &[u8] = b"\0\x03?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0?column?\0\0\0\0\0\0\0\0\0\0\x17\0\x04\xff\xff\xff\xff\0\0"; #[test] fn it_decodes_row_description() -> io::Result<()> { - let src = Bytes::from_static(ROW_DESC_1); + let src = Bytes::from_static(ROW_DESC); let message = RowDescription::decode(src)?; - assert_eq!(message.fields().len(), 1); + assert_eq!(message.fields().len(), 3); - let mut fields = message.fields(); - - let field_1 = fields.next().unwrap(); - assert_eq!(field_1.name(), "?column?"); - assert_eq!(field_1.table_oid(), None); - assert_eq!(field_1.column_attribute_num(), None); - assert_eq!(field_1.type_oid(), 23); - assert_eq!(field_1.type_size(), 4); - assert_eq!(field_1.type_modifier(), -1); - assert_eq!(field_1.format(), 0); + for field in message.fields() { + assert_eq!(field.name(), "?column?"); + assert_eq!(field.table_oid(), None); + assert_eq!(field.column_attribute_num(), None); + assert_eq!(field.type_oid(), 23); + assert_eq!(field.type_size(), 4); + assert_eq!(field.type_modifier(), -1); + assert_eq!(field.format(), 0); + } Ok(()) } diff --git a/sqlx-postgres/src/connection/query.rs b/sqlx-postgres/src/connection/query.rs index e0a5c58db..455bd7fcc 100644 --- a/sqlx-postgres/src/connection/query.rs +++ b/sqlx-postgres/src/connection/query.rs @@ -12,10 +12,18 @@ pub async fn query<'a: 'b, 'b>(conn: &'a mut Connection, query: &'b str) -> io:: // Do nothing } + Message::DataRow(_) => { + // Do nothing (for now) + } + Message::ReadyForQuery(_) => { break; } + Message::CommandComplete(_) => { + // Do nothing (for now) + } + message => { unimplemented!("received {:?} unimplemented message", message); } diff --git a/src/main.rs b/src/main.rs index 58c021df9..0e799f1ec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ async fn main() -> io::Result<()> { ) .await?; - conn.execute("SELECT 1").await?; + conn.execute("SELECT 1, 2, 3").await?; conn.close().await?;