implement prepared statement caching for postgres

This commit is contained in:
Austin Bonander 2019-12-12 18:40:54 -08:00 committed by Ryan Leckey
parent 0865549175
commit d8d93867b7
7 changed files with 126 additions and 36 deletions

46
sqlx-core/src/cache.rs Normal file
View File

@ -0,0 +1,46 @@
use std::collections::hash_map::{HashMap, Entry};
use bitflags::_core::cmp::Ordering;
use futures_core::Future;
pub struct StatementCache<Id> {
statements: HashMap<String, Id>
}
impl<Id> StatementCache<Id> {
pub fn new() -> Self {
StatementCache {
statements: HashMap::with_capacity(10),
}
}
#[cfg(feature = "mariadb")]
pub async fn get_or_compute<'a, E, Fut>(&'a mut self, query: &str, compute: impl FnOnce() -> Fut)
-> Result<&'a Id, E>
where
Fut: Future<Output = Result<Id, E>>
{
match self.statements.entry(query.to_string()) {
Entry::Occupied(occupied) => Ok(occupied.into_mut()),
Entry::Vacant(vacant) => {
Ok(vacant.insert(compute().await?))
}
}
}
// for Postgres so it can return the synthetic statement name instead of formatting twice
#[cfg(feature = "postgres")]
pub async fn map_or_compute<R, E, Fut>(&mut self, query: &str, map: impl FnOnce(&Id) -> R, compute: impl FnOnce() -> Fut)
-> Result<R, E>
where
Fut: Future<Output = Result<(Id, R), E>> {
match self.statements.entry(query.to_string()) {
Entry::Occupied(occupied) => Ok(map(occupied.get())),
Entry::Vacant(vacant) => {
let (id, ret) = compute().await?;
vacant.insert(id);
Ok(ret)
}
}
}
}

View File

