Reuse a cached DB connection instead of always recreating for sqlx-macros (#1782)

* refactor: Reuse a cached connection instead of always recreating for `sqlx-macros`

* fix: Fix type inference issue when no database features used

* refactor: Switch cached db conn to an `AnyConnection`

* fix: Fix invalid variant name only exposed with features

* fix: Tweak connection options for SQLite with `sqlx-macros`

* fix: Remove read only option for SQLite connection

* fix: Fix feature flags regarding usage of `sqlx_core::any`
This commit is contained in:
LovecraftianHorror 2022-05-25 19:22:09 -06:00 committed by GitHub
parent fa5c436918
commit a2691b9635
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 119 additions and 60 deletions

View File

@ -33,7 +33,9 @@ mod executor;
pub struct AnyConnection(pub(super) AnyConnectionKind);
#[derive(Debug)]
pub(crate) enum AnyConnectionKind {
// Used internally in `sqlx-macros`
#[doc(hidden)]
pub enum AnyConnectionKind {
#[cfg(feature = "postgres")]
Postgres(postgres::PgConnection),
@ -69,6 +71,12 @@ impl AnyConnection {
pub fn kind(&self) -> AnyKind {
self.0.kind()
}
// Used internally in `sqlx-macros`
#[doc(hidden)]
pub fn private_get_mut(&mut self) -> &mut AnyConnectionKind {
&mut self.0
}
}
macro_rules! delegate_to {

View File

@ -31,6 +31,9 @@ mod migrate;
pub use arguments::{AnyArgumentBuffer, AnyArguments};
pub use column::{AnyColumn, AnyColumnIndex};
pub use connection::AnyConnection;
// Used internally in `sqlx-macros`
#[doc(hidden)]
pub use connection::AnyConnectionKind;
pub use database::Any;
pub use decode::AnyDecode;
pub use encode::AnyEncode;

View File

@ -55,7 +55,15 @@
//! [`Pool::begin`].
use self::inner::SharedPool;
#[cfg(feature = "any")]
#[cfg(all(
any(
feature = "postgres",
feature = "mysql",
feature = "mssql",
feature = "sqlite"
),
feature = "any"
))]
use crate::any::{Any, AnyKind};
use crate::connection::Connection;
use crate::database::Database;
@ -429,12 +437,19 @@ impl<DB: Database> Pool<DB> {
}
}
#[cfg(feature = "any")]
#[cfg(all(
any(
feature = "postgres",
feature = "mysql",
feature = "mssql",
feature = "sqlite"
),
feature = "any"
))]
impl Pool<Any> {
/// Returns the database driver currently in-use by this `Pool`.
///
/// Determined by the connection URI.
#[cfg(feature = "any")]
pub fn any_kind(&self) -> AnyKind {
self.0.connect_options.kind()
}

View File

@ -84,7 +84,7 @@ heck = { version = "0.4", features = ["unicode"] }
either = "1.6.1"
once_cell = "1.9.0"
proc-macro2 = { version = "1.0.36", default-features = false }
sqlx-core = { version = "0.5.12", default-features = false, path = "../sqlx-core" }
sqlx-core = { version = "0.5.12", default-features = false, features = ["any"], path = "../sqlx-core" }
sqlx-rt = { version = "0.5.12", default-features = false, path = "../sqlx-rt" }
serde = { version = "1.0.132", features = ["derive"], optional = true }
serde_json = { version = "1.0.73", optional = true }

View File

@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::path::PathBuf;
#[cfg(feature = "offline")]
use std::sync::{Arc, Mutex};
@ -12,7 +13,7 @@ use quote::{format_ident, quote};
use sqlx_core::connection::Connection;
use sqlx_core::database::Database;
use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo};
use sqlx_rt::block_on;
use sqlx_rt::{block_on, AsyncMutex};
use crate::database::DatabaseExt;
use crate::query::data::QueryData;
@ -117,6 +118,28 @@ static METADATA: Lazy<Metadata> = Lazy::new(|| {
pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
match &*METADATA {
#[cfg(not(any(
feature = "postgres",
feature = "mysql",
feature = "mssql",
feature = "sqlite"
)))]
Metadata {
offline: false,
database_url: Some(db_url),
..
} => Err(
"At least one of the features ['postgres', 'mysql', 'mssql', 'sqlite'] must be enabled \
to get information directly from a database"
.into(),
),
#[cfg(any(
feature = "postgres",
feature = "mysql",
feature = "mssql",
feature = "sqlite"
))]
Metadata {
offline: false,
database_url: Some(db_url),
@ -157,67 +180,76 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result<TokenStream> {
}
}
#[allow(unused_variables)]
#[cfg(any(
feature = "postgres",
feature = "mysql",
feature = "mssql",
feature = "sqlite"
))]
fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenStream> {
// FIXME: Introduce [sqlx::any::AnyConnection] and [sqlx::any::AnyDatabase] to support
// runtime determinism here
use sqlx_core::any::{AnyConnection, AnyConnectionKind};
let db_url = Url::parse(db_url)?;
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgres" | "postgresql" => {
let data = block_on(async {
let mut conn = sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.sql).await
})?;
static CONNECTION_CACHE: Lazy<AsyncMutex<BTreeMap<String, AnyConnection>>> =
Lazy::new(|| AsyncMutex::new(BTreeMap::new()));
expand_with_data(input, data, false)
},
let maybe_expanded: crate::Result<TokenStream> = block_on(async {
let mut cache = CONNECTION_CACHE.lock().await;
#[cfg(not(feature = "postgres"))]
"postgres" | "postgresql" => Err("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled".into()),
if !cache.contains_key(db_url) {
let parsed_db_url = Url::parse(db_url)?;
#[cfg(feature = "mssql")]
"mssql" | "sqlserver" => {
let data = block_on(async {
let mut conn = sqlx_core::mssql::MssqlConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.sql).await
})?;
let conn = match parsed_db_url.scheme() {
#[cfg(feature = "sqlite")]
"sqlite" => {
use sqlx_core::connection::ConnectOptions;
use sqlx_core::sqlite::{SqliteConnectOptions, SqliteJournalMode};
use std::str::FromStr;
expand_with_data(input, data, false)
},
let sqlite_conn = SqliteConnectOptions::from_str(db_url)?
// Connections in `CONNECTION_CACHE` won't get dropped so disable journaling
// to avoid `.db-wal` and `.db-shm` files from lingering around
.journal_mode(SqliteJournalMode::Off)
.connect()
.await?;
AnyConnection::from(sqlite_conn)
}
_ => AnyConnection::connect(db_url).await?,
};
#[cfg(not(feature = "mssql"))]
"mssql" | "sqlserver" => Err("database URL has the scheme of a MSSQL database but the `mssql` feature is not enabled".into()),
let _ = cache.insert(db_url.to_owned(), conn);
}
#[cfg(feature = "mysql")]
"mysql" | "mariadb" => {
let data = block_on(async {
let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.sql).await
})?;
let conn_item = cache.get_mut(db_url).expect("Item was just inserted");
match conn_item.private_get_mut() {
#[cfg(feature = "postgres")]
AnyConnectionKind::Postgres(conn) => {
let data = QueryData::from_db(conn, &input.sql).await?;
expand_with_data(input, data, false)
}
#[cfg(feature = "mssql")]
AnyConnectionKind::Mssql(conn) => {
let data = QueryData::from_db(conn, &input.sql).await?;
expand_with_data(input, data, false)
}
#[cfg(feature = "mysql")]
AnyConnectionKind::MySql(conn) => {
let data = QueryData::from_db(conn, &input.sql).await?;
expand_with_data(input, data, false)
}
#[cfg(feature = "sqlite")]
AnyConnectionKind::Sqlite(conn) => {
let data = QueryData::from_db(conn, &input.sql).await?;
expand_with_data(input, data, false)
}
// Variants depend on feature flags
#[allow(unreachable_patterns)]
item => {
return Err(format!("Missing expansion needed for: {:?}", item).into());
}
}
});
expand_with_data(input, data, false)
},
#[cfg(not(feature = "mysql"))]
"mysql" | "mariadb" => Err("database URL has the scheme of a MySQL/MariaDB database but the `mysql` feature is not enabled".into()),
#[cfg(feature = "sqlite")]
"sqlite" => {
let data = block_on(async {
let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.sql).await
})?;
expand_with_data(input, data, false)
},
#[cfg(not(feature = "sqlite"))]
"sqlite" => Err("database URL has the scheme of a SQLite database but the `sqlite` feature is not enabled".into()),
scheme => Err(format!("unknown database URL scheme {:?}", scheme).into())
}
maybe_expanded.map_err(Into::into)
}
#[cfg(feature = "offline")]

View File

@ -37,7 +37,8 @@ pub use native_tls;
))]
pub use tokio::{
self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf,
net::TcpStream, runtime::Handle, task::spawn, task::yield_now, time::sleep, time::timeout,
net::TcpStream, runtime::Handle, sync::Mutex as AsyncMutex, task::spawn, task::yield_now,
time::sleep, time::timeout,
};
#[cfg(all(
@ -142,7 +143,7 @@ macro_rules! blocking {
pub use async_std::{
self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt,
io::prelude::WriteExt as AsyncWriteExt, io::Read as AsyncRead, io::Write as AsyncWrite,
net::TcpStream, task::sleep, task::spawn, task::yield_now,
net::TcpStream, sync::Mutex as AsyncMutex, task::sleep, task::spawn, task::yield_now,
};
#[cfg(all(