diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 6711d9e4..281b1614 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -172,7 +172,7 @@ pub trait Execute<'q, DB: Database>: Send + Sized { /// will be prepared (and cached) before execution. fn take_arguments(&mut self) -> Option<>::Arguments>; - /// Returns true if query has any parameters. + /// Returns `true` if the statement should be cached. fn persistent(&self) -> bool; } @@ -191,7 +191,7 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { #[inline] fn persistent(&self) -> bool { - false + true } } @@ -208,6 +208,6 @@ impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option< bool { - self.1.is_some() + true } } diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 306794a9..4f592c62 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -26,7 +26,11 @@ use crate::mysql::{ use crate::statement::StatementInfo; impl MySqlConnection { - async fn prepare<'a>(&'a mut self, query: &str) -> Result, Error> { + async fn prepare<'a>( + &'a mut self, + query: &str, + persistent: bool, + ) -> Result, Error> { if self.cache_statement.contains_key(query) { let stmt = self.cache_statement.get_mut(query).unwrap(); return Ok(Cow::Borrowed(&*stmt)); @@ -81,7 +85,7 @@ impl MySqlConnection { nullable, }; - if self.cache_statement.is_enabled() { + if persistent && self.cache_statement.is_enabled() { // in case of the cache being full, close the least recently used statement if let Some(statement) = self.cache_statement.insert(query, statement) { self.stream @@ -142,12 +146,13 @@ impl MySqlConnection { &'c mut self, query: &str, arguments: Option, + persistent: bool, ) -> Result, Error>> + 'c, Error> { self.stream.wait_until_ready().await?; self.stream.busy = Busy::Result; let format = if let Some(arguments) = arguments { - let statement = self.prepare(query).await?.id; + let statement = self.prepare(query, persistent).await?.id; // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html self.stream @@ -250,9 +255,10 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { { let s = query.query(); let arguments = query.take_arguments(); + let persistent = query.persistent(); Box::pin(try_stream! { - let s = self.run(s, arguments).await?; + let s = self.run(s, arguments, persistent).await?; pin_mut!(s); while let Some(v) = s.try_next().await? { @@ -295,7 +301,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { let query = query.query(); Box::pin(async move { - let statement = self.prepare(query).await?; + let statement = self.prepare(query, false).await?; let columns = statement.columns.clone(); let nullable = statement.nullable.clone(); diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index b5ba909b..3073bca7 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -16,16 +16,12 @@ use crate::postgres::{ statement::PgStatement, PgArguments, PgConnection, PgDone, PgRow, PgValueFormat, Postgres, }; use crate::statement::StatementInfo; -use message::Flush; async fn prepare( conn: &mut PgConnection, query: &str, arguments: &PgArguments, ) -> Result { - // before we continue, wait until we are "ready" to accept more queries - conn.wait_until_ready().await?; - let id = conn.next_statement_id; conn.next_statement_id = conn.next_statement_id.wrapping_add(1); @@ -72,8 +68,8 @@ async fn prepare( // get the statement columns and parameters conn.stream.write(message::Describe::Statement(id)); - conn.write_sync(); + conn.write_sync(); conn.stream.flush().await?; let parameters = recv_desc_params(conn).await?; @@ -87,6 +83,8 @@ async fn prepare( let columns = (&*conn.scratch_row_columns).clone(); + conn.wait_until_ready().await?; + Ok(PgStatement { id, parameters, @@ -174,11 +172,12 @@ impl PgConnection { if store_to_cache && self.cache_statement.is_enabled() { if let Some(statement) = self.cache_statement.insert(query, statement) { self.stream.write(Close::Statement(statement.id)); - self.stream.write(Flush); + self.write_sync(); self.stream.flush().await?; self.wait_for_close_complete(1).await?; + self.recv_ready_for_query().await?; } Ok(Cow::Borrowed( @@ -194,6 +193,7 @@ impl PgConnection { query: &str, arguments: Option, limit: u8, + persistent: bool, ) -> Result, Error>> + '_, Error> { // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; @@ -201,7 +201,7 @@ impl PgConnection { let format = if let Some(mut arguments) = arguments { // prepare the statement if this our first time executing it // always return the statement ID here - let statement = self.prepare(query, &arguments, true).await?.id; + let statement = self.prepare(query, &arguments, persistent).await?.id; // patch holes created during encoding arguments.buffer.patch_type_holes(self).await?; @@ -334,9 +334,10 @@ impl<'c> Executor<'c> for &'c mut PgConnection { { let s = query.query(); let arguments = query.take_arguments(); + let persistent = query.persistent(); Box::pin(try_stream! { - let s = self.run(s, arguments, 0).await?; + let s = self.run(s, arguments, 0, persistent).await?; pin_mut!(s); while let Some(v) = s.try_next().await? { @@ -357,9 +358,10 @@ impl<'c> Executor<'c> for &'c mut PgConnection { { let s = query.query(); let arguments = query.take_arguments(); + let persistent = query.persistent(); Box::pin(async move { - let s = self.run(s, arguments, 1).await?; + let s = self.run(s, arguments, 1, persistent).await?; pin_mut!(s); while let Some(s) = s.try_next().await? { diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index d9efd441..bf173741 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -5,7 +5,7 @@ use futures_core::stream::BoxStream; use futures_util::{future, StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::{Arguments, IntoArguments}; -use crate::database::{Database, HasArguments}; +use crate::database::{Database, HasArguments, HasStatementCache}; use crate::encode::Encode; use crate::error::Error; use crate::executor::{Execute, Executor}; @@ -17,6 +17,7 @@ pub struct Query<'q, DB: Database, A> { pub(crate) query: &'q str, pub(crate) arguments: Option, pub(crate) database: PhantomData, + pub(crate) persistent: bool, } /// SQL query that will map its results to owned Rust types. @@ -50,7 +51,7 @@ where #[inline] fn persistent(&self) -> bool { - self.arguments.is_some() + self.persistent } } @@ -72,6 +73,24 @@ impl<'q, DB: Database> Query<'q, DB, >::Arguments> { } } +impl<'q, DB, A> Query<'q, DB, A> +where + DB: Database + HasStatementCache, +{ + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.persistent = value; + self + } +} + impl<'q, DB, A: Send> Query<'q, DB, A> where DB: Database, @@ -360,6 +379,7 @@ where database: PhantomData, arguments: Some(Default::default()), query: sql, + persistent: true, } } @@ -374,6 +394,7 @@ where database: PhantomData, arguments: Some(arguments), query: sql, + persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index f5573d1f..0860410f 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -38,7 +38,7 @@ where #[inline] fn persistent(&self) -> bool { - self.inner.arguments.is_some() + self.inner.persistent() } } diff --git a/sqlx-core/src/sqlite/connection/executor.rs b/sqlx-core/src/sqlite/connection/executor.rs index 14b488b1..a6703dba 100644 --- a/sqlx-core/src/sqlite/connection/executor.rs +++ b/sqlx-core/src/sqlite/connection/executor.rs @@ -109,6 +109,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { { let s = query.query(); let arguments = query.take_arguments(); + let persistent = query.persistent() && arguments.is_some(); Box::pin(try_stream! { let SqliteConnection { @@ -121,7 +122,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let mut stmt = prepare(conn, statements, statement, s, arguments.is_some())?; + let mut stmt = prepare(conn, statements, statement, s, persistent)?; // bind arguments, if any, to the statement bind(&mut stmt, arguments)?; diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 2951d1c2..3b871d44 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -1,5 +1,5 @@ use futures::TryStreamExt; -use sqlx::mysql::{MySql, MySqlPool, MySqlPoolOptions, MySqlRow}; +use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow}; use sqlx::{Connection, Done, Executor, Row}; use sqlx_test::{new, setup_if_needed}; use std::env; @@ -207,6 +207,7 @@ async fn it_caches_statements() -> anyhow::Result<()> { for i in 0..2 { let row = sqlx::query("SELECT ? AS val") .bind(i) + .persistent(true) .fetch_one(&mut conn) .await?; @@ -219,6 +220,20 @@ async fn it_caches_statements() -> anyhow::Result<()> { conn.clear_cached_statements().await?; assert_eq!(0, conn.cached_statements_size()); + for i in 0..2 { + let row = sqlx::query("SELECT ? AS val") + .bind(i) + .persistent(false) + .fetch_one(&mut conn) + .await?; + + let val: u32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(0, conn.cached_statements_size()); + Ok(()) } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index e8907b94..804cb23e 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -3,7 +3,7 @@ use sqlx::postgres::{ PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity, }; use sqlx::postgres::{PgPoolOptions, PgRow, Postgres}; -use sqlx::{Connection, Done, Executor, PgPool, Row}; +use sqlx::{Connection, Done, Executor, Row}; use sqlx_test::{new, setup_if_needed}; use std::env; use std::thread; @@ -28,7 +28,7 @@ async fn it_can_select_void() -> anyhow::Result<()> { let mut conn = new::().await?; // pg_notify just happens to be a function that returns void - let _value: () = sqlx::query_scalar("select pg_notify('chan', 'message');") + let _: () = sqlx::query_scalar("select pg_notify('chan', 'message');") .fetch_one(&mut conn) .await?; @@ -132,12 +132,12 @@ CREATE TEMPORARY TABLE json_stuff (obj json); let query = "INSERT INTO json_stuff (obj) VALUES ($1)"; let _ = conn.describe(query).await?; - let cnt = sqlx::query(query) + let done = sqlx::query(query) .bind(serde_json::json!({ "a": "a" })) .execute(&mut conn) .await?; - assert_eq!(cnt, 1); + assert_eq!(done.rows_affected(), 1); Ok(()) } @@ -563,6 +563,7 @@ async fn it_caches_statements() -> anyhow::Result<()> { for i in 0..2 { let row = sqlx::query("SELECT $1 AS val") .bind(i) + .persistent(true) .fetch_one(&mut conn) .await?; @@ -575,6 +576,20 @@ async fn it_caches_statements() -> anyhow::Result<()> { conn.clear_cached_statements().await?; assert_eq!(0, conn.cached_statements_size()); + for i in 0..2 { + let row = sqlx::query("SELECT $1 AS val") + .bind(i) + .persistent(false) + .fetch_one(&mut conn) + .await?; + + let val: u32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(0, conn.cached_statements_size()); + Ok(()) } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index a2cca849..1cc93cc1 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -374,6 +374,7 @@ async fn it_caches_statements() -> anyhow::Result<()> { for i in 0..2 { let row = sqlx::query("SELECT ? AS val") .bind(i) + .persistent(true) .fetch_one(&mut conn) .await?; @@ -386,5 +387,21 @@ async fn it_caches_statements() -> anyhow::Result<()> { conn.clear_cached_statements().await?; assert_eq!(0, conn.cached_statements_size()); + let mut conn = new::().await?; + + for i in 0..2 { + let row = sqlx::query("SELECT ? AS val") + .bind(i) + .persistent(false) + .fetch_one(&mut conn) + .await?; + + let val: i32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(0, conn.cached_statements_size()); + Ok(()) }