Extract BufStream from PostgresRawConnection

This commit is contained in:
Ryan Leckey 2019-08-25 14:01:07 -07:00
parent 6c98ba01b8
commit c33530b25c
13 changed files with 269 additions and 161 deletions

View File

@ -82,14 +82,14 @@ async fn insert(pool: &PostgresPool, count: usize) -> Result<(), sqlx::Error> {
r#"
INSERT INTO contacts (name, username, password, email, phone)
VALUES ($1, $2, $3, $4, $5)
"#,
"#,
(
contact.name,
contact.username,
contact.password,
contact.email,
contact.phone,
)
),
)
.await
.unwrap();
@ -115,14 +115,16 @@ async fn select(pool: &PostgresPool, iterations: usize) -> Result<(), sqlx::Erro
for _ in 0..iterations {
// TODO: Once we have FromRow derives we can replace this with Vec<Contact>
let contacts: Vec<(String, String, String, String, String)> = pool.fetch(
r#"
let contacts: Vec<(String, String, String, String, String)> = pool
.fetch(
r#"
SELECT name, username, password, email, phone
FROM contacts
"#, (),
)
.try_collect()
.await?;
"#,
(),
)
.try_collect()
.await?;
rows = contacts.len();
}

View File

@ -76,6 +76,7 @@ SELECT id, text
FROM tasks
WHERE done_at IS NULL
"#,
(),
)
.try_for_each(|(id, text): (i64, String)| {
// language=text
@ -89,7 +90,8 @@ WHERE done_at IS NULL
}
async fn add_task(conn: &mut Connection<Postgres>, text: &str) -> Fallible<()> {
conn.execute("INSERT INTO tasks ( text ) VALUES ( $1 )", (text,)).await?;
conn.execute("INSERT INTO tasks ( text ) VALUES ( $1 )", (text,))
.await?;
Ok(())
}

View File

@ -1,5 +1,9 @@
use crate::{
backend::Backend, error::Error, executor::Executor, query::{QueryParameters, IntoQueryParameters}, row::FromSqlRow,
backend::Backend,
error::Error,
executor::Executor,
query::{IntoQueryParameters, QueryParameters},
row::FromSqlRow,
};
use crossbeam_queue::SegQueue;
use crossbeam_utils::atomic::AtomicCell;

View File

@ -40,6 +40,13 @@ impl Display for Error {
}
}
impl From<io::Error> for Error {
#[inline]
fn from(err: io::Error) -> Self {
Error::Io(err)
}
}
// TODO: Define a RawError type for the database backend for forwarding error information
/// An error that was returned by the database backend.

102
src/io/buf_stream.rs Normal file
View File

@ -0,0 +1,102 @@
use bytes::{BufMut, BytesMut};
use std::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub struct BufStream<S> {
stream: S,
// Have we reached end-of-file (been disconnected)
stream_eof: bool,
// Buffer used when sending outgoing messages
wbuf: Vec<u8>,
// Buffer used when reading incoming messages
rbuf: BytesMut,
}
impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(stream: S) -> Self {
Self {
stream,
stream_eof: false,
wbuf: Vec::with_capacity(1 * 1024),
rbuf: BytesMut::with_capacity(8 * 1024),
}
}
pub async fn close(&mut self) -> io::Result<()> {
self.stream.shutdown().await
}
#[inline]
pub fn buffer_mut(&mut self) -> &mut Vec<u8> {
&mut self.wbuf
}
#[inline]
pub async fn flush(&mut self) -> io::Result<()> {
if self.wbuf.len() > 0 {
self.stream.write_all(&self.wbuf).await?;
self.wbuf.clear();
}
Ok(())
}
#[inline]
pub fn consume(&mut self, cnt: usize) {
self.rbuf.advance(cnt);
}
pub async fn peek(&mut self, cnt: usize) -> io::Result<Option<&[u8]>> {
loop {
// Reaching end-of-file (read 0 bytes) will continuously
// return None from all future calls to read
if self.stream_eof {
return Ok(None);
}
// If we have enough bytes in our read buffer,
// return immediately
if self.rbuf.len() >= cnt {
return Ok(Some(&self.rbuf[..cnt]));
}
if self.rbuf.capacity() < cnt {
// Ask for exactly how much we need with a lower bound of 32-bytes
let needed = (cnt - self.rbuf.capacity()).max(32);
self.rbuf.reserve(needed);
}
// SAFE: Read data in directly to buffer without zero-initializing the data.
// Postgres is a self-describing format and the TCP frames encode
// length headers. We will never attempt to decode more than we
// received.
let n = self.stream.read(unsafe { self.rbuf.bytes_mut() }).await?;
// SAFE: After we read in N bytes, we can tell the buffer that it actually
// has that many bytes MORE for the decode routines to look at
unsafe { self.rbuf.advance_mut(n) }
if n == 0 {
self.stream_eof = true;
}
}
}
}
// Return `Ok(None)` immediately from a function if the wrapped value is `None`
macro_rules! ret_if_none {
($val:expr) => {
match $val {
Some(val) => val,
None => {
return Ok(None);
}
}
};
}

