Allow setting caching per-query

This commit is contained in:
Julius de Bruijn 2020-07-16 15:07:47 +02:00 committed by Ryan Leckey
parent c9c11c8302
commit e8a4c54ac7
9 changed files with 103 additions and 26 deletions

View File

@ -172,7 +172,7 @@ pub trait Execute<'q, DB: Database>: Send + Sized {
/// will be prepared (and cached) before execution.
fn take_arguments(&mut self) -> Option<<DB as HasArguments<'q>>::Arguments>;
/// Returns true if query has any parameters.
/// Returns `true` if the statement should be cached.
fn persistent(&self) -> bool;
}
@ -191,7 +191,7 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str {
#[inline]
fn persistent(&self) -> bool {
false
true
}
}
@ -208,6 +208,6 @@ impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<<DB as HasArguments<
#[inline]
fn persistent(&self) -> bool {
self.1.is_some()
true
}
}

View File

@ -26,7 +26,11 @@ use crate::mysql::{
use crate::statement::StatementInfo;
impl MySqlConnection {
async fn prepare<'a>(&'a mut self, query: &str) -> Result<Cow<'a, MySqlStatement>, Error> {
async fn prepare<'a>(
&'a mut self,
query: &str,
persistent: bool,
) -> Result<Cow<'a, MySqlStatement>, Error> {
if self.cache_statement.contains_key(query) {
let stmt = self.cache_statement.get_mut(query).unwrap();
return Ok(Cow::Borrowed(&*stmt));
@ -81,7 +85,7 @@ impl MySqlConnection {
nullable,
};
if self.cache_statement.is_enabled() {
if persistent && self.cache_statement.is_enabled() {
// in case of the cache being full, close the least recently used statement
if let Some(statement) = self.cache_statement.insert(query, statement) {
self.stream
@ -142,12 +146,13 @@ impl MySqlConnection {
&'c mut self,
query: &str,
arguments: Option<MySqlArguments>,
persistent: bool,
) -> Result<impl Stream<Item = Result<Either<MySqlDone, MySqlRow>, Error>> + 'c, Error> {
self.stream.wait_until_ready().await?;
self.stream.busy = Busy::Result;
let format = if let Some(arguments) = arguments {
let statement = self.prepare(query).await?.id;
let statement = self.prepare(query, persistent).await?.id;
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
self.stream
@ -250,9 +255,10 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
{
let s = query.query();
let arguments = query.take_arguments();
let persistent = query.persistent();
Box::pin(try_stream! {
let s = self.run(s, arguments).await?;
let s = self.run(s, arguments, persistent).await?;
pin_mut!(s);
while let Some(v) = s.try_next().await? {
@ -295,7 +301,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
let query = query.query();
Box::pin(async move {
let statement = self.prepare(query).await?;
let statement = self.prepare(query, false).await?;
let columns = statement.columns.clone();
let nullable = statement.nullable.clone();

View File

@ -16,16 +16,12 @@ use crate::postgres::{
statement::PgStatement, PgArguments, PgConnection, PgDone, PgRow, PgValueFormat, Postgres,
};
use crate::statement::StatementInfo;
use message::Flush;
async fn prepare(
conn: &mut PgConnection,
query: &str,
arguments: &PgArguments,
) -> Result<PgStatement, Error> {
// before we continue, wait until we are "ready" to accept more queries
conn.wait_until_ready().await?;
let id = conn.next_statement_id;
conn.next_statement_id = conn.next_statement_id.wrapping_add(1);
@ -72,8 +68,8 @@ async fn prepare(
// get the statement columns and parameters
conn.stream.write(message::Describe::Statement(id));
conn.write_sync();
conn.write_sync();
conn.stream.flush().await?;
let parameters = recv_desc_params(conn).await?;
@ -87,6 +83,8 @@ async fn prepare(
let columns = (&*conn.scratch_row_columns).clone();
conn.wait_until_ready().await?;
Ok(PgStatement {
id,
parameters,
@ -174,11 +172,12 @@ impl PgConnection {
if store_to_cache && self.cache_statement.is_enabled() {
if let Some(statement) = self.cache_statement.insert(query, statement) {
self.stream.write(Close::Statement(statement.id));
self.stream.write(Flush);
self.write_sync();
self.stream.flush().await?;
self.wait_for_close_complete(1).await?;
self.recv_ready_for_query().await?;
}
Ok(Cow::Borrowed(
@ -194,6 +193,7 @@ impl PgConnection {
query: &str,
arguments: Option<PgArguments>,
limit: u8,
persistent: bool,
) -> Result<impl Stream<Item = Result<Either<PgDone, PgRow>, Error>> + '_, Error> {
// before we continue, wait until we are "ready" to accept more queries
self.wait_until_ready().await?;
@ -201,7 +201,7 @@ impl PgConnection {
let format = if let Some(mut arguments) = arguments {
// prepare the statement if this our first time executing it
// always return the statement ID here
let statement = self.prepare(query, &arguments, true).await?.id;
let statement = self.prepare(query, &arguments, persistent).await?.id;
// patch holes created during encoding
arguments.buffer.patch_type_holes(self).await?;
@ -334,9 +334,10 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
{
let s = query.query();
let arguments = query.take_arguments();
let persistent = query.persistent();
Box::pin(try_stream! {
let s = self.run(s, arguments, 0).await?;
let s = self.run(s, arguments, 0, persistent).await?;
pin_mut!(s);
while let Some(v) = s.try_next().await? {
@ -357,9 +358,10 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
{
let s = query.query();
let arguments = query.take_arguments();
let persistent = query.persistent();
Box::pin(async move {
let s = self.run(s, arguments, 1).await?;
let s = self.run(s, arguments, 1, persistent).await?;
pin_mut!(s);
while let Some(s) = s.try_next().await? {

View File

@ -5,7 +5,7 @@ use futures_core::stream::BoxStream;
use futures_util::{future, StreamExt, TryFutureExt, TryStreamExt};
use crate::arguments::{Arguments, IntoArguments};
use crate::database::{Database, HasArguments};
use crate::database::{Database, HasArguments, HasStatementCache};
use crate::encode::Encode;
use crate::error::Error;
use crate::executor::{Execute, Executor};
@ -17,6 +17,7 @@ pub struct Query<'q, DB: Database, A> {
pub(crate) query: &'q str,
pub(crate) arguments: Option<A>,
pub(crate) database: PhantomData<DB>,
pub(crate) persistent: bool,
}
/// SQL query that will map its results to owned Rust types.
@ -50,7 +51,7 @@ where
#[inline]
fn persistent(&self) -> bool {
self.arguments.is_some()
self.persistent
}
}
@ -72,6 +73,24 @@ impl<'q, DB: Database> Query<'q, DB, <DB as HasArguments<'q>>::Arguments> {
}
}
impl<'q, DB, A> Query<'q, DB, A>
where
DB: Database + HasStatementCache,
{
/// If `true`, the statement will get prepared once and cached to the
/// connection's statement cache.
///
/// If queried once with the flag set to `true`, all subsequent queries
/// matching the one with the flag will use the cached statement until the
/// cache is cleared.
///
/// Default: `true`.
pub fn persistent(mut self, value: bool) -> Self {
self.persistent = value;
self
}
}
impl<'q, DB, A: Send> Query<'q, DB, A>
where
DB: Database,
@ -360,6 +379,7 @@ where
database: PhantomData,
arguments: Some(Default::default()),
query: sql,
persistent: true,
}
}
@ -374,6 +394,7 @@ where
database: PhantomData,
arguments: Some(arguments),
query: sql,
persistent: true,
}
}

View File

@ -38,7 +38,7 @@ where
#[inline]
fn persistent(&self) -> bool {
self.inner.arguments.is_some()
self.inner.persistent()
}
}

View File

@ -109,6 +109,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
{
let s = query.query();
let arguments = query.take_arguments();
let persistent = query.persistent() && arguments.is_some();
Box::pin(try_stream! {
let SqliteConnection {
@ -121,7 +122,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
} = self;
// prepare statement object (or checkout from cache)
let mut stmt = prepare(conn, statements, statement, s, arguments.is_some())?;
let mut stmt = prepare(conn, statements, statement, s, persistent)?;
// bind arguments, if any, to the statement
bind(&mut stmt, arguments)?;

View File

@ -1,5 +1,5 @@
use futures::TryStreamExt;
use sqlx::mysql::{MySql, MySqlPool, MySqlPoolOptions, MySqlRow};
use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow};
use sqlx::{Connection, Done, Executor, Row};
use sqlx_test::{new, setup_if_needed};
use std::env;
@ -207,6 +207,7 @@ async fn it_caches_statements() -> anyhow::Result<()> {
for i in 0..2 {
let row = sqlx::query("SELECT ? AS val")
.bind(i)
.persistent(true)
.fetch_one(&mut conn)
.await?;
@ -219,6 +220,20 @@ async fn it_caches_statements() -> anyhow::Result<()> {
conn.clear_cached_statements().await?;
assert_eq!(0, conn.cached_statements_size());
for i in 0..2 {
let row = sqlx::query("SELECT ? AS val")
.bind(i)
.persistent(false)
.fetch_one(&mut conn)
.await?;
let val: u32 = row.get("val");
assert_eq!(i, val);
}
assert_eq!(0, conn.cached_statements_size());
Ok(())
}

View File

@ -3,7 +3,7 @@ use sqlx::postgres::{
PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity,
};
use sqlx::postgres::{PgPoolOptions, PgRow, Postgres};
use sqlx::{Connection, Done, Executor, PgPool, Row};
use sqlx::{Connection, Done, Executor, Row};
use sqlx_test::{new, setup_if_needed};
use std::env;
use std::thread;
@ -28,7 +28,7 @@ async fn it_can_select_void() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
// pg_notify just happens to be a function that returns void
let _value: () = sqlx::query_scalar("select pg_notify('chan', 'message');")
let _: () = sqlx::query_scalar("select pg_notify('chan', 'message');")
.fetch_one(&mut conn)
.await?;
@ -132,12 +132,12 @@ CREATE TEMPORARY TABLE json_stuff (obj json);
let query = "INSERT INTO json_stuff (obj) VALUES ($1)";
let _ = conn.describe(query).await?;
let cnt = sqlx::query(query)
let done = sqlx::query(query)
.bind(serde_json::json!({ "a": "a" }))
.execute(&mut conn)
.await?;
assert_eq!(cnt, 1);
assert_eq!(done.rows_affected(), 1);
Ok(())
}
@ -563,6 +563,7 @@ async fn it_caches_statements() -> anyhow::Result<()> {
for i in 0..2 {
let row = sqlx::query("SELECT $1 AS val")
.bind(i)
.persistent(true)
.fetch_one(&mut conn)
.await?;
@ -575,6 +576,20 @@ async fn it_caches_statements() -> anyhow::Result<()> {
conn.clear_cached_statements().await?;
assert_eq!(0, conn.cached_statements_size());
for i in 0..2 {
let row = sqlx::query("SELECT $1 AS val")
.bind(i)
.persistent(false)
.fetch_one(&mut conn)
.await?;
let val: u32 = row.get("val");
assert_eq!(i, val);
}
assert_eq!(0, conn.cached_statements_size());
Ok(())
}

View File

@ -374,6 +374,7 @@ async fn it_caches_statements() -> anyhow::Result<()> {
for i in 0..2 {
let row = sqlx::query("SELECT ? AS val")
.bind(i)
.persistent(true)
.fetch_one(&mut conn)
.await?;
@ -386,5 +387,21 @@ async fn it_caches_statements() -> anyhow::Result<()> {
conn.clear_cached_statements().await?;
assert_eq!(0, conn.cached_statements_size());
let mut conn = new::<Sqlite>().await?;
for i in 0..2 {
let row = sqlx::query("SELECT ? AS val")
.bind(i)
.persistent(false)
.fetch_one(&mut conn)
.await?;
let val: i32 = row.get("val");
assert_eq!(i, val);
}
assert_eq!(0, conn.cached_statements_size());
Ok(())
}