@ -33,6 +33,8 @@ pub mod types;
mod describe;
mod cache;
#[doc(inline)]
pub use self::{
backend::Backend,

View File

@ -1,4 +1,4 @@
use super::{connection::Step, Postgres};
use super::{connection::{PostgresConn, Step}, Postgres};
use crate::{
backend::Backend,
describe::{Describe, ResultField},
@ -7,6 +7,7 @@ use crate::{
url::Url,
};
use futures_core::{future::BoxFuture, stream::BoxStream};
use crate::cache::StatementCache;
impl Backend for Postgres {
type QueryParameters = PostgresQueryParameters;
@ -21,7 +22,7 @@ impl Backend for Postgres {
Box::pin(async move {
let url = url?;
let address = url.resolve(5432);
let mut conn = Self::new(address).await?;
let mut conn = PostgresConn::new(address).await?;
conn.startup(
url.username(),
@ -30,12 +31,16 @@ impl Backend for Postgres {
)
.await?;
Ok(conn)
Ok(Postgres {
conn,
statements: StatementCache::new(),
next_id: 0
})
})
}
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(self.terminate())
Box::pin(self.conn.terminate())
}
}

View File

@ -13,7 +13,7 @@ use std::{
net::{Shutdown, SocketAddr},
};
pub struct Postgres {
pub struct PostgresConn {
stream: BufStream<TcpStream>,
// Process ID of the Backend
@ -34,7 +34,7 @@ pub struct Postgres {
// [ ] 52.2.9. SSL Session Encryption
// [ ] 52.2.10. GSSAPI Session Encryption
impl Postgres {
impl PostgresConn {
pub(super) async fn new(address: SocketAddr) -> crate::Result<Self> {
let stream = TcpStream::connect(&address).await?;
@ -139,7 +139,7 @@ impl Postgres {
Ok(())
}
pub(super) fn parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) {
pub(super) fn buffer_parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) {
protocol::Parse {
statement,
query,
@ -148,6 +148,13 @@ impl Postgres {
.encode(self.stream.buffer_mut());
}
pub(super) async fn try_parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) -> crate::Result<()> {
self.buffer_parse(statement, query, params);
self.sync().await?;
while let Some(_) = self.step().await? {}
Ok(())
}
pub(super) fn describe(&mut self, statement: &str) {
protocol::Describe {
kind: protocol::DescribeKind::PreparedStatement,

View File

@ -10,6 +10,28 @@ use crate::{
use futures_core::{future::BoxFuture, stream::BoxStream};
use crate::postgres::query::PostgresQueryParameters;
impl Postgres {
async fn prepare_cached(&mut self, query: &str, params: &PostgresQueryParameters) -> crate::Result<String> {
fn get_stmt_name(id: u64) -> String {
format!("sqlx_postgres_stmt_{}", id)
}
let conn = &mut self.conn;
let next_id = &mut self.next_id;
self.statements.map_or_compute(
query,
|&id| get_stmt_name(id),
|| async {
let stmt_id = *next_id;
let stmt_name = get_stmt_name(stmt_id);
conn.try_parse(&stmt_name, query, params).await?;
*next_id += 1;
Ok((stmt_id, stmt_name))
}).await
}
}
impl Executor for Postgres {
type Backend = Self;
@ -19,14 +41,15 @@ impl Executor for Postgres {
params: PostgresQueryParameters,
) -> BoxFuture<'e, crate::Result<u64>> {
Box::pin(async move {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 1);
self.sync().await?;
let stmt = self.prepare_cached(query, &params).await?;
self.conn.bind("", &stmt, &params);
self.conn.execute("", 1);
self.conn.sync().await?;
let mut affected = 0;
while let Some(step) = self.step().await? {
while let Some(step) = self.conn.step().await? {
if let Step::Command(cnt) = step {
affected = cnt;
}
@ -41,17 +64,16 @@ impl Executor for Postgres {
query: &'q str,
params: PostgresQueryParameters,
) -> BoxStream<'e, crate::Result<T>>
where
T: FromRow<Self::Backend> + Send + Unpin,
where
T: FromRow<Self::Backend> + Send + Unpin,
{
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 0);
Box::pin(async_stream::try_stream! {
self.sync().await?;
let stmt = self.prepare_cached(query, &params).await?;
self.conn.bind("", &stmt, &params);
self.conn.execute("", 0);
self.conn.sync().await?;
while let Some(step) = self.step().await? {
while let Some(step) = self.conn.step().await? {
if let Step::Row(row) = step {
yield FromRow::from_row(row);
}
@ -64,18 +86,18 @@ impl Executor for Postgres {
query: &'q str,
params: PostgresQueryParameters,
) -> BoxFuture<'e, crate::Result<Option<T>>>
where
T: FromRow<Self::Backend> + Send,
where
T: FromRow<Self::Backend> + Send,
{
Box::pin(async move {
self.parse("", query, &params);
self.bind("", "", &params);
self.execute("", 2);
self.sync().await?;
let stmt = self.prepare_cached(query, &params).await?;
self.conn.bind("", &stmt, &params);
self.conn.execute("", 2);
self.conn.sync().await?;
let mut row: Option<_> = None;
while let Some(step) = self.step().await? {
while let Some(step) = self.conn.step().await? {
if let Step::Row(r) = step {
if row.is_some() {
return Err(crate::Error::FoundMoreThanOne);
@ -94,13 +116,13 @@ impl Executor for Postgres {
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Backend>>> {
Box::pin(async move {
self.parse("", query, &Default::default());
self.describe("");
self.sync().await?;
let stmt = self.prepare_cached(query, &PostgresQueryParameters::default()).await?;
self.conn.describe(&stmt);
self.conn.sync().await?;
let param_desc = loop {
let step = self
.step()
.conn.step()
.await?
.ok_or(protocol_err!("did not receive ParameterDescription"));
@ -111,7 +133,7 @@ impl Executor for Postgres {
let row_desc = loop {
let step = self
.step()
.conn.step()
.await?
.ok_or(protocol_err!("did not receive RowDescription"));

View File

@ -1,3 +1,6 @@
use crate::postgres::connection::PostgresConn;
use crate::cache::StatementCache;
mod backend;
mod connection;
mod error;
@ -13,4 +16,8 @@ pub mod protocol;
pub mod types;
pub use self::connection::Postgres;
pub struct Postgres {
conn: PostgresConn,
statements: StatementCache<u64>,
next_id: u64,
}

View File

@ -3,17 +3,18 @@ use sqlx::{Connection, Postgres, Row};
macro_rules! test {
($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
#[async_std::test]
async fn $name () -> sqlx::Result<()> {
async fn $name () -> Result<(), String> {
let mut conn =
Connection::<Postgres>::open(
&dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set")
).await?;
).await.map_err(|e| format!("failed to connect to Postgres: {}", e))?;
$(
let row = sqlx::query(&format!("SELECT {} = $1, $1", $text))
.bind($value)
.fetch_one(&mut conn)
.await?;
.await
.map_err(|e| format!("failed to run query: {}", e))?;
assert!(row.get::<bool>(0));
assert!($value == row.get::<$ty>(1));