From 0c9bea4ab21c2de2af9de519b3548386e85ad013 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 10 Jul 2020 11:11:43 +0200 Subject: [PATCH] Fixing panics for disabled statement cache --- Cargo.lock | 1 + Cargo.toml | 1 + sqlx-core/src/common/statement_cache.rs | 5 +++ sqlx-core/src/mysql/connection/executor.rs | 30 ++++++++++------- sqlx-core/src/mysql/statement.rs | 2 +- sqlx-core/src/postgres/connection/executor.rs | 33 +++++++++++-------- sqlx-core/src/postgres/statement.rs | 2 +- tests/mysql/mysql.rs | 23 ++++++++++++- tests/postgres/postgres.rs | 28 +++++++++++++--- 9 files changed, 93 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 36847e7d..9204d0c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2429,6 +2429,7 @@ dependencies = [ "time 0.2.16", "tokio 0.2.21", "trybuild", + "url 2.1.1", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index e5706e39..390f6910 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,7 @@ sqlx-test = { path = "./sqlx-test" } paste = "0.1.16" serde = { version = "1.0.111", features = [ "derive" ] } serde_json = "1.0.53" +url = "2.1.1" # # Any diff --git a/sqlx-core/src/common/statement_cache.rs b/sqlx-core/src/common/statement_cache.rs index dc2f324e..d5695a7c 100644 --- a/sqlx-core/src/common/statement_cache.rs +++ b/sqlx-core/src/common/statement_cache.rs @@ -63,4 +63,9 @@ impl StatementCache { pub fn capacity(&self) -> usize { self.inner.capacity() } + + /// Returns true if the cache capacity is more than 0. + pub fn is_enabled(&self) -> bool { + self.capacity() > 0 + } } diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 8365750d..306794a9 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{borrow::Cow, sync::Arc}; use bytes::Bytes; use either::Either; @@ -26,9 +26,10 @@ use crate::mysql::{ use crate::statement::StatementInfo; impl MySqlConnection { - async fn prepare(&mut self, query: &str) -> Result<&mut MySqlStatement, Error> { + async fn prepare<'a>(&'a mut self, query: &str) -> Result, Error> { if self.cache_statement.contains_key(query) { - return Ok(self.cache_statement.get_mut(query).unwrap()); + let stmt = self.cache_statement.get_mut(query).unwrap(); + return Ok(Cow::Borrowed(&*stmt)); } // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html @@ -80,16 +81,21 @@ impl MySqlConnection { nullable, }; - // in case of the cache being full, close the least recently used statement - if let Some(statement) = self.cache_statement.insert(query, statement) { - self.stream - .send_packet(StmtClose { - statement: statement.id, - }) - .await?; - } + if self.cache_statement.is_enabled() { + // in case of the cache being full, close the least recently used statement + if let Some(statement) = self.cache_statement.insert(query, statement) { + self.stream + .send_packet(StmtClose { + statement: statement.id, + }) + .await?; + } - Ok(self.cache_statement.get_mut(query).unwrap()) + let stmt = self.cache_statement.get_mut(query).unwrap(); + Ok(Cow::Borrowed(&*stmt)) + } else { + Ok(Cow::Owned(statement)) + } } async fn recv_result_metadata(&mut self, mut packet: Packet) -> Result<(), Error> { diff --git a/sqlx-core/src/mysql/statement.rs b/sqlx-core/src/mysql/statement.rs index 772a78e4..6ab41400 100644 --- a/sqlx-core/src/mysql/statement.rs +++ b/sqlx-core/src/mysql/statement.rs @@ -1,6 +1,6 @@ use super::MySqlColumn; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MySqlStatement { pub(crate) id: u32, pub(crate) columns: Vec, diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index 939e8352..39809e75 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -3,7 +3,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; -use std::sync::Arc; +use std::{borrow::Cow, sync::Arc}; use crate::error::Error; use crate::executor::{Execute, Executor}; @@ -155,30 +155,37 @@ impl PgConnection { self.pending_ready_for_query_count += 1; } - async fn prepare( - &mut self, + async fn prepare<'a>( + &'a mut self, query: &str, arguments: &PgArguments, - ) -> Result<&mut PgStatement, Error> { + ) -> Result, Error> { let contains = self.cache_statement.contains_key(query); if contains { - return Ok(self.cache_statement.get_mut(query).unwrap()); + return Ok(Cow::Borrowed( + &*self.cache_statement.get_mut(query).unwrap(), + )); } let statement = prepare(self, query, arguments).await?; - if let Some(statement) = self.cache_statement.insert(query, statement) { - self.stream.write(Close::Statement(statement.id)); - self.stream.write(Flush); + if self.cache_statement.is_enabled() { + if let Some(statement) = self.cache_statement.insert(query, statement) { + self.stream.write(Close::Statement(statement.id)); + self.stream.write(Flush); - self.stream.flush().await?; + self.stream.flush().await?; - self.wait_for_close_complete(1).await?; - self.recv_ready_for_query().await?; + self.wait_for_close_complete(1).await?; + } + + Ok(Cow::Borrowed( + &*self.cache_statement.get_mut(query).unwrap(), + )) + } else { + Ok(Cow::Owned(statement)) } - - Ok(self.cache_statement.get_mut(query).unwrap()) } async fn run( diff --git a/sqlx-core/src/postgres/statement.rs b/sqlx-core/src/postgres/statement.rs index 10ee0970..c9234c12 100644 --- a/sqlx-core/src/postgres/statement.rs +++ b/sqlx-core/src/postgres/statement.rs @@ -1,6 +1,6 @@ use super::{PgColumn, PgTypeInfo}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PgStatement { pub(crate) id: u32, pub(crate) columns: Vec, diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index bbd414c2..2951d1c2 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -1,7 +1,8 @@ use futures::TryStreamExt; use sqlx::mysql::{MySql, MySqlPool, MySqlPoolOptions, MySqlRow}; use sqlx::{Connection, Done, Executor, Row}; -use sqlx_test::new; +use sqlx_test::{new, setup_if_needed}; +use std::env; #[sqlx_macros::test] async fn it_connects() -> anyhow::Result<()> { @@ -97,6 +98,26 @@ async fn it_executes_with_pool() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_works_with_cache_disabled() -> anyhow::Result<()> { + setup_if_needed(); + + let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?; + url.query_pairs_mut() + .append_pair("statement-cache-capacity", "0"); + + let mut conn = MySqlConnection::connect(url.as_ref()).await?; + + for index in 1..=10_i32 { + let _ = sqlx::query("SELECT ?") + .bind(index) + .execute(&mut conn) + .await?; + } + + Ok(()) +} + #[sqlx_macros::test] async fn it_drops_results_in_affected_rows() -> anyhow::Result<()> { let mut conn = new::().await?; diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 958d6c51..ba045839 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -2,9 +2,9 @@ use futures::TryStreamExt; use sqlx::postgres::{ PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity, }; -use sqlx::postgres::{PgPoolOptions, PgRow}; -use sqlx::{postgres::Postgres, Connection, Done, Executor, Row}; -use sqlx_test::new; +use sqlx::postgres::{PgPoolOptions, PgRow, Postgres}; +use sqlx::{Connection, Done, Executor, PgPool, Row}; +use sqlx_test::{new, setup_if_needed}; use std::env; use std::thread; use std::time::Duration; @@ -124,7 +124,7 @@ async fn it_describes_and_inserts_json() -> anyhow::Result<()> { let _ = conn .execute( r#" -CREATE TEMPORARY TABLE json_stuff (obj json); +CREATE TEMPORARY TABLE json_stuff (obj jsonb); "#, ) .await?; @@ -142,6 +142,26 @@ CREATE TEMPORARY TABLE json_stuff (obj json); Ok(()) } +#[sqlx_macros::test] +async fn it_works_with_cache_disabled() -> anyhow::Result<()> { + setup_if_needed(); + + let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?; + url.query_pairs_mut() + .append_pair("statement-cache-capacity", "0"); + + let mut conn = PgConnection::connect(url.as_ref()).await?; + + for index in 1..=10_i32 { + let _ = sqlx::query("SELECT $1") + .bind(index) + .execute(&mut conn) + .await?; + } + + Ok(()) +} + #[sqlx_macros::test] async fn it_executes_with_pool() -> anyhow::Result<()> { let pool = sqlx_test::pool::().await?;