diff --git a/src/executor.rs b/src/executor.rs index 177bb57f..0d9af406 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -29,7 +29,7 @@ pub trait Executor: Send { ) -> BoxFuture<'c, Result, Error>> where A: IntoQueryParameters + Send, - T: FromSqlRow + Send + Unpin + T: FromSqlRow + Send + Unpin, { Box::pin(self.fetch(query, params).try_collect()) } @@ -50,12 +50,10 @@ pub trait Executor: Send { ) -> BoxFuture<'c, Result> where A: IntoQueryParameters + Send, - T: FromSqlRow + Send + T: FromSqlRow + Send, { let fut = self.fetch_optional(query, params); - Box::pin(async move { - fut.await?.ok_or(Error::NotFound) - }) + Box::pin(async move { fut.await?.ok_or(Error::NotFound) }) } } diff --git a/src/postgres/connection.rs b/src/postgres/connection.rs index 5abb8e3d..2738eba6 100644 --- a/src/postgres/connection.rs +++ b/src/postgres/connection.rs @@ -229,13 +229,7 @@ impl PostgresRawConnection { async fn step(&mut self) -> crate::Result> { while let Some(message) = self.receive().await? { match message { - Message::BindComplete - | Message::ParseComplete - | Message::CloseComplete => {} - - Message::PortalSuspended => { - return Ok(Some(Step::MoreRowsAvailable)); - } + Message::BindComplete | Message::ParseComplete | Message::PortalSuspended | Message::CloseComplete => {} Message::CommandComplete(body) => { return Ok(Some(Step::Command(body.affected_rows()))); @@ -267,7 +261,6 @@ impl PostgresRawConnection { enum Step { Command(u64), Row(PostgresRow), - MoreRowsAvailable, } impl RawConnection for PostgresRawConnection { @@ -324,27 +317,18 @@ impl RawConnection for PostgresRawConnection { query: &str, params: PostgresQueryParameters, ) -> BoxFuture<'c, crate::Result>> { - self.execute(query, params, 1); + self.execute(query, params, 2); Box::pin(async move { let mut row: Option = None; while let Some(step) = self.step().await? { - match step { - Step::Row(r) => { - // This should only ever execute once because we used the - // protocol-level limit - debug_assert!(row.is_none()); - row = Some(r); - } - - Step::MoreRowsAvailable => { - // Command execution finished but there was more than - // one row available + if let Step::Row(r) = step { + if row.is_some() { return Err(crate::Error::FoundMoreThanOne); } - _ => {} + row = Some(r); } } diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index 40e13925..4b1d8f20 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -64,7 +64,7 @@ mod tests { } #[tokio::test] - async fn it_fetches_tuples_from_a_system_table() { + async fn it_fetches_tuples_from_system_roles() { let conn = Connection::::establish(DATABASE_URL) .await .unwrap(); @@ -78,4 +78,52 @@ mod tests { // Sanity check to be sure we did indeed fetch tuples assert!(roles.binary_search(&("postgres".to_string(), true)).is_ok()); } + + #[tokio::test] + async fn it_fetches_nothing_for_no_rows_from_system_roles() { + let conn = Connection::::establish(DATABASE_URL) + .await + .unwrap(); + + let res: Option<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'") + .fetch_optional(&conn) + .await + .unwrap(); + + assert!(res.is_none()); + + let res: crate::Result<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'not-a-user'") + .fetch_one(&conn) + .await; + + matches::assert_matches!(res, Err(crate::Error::NotFound)); + } + + #[tokio::test] + async fn it_errors_on_fetching_more_than_one_row_from_system_roles() { + let conn = Connection::::establish(DATABASE_URL) + .await + .unwrap(); + + let res: crate::Result<(String, bool)> = crate::query("SELECT rolname, rolsuper FROM pg_roles") + .fetch_one(&conn) + .await; + + matches::assert_matches!(res, Err(crate::Error::FoundMoreThanOne)); + } + + #[tokio::test] + async fn it_fetches_one_row_from_system_roles() { + let conn = Connection::::establish(DATABASE_URL) + .await + .unwrap(); + + let res: (String, bool) = crate::query("SELECT rolname, rolsuper FROM pg_roles WHERE rolname = 'postgres'") + .fetch_one(&conn) + .await + .unwrap(); + + assert_eq!(res.0, "postgres"); + assert!(res.1); + } }