diff --git a/Cargo.lock b/Cargo.lock index 5c464105..adb5504c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1205,6 +1205,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dd5a6d5999d9907cda8ed67bbd137d3af8085216c2ac62de5be860bd41f304a" + [[package]] name = "lock_api" version = "0.3.4" @@ -1234,6 +1240,15 @@ dependencies = [ "scoped-tls 0.1.2", ] +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2281,6 +2296,7 @@ dependencies = [ "libc", "libsqlite3-sys", "log", + "lru-cache", "md-5", "memchr", "num-bigint", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 0a5a692f..d8219763 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -84,3 +84,4 @@ url = { version = "2.1.1", default-features = false } uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] } whoami = "0.8.1" stringprep = "0.1.2" +lru-cache = "0.1.2" diff --git a/sqlx-core/src/caching_connection.rs b/sqlx-core/src/caching_connection.rs new file mode 100644 index 00000000..88482514 --- /dev/null +++ b/sqlx-core/src/caching_connection.rs @@ -0,0 +1,13 @@ +use futures_core::future::BoxFuture; + +use crate::error::Error; + +/// A connection that is capable of caching prepared statements. +pub trait CachingConnection: Send { + /// The number of statements currently cached in the connection. + fn cached_statements_count(&self) -> usize; + + /// Removes all statements from the cache, closing them on the server if + /// needed. + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>>; +} diff --git a/sqlx-core/src/common/mod.rs b/sqlx-core/src/common/mod.rs new file mode 100644 index 00000000..f9698f28 --- /dev/null +++ b/sqlx-core/src/common/mod.rs @@ -0,0 +1,3 @@ +mod statement_cache; + +pub(crate) use statement_cache::StatementCache; diff --git a/sqlx-core/src/common/statement_cache.rs b/sqlx-core/src/common/statement_cache.rs new file mode 100644 index 00000000..e87785eb --- /dev/null +++ b/sqlx-core/src/common/statement_cache.rs @@ -0,0 +1,50 @@ +use lru_cache::LruCache; + +/// A cache for prepared statements. When full, the least recently used +/// statement gets removed. +#[derive(Debug)] +pub struct StatementCache { + inner: LruCache, +} + +impl StatementCache { + /// Create a new cache with the given capacity. + pub fn new(capacity: usize) -> Self { + Self { + inner: LruCache::new(capacity), + } + } + + /// Returns a mutable reference to the value corresponding to the given key + /// in the cache, if any. + pub fn get_mut(&mut self, k: &str) -> Option<&mut u32> { + self.inner.get_mut(k) + } + + /// Inserts a new statement to the cache, returning the least recently used + /// statement id if the cache is full, or if inserting with an existing key, + /// the replaced existing statement. + pub fn insert(&mut self, k: &str, v: u32) -> Option { + let mut lru_item = None; + + if self.inner.capacity() == self.len() && !self.inner.contains_key(k) { + lru_item = self.remove_lru(); + } else if self.inner.contains_key(k) { + lru_item = self.inner.remove(k); + } + + self.inner.insert(k.into(), v); + + lru_item + } + + /// The number of statements in the cache. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Removes the least recently used item from the cache. + pub fn remove_lru(&mut self) -> Option { + self.inner.remove_lru().map(|(_, v)| v) + } +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index cda49dbf..3785c2bd 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -29,6 +29,7 @@ pub mod arguments; #[macro_use] pub mod pool; +pub mod caching_connection; pub mod connection; #[macro_use] @@ -37,6 +38,7 @@ pub mod transaction; #[macro_use] pub mod encode; +mod common; pub mod database; pub mod decode; pub mod describe; diff --git a/sqlx-core/src/mysql/connection/establish.rs b/sqlx-core/src/mysql/connection/establish.rs index 34f6bf6c..31745894 100644 --- a/sqlx-core/src/mysql/connection/establish.rs +++ b/sqlx-core/src/mysql/connection/establish.rs @@ -1,6 +1,6 @@ use bytes::Bytes; -use hashbrown::HashMap; +use crate::common::StatementCache; use crate::error::Error; use crate::mysql::connection::{tls, MySqlStream, COLLATE_UTF8MB4_UNICODE_CI, MAX_PACKET_SIZE}; use crate::mysql::protocol::connect::{ @@ -98,7 +98,7 @@ impl MySqlConnection { Ok(Self { stream, - cache_statement: HashMap::new(), + cache_statement: StatementCache::new(options.statement_cache_size), scratch_row_columns: Default::default(), scratch_row_column_names: Default::default(), }) diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 5110f3c3..7b28f360 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -15,7 +15,7 @@ use crate::mysql::connection::stream::Busy; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::Status; use crate::mysql::protocol::statement::{ - BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, + BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose, }; use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow}; use crate::mysql::protocol::Packet; @@ -26,8 +26,8 @@ use crate::mysql::{ impl MySqlConnection { async fn prepare(&mut self, query: &str) -> Result { - if let Some(&statement) = self.cache_statement.get(query) { - return Ok(statement); + if let Some(statement) = self.cache_statement.get_mut(query) { + return Ok(*statement); } // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html @@ -60,8 +60,10 @@ impl MySqlConnection { self.stream.maybe_recv_eof().await?; } - self.cache_statement - .insert(query.to_owned(), ok.statement_id); + // in case of the cache being full, close the least recently used statement + if let Some(statement) = self.cache_statement.insert(query, ok.statement_id) { + self.stream.send_packet(StmtClose { statement }).await?; + } Ok(ok.statement_id) } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index a3d1d35f..0711eae4 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -6,10 +6,13 @@ use futures_core::future::BoxFuture; use futures_util::FutureExt; use hashbrown::HashMap; +use crate::caching_connection::CachingConnection; +use crate::common::StatementCache; use crate::connection::{Connect, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::ext::ustr::UStr; +use crate::mysql::protocol::statement::StmtClose; use crate::mysql::protocol::text::{Ping, Quit}; use crate::mysql::row::MySqlColumn; use crate::mysql::{MySql, MySqlConnectOptions}; @@ -34,7 +37,7 @@ pub struct MySqlConnection { pub(crate) stream: MySqlStream, // cache by query string to the statement id - cache_statement: HashMap, + cache_statement: StatementCache, // working memory for the active row's column information // this allows us to re-use these allocations unless the user is persisting the @@ -43,6 +46,22 @@ pub struct MySqlConnection { scratch_row_column_names: Arc>, } +impl CachingConnection for MySqlConnection { + fn cached_statements_count(&self) -> usize { + self.cache_statement.len() + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + while let Some(statement) = self.cache_statement.remove_lru() { + self.stream.send_packet(StmtClose { statement }).await?; + } + + Ok(()) + }) + } +} + impl Debug for MySqlConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("MySqlConnection").finish() diff --git a/sqlx-core/src/mysql/options.rs b/sqlx-core/src/mysql/options.rs index 3720f444..339a9605 100644 --- a/sqlx-core/src/mysql/options.rs +++ b/sqlx-core/src/mysql/options.rs @@ -101,6 +101,7 @@ pub struct MySqlConnectOptions { pub(crate) database: Option, pub(crate) ssl_mode: MySqlSslMode, pub(crate) ssl_ca: Option, + pub(crate) statement_cache_size: usize, } impl Default for MySqlConnectOptions { @@ -120,6 +121,7 @@ impl MySqlConnectOptions { database: None, ssl_mode: MySqlSslMode::Preferred, ssl_ca: None, + statement_cache_size: 100, } } @@ -190,6 +192,17 @@ impl MySqlConnectOptions { self.ssl_ca = Some(file_name.as_ref().to_owned()); self } + + /// Sets the size of the connection's statement cache in a number of stored + /// distinct statements. Caching is handled using LRU, meaning when the + /// amount of queries hits the defined limit, the oldest statement will get + /// dropped. + /// + /// The default cache size is 100 statements. + pub fn statement_cache_size(mut self, size: usize) -> Self { + self.statement_cache_size = size; + self + } } impl FromStr for MySqlConnectOptions { @@ -231,6 +244,10 @@ impl FromStr for MySqlConnectOptions { options = options.ssl_ca(&*value); } + "statement-cache-size" => { + options = options.statement_cache_size(value.parse()?); + } + _ => {} } } diff --git a/sqlx-core/src/mysql/protocol/statement/mod.rs b/sqlx-core/src/mysql/protocol/statement/mod.rs index 5ad292f5..9ae6b3c9 100644 --- a/sqlx-core/src/mysql/protocol/statement/mod.rs +++ b/sqlx-core/src/mysql/protocol/statement/mod.rs @@ -2,8 +2,10 @@ mod execute; mod prepare; mod prepare_ok; mod row; +mod stmt_close; pub(crate) use execute::Execute; pub(crate) use prepare::Prepare; pub(crate) use prepare_ok::PrepareOk; pub(crate) use row::BinaryRow; +pub(crate) use stmt_close::StmtClose; diff --git a/sqlx-core/src/mysql/protocol/statement/stmt_close.rs b/sqlx-core/src/mysql/protocol/statement/stmt_close.rs new file mode 100644 index 00000000..13f095f9 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/stmt_close.rs @@ -0,0 +1,16 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-stmt-close.html + +#[derive(Debug)] +pub struct StmtClose { + pub statement: u32, +} + +impl Encode<'_, Capabilities> for StmtClose { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x19); // COM_STMT_CLOSE + buf.extend(&self.statement.to_le_bytes()); + } +} diff --git a/src/lib.rs b/src/lib.rs index a11f3ce3..f05cc6dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(docsrs, feature(doc_cfg))] pub use sqlx_core::arguments::{Arguments, IntoArguments}; +pub use sqlx_core::caching_connection::CachingConnection; pub use sqlx_core::connection::{Connect, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::executor::{Execute, Executor}; diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index dfdf0cf5..95a1815b 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -1,6 +1,6 @@ use futures::TryStreamExt; use sqlx::mysql::{MySql, MySqlPool, MySqlRow}; -use sqlx::{Connection, Executor, Row}; +use sqlx::{CachingConnection, Connection, Executor, Row}; use sqlx_test::new; #[sqlx_macros::test] @@ -177,3 +177,25 @@ SELECT id, text FROM messages; Ok(()) } + +#[sqlx_macros::test] +async fn it_caches_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + for i in 0..2 { + let row = sqlx::query("SELECT ? AS val") + .bind(i) + .fetch_one(&mut conn) + .await?; + + let val: u32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(1, conn.cached_statements_count()); + conn.clear_cached_statements().await?; + assert_eq!(0, conn.cached_statements_count()); + + Ok(()) +}