From d26f72e1e635d96332de2ae390fa58841bdede79 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 2 Sep 2019 16:52:54 -0700 Subject: [PATCH] Clean up Postgres Connection and start (finally) bubbling database errors --- Cargo.toml | 1 + src/error.rs | 32 +- src/lib.rs | 2 +- src/postgres/connection.rs | 372 ++++++++++++++++++++++ src/postgres/connection/establish.rs | 91 ------ src/postgres/connection/execute.rs | 31 -- src/postgres/connection/fetch.rs | 37 --- src/postgres/connection/fetch_optional.rs | 36 --- src/postgres/connection/mod.rs | 199 ------------ src/postgres/error.rs | 11 + src/postgres/mod.rs | 67 +++- src/url.rs | 6 +- 12 files changed, 469 insertions(+), 416 deletions(-) create mode 100644 src/postgres/connection.rs delete mode 100644 src/postgres/connection/establish.rs delete mode 100644 src/postgres/connection/execute.rs delete mode 100644 src/postgres/connection/fetch.rs delete mode 100644 src/postgres/connection/fetch_optional.rs delete mode 100644 src/postgres/connection/mod.rs create mode 100644 src/postgres/error.rs diff --git a/Cargo.toml b/Cargo.toml index 207cad74..b5cc6dff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ url = "2.1.0" [dev-dependencies] matches = "0.1.8" +tokio = { version = "=0.2.0-alpha.2", default-features = false, features = [ "rt-full" ] } [profile.release] lto = true diff --git a/src/error.rs b/src/error.rs index 82b6d11e..97914205 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,11 @@ use std::{ error::Error as StdError, - fmt::{self, Display}, + fmt::{self, Debug, Display}, io, }; +pub type Result = std::result::Result; + #[derive(Debug)] pub enum Error { /// Error communicating with the database backend. @@ -22,7 +24,7 @@ pub enum Error { Io(io::Error), /// An error was returned by the database backend. - Database(DbError), + Database(Box), /// No rows were returned by a query expected to return at least one row. NotFound, @@ -49,21 +51,19 @@ impl From for Error { } } -// TODO: Define a RawError type for the database backend for forwarding error information +impl From for Error +where + T: 'static + DbError, +{ + #[inline] + fn from(err: T) -> Self { + Error::Database(Box::new(err)) + } +} /// An error that was returned by the database backend. -#[derive(Debug)] -pub struct DbError { - message: String, -} +pub trait DbError: Debug + Send + Sync { + fn message(&self) -> &str; -impl DbError { - pub(crate) fn new(message: String) -> Self { - Self { message } - } - - /// The primary human-readable error message. - pub fn message(&self) -> &str { - &self.message - } + // TODO: Expose more error properties } diff --git a/src/lib.rs b/src/lib.rs index ac7b2bff..c98a8795 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ pub mod types; pub use self::{ connection::Connection, - error::Error, + error::{Error, Result}, executor::Executor, pool::Pool, sql::{query, SqlQuery}, diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs new file mode 100644 index 00000000..30c43356 --- /dev/null +++ b/src/postgres/connection.rs @@ -0,0 +1,372 @@ +use super::{ + protocol::{self, Decode, Encode, Message, Terminate}, + Postgres, PostgresError, PostgresQueryParameters, PostgresRow, +}; +use crate::{ + connection::RawConnection, + error::Error, + io::{Buf, BufStream}, + query::QueryParameters, + url::Url, +}; +use byteorder::NetworkEndian; +use futures_core::{future::BoxFuture, stream::BoxStream}; +use std::{ + io, + net::{IpAddr, Shutdown, SocketAddr}, + sync::atomic::{AtomicU64, Ordering}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, +}; + +pub struct PostgresRawConnection { + stream: BufStream, + + // Process ID of the Backend + process_id: u32, + + // Backend-unique key to use to send a cancel query message to the server + secret_key: u32, + + // Statement ID counter + next_statement_id: AtomicU64, + + // Portal ID counter + next_portal_id: AtomicU64, +} + +impl PostgresRawConnection { + async fn establish(url: &str) -> crate::Result { + let url = Url::parse(url)?; + let stream = TcpStream::connect(&url.resolve(5432)).await?; + + let mut conn = Self { + stream: BufStream::new(stream), + process_id: 0, + secret_key: 0, + next_statement_id: AtomicU64::new(0), + next_portal_id: AtomicU64::new(0), + }; + + let user = url.username(); + let password = url.password().unwrap_or(""); + let database = url.database(); + + // See this doc for more runtime parameters + // https://www.postgresql.org/docs/12/runtime-config-client.html + let params = &[ + // FIXME: ConnectOptions user and database need to be required parameters and error + // before they get here + ("user", user), + ("database", database), + // Sets the display format for date and time values, + // as well as the rules for interpreting ambiguous date input values. + ("DateStyle", "ISO, MDY"), + // Sets the display format for interval values. + ("IntervalStyle", "iso_8601"), + // Sets the time zone for displaying and interpreting time stamps. + ("TimeZone", "UTC"), + // Adjust postgres to return percise values for floats + // NOTE: This is default in postgres 12+ + ("extra_float_digits", "3"), + // Sets the client-side encoding (character set). + ("client_encoding", "UTF-8"), + ]; + + conn.write(protocol::StartupMessage { params }); + conn.stream.flush().await?; + + while let Some(message) = conn.receive().await? { + match message { + Message::Authentication(auth) => { + match *auth { + protocol::Authentication::Ok => { + // Do nothing. No password is needed to continue. + } + + protocol::Authentication::CleartextPassword => { + conn.write(protocol::PasswordMessage::Cleartext(password)); + + conn.stream.flush().await?; + } + + protocol::Authentication::Md5Password { salt } => { + conn.write(protocol::PasswordMessage::Md5 { + password, + user, + salt, + }); + + conn.stream.flush().await?; + } + + auth => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("requires unimplemented authentication method: {:?}", auth), + ) + .into()); + } + } + } + + Message::BackendKeyData(body) => { + conn.process_id = body.process_id(); + conn.secret_key = body.secret_key(); + } + + Message::ReadyForQuery(_) => { + break; + } + + message => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received unexpected message: {:?}", message), + ) + .into()); + } + } + } + + Ok(conn) + } + + async fn finalize(&mut self) -> crate::Result<()> { + self.write(Terminate); + self.stream.flush().await?; + self.stream.close().await?; + + Ok(()) + } + + // Wait and return the next message to be received from Postgres. + async fn receive(&mut self) -> crate::Result> { + // Before we start the receive loop + // Flush any pending data from the send buffer + self.stream.flush().await?; + + loop { + // Read the message header (id + len) + let mut header = ret_if_none!(self.stream.peek(5).await?); + + let id = header.get_u8()?; + let len = (header.get_u32::()? - 4) as usize; + + // Read the message body + self.stream.consume(5); + let body = ret_if_none!(self.stream.peek(len).await?); + + let message = match id { + b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)), + b'D' => Message::DataRow(protocol::DataRow::decode(body)?), + b'S' => { + Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?)) + } + b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?), + b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)), + b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?), + b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?), + b'A' => Message::NotificationResponse(Box::new( + protocol::NotificationResponse::decode(body)?, + )), + b'1' => Message::ParseComplete, + b'2' => Message::BindComplete, + b'3' => Message::CloseComplete, + b'n' => Message::NoData, + b's' => Message::PortalSuspended, + b't' => Message::ParameterDescription(Box::new( + protocol::ParameterDescription::decode(body)?, + )), + + id => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received unexpected message id: {:?}", id), + ) + .into()); + } + }; + + self.stream.consume(len); + + match message { + Message::ParameterStatus(_body) => { + // TODO: not sure what to do with these yet + } + + Message::Response(body) => { + if body.severity().is_error() { + // This is an error, stop the world and bubble as an error + return Err(PostgresError(body).into()); + } else { + // This is a _warning_ + // TODO: Do we *want* to do anything with these + } + } + + message => { + return Ok(Some(message)); + } + } + } + } + + pub(super) fn write(&mut self, message: impl Encode) { + message.encode(self.stream.buffer_mut()); + } + + fn execute(&mut self, query: &str, params: PostgresQueryParameters, limit: i32) { + self.write(protocol::Parse { + portal: "", + query, + param_types: &*params.types, + }); + + self.write(protocol::Bind { + portal: "", + statement: "", + formats: &[1], // [BINARY] + // TODO: Early error if there is more than i16 + values_len: params.types.len() as i16, + values: &*params.buf, + result_formats: &[1], // [BINARY] + }); + + // TODO: Make limit be 1 for fetch_optional + self.write(protocol::Execute { portal: "", limit }); + + self.write(protocol::Sync); + } + + // Ask for the next Row in the stream + async fn step(&mut self) -> crate::Result> { + while let Some(message) = self.receive().await? { + match message { + Message::BindComplete + | Message::ParseComplete + | Message::PortalSuspended + | Message::CloseComplete => {} + + Message::CommandComplete(body) => { + return Ok(Some(Step::Command(body.affected_rows()))); + } + + Message::DataRow(body) => { + return Ok(Some(Step::Row(PostgresRow(body)))); + } + + Message::ReadyForQuery(_) => { + return Ok(None); + } + + message => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received unexpected message: {:?}", message), + ) + .into()); + } + } + } + + // Connection was (unexpectedly) closed + Err(io::Error::from(io::ErrorKind::UnexpectedEof).into()) + } + + // TODO: Remove usage of fmt! + fn next_portal_id(&self) -> String { + format!( + "__sqlx_portal_{}", + self.next_portal_id.fetch_add(1, Ordering::AcqRel) + ) + } + + // TODO: Remove usage of fmt! + fn next_statement_id(&self) -> String { + format!( + "__sqlx_statement_{}", + self.next_statement_id.fetch_add(1, Ordering::AcqRel) + ) + } +} + +enum Step { + Command(u64), + Row(PostgresRow), +} + +impl RawConnection for PostgresRawConnection { + type Backend = Postgres; + + #[inline] + fn establish(url: &str) -> BoxFuture> { + Box::pin(Self::establish(url)) + } + + #[inline] + fn finalize<'c>(&'c mut self) -> BoxFuture<'c, crate::Result<()>> { + Box::pin(self.finalize()) + } + + fn execute<'c>( + &'c mut self, + query: &str, + params: PostgresQueryParameters, + ) -> BoxFuture<'c, crate::Result> { + self.execute(query, params, 1); + + Box::pin(async move { + let mut affected = 0; + + while let Some(step) = self.step().await? { + if let Step::Command(cnt) = step { + affected = cnt; + } + } + + Ok(affected) + }) + } + + fn fetch<'c>( + &'c mut self, + query: &str, + params: PostgresQueryParameters, + ) -> BoxStream<'c, crate::Result> { + self.execute(query, params, 0); + + Box::pin(async_stream::try_stream! { + while let Some(step) = self.step().await? { + if let Step::Row(row) = step { + yield row; + } + } + }) + } + + fn fetch_optional<'c>( + &'c mut self, + query: &str, + params: PostgresQueryParameters, + ) -> BoxFuture<'c, crate::Result>> { + self.execute(query, params, 1); + + Box::pin(async move { + let mut row: Option = None; + + while let Some(step) = self.step().await? { + if let Step::Row(r) = step { + // This should only ever execute once because we used the + // protocol-level limit + debug_assert!(row.is_none()); + row = Some(r); + } + } + + Ok(row) + }) + } +} diff --git a/src/postgres/connection/establish.rs b/src/postgres/connection/establish.rs deleted file mode 100644 index 69f4d40c..00000000 --- a/src/postgres/connection/establish.rs +++ /dev/null @@ -1,91 +0,0 @@ -use super::PostgresRawConnection; -use crate::{ - error::Error, - postgres::protocol::{Authentication, Message, PasswordMessage, StartupMessage}, - url::Url, -}; -use std::io; - -pub async fn establish<'a, 'b: 'a>( - conn: &'a mut PostgresRawConnection, - url: &'b Url, -) -> Result<(), Error> { - let user = url.username(); - let password = url.password().unwrap_or(""); - let database = url.database(); - - // See this doc for more runtime parameters - // https://www.postgresql.org/docs/12/runtime-config-client.html - let params = &[ - // FIXME: ConnectOptions user and database need to be required parameters and error - // before they get here - ("user", user), - ("database", database), - // Sets the display format for date and time values, - // as well as the rules for interpreting ambiguous date input values. - ("DateStyle", "ISO, MDY"), - // Sets the display format for interval values. - ("IntervalStyle", "iso_8601"), - // Sets the time zone for displaying and interpreting time stamps. - ("TimeZone", "UTC"), - // Adjust postgres to return percise values for floats - // NOTE: This is default in postgres 12+ - ("extra_float_digits", "3"), - // Sets the client-side encoding (character set). - ("client_encoding", "UTF-8"), - ]; - - let message = StartupMessage { params }; - - conn.write(message); - conn.stream.flush().await?; - - while let Some(message) = conn.receive().await? { - match message { - Message::Authentication(auth) => { - match *auth { - Authentication::Ok => { - // Do nothing. No password is needed to continue. - } - - Authentication::CleartextPassword => { - // FIXME: Should error early (before send) if the user did not supply a password - conn.write(PasswordMessage::Cleartext(password)); - - conn.stream.flush().await?; - } - - Authentication::Md5Password { salt } => { - // FIXME: Should error early (before send) if the user did not supply a password - conn.write(PasswordMessage::Md5 { - password, - user, - salt, - }); - - conn.stream.flush().await?; - } - - auth => { - unimplemented!("received {:?} unimplemented authentication message", auth); - } - } - } - - Message::BackendKeyData(body) => { - conn.process_id = body.process_id(); - conn.secret_key = body.secret_key(); - } - - Message::ReadyForQuery(_) => { - break; - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - Ok(()) -} diff --git a/src/postgres/connection/execute.rs b/src/postgres/connection/execute.rs deleted file mode 100644 index 2e12e239..00000000 --- a/src/postgres/connection/execute.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::PostgresRawConnection; -use crate::{error::Error, postgres::protocol::Message}; -use std::io; - -pub async fn execute(conn: &mut PostgresRawConnection) -> Result { - conn.stream.flush().await?; - - let mut rows = 0; - - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {} - - Message::CommandComplete(body) => { - rows = body.affected_rows(); - } - - Message::ReadyForQuery(_) => { - // Successful completion of the whole cycle - return Ok(rows); - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() -} diff --git a/src/postgres/connection/fetch.rs b/src/postgres/connection/fetch.rs deleted file mode 100644 index 4b37b1b8..00000000 --- a/src/postgres/connection/fetch.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::{PostgresRawConnection, PostgresRow}; -use crate::{error::Error, postgres::protocol::Message}; -use futures_core::stream::Stream; -use std::io; - -pub fn fetch<'a>( - conn: &'a mut PostgresRawConnection, -) -> impl Stream> + 'a { - async_stream::try_stream! { - conn.stream.flush().await.map_err(Error::from)?; - - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete - | Message::ParseComplete - | Message::PortalSuspended - | Message::CloseComplete - | Message::CommandComplete(_) => {} - - Message::DataRow(body) => { - yield PostgresRow(body); - } - - Message::ReadyForQuery(_) => { - return; - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() - } -} diff --git a/src/postgres/connection/fetch_optional.rs b/src/postgres/connection/fetch_optional.rs deleted file mode 100644 index b87f1d55..00000000 --- a/src/postgres/connection/fetch_optional.rs +++ /dev/null @@ -1,36 +0,0 @@ -use super::{PostgresRawConnection, PostgresRow}; -use crate::{error::Error, postgres::protocol::Message}; -use std::io; - -pub async fn fetch_optional<'a>( - conn: &'a mut PostgresRawConnection, -) -> Result, Error> { - conn.stream.flush().await?; - - let mut row: Option = None; - - while let Some(message) = conn.receive().await? { - match message { - Message::BindComplete - | Message::ParseComplete - | Message::PortalSuspended - | Message::CloseComplete - | Message::CommandComplete(_) => {} - - Message::DataRow(body) => { - row = Some(PostgresRow(body)); - } - - Message::ReadyForQuery(_) => { - return Ok(row); - } - - message => { - unimplemented!("received {:?} unimplemented message", message); - } - } - } - - // FIXME: This is an end-of-file error. How we should bubble this up here? - unreachable!() -} diff --git a/src/postgres/connection/mod.rs b/src/postgres/connection/mod.rs deleted file mode 100644 index 4e31f8eb..00000000 --- a/src/postgres/connection/mod.rs +++ /dev/null @@ -1,199 +0,0 @@ -use super::{ - protocol::{self, Decode, Encode, Message, Terminate}, - Postgres, PostgresQueryParameters, PostgresRow, -}; -use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters}; -// use bytes::{BufMut, BytesMut}; -use crate::{io::Buf, url::Url}; -use byteorder::NetworkEndian; -use futures_core::{future::BoxFuture, stream::BoxStream}; -use std::{ - io, - net::{IpAddr, Shutdown, SocketAddr}, -}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, -}; - -mod establish; -mod execute; -mod fetch; -mod fetch_optional; - -pub struct PostgresRawConnection { - stream: BufStream, - - // Process ID of the Backend - process_id: u32, - - // Backend-unique key to use to send a cancel query message to the server - secret_key: u32, -} - -impl PostgresRawConnection { - async fn establish(url: &str) -> Result { - let url = Url::parse(url); - let stream = TcpStream::connect(&url.address(5432)).await?; - - let mut conn = Self { - stream: BufStream::new(stream), - process_id: 0, - secret_key: 0, - }; - - establish::establish(&mut conn, &url).await?; - - Ok(conn) - } - - async fn finalize(&mut self) -> Result<(), Error> { - self.write(Terminate); - self.stream.flush().await?; - self.stream.close().await?; - - Ok(()) - } - - // Wait and return the next message to be received from Postgres. - async fn receive(&mut self) -> Result, Error> { - loop { - // Read the message header (id + len) - let mut header = ret_if_none!(self.stream.peek(5).await?); - log::trace!("recv:header {:?}", bytes::Bytes::from(&*header)); - - let id = header.get_u8()?; - let len = (header.get_u32::()? - 4) as usize; - - // Read the message body - self.stream.consume(5); - let body = ret_if_none!(self.stream.peek(len).await?); - log::trace!("recv {:?}", bytes::Bytes::from(&*body)); - - let message = match id { - b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)), - b'D' => Message::DataRow(protocol::DataRow::decode(body)?), - b'S' => { - Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?)) - } - b'Z' => Message::ReadyForQuery(protocol::ReadyForQuery::decode(body)?), - b'R' => Message::Authentication(Box::new(protocol::Authentication::decode(body)?)), - b'K' => Message::BackendKeyData(protocol::BackendKeyData::decode(body)?), - b'C' => Message::CommandComplete(protocol::CommandComplete::decode(body)?), - b'A' => Message::NotificationResponse(Box::new( - protocol::NotificationResponse::decode(body)?, - )), - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'n' => Message::NoData, - b's' => Message::PortalSuspended, - b't' => Message::ParameterDescription(Box::new( - protocol::ParameterDescription::decode(body)?, - )), - - _ => unimplemented!("unknown message id: {}", id as char), - }; - - self.stream.consume(len); - - match message { - Message::ParameterStatus(_body) => { - // TODO: not sure what to do with these yet - } - - Message::Response(_body) => { - // TODO: Transform Errors+ into an error type and return - // TODO: Log all others - } - - message => { - return Ok(Some(message)); - } - } - } - } - - pub(super) fn write(&mut self, message: impl Encode) { - let pos = self.stream.buffer_mut().len(); - - message.encode(self.stream.buffer_mut()); - - log::trace!( - "send {:?}", - bytes::Bytes::from(&self.stream.buffer_mut()[pos..]) - ); - } -} - -impl RawConnection for PostgresRawConnection { - type Backend = Postgres; - - #[inline] - fn establish(url: &str) -> BoxFuture> { - Box::pin(PostgresRawConnection::establish(url)) - } - - #[inline] - fn finalize<'c>(&'c mut self) -> BoxFuture<'c, Result<(), Error>> { - Box::pin(self.finalize()) - } - - fn execute<'c>( - &'c mut self, - query: &str, - params: PostgresQueryParameters, - ) -> BoxFuture<'c, Result> { - finish(self, query, params, 0); - - Box::pin(execute::execute(self)) - } - - fn fetch<'c>( - &'c mut self, - query: &str, - params: PostgresQueryParameters, - ) -> BoxStream<'c, Result> { - finish(self, query, params, 0); - - Box::pin(fetch::fetch(self)) - } - - fn fetch_optional<'c>( - &'c mut self, - query: &str, - params: PostgresQueryParameters, - ) -> BoxFuture<'c, Result, Error>> { - finish(self, query, params, 1); - - Box::pin(fetch_optional::fetch_optional(self)) - } -} - -fn finish( - conn: &mut PostgresRawConnection, - query: &str, - params: PostgresQueryParameters, - limit: i32, -) { - conn.write(protocol::Parse { - portal: "", - query, - param_types: &*params.types, - }); - - conn.write(protocol::Bind { - portal: "", - statement: "", - formats: &[1], // [BINARY] - // TODO: Early error if there is more than i16 - values_len: params.types.len() as i16, - values: &*params.buf, - result_formats: &[1], // [BINARY] - }); - - // TODO: Make limit be 1 for fetch_optional - conn.write(protocol::Execute { portal: "", limit }); - - conn.write(protocol::Sync); -} diff --git a/src/postgres/error.rs b/src/postgres/error.rs new file mode 100644 index 00000000..4fd098ea --- /dev/null +++ b/src/postgres/error.rs @@ -0,0 +1,11 @@ +use super::protocol::Response; +use crate::error::DbError; + +#[derive(Debug)] +pub struct PostgresError(pub(super) Box); + +impl DbError for PostgresError { + fn message(&self) -> &str { + self.0.message() + } +} diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index ea5bfcd9..f59f034f 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -1,11 +1,74 @@ mod backend; mod connection; +mod error; mod protocol; mod query; mod row; pub mod types; pub use self::{ - backend::Postgres, connection::PostgresRawConnection, query::PostgresQueryParameters, - row::PostgresRow, + backend::Postgres, connection::PostgresRawConnection, error::PostgresError, + query::PostgresQueryParameters, row::PostgresRow, }; + +#[cfg(test)] +mod tests { + use super::{Postgres, PostgresRawConnection}; + use crate::connection::{Connection, RawConnection}; + use futures_util::TryStreamExt; + + const DATABASE_URL: &str = "postgres://postgres@127.0.0.1:5432/"; + + #[tokio::test] + async fn it_connects() { + let mut conn = PostgresRawConnection::establish(DATABASE_URL) + .await + .unwrap(); + + conn.finalize().await.unwrap(); + } + + #[tokio::test] + async fn it_fails_on_connect_with_an_unknown_user() { + let res = PostgresRawConnection::establish("postgres://not_a_user@127.0.0.1:5432/").await; + + match res { + Err(crate::Error::Database(err)) => { + assert_eq!(err.message(), "role \"not_a_user\" does not exist"); + } + + _ => panic!("unexpected result"), + } + } + + #[tokio::test] + async fn it_fails_on_connect_with_an_unknown_database() { + let res = + PostgresRawConnection::establish("postgres://postgres@127.0.0.1:5432/fdggsdfgsdaf") + .await; + + match res { + Err(crate::Error::Database(err)) => { + assert_eq!(err.message(), "database \"fdggsdfgsdaf\" does not exist"); + } + + _ => panic!("unexpected result"), + } + } + + #[tokio::test] + async fn it_fetches_tuples() { + let conn = Connection::::establish(DATABASE_URL) + .await + .unwrap(); + + let roles: Vec<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles") + .fetch(&conn) + .try_collect() + .await + .unwrap(); + + // Sanity check to be sure we did indeed fetch tuples + assert!(roles.binary_search(&("postgres".to_string(), true)).is_ok()); + } +} diff --git a/src/url.rs b/src/url.rs index ddd06f95..348ffaff 100644 --- a/src/url.rs +++ b/src/url.rs @@ -3,9 +3,9 @@ use std::net::{IpAddr, SocketAddr}; pub struct Url(url::Url); impl Url { - pub fn parse(url: &str) -> Self { + pub fn parse(url: &str) -> crate::Result { // TODO: Handle parse errors - Url(url::Url::parse(url).unwrap()) + Ok(Url(url::Url::parse(url).unwrap())) } pub fn host(&self) -> &str { @@ -16,7 +16,7 @@ impl Url { self.0.port().unwrap_or(default) } - pub fn address(&self, default_port: u16) -> SocketAddr { + pub fn resolve(&self, default_port: u16) -> SocketAddr { // TODO: DNS let host: IpAddr = self.host().parse().unwrap(); let port = self.port(default_port);