diff --git a/examples/contacts/src/main.rs b/examples/contacts/src/main.rs index f8892a45..e287ff45 100644 --- a/examples/contacts/src/main.rs +++ b/examples/contacts/src/main.rs @@ -82,14 +82,14 @@ async fn insert(pool: &PostgresPool, count: usize) -> Result<(), sqlx::Error> { r#" INSERT INTO contacts (name, username, password, email, phone) VALUES ($1, $2, $3, $4, $5) - "#, + "#, ( contact.name, contact.username, contact.password, contact.email, contact.phone, - ) + ), ) .await .unwrap(); @@ -115,14 +115,16 @@ async fn select(pool: &PostgresPool, iterations: usize) -> Result<(), sqlx::Erro for _ in 0..iterations { // TODO: Once we have FromRow derives we can replace this with Vec - let contacts: Vec<(String, String, String, String, String)> = pool.fetch( - r#" + let contacts: Vec<(String, String, String, String, String)> = pool + .fetch( + r#" SELECT name, username, password, email, phone FROM contacts - "#, (), - ) - .try_collect() - .await?; + "#, + (), + ) + .try_collect() + .await?; rows = contacts.len(); } diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 877434bd..4774e10b 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -76,6 +76,7 @@ SELECT id, text FROM tasks WHERE done_at IS NULL "#, + (), ) .try_for_each(|(id, text): (i64, String)| { // language=text @@ -89,7 +90,8 @@ WHERE done_at IS NULL } async fn add_task(conn: &mut Connection, text: &str) -> Fallible<()> { - conn.execute("INSERT INTO tasks ( text ) VALUES ( $1 )", (text,)).await?; + conn.execute("INSERT INTO tasks ( text ) VALUES ( $1 )", (text,)) + .await?; Ok(()) } diff --git a/src/connection.rs b/src/connection.rs index 505f6649..779dd170 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,5 +1,9 @@ use crate::{ - backend::Backend, error::Error, executor::Executor, query::{QueryParameters, IntoQueryParameters}, row::FromSqlRow, + backend::Backend, + error::Error, + executor::Executor, + query::{IntoQueryParameters, QueryParameters}, + row::FromSqlRow, }; use crossbeam_queue::SegQueue; use crossbeam_utils::atomic::AtomicCell; diff --git a/src/error.rs b/src/error.rs index 8a3c6b66..5934a5d0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -40,6 +40,13 @@ impl Display for Error { } } +impl From for Error { + #[inline] + fn from(err: io::Error) -> Self { + Error::Io(err) + } +} + // TODO: Define a RawError type for the database backend for forwarding error information /// An error that was returned by the database backend. diff --git a/src/io/buf_stream.rs b/src/io/buf_stream.rs new file mode 100644 index 00000000..b90a78e5 --- /dev/null +++ b/src/io/buf_stream.rs @@ -0,0 +1,102 @@ +use bytes::{BufMut, BytesMut}; +use std::io; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +pub struct BufStream { + stream: S, + + // Have we reached end-of-file (been disconnected) + stream_eof: bool, + + // Buffer used when sending outgoing messages + wbuf: Vec, + + // Buffer used when reading incoming messages + rbuf: BytesMut, +} + +impl BufStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(stream: S) -> Self { + Self { + stream, + stream_eof: false, + wbuf: Vec::with_capacity(1 * 1024), + rbuf: BytesMut::with_capacity(8 * 1024), + } + } + + pub async fn close(&mut self) -> io::Result<()> { + self.stream.shutdown().await + } + + #[inline] + pub fn buffer_mut(&mut self) -> &mut Vec { + &mut self.wbuf + } + + #[inline] + pub async fn flush(&mut self) -> io::Result<()> { + if self.wbuf.len() > 0 { + self.stream.write_all(&self.wbuf).await?; + self.wbuf.clear(); + } + + Ok(()) + } + + #[inline] + pub fn consume(&mut self, cnt: usize) { + self.rbuf.advance(cnt); + } + + pub async fn peek(&mut self, cnt: usize) -> io::Result> { + loop { + // Reaching end-of-file (read 0 bytes) will continuously + // return None from all future calls to read + if self.stream_eof { + return Ok(None); + } + + // If we have enough bytes in our read buffer, + // return immediately + if self.rbuf.len() >= cnt { + return Ok(Some(&self.rbuf[..cnt])); + } + + if self.rbuf.capacity() < cnt { + // Ask for exactly how much we need with a lower bound of 32-bytes + let needed = (cnt - self.rbuf.capacity()).max(32); + self.rbuf.reserve(needed); + } + + // SAFE: Read data in directly to buffer without zero-initializing the data. + // Postgres is a self-describing format and the TCP frames encode + // length headers. We will never attempt to decode more than we + // received. + let n = self.stream.read(unsafe { self.rbuf.bytes_mut() }).await?; + + // SAFE: After we read in N bytes, we can tell the buffer that it actually + // has that many bytes MORE for the decode routines to look at + unsafe { self.rbuf.advance_mut(n) } + + if n == 0 { + self.stream_eof = true; + } + } + } +} + +// Return `Ok(None)` immediately from a function if the wrapped value is `None` +macro_rules! ret_if_none { + ($val:expr) => { + match $val { + Some(val) => val, + None => { + return Ok(None); + } + } + }; +} diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 00000000..466e39b4 --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,4 @@ +#[macro_use] +mod buf_stream; + +pub use self::buf_stream::BufStream; diff --git a/src/lib.rs b/src/lib.rs index 4432705a..231770a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,9 @@ #[macro_use] mod macros; +#[macro_use] +mod io; + pub mod backend; pub mod deserialize; diff --git a/src/pool.rs b/src/pool.rs index 06b7dfa7..f143ffe1 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,6 +1,10 @@ use crate::{ - backend::Backend, connection::RawConnection, error::Error, executor::Executor, - query::{QueryParameters, IntoQueryParameters}, row::FromSqlRow, + backend::Backend, + connection::RawConnection, + error::Error, + executor::Executor, + query::{IntoQueryParameters, QueryParameters}, + row::FromSqlRow, }; use crossbeam_queue::{ArrayQueue, SegQueue}; use futures_channel::oneshot; diff --git a/src/postgres/connection/mod.rs b/src/postgres/connection/mod.rs index 27c8a92b..ce7e9502 100644 --- a/src/postgres/connection/mod.rs +++ b/src/postgres/connection/mod.rs @@ -1,9 +1,10 @@ use super::{ - protocol::{self, Encode, Message, Terminate}, + protocol::{self, Decode, Encode, Message, Terminate}, Postgres, PostgresQueryParameters, PostgresRow, }; -use crate::{connection::RawConnection, error::Error, query::QueryParameters}; -use bytes::{BufMut, BytesMut}; +use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters}; +// use bytes::{BufMut, BytesMut}; +use super::protocol::Buf; use futures_core::{future::BoxFuture, stream::BoxStream}; use std::{ io, @@ -21,20 +22,7 @@ mod fetch; mod fetch_optional; pub struct PostgresRawConnection { - stream: TcpStream, - - // Do we think that there is data in the read buffer to be decoded - stream_readable: bool, - - // Have we reached end-of-file (been disconnected) - stream_eof: bool, - - // Buffer used when sending outgoing messages - pub(super) wbuf: Vec, - - // Buffer used when reading incoming messages - // TODO: Evaluate if we _really_ want to use BytesMut here - rbuf: BytesMut, + stream: BufStream, // Process ID of the Backend process_id: u32, @@ -58,11 +46,7 @@ impl PostgresRawConnection { let stream = TcpStream::connect(&addr).await.map_err(Error::Io)?; let mut conn = Self { - wbuf: Vec::with_capacity(1024), - rbuf: BytesMut::with_capacity(1024 * 8), - stream, - stream_readable: false, - stream_eof: false, + stream: BufStream::new(stream), process_id: 0, secret_key: 0, }; @@ -74,8 +58,8 @@ impl PostgresRawConnection { async fn finalize(&mut self) -> Result<(), Error> { self.write(Terminate); - self.flush().await?; - self.stream.shutdown(Shutdown::Both).map_err(Error::Io)?; + self.stream.flush().await?; + self.stream.close().await?; Ok(()) } @@ -83,69 +67,64 @@ impl PostgresRawConnection { // Wait and return the next message to be received from Postgres. async fn receive(&mut self) -> Result, Error> { loop { - if self.stream_eof { - // Reached end-of-file on a previous read call. - return Ok(None); - } + // Read the message header (id + len) + let mut header = ret_if_none!(self.stream.peek(5).await?); + let id = header.get_int_1()?; + let len = (header.get_int_4()? - 4) as usize; - if self.stream_readable { - loop { - match Message::decode(&mut self.rbuf) { - Some(Message::ParameterStatus(_body)) => { - // TODO: not sure what to do with these yet - } + // Read the message body + self.stream.consume(5); + let body = ret_if_none!(self.stream.peek(len).await?); - Some(Message::Response(_body)) => { - // TODO: Transform Errors+ into an error type and return - // TODO: Log all others - } + let message = match id { + b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body))), + b'D' => Message::DataRow(Box::new(protocol::DataRow::decode(body))), + b'S' => Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body))), + b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)), + b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body))), + b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)), + b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body))), + b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)), + b'A' => Message::NotificationResponse(Box::new( + protocol::NotificationResponse::decode(body), + )), + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, + b'3' => Message::CloseComplete, + b'n' => Message::NoData, + b's' => Message::PortalSuspended, + b't' => Message::ParameterDescription(Box::new( + protocol::ParameterDescription::decode(body), + )), - Some(message) => { - return Ok(Some(message)); - } + _ => unimplemented!("unknown message id: {}", id as char), + }; - None => { - // Not enough data in the read buffer to parse a message - self.stream_readable = true; - break; - } - } + self.stream.consume(len); + + match message { + Message::ParameterStatus(_body) => { + // TODO: not sure what to do with these yet + } + + Message::Response(_body) => { + // TODO: Transform Errors+ into an error type and return + // TODO: Log all others + } + + message => { + return Ok(Some(message)); } } - - // Ensure there is at least 32-bytes of space available - // in the read buffer so we can safely detect end-of-file - self.rbuf.reserve(32); - - // SAFE: Read data in directly to buffer without zero-initializing the data. - // Postgres is a self-describing format and the TCP frames encode - // length headers. We will never attempt to decode more than we - // received. - let n = self - .stream - .read(unsafe { self.rbuf.bytes_mut() }) - .await - .map_err(Error::Io)?; - - // SAFE: After we read in N bytes, we can tell the buffer that it actually - // has that many bytes MORE for the decode routines to look at - unsafe { self.rbuf.advance_mut(n) } - - if n == 0 { - self.stream_eof = true; - } - - self.stream_readable = true; } } pub(super) fn write(&mut self, message: impl Encode) { - message.encode(&mut self.wbuf); + message.encode(self.stream.buffer_mut()); } async fn flush(&mut self) -> Result<(), Error> { - self.stream.write_all(&self.wbuf).await.map_err(Error::Io)?; - self.wbuf.clear(); + self.stream.flush().await?; Ok(()) } @@ -195,7 +174,12 @@ impl RawConnection for PostgresRawConnection { } } -fn finish(conn: &mut PostgresRawConnection, query: &str, params: PostgresQueryParameters, limit: i32) { +fn finish( + conn: &mut PostgresRawConnection, + query: &str, + params: PostgresQueryParameters, + limit: i32, +) { conn.write(protocol::Parse { portal: "", query, @@ -213,10 +197,7 @@ fn finish(conn: &mut PostgresRawConnection, query: &str, params: PostgresQueryPa }); // TODO: Make limit be 1 for fetch_optional - conn.write(protocol::Execute { - portal: "", - limit, - }); + conn.write(protocol::Execute { portal: "", limit }); conn.write(protocol::Sync); } diff --git a/src/postgres/protocol/decode.rs b/src/postgres/protocol/decode.rs index 76dc6b80..f56797d3 100644 --- a/src/postgres/protocol/decode.rs +++ b/src/postgres/protocol/decode.rs @@ -1,5 +1,5 @@ use memchr::memchr; -use std::str; +use std::{convert::TryInto, io, str}; pub trait Decode { fn decode(src: &[u8]) -> Self @@ -14,3 +14,50 @@ pub(crate) fn get_str(src: &[u8]) -> &str { unsafe { str::from_utf8_unchecked(buf) } } + +pub trait Buf { + fn advance(&mut self, cnt: usize); + + // An n-bit integer in network byte order + fn get_int_1(&mut self) -> io::Result; + fn get_int_4(&mut self) -> io::Result; + + // A null-terminated string + fn get_str(&mut self) -> io::Result<&str>; +} + +impl<'a> Buf for &'a [u8] { + #[inline] + fn advance(&mut self, cnt: usize) { + *self = &self[cnt..]; + } + + #[inline] + fn get_int_1(&mut self) -> io::Result { + let val = self[0]; + + self.advance(1); + + Ok(val) + } + + #[inline] + fn get_int_4(&mut self) -> io::Result { + let val: [u8; 4] = (*self) + .try_into() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + + self.advance(4); + + Ok(u32::from_be_bytes(val)) + } + + fn get_str(&mut self) -> io::Result<&str> { + let end = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?; + let buf = &self[..end]; + + self.advance(end); + + str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } +} diff --git a/src/postgres/protocol/message.rs b/src/postgres/protocol/message.rs index 4cbb2f76..409a7e20 100644 --- a/src/postgres/protocol/message.rs +++ b/src/postgres/protocol/message.rs @@ -25,61 +25,3 @@ pub enum Message { PortalSuspended, ParameterDescription(Box), } - -impl Message { - // FIXME: `Message::decode` shares the name of the remaining message type `::decode` despite being very - // different - pub fn decode(src: &mut BytesMut) -> Option - where - Self: Sized, - { - if src.len() < 5 { - // No message is less than 5 bytes - return None; - } - - let token = src[0]; - if token == 0 { - // FIXME: Handle end-of-stream - panic!("unexpectede end-of-stream"); - } - - // FIXME: What happens if len(u32) < len(usize) ? - let len = BigEndian::read_u32(&src[1..5]) as usize; - - if src.len() >= (len + 1) { - let window = &src[5..=len]; - - let message = match token { - b'N' | b'E' => Message::Response(Box::new(Response::decode(window))), - b'D' => Message::DataRow(Box::new(DataRow::decode(window))), - b'S' => Message::ParameterStatus(Box::new(ParameterStatus::decode(window))), - b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(window)), - b'R' => Message::Authentication(Box::new(Authentication::decode(window))), - b'K' => Message::BackendKeyData(BackendKeyData::decode(window)), - b'T' => Message::RowDescription(Box::new(RowDescription::decode(window))), - b'C' => Message::CommandComplete(CommandComplete::decode(window)), - b'A' => { - Message::NotificationResponse(Box::new(NotificationResponse::decode(window))) - } - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'n' => Message::NoData, - b's' => Message::PortalSuspended, - b't' => { - Message::ParameterDescription(Box::new(ParameterDescription::decode(window))) - } - - _ => unimplemented!("decode not implemented for token: {}", token as char), - }; - - src.advance(len + 1); - - Some(message) - } else { - // We don't have enough in the stream yet - None - } - } -} diff --git a/src/postgres/protocol/mod.rs b/src/postgres/protocol/mod.rs index 9dc3a778..e8447220 100644 --- a/src/postgres/protocol/mod.rs +++ b/src/postgres/protocol/mod.rs @@ -59,9 +59,16 @@ mod response; mod row_description; pub use self::{ - authentication::Authentication, backend_key_data::BackendKeyData, - command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message, - notification_response::NotificationResponse, parameter_description::ParameterDescription, - parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response, + authentication::Authentication, + backend_key_data::BackendKeyData, + command_complete::CommandComplete, + data_row::DataRow, + decode::{Buf, Decode}, + message::Message, + notification_response::NotificationResponse, + parameter_description::ParameterDescription, + parameter_status::ParameterStatus, + ready_for_query::ReadyForQuery, + response::Response, row_description::RowDescription, }; diff --git a/src/query.rs b/src/query.rs index 23242e39..0c9b288a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -18,7 +18,10 @@ pub trait QueryParameters: Send { T: ToSql; } -pub trait IntoQueryParameters where DB: Backend { +pub trait IntoQueryParameters +where + DB: Backend, +{ fn into(self) -> DB::QueryParameters; } @@ -26,9 +29,9 @@ pub trait IntoQueryParameters where DB: Backend { macro_rules! impl_into_query_parameters { ($( ($idx:tt) -> $T:ident );+;) => { - impl<$($T,)+ DB> IntoQueryParameters for ($($T,)+) - where - DB: Backend, + impl<$($T,)+ DB> IntoQueryParameters for ($($T,)+) + where + DB: Backend, $(DB: crate::types::HasSqlType<$T>,)+ $($T: crate::serialize::ToSql,)+ { @@ -41,9 +44,9 @@ macro_rules! impl_into_query_parameters { }; } -impl IntoQueryParameters for () -where - DB: Backend, +impl IntoQueryParameters for () +where + DB: Backend, { fn into(self) -> DB::QueryParameters { DB::QueryParameters::new()