From a2691b9635873e06a3d65b5e7e725506e762e5be Mon Sep 17 00:00:00 2001 From: LovecraftianHorror Date: Wed, 25 May 2022 19:22:09 -0600 Subject: [PATCH] Reuse a cached DB connection instead of always recreating for `sqlx-macros` (#1782) * refactor: Reuse a cached connection instead of always recreating for `sqlx-macros` * fix: Fix type inference issue when no database features used * refactor: Switch cached db conn to an `AnyConnection` * fix: Fix invalid variant name only exposed with features * fix: Tweak connection options for SQLite with `sqlx-macros` * fix: Remove read only option for SQLite connection * fix: Fix feature flags regarding usage of `sqlx_core::any` --- sqlx-core/src/any/connection/mod.rs | 10 +- sqlx-core/src/any/mod.rs | 3 + sqlx-core/src/pool/mod.rs | 21 ++++- sqlx-macros/Cargo.toml | 2 +- sqlx-macros/src/query/mod.rs | 138 +++++++++++++++++----------- sqlx-rt/src/lib.rs | 5 +- 6 files changed, 119 insertions(+), 60 deletions(-) diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index cabe7819..6f8d1e38 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -33,7 +33,9 @@ mod executor; pub struct AnyConnection(pub(super) AnyConnectionKind); #[derive(Debug)] -pub(crate) enum AnyConnectionKind { +// Used internally in `sqlx-macros` +#[doc(hidden)] +pub enum AnyConnectionKind { #[cfg(feature = "postgres")] Postgres(postgres::PgConnection), @@ -69,6 +71,12 @@ impl AnyConnection { pub fn kind(&self) -> AnyKind { self.0.kind() } + + // Used internally in `sqlx-macros` + #[doc(hidden)] + pub fn private_get_mut(&mut self) -> &mut AnyConnectionKind { + &mut self.0 + } } macro_rules! delegate_to { diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs index e37b8e23..c026827b 100644 --- a/sqlx-core/src/any/mod.rs +++ b/sqlx-core/src/any/mod.rs @@ -31,6 +31,9 @@ mod migrate; pub use arguments::{AnyArgumentBuffer, AnyArguments}; pub use column::{AnyColumn, AnyColumnIndex}; pub use connection::AnyConnection; +// Used internally in `sqlx-macros` +#[doc(hidden)] +pub use connection::AnyConnectionKind; pub use database::Any; pub use decode::AnyDecode; pub use encode::AnyEncode; diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 439f826f..6ca67391 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -55,7 +55,15 @@ //! [`Pool::begin`]. use self::inner::SharedPool; -#[cfg(feature = "any")] +#[cfg(all( + any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" + ), + feature = "any" +))] use crate::any::{Any, AnyKind}; use crate::connection::Connection; use crate::database::Database; @@ -429,12 +437,19 @@ impl Pool { } } -#[cfg(feature = "any")] +#[cfg(all( + any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" + ), + feature = "any" +))] impl Pool { /// Returns the database driver currently in-use by this `Pool`. /// /// Determined by the connection URI. - #[cfg(feature = "any")] pub fn any_kind(&self) -> AnyKind { self.0.connect_options.kind() } diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index bac78e5f..4215247a 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -84,7 +84,7 @@ heck = { version = "0.4", features = ["unicode"] } either = "1.6.1" once_cell = "1.9.0" proc-macro2 = { version = "1.0.36", default-features = false } -sqlx-core = { version = "0.5.12", default-features = false, path = "../sqlx-core" } +sqlx-core = { version = "0.5.12", default-features = false, features = ["any"], path = "../sqlx-core" } sqlx-rt = { version = "0.5.12", default-features = false, path = "../sqlx-rt" } serde = { version = "1.0.132", features = ["derive"], optional = true } serde_json = { version = "1.0.73", optional = true } diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 94aa2824..dbd6bd92 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::path::PathBuf; #[cfg(feature = "offline")] use std::sync::{Arc, Mutex}; @@ -12,7 +13,7 @@ use quote::{format_ident, quote}; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo}; -use sqlx_rt::block_on; +use sqlx_rt::{block_on, AsyncMutex}; use crate::database::DatabaseExt; use crate::query::data::QueryData; @@ -117,6 +118,28 @@ static METADATA: Lazy = Lazy::new(|| { pub fn expand_input(input: QueryMacroInput) -> crate::Result { match &*METADATA { + #[cfg(not(any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" + )))] + Metadata { + offline: false, + database_url: Some(db_url), + .. + } => Err( + "At least one of the features ['postgres', 'mysql', 'mssql', 'sqlite'] must be enabled \ + to get information directly from a database" + .into(), + ), + + #[cfg(any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" + ))] Metadata { offline: false, database_url: Some(db_url), @@ -157,67 +180,76 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { } } -#[allow(unused_variables)] +#[cfg(any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "sqlite" +))] fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { - // FIXME: Introduce [sqlx::any::AnyConnection] and [sqlx::any::AnyDatabase] to support - // runtime determinism here + use sqlx_core::any::{AnyConnection, AnyConnectionKind}; - let db_url = Url::parse(db_url)?; - match db_url.scheme() { - #[cfg(feature = "postgres")] - "postgres" | "postgresql" => { - let data = block_on(async { - let mut conn = sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.sql).await - })?; + static CONNECTION_CACHE: Lazy>> = + Lazy::new(|| AsyncMutex::new(BTreeMap::new())); - expand_with_data(input, data, false) - }, + let maybe_expanded: crate::Result = block_on(async { + let mut cache = CONNECTION_CACHE.lock().await; - #[cfg(not(feature = "postgres"))] - "postgres" | "postgresql" => Err("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled".into()), + if !cache.contains_key(db_url) { + let parsed_db_url = Url::parse(db_url)?; - #[cfg(feature = "mssql")] - "mssql" | "sqlserver" => { - let data = block_on(async { - let mut conn = sqlx_core::mssql::MssqlConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.sql).await - })?; + let conn = match parsed_db_url.scheme() { + #[cfg(feature = "sqlite")] + "sqlite" => { + use sqlx_core::connection::ConnectOptions; + use sqlx_core::sqlite::{SqliteConnectOptions, SqliteJournalMode}; + use std::str::FromStr; - expand_with_data(input, data, false) - }, + let sqlite_conn = SqliteConnectOptions::from_str(db_url)? + // Connections in `CONNECTION_CACHE` won't get dropped so disable journaling + // to avoid `.db-wal` and `.db-shm` files from lingering around + .journal_mode(SqliteJournalMode::Off) + .connect() + .await?; + AnyConnection::from(sqlite_conn) + } + _ => AnyConnection::connect(db_url).await?, + }; - #[cfg(not(feature = "mssql"))] - "mssql" | "sqlserver" => Err("database URL has the scheme of a MSSQL database but the `mssql` feature is not enabled".into()), + let _ = cache.insert(db_url.to_owned(), conn); + } - #[cfg(feature = "mysql")] - "mysql" | "mariadb" => { - let data = block_on(async { - let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.sql).await - })?; + let conn_item = cache.get_mut(db_url).expect("Item was just inserted"); + match conn_item.private_get_mut() { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(conn) => { + let data = QueryData::from_db(conn, &input.sql).await?; + expand_with_data(input, data, false) + } + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(conn) => { + let data = QueryData::from_db(conn, &input.sql).await?; + expand_with_data(input, data, false) + } + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(conn) => { + let data = QueryData::from_db(conn, &input.sql).await?; + expand_with_data(input, data, false) + } + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(conn) => { + let data = QueryData::from_db(conn, &input.sql).await?; + expand_with_data(input, data, false) + } + // Variants depend on feature flags + #[allow(unreachable_patterns)] + item => { + return Err(format!("Missing expansion needed for: {:?}", item).into()); + } + } + }); - expand_with_data(input, data, false) - }, - - #[cfg(not(feature = "mysql"))] - "mysql" | "mariadb" => Err("database URL has the scheme of a MySQL/MariaDB database but the `mysql` feature is not enabled".into()), - - #[cfg(feature = "sqlite")] - "sqlite" => { - let data = block_on(async { - let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.sql).await - })?; - - expand_with_data(input, data, false) - }, - - #[cfg(not(feature = "sqlite"))] - "sqlite" => Err("database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled".into()), - - scheme => Err(format!("unknown database URL scheme {:?}", scheme).into()) - } + maybe_expanded.map_err(Into::into) } #[cfg(feature = "offline")] diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index 5e6b7322..a9a6f21d 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -37,7 +37,8 @@ pub use native_tls; ))] pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, - net::TcpStream, runtime::Handle, task::spawn, task::yield_now, time::sleep, time::timeout, + net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now, + time::sleep, time::timeout, }; #[cfg(all( @@ -142,7 +143,7 @@ macro_rules! blocking { pub use async_std::{ self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite, - net::TcpStream, task::sleep, task::spawn, task::yield_now, + net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now, }; #[cfg(all(