Clean up Postgres Connection and start (finally) bubbling database errors

This commit is contained in:
Ryan Leckey 2019-09-02 16:52:54 -07:00
parent 1b7e7c2729
commit d26f72e1e6
12 changed files with 469 additions and 416 deletions

View File

@ -33,6 +33,7 @@ url = "2.1.0"
[dev-dependencies]
matches = "0.1.8"
tokio = { version = "=0.2.0-alpha.2", default-features = false, features = [ "rt-full" ] }
[profile.release]
lto = true

View File

@ -1,9 +1,11 @@
use std::{
error::Error as StdError,
fmt::{self, Display},
fmt::{self, Debug, Display},
io,
};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
/// Error communicating with the database backend.
@ -22,7 +24,7 @@ pub enum Error {
Io(io::Error),
/// An error was returned by the database backend.
Database(DbError),
Database(Box<dyn DbError + Send + Sync>),
/// No rows were returned by a query expected to return at least one row.
NotFound,
@ -49,21 +51,19 @@ impl From<io::Error> for Error {
}
}
// TODO: Define a RawError type for the database backend for forwarding error information
impl<T> From<T> for Error
where
T: 'static + DbError,
{
#[inline]
fn from(err: T) -> Self {
Error::Database(Box::new(err))
}
}
/// An error that was returned by the database backend.
#[derive(Debug)]
pub struct DbError {
message: String,
}
pub trait DbError: Debug + Send + Sync {
fn message(&self) -> &str;
impl DbError {
pub(crate) fn new(message: String) -> Self {
Self { message }
}
/// The primary human-readable error message.
pub fn message(&self) -> &str {
&self.message
}
// TODO: Expose more error properties
}

View File

@ -32,7 +32,7 @@ pub mod types;
pub use self::{
connection::Connection,
error::Error,
error::{Error, Result},
executor::Executor,
pool::Pool,
sql::{query, SqlQuery},

372
src/postgres/connection.rs Normal file
View File

@ -0,0 +1,372 @@
use super::{
protocol::{self, Decode, Encode, Message, Terminate},
Postgres, PostgresError, PostgresQueryParameters, PostgresRow,
};
use crate::{
connection::RawConnection,
error::Error,
io::{Buf, BufStream},
query::QueryParameters,
url::Url,
};
use byteorder::NetworkEndian;
use futures_core::{future::BoxFuture, stream::BoxStream};
use std::{
io,
net::{IpAddr, Shutdown, SocketAddr},
sync::atomic::{AtomicU64, Ordering},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
pub struct PostgresRawConnection {
stream: BufStream<TcpStream>,
// Process ID of the Backend
process_id: u32,
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
// Statement ID counter
next_statement_id: AtomicU64,
// Portal ID counter
next_portal_id: AtomicU64,
}
impl PostgresRawConnection {
async fn establish(url: &str) -> crate::Result<Self> {
let url = Url::parse(url)?;
let stream = TcpStream::connect(&url.resolve(5432)).await?;
let mut conn = Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
next_statement_id: AtomicU64::new(0),
next_portal_id: AtomicU64::new(0),
};
let user = url.username();
let password = url.password().unwrap_or("");
let database = url.database();
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
// FIXME: ConnectOptions user and database need to be required parameters and error
// before they get here
("user", user),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
("client_encoding", "UTF-8"),
];
conn.write(protocol::StartupMessage { params });
conn.stream.flush().await?;
while let Some(message) = conn.receive().await? {
match message {
Message::Authentication(auth) => {
match *auth {
protocol::Authentication::Ok => {
// Do nothing. No password is needed to continue.
}
protocol::Authentication::CleartextPassword => {
conn.write(protocol::PasswordMessage::Cleartext(password));
conn.stream.flush().await?;
}
protocol::Authentication::Md5Password { salt } => {
conn.write(protocol::PasswordMessage::Md5 {
password,
user,
salt,
});
conn.stream.flush().await?;
}
auth => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("requires unimplemented authentication method: {:?}", auth),
)
.into());
}
}
}
Message::BackendKeyData(body) => {
conn.process_id = body.process_id();
conn.secret_key = body.secret_key();
}
Message::ReadyForQuery(_) => {
break;
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
}
}
}
Ok(conn)
}
async fn finalize(&mut self) -> crate::Result<()> {
self.write(Terminate);
self.stream.flush().await?;
self.stream.close().await?;
Ok(())
}
// Wait and return the next message to be received from Postgres.
async fn receive(&mut self) -> crate::Result<Option<Message>> {
// Before we start the receive loop
// Flush any pending data from the send buffer
self.stream.flush().await?;
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message id: {:?}", id),
)
.into());
}
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(body) => {
if body.severity().is_error() {
// This is an error, stop the world and bubble as an error
return Err(PostgresError(body).into());
} else {
// This is a _warning_
// TODO: Do we *want* to do anything with these
}
}
message => {
return Ok(Some(message));
}
}
}
}
pub(super) fn write(&mut self, message: impl Encode) {
message.encode(self.stream.buffer_mut());
}
fn execute(&mut self, query: &str, params: PostgresQueryParameters, limit: i32) {
self.write(protocol::Parse {
portal: "",
query,
param_types: &*params.types,
});
self.write(protocol::Bind {
portal: "",
statement: "",
formats: &[1], // [BINARY]
// TODO: Early error if there is more than i16
values_len: params.types.len() as i16,
values: &*params.buf,
result_formats: &[1], // [BINARY]
});
// TODO: Make limit be 1 for fetch_optional
self.write(protocol::Execute { portal: "", limit });
self.write(protocol::Sync);
}
// Ask for the next Row in the stream
async fn step(&mut self) -> crate::Result<Option<Step>> {
while let Some(message) = self.receive().await? {
match message {
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete => {}
Message::CommandComplete(body) => {
return Ok(Some(Step::Command(body.affected_rows())));
}
Message::DataRow(body) => {
return Ok(Some(Step::Row(PostgresRow(body))));
}
Message::ReadyForQuery(_) => {
return Ok(None);
}
message => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received unexpected message: {:?}", message),
)
.into());
}
}
}
// Connection was (unexpectedly) closed
Err(io::Error::from(io::ErrorKind::UnexpectedEof).into())
}
// TODO: Remove usage of fmt!
fn next_portal_id(&self) -> String {
format!(
"__sqlx_portal_{}",
self.next_portal_id.fetch_add(1, Ordering::AcqRel)
)
}
// TODO: Remove usage of fmt!
fn next_statement_id(&self) -> String {
format!(
"__sqlx_statement_{}",
self.next_statement_id.fetch_add(1, Ordering::AcqRel)
)
}
}
enum Step {
Command(u64),
Row(PostgresRow),
}
impl RawConnection for PostgresRawConnection {
type Backend = Postgres;
#[inline]
fn establish(url: &str) -> BoxFuture<crate::Result<Self>> {
Box::pin(Self::establish(url))
}
#[inline]
fn finalize<'c>(&'c mut self) -> BoxFuture<'c, crate::Result<()>> {
Box::pin(self.finalize())
}
fn execute<'c>(
&'c mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxFuture<'c, crate::Result<u64>> {
self.execute(query, params, 1);
Box::pin(async move {
let mut affected = 0;
while let Some(step) = self.step().await? {
if let Step::Command(cnt) = step {
affected = cnt;
}
}
Ok(affected)
})
}
fn fetch<'c>(
&'c mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxStream<'c, crate::Result<PostgresRow>> {
self.execute(query, params, 0);
Box::pin(async_stream::try_stream! {
while let Some(step) = self.step().await? {
if let Step::Row(row) = step {
yield row;
}
}
})
}
fn fetch_optional<'c>(
&'c mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxFuture<'c, crate::Result<Option<PostgresRow>>> {
self.execute(query, params, 1);
Box::pin(async move {
let mut row: Option<PostgresRow> = None;
while let Some(step) = self.step().await? {
if let Step::Row(r) = step {
// This should only ever execute once because we used the
// protocol-level limit
debug_assert!(row.is_none());
row = Some(r);
}
}
Ok(row)
})
}
}