4
src/io/mod.rs Normal file
View File

@ -0,0 +1,4 @@
#[macro_use]
mod buf_stream;
pub use self::buf_stream::BufStream;

View File

@ -11,6 +11,9 @@
#[macro_use]
mod macros;
#[macro_use]
mod io;
pub mod backend;
pub mod deserialize;

View File

@ -1,6 +1,10 @@
use crate::{
backend::Backend, connection::RawConnection, error::Error, executor::Executor,
query::{QueryParameters, IntoQueryParameters}, row::FromSqlRow,
backend::Backend,
connection::RawConnection,
error::Error,
executor::Executor,
query::{IntoQueryParameters, QueryParameters},
row::FromSqlRow,
};
use crossbeam_queue::{ArrayQueue, SegQueue};
use futures_channel::oneshot;

View File

@ -1,9 +1,10 @@
use super::{
protocol::{self, Encode, Message, Terminate},
protocol::{self, Decode, Encode, Message, Terminate},
Postgres, PostgresQueryParameters, PostgresRow,
};
use crate::{connection::RawConnection, error::Error, query::QueryParameters};
use bytes::{BufMut, BytesMut};
use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters};
// use bytes::{BufMut, BytesMut};
use super::protocol::Buf;
use futures_core::{future::BoxFuture, stream::BoxStream};
use std::{
io,
@ -21,20 +22,7 @@ mod fetch;
mod fetch_optional;
pub struct PostgresRawConnection {
stream: TcpStream,
// Do we think that there is data in the read buffer to be decoded
stream_readable: bool,
// Have we reached end-of-file (been disconnected)
stream_eof: bool,
// Buffer used when sending outgoing messages
pub(super) wbuf: Vec<u8>,
// Buffer used when reading incoming messages
// TODO: Evaluate if we _really_ want to use BytesMut here
rbuf: BytesMut,
stream: BufStream<TcpStream>,
// Process ID of the Backend
process_id: u32,
@ -58,11 +46,7 @@ impl PostgresRawConnection {
let stream = TcpStream::connect(&addr).await.map_err(Error::Io)?;
let mut conn = Self {
wbuf: Vec::with_capacity(1024),
rbuf: BytesMut::with_capacity(1024 * 8),
stream,
stream_readable: false,
stream_eof: false,
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
};
@ -74,8 +58,8 @@ impl PostgresRawConnection {
async fn finalize(&mut self) -> Result<(), Error> {
self.write(Terminate);
self.flush().await?;
self.stream.shutdown(Shutdown::Both).map_err(Error::Io)?;
self.stream.flush().await?;
self.stream.close().await?;
Ok(())
}
@ -83,69 +67,64 @@ impl PostgresRawConnection {
// Wait and return the next message to be received from Postgres.
async fn receive(&mut self) -> Result<Option<Message>, Error> {
loop {
if self.stream_eof {
// Reached end-of-file on a previous read call.
return Ok(None);
}
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_int_1()?;
let len = (header.get_int_4()? - 4) as usize;
if self.stream_readable {
loop {
match Message::decode(&mut self.rbuf) {
Some(Message::ParameterStatus(_body)) => {
// TODO: not sure what to do with these yet
}
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
Some(Message::Response(_body)) => {
// TODO: Transform Errors+ into an error type and return
// TODO: Log all others
}
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body))),
b'D' => Message::DataRow(Box::new(protocol::DataRow::decode(body))),
b'S' => Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body))),
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body))),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body))),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body),
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body),
)),
Some(message) => {
return Ok(Some(message));
}
_ => unimplemented!("unknown message id: {}", id as char),
};
None => {
// Not enough data in the read buffer to parse a message
self.stream_readable = true;
break;
}
}
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(_body) => {
// TODO: Transform Errors+ into an error type and return
// TODO: Log all others
}
message => {
return Ok(Some(message));
}
}
// Ensure there is at least 32-bytes of space available
// in the read buffer so we can safely detect end-of-file
self.rbuf.reserve(32);
// SAFE: Read data in directly to buffer without zero-initializing the data.
// Postgres is a self-describing format and the TCP frames encode
// length headers. We will never attempt to decode more than we
// received.
let n = self
.stream
.read(unsafe { self.rbuf.bytes_mut() })
.await
.map_err(Error::Io)?;
// SAFE: After we read in N bytes, we can tell the buffer that it actually
// has that many bytes MORE for the decode routines to look at
unsafe { self.rbuf.advance_mut(n) }
if n == 0 {
self.stream_eof = true;
}
self.stream_readable = true;
}
}
pub(super) fn write(&mut self, message: impl Encode) {
message.encode(&mut self.wbuf);
message.encode(self.stream.buffer_mut());
}
async fn flush(&mut self) -> Result<(), Error> {
self.stream.write_all(&self.wbuf).await.map_err(Error::Io)?;
self.wbuf.clear();
self.stream.flush().await?;
Ok(())
}
@ -195,7 +174,12 @@ impl RawConnection for PostgresRawConnection {
}
}
fn finish(conn: &mut PostgresRawConnection, query: &str, params: PostgresQueryParameters, limit: i32) {
fn finish(
conn: &mut PostgresRawConnection,
query: &str,
params: PostgresQueryParameters,
limit: i32,
) {
conn.write(protocol::Parse {
portal: "",
query,
@ -213,10 +197,7 @@ fn finish(conn: &mut PostgresRawConnection, query: &str, params: PostgresQueryPa
});
// TODO: Make limit be 1 for fetch_optional
conn.write(protocol::Execute {
portal: "",
limit,
});
conn.write(protocol::Execute { portal: "", limit });
conn.write(protocol::Sync);
}

View File

@ -1,5 +1,5 @@
use memchr::memchr;
use std::str;
use std::{convert::TryInto, io, str};
pub trait Decode {
fn decode(src: &[u8]) -> Self
@ -14,3 +14,50 @@ pub(crate) fn get_str(src: &[u8]) -> &str {
unsafe { str::from_utf8_unchecked(buf) }
}
pub trait Buf {
fn advance(&mut self, cnt: usize);
// An n-bit integer in network byte order
fn get_int_1(&mut self) -> io::Result<u8>;
fn get_int_4(&mut self) -> io::Result<u32>;
// A null-terminated string
fn get_str(&mut self) -> io::Result<&str>;
}
impl<'a> Buf for &'a [u8] {
#[inline]
fn advance(&mut self, cnt: usize) {
*self = &self[cnt..];
}
#[inline]
fn get_int_1(&mut self) -> io::Result<u8> {
let val = self[0];
self.advance(1);
Ok(val)
}
#[inline]
fn get_int_4(&mut self) -> io::Result<u32> {
let val: [u8; 4] = (*self)
.try_into()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
self.advance(4);
Ok(u32::from_be_bytes(val))
}
fn get_str(&mut self) -> io::Result<&str> {
let end = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?;
let buf = &self[..end];
self.advance(end);
str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
}
}

View File

@ -25,61 +25,3 @@ pub enum Message {
PortalSuspended,
ParameterDescription(Box<ParameterDescription>),
}
impl Message {
// FIXME: `Message::decode` shares the name of the remaining message type `::decode` despite being very
// different
pub fn decode(src: &mut BytesMut) -> Option<Self>
where
Self: Sized,
{
if src.len() < 5 {
// No message is less than 5 bytes
return None;
}
let token = src[0];
if token == 0 {
// FIXME: Handle end-of-stream
panic!("unexpectede end-of-stream");
}
// FIXME: What happens if len(u32) < len(usize) ?
let len = BigEndian::read_u32(&src[1..5]) as usize;
if src.len() >= (len + 1) {
let window = &src[5..=len];
let message = match token {
b'N' | b'E' => Message::Response(Box::new(Response::decode(window))),
b'D' => Message::DataRow(Box::new(DataRow::decode(window))),
b'S' => Message::ParameterStatus(Box::new(ParameterStatus::decode(window))),
b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(window)),
b'R' => Message::Authentication(Box::new(Authentication::decode(window))),
b'K' => Message::BackendKeyData(BackendKeyData::decode(window)),
b'T' => Message::RowDescription(Box::new(RowDescription::decode(window))),
b'C' => Message::CommandComplete(CommandComplete::decode(window)),
b'A' => {
Message::NotificationResponse(Box::new(NotificationResponse::decode(window)))
}
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => {
Message::ParameterDescription(Box::new(ParameterDescription::decode(window)))
}
_ => unimplemented!("decode not implemented for token: {}", token as char),
};
src.advance(len + 1);
Some(message)
} else {
// We don't have enough in the stream yet
None
}
}
}

View File

@ -59,9 +59,16 @@ mod response;
mod row_description;
pub use self::{
authentication::Authentication, backend_key_data::BackendKeyData,
command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message,
notification_response::NotificationResponse, parameter_description::ParameterDescription,
parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response,
authentication::Authentication,
backend_key_data::BackendKeyData,
command_complete::CommandComplete,
data_row::DataRow,
decode::{Buf, Decode},
message::Message,
notification_response::NotificationResponse,
parameter_description::ParameterDescription,
parameter_status::ParameterStatus,
ready_for_query::ReadyForQuery,
response::Response,
row_description::RowDescription,
};

View File

@ -18,7 +18,10 @@ pub trait QueryParameters: Send {
T: ToSql<Self::Backend>;
}
pub trait IntoQueryParameters<DB> where DB: Backend {
pub trait IntoQueryParameters<DB>
where
DB: Backend,
{
fn into(self) -> DB::QueryParameters;
}
@ -26,9 +29,9 @@ pub trait IntoQueryParameters<DB> where DB: Backend {
macro_rules! impl_into_query_parameters {
($( ($idx:tt) -> $T:ident );+;) => {
impl<$($T,)+ DB> IntoQueryParameters<DB> for ($($T,)+)
where
DB: Backend,
impl<$($T,)+ DB> IntoQueryParameters<DB> for ($($T,)+)
where
DB: Backend,
$(DB: crate::types::HasSqlType<$T>,)+
$($T: crate::serialize::ToSql<DB>,)+
{
@ -41,9 +44,9 @@ macro_rules! impl_into_query_parameters {
};
}
impl<DB> IntoQueryParameters<DB> for ()
where
DB: Backend,
impl<DB> IntoQueryParameters<DB> for ()
where
DB: Backend,
{
fn into(self) -> DB::QueryParameters {
DB::QueryParameters::new()