fix: bugs in sharded pool

This commit is contained in:
Austin Bonander 2025-10-29 06:54:15 -07:00
parent 23643d7fe2
commit dd9cb718de
8 changed files with 119 additions and 41 deletions

View File

@ -224,7 +224,6 @@ sqlx-sqlite = { workspace = true, optional = true }
anyhow = "1.0.52"
time_ = { version = "0.3.2", package = "time" }
futures-util = { version = "0.3.19", default-features = false, features = ["alloc"] }
env_logger = "0.11"
async-std = { workspace = true, features = ["attributes"] }
tokio = { version = "1.15.0", features = ["full"] }
dotenvy = "0.15.0"
@ -241,7 +240,8 @@ tempfile = "3.10.1"
criterion = { version = "0.7.0", features = ["async_tokio"] }
libsqlite3-sys = { version = "0.30.1" }
tracing = { version = "0.1.44", features = ["attributes"] }
tracing = "0.1.41"
tracing-subscriber = "0.3.20"
# If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`,
# even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code.

View File

@ -26,6 +26,8 @@ use futures_util::{stream, FutureExt, TryStreamExt};
use std::time::{Duration, Instant};
use tracing::Level;
const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5);
pub(crate) struct PoolInner<DB: Database> {
pub(super) connector: DynConnector<DB>,
pub(super) counter: ConnectionCounter,
@ -47,6 +49,8 @@ impl<DB: Database> PoolInner<DB> {
let reconnect = move |slot| {
let Some(pool) = pool_weak.upgrade() else {
// Prevent an infinite loop on pool drop.
DisconnectedSlot::leak(slot);
return;
};
@ -104,7 +108,7 @@ impl<DB: Database> PoolInner<DB> {
self.sharded.drain(|slot| async move {
let (conn, slot) = ConnectedSlot::take(slot);
let _ = conn.raw.close().await;
let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await;
slot
})
@ -248,37 +252,42 @@ impl<DB: Database> PoolInner<DB> {
}
}
pub(crate) async fn try_min_connections(self: &Arc<Self>) -> Result<(), Error> {
stream::iter(
self.sharded
.iter_min_connections()
.map(Result::<_, Error>::Ok),
)
.try_for_each_concurrent(None, |slot| async move {
let shared = ConnectTaskShared::new_arc();
pub(crate) async fn try_min_connections(
self: &Arc<Self>,
deadline: Option<Instant>,
) -> Result<(), Error> {
let shared = ConnectTaskShared::new_arc();
let res = self
.connector
.connect(
let connect_min_connections =
future::try_join_all(self.sharded.iter_min_connections().map(|slot| {
self.connector.connect(
Pool(self.clone()),
ConnectionId::next(),
slot,
shared.clone(),
)
.await;
}));
match res {
Ok(conn) => {
drop(conn);
Ok(())
let mut conns = if let Some(deadline) = deadline {
match rt::timeout_at(deadline, connect_min_connections).await {
Ok(Ok(conns)) => conns,
Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => {
return Err(Error::PoolTimedOut {
last_connect_error: shared.take_error().map(Box::new),
});
}
Err(Error::PoolTimedOut { .. }) => Err(Error::PoolTimedOut {
last_connect_error: shared.take_error().map(Box::new),
}),
Err(other) => Err(other),
Ok(Err(e)) => return Err(e),
}
})
.await
} else {
connect_min_connections.await?
};
for mut conn in conns {
// Bypass `after_release`
drop(conn.return_to_pool());
}
Ok(())
}
}
@ -378,7 +387,7 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
if pool.options.min_connections > 0 {
rt::spawn(async move {
if let Some(pool) = pool_weak.upgrade() {
if let Err(error) = pool.try_min_connections().await {
if let Err(error) = pool.try_min_connections(None).await {
tracing::error!(
target: "sqlx::pool",
?error,

View File

@ -369,7 +369,9 @@ impl<DB: Database> Pool<DB> {
/// Returns `None` immediately if there are no idle connections available in the pool
/// or there are tasks waiting for a connection which have yet to wake.
pub fn try_acquire(&self) -> Option<PoolConnection<DB>> {
self.0.try_acquire().map(|conn| PoolConnection::new(conn, self.0.clone()))
self.0
.try_acquire()
.map(|conn| PoolConnection::new(conn, self.0.clone()))
}
/// Retrieves a connection and immediately begins a new transaction.

View File

@ -574,6 +574,10 @@ impl<DB: Database> PoolOptions<DB> {
let inner = PoolInner::new_arc(self, connector);
if inner.options.min_connections > 0 {
inner.try_min_connections(Some(deadline)).await?;
}
Ok(Pool(inner))
}

View File

@ -47,6 +47,7 @@ struct SlotGuard<T> {
locked: Option<ArcMutexGuard<T>>,
shard: ArcShard<T>,
index: ConnectionIndex,
dropped: bool,
}
pub struct ConnectedSlot<T>(SlotGuard<T>);
@ -134,7 +135,7 @@ impl<T> Sharded<T> {
pub fn count_unlocked(&self, connected: bool) -> usize {
self.shards
.iter()
.map(|shard| shard.unlocked_mask(connected).count_ones() as usize)
.map(|shard| shard.unlocked_mask(connected).count_ones())
.sum()
}
@ -156,10 +157,10 @@ impl<T> Sharded<T> {
}
pub async fn acquire_disconnected(&self) -> DisconnectedSlot<T> {
let guard = self.acquire(true).await;
let guard = self.acquire(false).await;
assert!(
guard.get().is_some(),
guard.get().is_none(),
"BUG: expected slot {}/{} NOT to be connected but it WAS",
guard.shard.shard_id,
guard.index
@ -347,6 +348,7 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
locked: Some(locked),
shard: self.clone(),
index,
dropped: false,
})
}
@ -370,6 +372,13 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
})
}
fn all_leaked(&self) -> bool {
let all_leaked_mask = (1usize << self.connections.len()) - 1;
let leaked_set = self.leaked_set.load(Ordering::Acquire);
leaked_set == all_leaked_mask
}
async fn drain<F, Fut>(self: &Arc<Self>, close: F)
where
F: Fn(ConnectedSlot<T>) -> Fut,
@ -396,11 +405,9 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
}
});
let finished_mask = (1usize << self.connections.len()) - 1;
std::future::poll_fn(|cx| {
// The connection set is drained once all slots are leaked.
if self.leaked_set.load(Ordering::Acquire) == finished_mask {
if self.all_leaked() {
return Poll::Ready(());
}
@ -409,7 +416,12 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
let _ = drain_disconnected.as_mut().poll(cx);
let _ = drain_leaked.as_mut().poll(cx);
Poll::Pending
// Check again after driving the `drain` futures forward.
if self.all_leaked() {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await;
}
@ -443,6 +455,13 @@ impl<T> ConnectedSlot<T> {
.take()
.expect("BUG: expected slot to be populated, but it wasn't");
atomic_set(
&this.0.shard.connected_set,
this.0.index,
false,
Ordering::AcqRel,
);
(conn, DisconnectedSlot(this.0))
}
}
@ -450,6 +469,14 @@ impl<T> ConnectedSlot<T> {
impl<T> DisconnectedSlot<T> {
pub fn put(mut self, connection: T) -> ConnectedSlot<T> {
*self.0.get_mut() = Some(connection);
atomic_set(
&self.0.shard.connected_set,
self.0.index,
true,
Ordering::AcqRel,
);
ConnectedSlot(self.0)
}
@ -546,10 +573,13 @@ impl<T> Drop for SlotGuard<T> {
locked: Some(locked),
shard: self.shard.clone(),
index: self.index,
// To avoid infinite recursion or deadlock, don't send another notification
// if this guard was already dropped once: just unlock it.
dropped: true,
}
};
if connected {
if !self.dropped && connected {
// Check for global waiters first.
if self
.shard
@ -564,7 +594,7 @@ impl<T> Drop for SlotGuard<T> {
if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 {
return;
}
} else {
} else if !self.dropped {
if self
.shard
.global
@ -584,7 +614,6 @@ impl<T> Drop for SlotGuard<T> {
return;
}
// If this connection is required to satisfy `min_connections`
if self.should_reconnect() {
(self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag()));
return;
@ -695,7 +724,7 @@ impl Iterator for Mask {
}
let index = self.0.trailing_zeros() as usize;
self.0 &= 1 << index;
self.0 &= !(1 << index);
Some(index)
}

View File

@ -10,6 +10,7 @@ sqlx = { default-features = false, path = ".." }
env_logger = "0.11"
dotenvy = "0.15.0"
anyhow = "1.0.26"
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
[lints]
workspace = true

View File

@ -1,10 +1,14 @@
use sqlx::pool::PoolOptions;
use sqlx::{Connection, Database, Error, Pool};
use std::env;
use tracing_subscriber::EnvFilter;
pub fn setup_if_needed() {
let _ = dotenvy::dotenv();
let _ = env_logger::builder().is_test(true).try_init();
let _ = tracing_subscriber::fmt::Subscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_test_writer()
.finish();
}
// Make a new connection

View File

@ -1,6 +1,6 @@
use sqlx::any::{AnyConnectOptions, AnyPoolOptions};
use sqlx::Executor;
use sqlx_core::connection::ConnectOptions;
use sqlx_core::connection::{ConnectOptions, Connection};
use sqlx_core::pool::PoolConnectMetadata;
use sqlx_core::sql_str::AssertSqlSafe;
use std::sync::{
@ -9,6 +9,29 @@ use std::sync::{
};
use std::time::Duration;
#[sqlx_macros::test]
async fn pool_basic_functions() -> anyhow::Result<()> {
sqlx::any::install_default_drivers();
let pool = AnyPoolOptions::new()
.max_connections(2)
.acquire_timeout(Duration::from_secs(3))
.connect(&dotenvy::var("DATABASE_URL")?)
.await?;
let mut conn = pool.acquire().await?;
conn.ping().await?;
drop(conn);
let b: bool = sqlx::query_scalar("SELECT true").fetch_one(&pool).await?;
assert!(b);
Ok(())
}
// https://github.com/launchbadge/sqlx/issues/527
#[sqlx_macros::test]
async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {
@ -43,6 +66,7 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {
#[sqlx_macros::test]
async fn test_pool_callbacks() -> anyhow::Result<()> {
sqlx::any::install_default_drivers();
tracing_subscriber::fmt::init();
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct ConnStats {
@ -131,7 +155,9 @@ async fn test_pool_callbacks() -> anyhow::Result<()> {
id
);
conn.execute(&statement[..]).await?;
sqlx::raw_sql(AssertSqlSafe(statement))
.execute(&mut conn)
.await?;
Ok(conn)
}
});
@ -154,6 +180,8 @@ async fn test_pool_callbacks() -> anyhow::Result<()> {
];
for (id, before_acquire_calls, after_release_calls) in pattern {
eprintln!("ID: {id}, before_acquire calls: {before_acquire_calls}, after_release calls: {after_release_calls}");
let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats")
.fetch_one(&pool)
.await?;
@ -183,6 +211,7 @@ async fn test_connection_maintenance() -> anyhow::Result<()> {
let last_meta = Arc::new(Mutex::new(None));
let last_meta_ = last_meta.clone();
let pool = AnyPoolOptions::new()
.acquire_timeout(Duration::from_secs(1))
.max_lifetime(Duration::from_millis(400))
.min_connections(3)
.before_acquire(move |_conn, _meta| {