View File

@ -1,91 +0,0 @@
use super::PostgresRawConnection;
use crate::{
error::Error,
postgres::protocol::{Authentication, Message, PasswordMessage, StartupMessage},
url::Url,
};
use std::io;
pub async fn establish<'a, 'b: 'a>(
conn: &'a mut PostgresRawConnection,
url: &'b Url,
) -> Result<(), Error> {
let user = url.username();
let password = url.password().unwrap_or("");
let database = url.database();
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = &[
// FIXME: ConnectOptions user and database need to be required parameters and error
// before they get here
("user", user),
("database", database),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
("client_encoding", "UTF-8"),
];
let message = StartupMessage { params };
conn.write(message);
conn.stream.flush().await?;
while let Some(message) = conn.receive().await? {
match message {
Message::Authentication(auth) => {
match *auth {
Authentication::Ok => {
// Do nothing. No password is needed to continue.
}
Authentication::CleartextPassword => {
// FIXME: Should error early (before send) if the user did not supply a password
conn.write(PasswordMessage::Cleartext(password));
conn.stream.flush().await?;
}
Authentication::Md5Password { salt } => {
// FIXME: Should error early (before send) if the user did not supply a password
conn.write(PasswordMessage::Md5 {
password,
user,
salt,
});
conn.stream.flush().await?;
}
auth => {
unimplemented!("received {:?} unimplemented authentication message", auth);
}
}
}
Message::BackendKeyData(body) => {
conn.process_id = body.process_id();
conn.secret_key = body.secret_key();
}
Message::ReadyForQuery(_) => {
break;
}
message => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
Ok(())
}

View File

@ -1,31 +0,0 @@
use super::PostgresRawConnection;
use crate::{error::Error, postgres::protocol::Message};
use std::io;
pub async fn execute(conn: &mut PostgresRawConnection) -> Result<u64, Error> {
conn.stream.flush().await?;
let mut rows = 0;
while let Some(message) = conn.receive().await? {
match message {
Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = body.affected_rows();
}
Message::ReadyForQuery(_) => {
// Successful completion of the whole cycle
return Ok(rows);
}
message => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
// FIXME: This is an end-of-file error. How we should bubble this up here?
unreachable!()
}

View File

@ -1,37 +0,0 @@
use super::{PostgresRawConnection, PostgresRow};
use crate::{error::Error, postgres::protocol::Message};
use futures_core::stream::Stream;
use std::io;
pub fn fetch<'a>(
conn: &'a mut PostgresRawConnection,
) -> impl Stream<Item = Result<PostgresRow, Error>> + 'a {
async_stream::try_stream! {
conn.stream.flush().await.map_err(Error::from)?;
while let Some(message) = conn.receive().await? {
match message {
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete
| Message::CommandComplete(_) => {}
Message::DataRow(body) => {
yield PostgresRow(body);
}
Message::ReadyForQuery(_) => {
return;
}
message => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
// FIXME: This is an end-of-file error. How we should bubble this up here?
unreachable!()
}
}

View File

@ -1,36 +0,0 @@
use super::{PostgresRawConnection, PostgresRow};
use crate::{error::Error, postgres::protocol::Message};
use std::io;
pub async fn fetch_optional<'a>(
conn: &'a mut PostgresRawConnection,
) -> Result<Option<PostgresRow>, Error> {
conn.stream.flush().await?;
let mut row: Option<PostgresRow> = None;
while let Some(message) = conn.receive().await? {
match message {
Message::BindComplete
| Message::ParseComplete
| Message::PortalSuspended
| Message::CloseComplete
| Message::CommandComplete(_) => {}
Message::DataRow(body) => {
row = Some(PostgresRow(body));
}
Message::ReadyForQuery(_) => {
return Ok(row);
}
message => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
// FIXME: This is an end-of-file error. How we should bubble this up here?
unreachable!()
}

View File

@ -1,199 +0,0 @@
use super::{
protocol::{self, Decode, Encode, Message, Terminate},
Postgres, PostgresQueryParameters, PostgresRow,
};
use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters};
// use bytes::{BufMut, BytesMut};
use crate::{io::Buf, url::Url};
use byteorder::NetworkEndian;
use futures_core::{future::BoxFuture, stream::BoxStream};
use std::{
io,
net::{IpAddr, Shutdown, SocketAddr},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
mod establish;
mod execute;
mod fetch;
mod fetch_optional;
pub struct PostgresRawConnection {
stream: BufStream<TcpStream>,
// Process ID of the Backend
process_id: u32,
// Backend-unique key to use to send a cancel query message to the server
secret_key: u32,
}
impl PostgresRawConnection {
async fn establish(url: &str) -> Result<Self, Error> {
let url = Url::parse(url);
let stream = TcpStream::connect(&url.address(5432)).await?;
let mut conn = Self {
stream: BufStream::new(stream),
process_id: 0,
secret_key: 0,
};
establish::establish(&mut conn, &url).await?;
Ok(conn)
}
async fn finalize(&mut self) -> Result<(), Error> {
self.write(Terminate);
self.stream.flush().await?;
self.stream.close().await?;
Ok(())
}
// Wait and return the next message to be received from Postgres.
async fn receive(&mut self) -> Result<Option<Message>, Error> {
loop {
// Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?);
log::trace!("recv:header {:?}", bytes::Bytes::from(&*header));
let id = header.get_u8()?;
let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body
self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?);
log::trace!("recv {:?}", bytes::Bytes::from(&*body));
let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
}
b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?),
b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)),
b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?),
b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?),
b'A' => Message::NotificationResponse(Box::new(
protocol::NotificationResponse::decode(body)?,
)),
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'n' => Message::NoData,
b's' => Message::PortalSuspended,
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
_ => unimplemented!("unknown message id: {}", id as char),
};
self.stream.consume(len);
match message {
Message::ParameterStatus(_body) => {
// TODO: not sure what to do with these yet
}
Message::Response(_body) => {
// TODO: Transform Errors+ into an error type and return
// TODO: Log all others
}
message => {
return Ok(Some(message));
}
}
}
}
pub(super) fn write(&mut self, message: impl Encode) {
let pos = self.stream.buffer_mut().len();
message.encode(self.stream.buffer_mut());
log::trace!(
"send {:?}",
bytes::Bytes::from(&self.stream.buffer_mut()[pos..])
);
}
}
impl RawConnection for PostgresRawConnection {
type Backend = Postgres;
#[inline]
fn establish(url: &str) -> BoxFuture<Result<Self, Error>> {
Box::pin(PostgresRawConnection::establish(url))
}
#[inline]
fn finalize<'c>(&'c mut self) -> BoxFuture<'c, Result<(), Error>> {
Box::pin(self.finalize())
}
fn execute<'c>(
&'c mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxFuture<'c, Result<u64, Error>> {
finish(self, query, params, 0);
Box::pin(execute::execute(self))
}
fn fetch<'c>(
&'c mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxStream<'c, Result<PostgresRow, Error>> {
finish(self, query, params, 0);
Box::pin(fetch::fetch(self))
}
fn fetch_optional<'c>(
&'c mut self,
query: &str,
params: PostgresQueryParameters,
) -> BoxFuture<'c, Result<Option<PostgresRow>, Error>> {
finish(self, query, params, 1);
Box::pin(fetch_optional::fetch_optional(self))
}
}
fn finish(
conn: &mut PostgresRawConnection,
query: &str,
params: PostgresQueryParameters,
limit: i32,
) {
conn.write(protocol::Parse {
portal: "",
query,
param_types: &*params.types,
});
conn.write(protocol::Bind {
portal: "",
statement: "",
formats: &[1], // [BINARY]
// TODO: Early error if there is more than i16
values_len: params.types.len() as i16,
values: &*params.buf,
result_formats: &[1], // [BINARY]
});
// TODO: Make limit be 1 for fetch_optional
conn.write(protocol::Execute { portal: "", limit });
conn.write(protocol::Sync);
}

