From e594a7fdca69bd64fe30c423ade1f1424fa609f9 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 9 Mar 2020 20:37:25 -0700 Subject: [PATCH] Postgres: don't cache failed statement --- sqlx-core/src/postgres/executor.rs | 21 ++++++++++++++++++--- tests/postgres.rs | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index 31fabca5..41fc40b2 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -56,8 +56,6 @@ impl PgConnection { query, }); - self.cache_statement_id.insert(query.into(), id); - Ok(id) } } @@ -163,6 +161,24 @@ impl PgConnection { self.stream.flush().await?; self.is_ready = false; + // only cache + if let Some(statement) = statement { + // prefer redundant lookup to copying the query string + if !self.cache_statement_id.contains_key(query) { + // wait for `ParseComplete` on the stream or the + // error before we cache the statement + match self.stream.read().await? { + Message::ParseComplete => { + self.cache_statement_id.insert(query.into(), statement); + } + + message => { + return Err(protocol_err!("run: unexpected message: {:?}", message).into()); + } + } + } + } + Ok(statement) } @@ -214,7 +230,6 @@ impl PgConnection { let result_fields = result.map_or_else(Default::default, |r| r.fields); - // TODO: cache this result let type_names = self .get_type_names( params diff --git a/tests/postgres.rs b/tests/postgres.rs index a064570f..96898f3f 100644 --- a/tests/postgres.rs +++ b/tests/postgres.rs @@ -1,6 +1,6 @@ use futures::TryStreamExt; use sqlx::postgres::{PgPool, PgQueryAs, PgRow}; -use sqlx::{Connection, Executor, Postgres, Row}; +use sqlx::{Connection, Cursor, Executor, Postgres, Row}; use sqlx_test::new; use std::time::Duration; @@ -306,6 +306,23 @@ async fn pool_smoke_test() -> anyhow::Result<()> { Ok(()) } +#[cfg_attr(feature = "runtime-async-std", async_std::test)] +#[cfg_attr(feature = "runtime-tokio", tokio::test)] +async fn test_invalid_query() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute("definitely not a correct query") + .await + .unwrap_err(); + + let mut cursor = conn.fetch("select 1"); + let row = cursor.next().await?.unwrap(); + + assert_eq!(row.get::(0), 1i32); + + Ok(()) +} + #[cfg_attr(feature = "runtime-async-std", async_std::test)] #[cfg_attr(feature = "runtime-tokio", tokio::test)] async fn test_describe() -> anyhow::Result<()> {