diff --git a/sqlx-core/src/cache.rs b/sqlx-core/src/cache.rs new file mode 100644 index 00000000..716ce52f --- /dev/null +++ b/sqlx-core/src/cache.rs @@ -0,0 +1,46 @@ +use std::collections::hash_map::{HashMap, Entry}; +use bitflags::_core::cmp::Ordering; +use futures_core::Future; + +pub struct StatementCache { + statements: HashMap +} + +impl StatementCache { + pub fn new() -> Self { + StatementCache { + statements: HashMap::with_capacity(10), + } + } + + #[cfg(feature = "mariadb")] + pub async fn get_or_compute<'a, E, Fut>(&'a mut self, query: &str, compute: impl FnOnce() -> Fut) + -> Result<&'a Id, E> + where + Fut: Future> + { + match self.statements.entry(query.to_string()) { + Entry::Occupied(occupied) => Ok(occupied.into_mut()), + Entry::Vacant(vacant) => { + Ok(vacant.insert(compute().await?)) + } + } + } + + // for Postgres so it can return the synthetic statement name instead of formatting twice + #[cfg(feature = "postgres")] + pub async fn map_or_compute(&mut self, query: &str, map: impl FnOnce(&Id) -> R, compute: impl FnOnce() -> Fut) + -> Result + where + Fut: Future> { + + match self.statements.entry(query.to_string()) { + Entry::Occupied(occupied) => Ok(map(occupied.get())), + Entry::Vacant(vacant) => { + let (id, ret) = compute().await?; + vacant.insert(id); + Ok(ret) + } + } + } +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 3fa38184..31c0ce5b 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -33,6 +33,8 @@ pub mod types; mod describe; +mod cache; + #[doc(inline)] pub use self::{ backend::Backend, diff --git a/sqlx-core/src/postgres/backend.rs b/sqlx-core/src/postgres/backend.rs index b6587d1c..49d8d2a2 100644 --- a/sqlx-core/src/postgres/backend.rs +++ b/sqlx-core/src/postgres/backend.rs @@ -1,4 +1,4 @@ -use super::{connection::Step, Postgres}; +use super::{connection::{PostgresConn, Step}, Postgres}; use crate::{ backend::Backend, describe::{Describe, ResultField}, @@ -7,6 +7,7 @@ use crate::{ url::Url, }; use futures_core::{future::BoxFuture, stream::BoxStream}; +use crate::cache::StatementCache; impl Backend for Postgres { type QueryParameters = PostgresQueryParameters; @@ -21,7 +22,7 @@ impl Backend for Postgres { Box::pin(async move { let url = url?; let address = url.resolve(5432); - let mut conn = Self::new(address).await?; + let mut conn = PostgresConn::new(address).await?; conn.startup( url.username(), @@ -30,12 +31,16 @@ impl Backend for Postgres { ) .await?; - Ok(conn) + Ok(Postgres { + conn, + statements: StatementCache::new(), + next_id: 0 + }) }) } fn close(self) -> BoxFuture<'static, crate::Result<()>> { - Box::pin(self.terminate()) + Box::pin(self.conn.terminate()) } } diff --git a/sqlx-core/src/postgres/connection.rs b/sqlx-core/src/postgres/connection.rs index 228d98a2..d3d56273 100644 --- a/sqlx-core/src/postgres/connection.rs +++ b/sqlx-core/src/postgres/connection.rs @@ -13,7 +13,7 @@ use std::{ net::{Shutdown, SocketAddr}, }; -pub struct Postgres { +pub struct PostgresConn { stream: BufStream, // Process ID of the Backend @@ -34,7 +34,7 @@ pub struct Postgres { // [ ] 52.2.9. SSL Session Encryption // [ ] 52.2.10. GSSAPI Session Encryption -impl Postgres { +impl PostgresConn { pub(super) async fn new(address: SocketAddr) -> crate::Result { let stream = TcpStream::connect(&address).await?; @@ -139,7 +139,7 @@ impl Postgres { Ok(()) } - pub(super) fn parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) { + pub(super) fn buffer_parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) { protocol::Parse { statement, query, @@ -148,6 +148,13 @@ impl Postgres { .encode(self.stream.buffer_mut()); } + pub(super) async fn try_parse(&mut self, statement: &str, query: &str, params: &PostgresQueryParameters) -> crate::Result<()> { + self.buffer_parse(statement, query, params); + self.sync().await?; + while let Some(_) = self.step().await? {} + Ok(()) + } + pub(super) fn describe(&mut self, statement: &str) { protocol::Describe { kind: protocol::DescribeKind::PreparedStatement, diff --git a/sqlx-core/src/postgres/executor.rs b/sqlx-core/src/postgres/executor.rs index 54fe4735..871c6522 100644 --- a/sqlx-core/src/postgres/executor.rs +++ b/sqlx-core/src/postgres/executor.rs @@ -10,6 +10,28 @@ use crate::{ use futures_core::{future::BoxFuture, stream::BoxStream}; use crate::postgres::query::PostgresQueryParameters; +impl Postgres { + async fn prepare_cached(&mut self, query: &str, params: &PostgresQueryParameters) -> crate::Result { + fn get_stmt_name(id: u64) -> String { + format!("sqlx_postgres_stmt_{}", id) + } + + let conn = &mut self.conn; + let next_id = &mut self.next_id; + + self.statements.map_or_compute( + query, + |&id| get_stmt_name(id), + || async { + let stmt_id = *next_id; + let stmt_name = get_stmt_name(stmt_id); + conn.try_parse(&stmt_name, query, params).await?; + *next_id += 1; + Ok((stmt_id, stmt_name)) + }).await + } +} + impl Executor for Postgres { type Backend = Self; @@ -19,14 +41,15 @@ impl Executor for Postgres { params: PostgresQueryParameters, ) -> BoxFuture<'e, crate::Result> { Box::pin(async move { - self.parse("", query, ¶ms); - self.bind("", "", ¶ms); - self.execute("", 1); - self.sync().await?; + let stmt = self.prepare_cached(query, ¶ms).await?; + + self.conn.bind("", &stmt, ¶ms); + self.conn.execute("", 1); + self.conn.sync().await?; let mut affected = 0; - while let Some(step) = self.step().await? { + while let Some(step) = self.conn.step().await? { if let Step::Command(cnt) = step { affected = cnt; } @@ -41,17 +64,16 @@ impl Executor for Postgres { query: &'q str, params: PostgresQueryParameters, ) -> BoxStream<'e, crate::Result> - where - T: FromRow + Send + Unpin, + where + T: FromRow + Send + Unpin, { - self.parse("", query, ¶ms); - self.bind("", "", ¶ms); - self.execute("", 0); - Box::pin(async_stream::try_stream! { - self.sync().await?; + let stmt = self.prepare_cached(query, ¶ms).await?; + self.conn.bind("", &stmt, ¶ms); + self.conn.execute("", 0); + self.conn.sync().await?; - while let Some(step) = self.step().await? { + while let Some(step) = self.conn.step().await? { if let Step::Row(row) = step { yield FromRow::from_row(row); } @@ -64,18 +86,18 @@ impl Executor for Postgres { query: &'q str, params: PostgresQueryParameters, ) -> BoxFuture<'e, crate::Result>> - where - T: FromRow + Send, + where + T: FromRow + Send, { Box::pin(async move { - self.parse("", query, ¶ms); - self.bind("", "", ¶ms); - self.execute("", 2); - self.sync().await?; + let stmt = self.prepare_cached(query, ¶ms).await?; + self.conn.bind("", &stmt, ¶ms); + self.conn.execute("", 2); + self.conn.sync().await?; let mut row: Option<_> = None; - while let Some(step) = self.step().await? { + while let Some(step) = self.conn.step().await? { if let Step::Row(r) = step { if row.is_some() { return Err(crate::Error::FoundMoreThanOne); @@ -94,13 +116,13 @@ impl Executor for Postgres { query: &'q str, ) -> BoxFuture<'e, crate::Result>> { Box::pin(async move { - self.parse("", query, &Default::default()); - self.describe(""); - self.sync().await?; + let stmt = self.prepare_cached(query, &PostgresQueryParameters::default()).await?; + self.conn.describe(&stmt); + self.conn.sync().await?; let param_desc = loop { let step = self - .step() + .conn.step() .await? .ok_or(protocol_err!("did not receive ParameterDescription")); @@ -111,7 +133,7 @@ impl Executor for Postgres { let row_desc = loop { let step = self - .step() + .conn.step() .await? .ok_or(protocol_err!("did not receive RowDescription")); diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index 14cf93ec..1a2453a7 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -1,3 +1,6 @@ +use crate::postgres::connection::PostgresConn; +use crate::cache::StatementCache; + mod backend; mod connection; mod error; @@ -13,4 +16,8 @@ pub mod protocol; pub mod types; -pub use self::connection::Postgres; +pub struct Postgres { + conn: PostgresConn, + statements: StatementCache, + next_id: u64, +} diff --git a/tests/postgres-types.rs b/tests/postgres-types.rs index f28522d4..e41e3b29 100644 --- a/tests/postgres-types.rs +++ b/tests/postgres-types.rs @@ -3,17 +3,18 @@ use sqlx::{Connection, Postgres, Row}; macro_rules! test { ($name:ident: $ty:ty: $($text:literal == $value:expr),+) => { #[async_std::test] - async fn $name () -> sqlx::Result<()> { + async fn $name () -> Result<(), String> { let mut conn = Connection::::open( &dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set") - ).await?; + ).await.map_err(|e| format!("failed to connect to Postgres: {}", e))?; $( let row = sqlx::query(&format!("SELECT {} = $1, $1", $text)) .bind($value) .fetch_one(&mut conn) - .await?; + .await + .map_err(|e| format!("failed to run query: {}", e))?; assert!(row.get::(0)); assert!($value == row.get::<$ty>(1));