From 1a7480774b2526ab01e97f44fc07f3c391be287d Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Thu, 2 Jul 2020 22:37:04 -0700 Subject: [PATCH] fix(postgres): after closing a statement, the connection should await CloseComplete --- sqlx-core/src/postgres/connection/executor.rs | 31 ++++++++++++++++++ sqlx-core/src/postgres/connection/mod.rs | 10 ++++-- sqlx-core/src/postgres/message/mod.rs | 2 ++ sqlx-test/src/lib.rs | 2 +- tests/postgres/postgres.rs | 32 +++++++++++++++++-- 5 files changed, 71 insertions(+), 6 deletions(-) diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index e18a9c72..5cb7eeed 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -90,6 +90,34 @@ async fn recv_desc_rows(conn: &mut PgConnection) -> Result Result<(), Error> { + // we need to wait for the [CloseComplete] to be returned from the server + while count > 0 { + match self.stream.recv().await? { + message if message.format == MessageFormat::PortalSuspended => { + // there was an open portal + // this can happen if the last time a statement was used it was not fully executed + // such as in [fetch_one] + } + + message if message.format == MessageFormat::CloseComplete => { + // successfully closed the statement (and freed up the server resources) + count -= 1; + } + + message => { + return Err(err_protocol!( + "expecting PortalSuspended or CloseComplete but received {:?}", + message.format + )); + } + } + } + + Ok(()) + } + async fn prepare(&mut self, query: &str, arguments: &PgArguments) -> Result { if let Some(statement) = self.cache_statement.get_mut(query) { return Ok(*statement); @@ -100,7 +128,10 @@ impl PgConnection { if let Some(statement) = self.cache_statement.insert(query, statement) { self.stream.write(Close::Statement(statement)); self.stream.write(Flush); + self.stream.flush().await?; + + self.wait_for_close_complete(1).await?; } Ok(statement) diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index 885ba92f..d4f6ea6a 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -126,16 +126,20 @@ impl Connection for PgConnection { fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { - let mut needs_flush = false; + let mut cleared = 0_usize; + + self.wait_until_ready().await?; while let Some(statement) = self.cache_statement.remove_lru() { self.stream.write(Close::Statement(statement)); - needs_flush = true; + cleared += 1; } - if needs_flush { + if cleared > 0 { self.stream.write(Flush); self.stream.flush().await?; + + self.wait_for_close_complete(cleared).await?; } Ok(()) diff --git a/sqlx-core/src/postgres/message/mod.rs b/sqlx-core/src/postgres/message/mod.rs index 87f11feb..6c8d1f30 100644 --- a/sqlx-core/src/postgres/message/mod.rs +++ b/sqlx-core/src/postgres/message/mod.rs @@ -55,6 +55,7 @@ pub enum MessageFormat { Authentication, BackendKeyData, BindComplete, + CloseComplete, CommandComplete, DataRow, EmptyQueryResponse, @@ -93,6 +94,7 @@ impl MessageFormat { Ok(match v { b'1' => MessageFormat::ParseComplete, b'2' => MessageFormat::BindComplete, + b'3' => MessageFormat::CloseComplete, b'C' => MessageFormat::CommandComplete, b'D' => MessageFormat::DataRow, b'E' => MessageFormat::ErrorResponse, diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index e54204c5..918e95eb 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -1,7 +1,7 @@ use sqlx::{database::Database, Connect, Pool}; use std::env; -fn setup_if_needed() { +pub fn setup_if_needed() { let _ = dotenv::dotenv(); let _ = env_logger::builder().is_test(true).try_init(); } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index eb83f490..f4564b9b 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1,7 +1,8 @@ use futures::TryStreamExt; use sqlx::postgres::PgRow; -use sqlx::postgres::{PgDatabaseError, PgErrorPosition, PgSeverity}; -use sqlx::{postgres::Postgres, Connection, Executor, PgPool, Row}; +use std::env; +use sqlx::postgres::{PgConnection, PgConnectOptions, PgDatabaseError, PgErrorPosition, PgSeverity}; +use sqlx::{postgres::Postgres, Connect, Connection, Executor, PgPool, Row}; use sqlx_test::new; use std::time::Duration; @@ -513,3 +514,30 @@ async fn it_caches_statements() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_closes_statement_from_cache_issue_470() -> anyhow::Result<()> { + sqlx_test::setup_if_needed(); + + let mut options: PgConnectOptions = env::var("DATABASE_URL")?.parse().unwrap(); + + // a capacity of 1 means that before each statement (after the first) + // we will close the previous statement + options = options.statement_cache_capacity(1); + + let mut conn = PgConnection::connect_with(&options).await?; + + for i in 0..5 { + let row = sqlx::query(&*format!("SELECT {}::int4 AS val", i)) + .fetch_one(&mut conn) + .await?; + + let val: i32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(1, conn.cached_statements_size()); + + Ok(()) +}