diff --git a/Cargo.toml b/Cargo.toml index c3cb86c8..a92c67d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -224,7 +224,6 @@ sqlx-sqlite = { workspace = true, optional = true } anyhow = "1.0.52" time_ = { version = "0.3.2", package = "time" } futures-util = { version = "0.3.19", default-features = false, features = ["alloc"] } -env_logger = "0.11" async-std = { workspace = true, features = ["attributes"] } tokio = { version = "1.15.0", features = ["full"] } dotenvy = "0.15.0" @@ -241,7 +240,8 @@ tempfile = "3.10.1" criterion = { version = "0.7.0", features = ["async_tokio"] } libsqlite3-sys = { version = "0.30.1" } -tracing = { version = "0.1.44", features = ["attributes"] } +tracing = "0.1.41" +tracing-subscriber = "0.3.20" # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index af9229d4..eb6e827e 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -26,6 +26,8 @@ use futures_util::{stream, FutureExt, TryStreamExt}; use std::time::{Duration, Instant}; use tracing::Level; +const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5); + pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, @@ -47,6 +49,8 @@ impl PoolInner { let reconnect = move |slot| { let Some(pool) = pool_weak.upgrade() else { + // Prevent an infinite loop on pool drop. + DisconnectedSlot::leak(slot); return; }; @@ -104,7 +108,7 @@ impl PoolInner { self.sharded.drain(|slot| async move { let (conn, slot) = ConnectedSlot::take(slot); - let _ = conn.raw.close().await; + let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await; slot }) @@ -248,37 +252,42 @@ impl PoolInner { } } - pub(crate) async fn try_min_connections(self: &Arc) -> Result<(), Error> { - stream::iter( - self.sharded - .iter_min_connections() - .map(Result::<_, Error>::Ok), - ) - .try_for_each_concurrent(None, |slot| async move { - let shared = ConnectTaskShared::new_arc(); + pub(crate) async fn try_min_connections( + self: &Arc, + deadline: Option, + ) -> Result<(), Error> { + let shared = ConnectTaskShared::new_arc(); - let res = self - .connector - .connect( + let connect_min_connections = + future::try_join_all(self.sharded.iter_min_connections().map(|slot| { + self.connector.connect( Pool(self.clone()), ConnectionId::next(), slot, shared.clone(), ) - .await; + })); - match res { - Ok(conn) => { - drop(conn); - Ok(()) + let mut conns = if let Some(deadline) = deadline { + match rt::timeout_at(deadline, connect_min_connections).await { + Ok(Ok(conns)) => conns, + Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => { + return Err(Error::PoolTimedOut { + last_connect_error: shared.take_error().map(Box::new), + }); } - Err(Error::PoolTimedOut { .. }) => Err(Error::PoolTimedOut { - last_connect_error: shared.take_error().map(Box::new), - }), - Err(other) => Err(other), + Ok(Err(e)) => return Err(e), } - }) - .await + } else { + connect_min_connections.await? + }; + + for mut conn in conns { + // Bypass `after_release` + drop(conn.return_to_pool()); + } + + Ok(()) } } @@ -378,7 +387,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if pool.options.min_connections > 0 { rt::spawn(async move { if let Some(pool) = pool_weak.upgrade() { - if let Err(error) = pool.try_min_connections().await { + if let Err(error) = pool.try_min_connections(None).await { tracing::error!( target: "sqlx::pool", ?error, diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 0b8d9452..7d2b18ed 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -369,7 +369,9 @@ impl Pool { /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| PoolConnection::new(conn, self.0.clone())) + self.0 + .try_acquire() + .map(|conn| PoolConnection::new(conn, self.0.clone())) } /// Retrieves a connection and immediately begins a new transaction. diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index e3469561..975583e6 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -574,6 +574,10 @@ impl PoolOptions { let inner = PoolInner::new_arc(self, connector); + if inner.options.min_connections > 0 { + inner.try_min_connections(Some(deadline)).await?; + } + Ok(Pool(inner)) } diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index 24750e0a..2385c998 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -47,6 +47,7 @@ struct SlotGuard { locked: Option>, shard: ArcShard, index: ConnectionIndex, + dropped: bool, } pub struct ConnectedSlot(SlotGuard); @@ -134,7 +135,7 @@ impl Sharded { pub fn count_unlocked(&self, connected: bool) -> usize { self.shards .iter() - .map(|shard| shard.unlocked_mask(connected).count_ones() as usize) + .map(|shard| shard.unlocked_mask(connected).count_ones()) .sum() } @@ -156,10 +157,10 @@ impl Sharded { } pub async fn acquire_disconnected(&self) -> DisconnectedSlot { - let guard = self.acquire(true).await; + let guard = self.acquire(false).await; assert!( - guard.get().is_some(), + guard.get().is_none(), "BUG: expected slot {}/{} NOT to be connected but it WAS", guard.shard.shard_id, guard.index @@ -347,6 +348,7 @@ impl Shard>>]> { locked: Some(locked), shard: self.clone(), index, + dropped: false, }) } @@ -370,6 +372,13 @@ impl Shard>>]> { }) } + fn all_leaked(&self) -> bool { + let all_leaked_mask = (1usize << self.connections.len()) - 1; + let leaked_set = self.leaked_set.load(Ordering::Acquire); + + leaked_set == all_leaked_mask + } + async fn drain(self: &Arc, close: F) where F: Fn(ConnectedSlot) -> Fut, @@ -396,11 +405,9 @@ impl Shard>>]> { } }); - let finished_mask = (1usize << self.connections.len()) - 1; - std::future::poll_fn(|cx| { // The connection set is drained once all slots are leaked. - if self.leaked_set.load(Ordering::Acquire) == finished_mask { + if self.all_leaked() { return Poll::Ready(()); } @@ -409,7 +416,12 @@ impl Shard>>]> { let _ = drain_disconnected.as_mut().poll(cx); let _ = drain_leaked.as_mut().poll(cx); - Poll::Pending + // Check again after driving the `drain` futures forward. + if self.all_leaked() { + Poll::Ready(()) + } else { + Poll::Pending + } }) .await; } @@ -443,6 +455,13 @@ impl ConnectedSlot { .take() .expect("BUG: expected slot to be populated, but it wasn't"); + atomic_set( + &this.0.shard.connected_set, + this.0.index, + false, + Ordering::AcqRel, + ); + (conn, DisconnectedSlot(this.0)) } } @@ -450,6 +469,14 @@ impl ConnectedSlot { impl DisconnectedSlot { pub fn put(mut self, connection: T) -> ConnectedSlot { *self.0.get_mut() = Some(connection); + + atomic_set( + &self.0.shard.connected_set, + self.0.index, + true, + Ordering::AcqRel, + ); + ConnectedSlot(self.0) } @@ -546,10 +573,13 @@ impl Drop for SlotGuard { locked: Some(locked), shard: self.shard.clone(), index: self.index, + // To avoid infinite recursion or deadlock, don't send another notification + // if this guard was already dropped once: just unlock it. + dropped: true, } }; - if connected { + if !self.dropped && connected { // Check for global waiters first. if self .shard @@ -564,7 +594,7 @@ impl Drop for SlotGuard { if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 { return; } - } else { + } else if !self.dropped { if self .shard .global @@ -584,7 +614,6 @@ impl Drop for SlotGuard { return; } - // If this connection is required to satisfy `min_connections` if self.should_reconnect() { (self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag())); return; @@ -695,7 +724,7 @@ impl Iterator for Mask { } let index = self.0.trailing_zeros() as usize; - self.0 &= 1 << index; + self.0 &= !(1 << index); Some(index) } diff --git a/sqlx-test/Cargo.toml b/sqlx-test/Cargo.toml index 32a341ad..4fdcb372 100644 --- a/sqlx-test/Cargo.toml +++ b/sqlx-test/Cargo.toml @@ -10,6 +10,7 @@ sqlx = { default-features = false, path = ".." } env_logger = "0.11" dotenvy = "0.15.0" anyhow = "1.0.26" +tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } [lints] workspace = true diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 3744724c..01cdc297 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -1,10 +1,14 @@ use sqlx::pool::PoolOptions; use sqlx::{Connection, Database, Error, Pool}; use std::env; +use tracing_subscriber::EnvFilter; pub fn setup_if_needed() { let _ = dotenvy::dotenv(); - let _ = env_logger::builder().is_test(true).try_init(); + let _ = tracing_subscriber::fmt::Subscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .with_test_writer() + .finish(); } // Make a new connection diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 1cc08380..d5d47d16 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,6 +1,6 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; -use sqlx_core::connection::ConnectOptions; +use sqlx_core::connection::{ConnectOptions, Connection}; use sqlx_core::pool::PoolConnectMetadata; use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ @@ -9,6 +9,29 @@ use std::sync::{ }; use std::time::Duration; +#[sqlx_macros::test] +async fn pool_basic_functions() -> anyhow::Result<()> { + sqlx::any::install_default_drivers(); + + let pool = AnyPoolOptions::new() + .max_connections(2) + .acquire_timeout(Duration::from_secs(3)) + .connect(&dotenvy::var("DATABASE_URL")?) + .await?; + + let mut conn = pool.acquire().await?; + + conn.ping().await?; + + drop(conn); + + let b: bool = sqlx::query_scalar("SELECT true").fetch_one(&pool).await?; + + assert!(b); + + Ok(()) +} + // https://github.com/launchbadge/sqlx/issues/527 #[sqlx_macros::test] async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { @@ -43,6 +66,7 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx::any::install_default_drivers(); + tracing_subscriber::fmt::init(); #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] struct ConnStats { @@ -131,7 +155,9 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + sqlx::raw_sql(AssertSqlSafe(statement)) + .execute(&mut conn) + .await?; Ok(conn) } }); @@ -154,6 +180,8 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { ]; for (id, before_acquire_calls, after_release_calls) in pattern { + eprintln!("ID: {id}, before_acquire calls: {before_acquire_calls}, after_release calls: {after_release_calls}"); + let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats") .fetch_one(&pool) .await?; @@ -183,6 +211,7 @@ async fn test_connection_maintenance() -> anyhow::Result<()> { let last_meta = Arc::new(Mutex::new(None)); let last_meta_ = last_meta.clone(); let pool = AnyPoolOptions::new() + .acquire_timeout(Duration::from_secs(1)) .max_lifetime(Duration::from_millis(400)) .min_connections(3) .before_acquire(move |_conn, _meta| {