11
src/postgres/error.rs Normal file
View File

@ -0,0 +1,11 @@
use super::protocol::Response;
use crate::error::DbError;
#[derive(Debug)]
pub struct PostgresError(pub(super) Box<Response>);
impl DbError for PostgresError {
fn message(&self) -> &str {
self.0.message()
}
}

View File

@ -1,11 +1,74 @@
mod backend;
mod connection;
mod error;
mod protocol;
mod query;
mod row;
pub mod types;
pub use self::{
backend::Postgres, connection::PostgresRawConnection, query::PostgresQueryParameters,
row::PostgresRow,
backend::Postgres, connection::PostgresRawConnection, error::PostgresError,
query::PostgresQueryParameters, row::PostgresRow,
};
#[cfg(test)]
mod tests {
use super::{Postgres, PostgresRawConnection};
use crate::connection::{Connection, RawConnection};
use futures_util::TryStreamExt;
const DATABASE_URL: &str = "postgres://postgres@127.0.0.1:5432/";
#[tokio::test]
async fn it_connects() {
let mut conn = PostgresRawConnection::establish(DATABASE_URL)
.await
.unwrap();
conn.finalize().await.unwrap();
}
#[tokio::test]
async fn it_fails_on_connect_with_an_unknown_user() {
let res = PostgresRawConnection::establish("postgres://not_a_user@127.0.0.1:5432/").await;
match res {
Err(crate::Error::Database(err)) => {
assert_eq!(err.message(), "role \"not_a_user\" does not exist");
}
_ => panic!("unexpected result"),
}
}
#[tokio::test]
async fn it_fails_on_connect_with_an_unknown_database() {
let res =
PostgresRawConnection::establish("postgres://postgres@127.0.0.1:5432/fdggsdfgsdaf")
.await;
match res {
Err(crate::Error::Database(err)) => {
assert_eq!(err.message(), "database \"fdggsdfgsdaf\" does not exist");
}
_ => panic!("unexpected result"),
}
}
#[tokio::test]
async fn it_fetches_tuples() {
let conn = Connection::<Postgres>::establish(DATABASE_URL)
.await
.unwrap();
let roles: Vec<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles")
.fetch(&conn)
.try_collect()
.await
.unwrap();
// Sanity check to be sure we did indeed fetch tuples
assert!(roles.binary_search(&("postgres".to_string(), true)).is_ok());
}
}

View File

@ -3,9 +3,9 @@ use std::net::{IpAddr, SocketAddr};
pub struct Url(url::Url);
impl Url {
pub fn parse(url: &str) -> Self {
pub fn parse(url: &str) -> crate::Result<Self> {
// TODO: Handle parse errors
Url(url::Url::parse(url).unwrap())
Ok(Url(url::Url::parse(url).unwrap()))
}
pub fn host(&self) -> &str {
@ -16,7 +16,7 @@ impl Url {
self.0.port().unwrap_or(default)
}
pub fn address(&self, default_port: u16) -> SocketAddr {
pub fn resolve(&self, default_port: u16) -> SocketAddr {
// TODO: DNS
let host: IpAddr = self.host().parse().unwrap();
let port = self.port(default_port);