LRU caching for PostgreSQL

This commit is contained in:
Julius de Bruijn 2020-06-24 18:59:39 +02:00
parent 2b6f242a22
commit 5d64310004
6 changed files with 66 additions and 4 deletions

View File

@ -47,4 +47,9 @@ impl StatementCache {
pub fn remove_lru(&mut self) -> Option<u32> {
self.inner.remove_lru().map(|(_, v)| v)
}
/// Clear all cached statements from the cache.
pub fn clear(&mut self) {
self.inner.clear();
}
}

View File

@ -1,5 +1,6 @@
use hashbrown::HashMap;
use crate::common::StatementCache;
use crate::error::Error;
use crate::io::Decode;
use crate::postgres::connection::{sasl, stream::PgStream, tls};
@ -138,7 +139,7 @@ impl PgConnection {
transaction_status,
pending_ready_for_query_count: 0,
next_statement_id: 1,
cache_statement: HashMap::with_capacity(10),
cache_statement: StatementCache::new(options.statement_cache_size),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),
scratch_row_columns: Default::default(),

View File

@ -88,15 +88,16 @@ async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription
Ok(rows)
}
impl PgConnection {
async fn prepare(&mut self, query: &str, arguments: &PgArguments) -> Result<u32, Error> {
if let Some(statement) = self.cache_statement.get(query) {
if let Some(statement) = self.cache_statement.get_mut(query) {
return Ok(*statement);
}
let statement = prepare(self, query, arguments).await?;
self.cache_statement.insert(query.to_owned(), statement);
self.cache_statement.insert(query, statement);
Ok(statement)
}

View File

@ -5,6 +5,8 @@ use futures_core::future::BoxFuture;
use futures_util::{FutureExt, TryFutureExt};
use hashbrown::HashMap;
use crate::caching_connection::CachingConnection;
use crate::common::StatementCache;
use crate::connection::{Connect, Connection};
use crate::error::Error;
use crate::executor::Executor;
@ -46,7 +48,7 @@ pub struct PgConnection {
next_statement_id: u32,
// cache statement by query string to the id and columns
cache_statement: HashMap<String, u32>,
cache_statement: StatementCache,
// cache user-defined types by id <-> info
cache_type_info: HashMap<u32, PgTypeInfo>,
@ -96,6 +98,19 @@ impl Debug for PgConnection {
}
}
impl CachingConnection for PgConnection {
fn cached_statements_count(&self) -> usize {
self.cache_statement.len()
}
fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
self.cache_statement.clear();
Ok(())
})
}
}
impl Connection for PgConnection {
type Database = Postgres;

View File

@ -115,6 +115,7 @@ pub struct PgConnectOptions {
pub(crate) database: Option<String>,
pub(crate) ssl_mode: PgSslMode,
pub(crate) ssl_root_cert: Option<PathBuf>,
pub(crate) statement_cache_size: usize,
}
impl Default for PgConnectOptions {
@ -162,6 +163,7 @@ impl PgConnectOptions {
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_default(),
statement_cache_size: 100,
}
}
@ -285,6 +287,17 @@ impl PgConnectOptions {
self.ssl_root_cert = Some(cert.as_ref().to_path_buf());
self
}
/// Sets the size of the connection's statement cache in a number of stored
/// distinct statements. Caching is handled using LRU, meaning when the
/// amount of queries hits the defined limit, the oldest statement will get
/// dropped.
///
/// The default cache size is 100 statements.
pub fn statement_cache_size(mut self, size: usize) -> Self {
self.statement_cache_size = size;
self
}
}
fn default_host(port: u16) -> String {
@ -345,6 +358,10 @@ impl FromStr for PgConnectOptions {
options = options.ssl_root_cert(&*value);
}
"statement-cache-size" => {
options = options.statement_cache_size(value.parse()?);
}
_ => {}
}
}

View File

@ -1,6 +1,7 @@
use futures::TryStreamExt;
use sqlx::postgres::PgRow;
use sqlx::postgres::{PgDatabaseError, PgErrorPosition, PgSeverity};
use sqlx::CachingConnection;
use sqlx::{postgres::Postgres, Connection, Executor, PgPool, Row};
use sqlx_test::new;
use std::time::Duration;
@ -487,3 +488,25 @@ SELECT id, text FROM _sqlx_test_postgres_5112;
Ok(())
}
#[sqlx_macros::test]
async fn it_caches_statements() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
for i in 0..2 {
let row = sqlx::query("SELECT $1 AS val")
.bind(i)
.fetch_one(&mut conn)
.await?;
let val: u32 = row.get("val");
assert_eq!(i, val);
}
assert_eq!(1, conn.cached_statements_count());
conn.clear_cached_statements().await?;
assert_eq!(0, conn.cached_statements_count());
Ok(())
}