mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
fix: bugs in sharded pool
This commit is contained in:
parent
23643d7fe2
commit
dd9cb718de
@ -224,7 +224,6 @@ sqlx-sqlite = { workspace = true, optional = true }
|
||||
anyhow = "1.0.52"
|
||||
time_ = { version = "0.3.2", package = "time" }
|
||||
futures-util = { version = "0.3.19", default-features = false, features = ["alloc"] }
|
||||
env_logger = "0.11"
|
||||
async-std = { workspace = true, features = ["attributes"] }
|
||||
tokio = { version = "1.15.0", features = ["full"] }
|
||||
dotenvy = "0.15.0"
|
||||
@ -241,7 +240,8 @@ tempfile = "3.10.1"
|
||||
criterion = { version = "0.7.0", features = ["async_tokio"] }
|
||||
libsqlite3-sys = { version = "0.30.1" }
|
||||
|
||||
tracing = { version = "0.1.44", features = ["attributes"] }
|
||||
tracing = "0.1.41"
|
||||
tracing-subscriber = "0.3.20"
|
||||
|
||||
# If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`,
|
||||
# even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code.
|
||||
|
||||
@ -26,6 +26,8 @@ use futures_util::{stream, FutureExt, TryStreamExt};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::Level;
|
||||
|
||||
const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
pub(crate) struct PoolInner<DB: Database> {
|
||||
pub(super) connector: DynConnector<DB>,
|
||||
pub(super) counter: ConnectionCounter,
|
||||
@ -47,6 +49,8 @@ impl<DB: Database> PoolInner<DB> {
|
||||
|
||||
let reconnect = move |slot| {
|
||||
let Some(pool) = pool_weak.upgrade() else {
|
||||
// Prevent an infinite loop on pool drop.
|
||||
DisconnectedSlot::leak(slot);
|
||||
return;
|
||||
};
|
||||
|
||||
@ -104,7 +108,7 @@ impl<DB: Database> PoolInner<DB> {
|
||||
self.sharded.drain(|slot| async move {
|
||||
let (conn, slot) = ConnectedSlot::take(slot);
|
||||
|
||||
let _ = conn.raw.close().await;
|
||||
let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await;
|
||||
|
||||
slot
|
||||
})
|
||||
@ -248,37 +252,42 @@ impl<DB: Database> PoolInner<DB> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn try_min_connections(self: &Arc<Self>) -> Result<(), Error> {
|
||||
stream::iter(
|
||||
self.sharded
|
||||
.iter_min_connections()
|
||||
.map(Result::<_, Error>::Ok),
|
||||
)
|
||||
.try_for_each_concurrent(None, |slot| async move {
|
||||
let shared = ConnectTaskShared::new_arc();
|
||||
pub(crate) async fn try_min_connections(
|
||||
self: &Arc<Self>,
|
||||
deadline: Option<Instant>,
|
||||
) -> Result<(), Error> {
|
||||
let shared = ConnectTaskShared::new_arc();
|
||||
|
||||
let res = self
|
||||
.connector
|
||||
.connect(
|
||||
let connect_min_connections =
|
||||
future::try_join_all(self.sharded.iter_min_connections().map(|slot| {
|
||||
self.connector.connect(
|
||||
Pool(self.clone()),
|
||||
ConnectionId::next(),
|
||||
slot,
|
||||
shared.clone(),
|
||||
)
|
||||
.await;
|
||||
}));
|
||||
|
||||
match res {
|
||||
Ok(conn) => {
|
||||
drop(conn);
|
||||
Ok(())
|
||||
let mut conns = if let Some(deadline) = deadline {
|
||||
match rt::timeout_at(deadline, connect_min_connections).await {
|
||||
Ok(Ok(conns)) => conns,
|
||||
Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => {
|
||||
return Err(Error::PoolTimedOut {
|
||||
last_connect_error: shared.take_error().map(Box::new),
|
||||
});
|
||||
}
|
||||
Err(Error::PoolTimedOut { .. }) => Err(Error::PoolTimedOut {
|
||||
last_connect_error: shared.take_error().map(Box::new),
|
||||
}),
|
||||
Err(other) => Err(other),
|
||||
Ok(Err(e)) => return Err(e),
|
||||
}
|
||||
})
|
||||
.await
|
||||
} else {
|
||||
connect_min_connections.await?
|
||||
};
|
||||
|
||||
for mut conn in conns {
|
||||
// Bypass `after_release`
|
||||
drop(conn.return_to_pool());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -378,7 +387,7 @@ fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
|
||||
if pool.options.min_connections > 0 {
|
||||
rt::spawn(async move {
|
||||
if let Some(pool) = pool_weak.upgrade() {
|
||||
if let Err(error) = pool.try_min_connections().await {
|
||||
if let Err(error) = pool.try_min_connections(None).await {
|
||||
tracing::error!(
|
||||
target: "sqlx::pool",
|
||||
?error,
|
||||
|
||||
@ -369,7 +369,9 @@ impl<DB: Database> Pool<DB> {
|
||||
/// Returns `None` immediately if there are no idle connections available in the pool
|
||||
/// or there are tasks waiting for a connection which have yet to wake.
|
||||
pub fn try_acquire(&self) -> Option<PoolConnection<DB>> {
|
||||
self.0.try_acquire().map(|conn| PoolConnection::new(conn, self.0.clone()))
|
||||
self.0
|
||||
.try_acquire()
|
||||
.map(|conn| PoolConnection::new(conn, self.0.clone()))
|
||||
}
|
||||
|
||||
/// Retrieves a connection and immediately begins a new transaction.
|
||||
|
||||
@ -574,6 +574,10 @@ impl<DB: Database> PoolOptions<DB> {
|
||||
|
||||
let inner = PoolInner::new_arc(self, connector);
|
||||
|
||||
if inner.options.min_connections > 0 {
|
||||
inner.try_min_connections(Some(deadline)).await?;
|
||||
}
|
||||
|
||||
Ok(Pool(inner))
|
||||
}
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ struct SlotGuard<T> {
|
||||
locked: Option<ArcMutexGuard<T>>,
|
||||
shard: ArcShard<T>,
|
||||
index: ConnectionIndex,
|
||||
dropped: bool,
|
||||
}
|
||||
|
||||
pub struct ConnectedSlot<T>(SlotGuard<T>);
|
||||
@ -134,7 +135,7 @@ impl<T> Sharded<T> {
|
||||
pub fn count_unlocked(&self, connected: bool) -> usize {
|
||||
self.shards
|
||||
.iter()
|
||||
.map(|shard| shard.unlocked_mask(connected).count_ones() as usize)
|
||||
.map(|shard| shard.unlocked_mask(connected).count_ones())
|
||||
.sum()
|
||||
}
|
||||
|
||||
@ -156,10 +157,10 @@ impl<T> Sharded<T> {
|
||||
}
|
||||
|
||||
pub async fn acquire_disconnected(&self) -> DisconnectedSlot<T> {
|
||||
let guard = self.acquire(true).await;
|
||||
let guard = self.acquire(false).await;
|
||||
|
||||
assert!(
|
||||
guard.get().is_some(),
|
||||
guard.get().is_none(),
|
||||
"BUG: expected slot {}/{} NOT to be connected but it WAS",
|
||||
guard.shard.shard_id,
|
||||
guard.index
|
||||
@ -347,6 +348,7 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
||||
locked: Some(locked),
|
||||
shard: self.clone(),
|
||||
index,
|
||||
dropped: false,
|
||||
})
|
||||
}
|
||||
|
||||
@ -370,6 +372,13 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
||||
})
|
||||
}
|
||||
|
||||
fn all_leaked(&self) -> bool {
|
||||
let all_leaked_mask = (1usize << self.connections.len()) - 1;
|
||||
let leaked_set = self.leaked_set.load(Ordering::Acquire);
|
||||
|
||||
leaked_set == all_leaked_mask
|
||||
}
|
||||
|
||||
async fn drain<F, Fut>(self: &Arc<Self>, close: F)
|
||||
where
|
||||
F: Fn(ConnectedSlot<T>) -> Fut,
|
||||
@ -396,11 +405,9 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
||||
}
|
||||
});
|
||||
|
||||
let finished_mask = (1usize << self.connections.len()) - 1;
|
||||
|
||||
std::future::poll_fn(|cx| {
|
||||
// The connection set is drained once all slots are leaked.
|
||||
if self.leaked_set.load(Ordering::Acquire) == finished_mask {
|
||||
if self.all_leaked() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
@ -409,7 +416,12 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
||||
let _ = drain_disconnected.as_mut().poll(cx);
|
||||
let _ = drain_leaked.as_mut().poll(cx);
|
||||
|
||||
Poll::Pending
|
||||
// Check again after driving the `drain` futures forward.
|
||||
if self.all_leaked() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
@ -443,6 +455,13 @@ impl<T> ConnectedSlot<T> {
|
||||
.take()
|
||||
.expect("BUG: expected slot to be populated, but it wasn't");
|
||||
|
||||
atomic_set(
|
||||
&this.0.shard.connected_set,
|
||||
this.0.index,
|
||||
false,
|
||||
Ordering::AcqRel,
|
||||
);
|
||||
|
||||
(conn, DisconnectedSlot(this.0))
|
||||
}
|
||||
}
|
||||
@ -450,6 +469,14 @@ impl<T> ConnectedSlot<T> {
|
||||
impl<T> DisconnectedSlot<T> {
|
||||
pub fn put(mut self, connection: T) -> ConnectedSlot<T> {
|
||||
*self.0.get_mut() = Some(connection);
|
||||
|
||||
atomic_set(
|
||||
&self.0.shard.connected_set,
|
||||
self.0.index,
|
||||
true,
|
||||
Ordering::AcqRel,
|
||||
);
|
||||
|
||||
ConnectedSlot(self.0)
|
||||
}
|
||||
|
||||
@ -546,10 +573,13 @@ impl<T> Drop for SlotGuard<T> {
|
||||
locked: Some(locked),
|
||||
shard: self.shard.clone(),
|
||||
index: self.index,
|
||||
// To avoid infinite recursion or deadlock, don't send another notification
|
||||
// if this guard was already dropped once: just unlock it.
|
||||
dropped: true,
|
||||
}
|
||||
};
|
||||
|
||||
if connected {
|
||||
if !self.dropped && connected {
|
||||
// Check for global waiters first.
|
||||
if self
|
||||
.shard
|
||||
@ -564,7 +594,7 @@ impl<T> Drop for SlotGuard<T> {
|
||||
if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
} else if !self.dropped {
|
||||
if self
|
||||
.shard
|
||||
.global
|
||||
@ -584,7 +614,6 @@ impl<T> Drop for SlotGuard<T> {
|
||||
return;
|
||||
}
|
||||
|
||||
// If this connection is required to satisfy `min_connections`
|
||||
if self.should_reconnect() {
|
||||
(self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag()));
|
||||
return;
|
||||
@ -695,7 +724,7 @@ impl Iterator for Mask {
|
||||
}
|
||||
|
||||
let index = self.0.trailing_zeros() as usize;
|
||||
self.0 &= 1 << index;
|
||||
self.0 &= !(1 << index);
|
||||
|
||||
Some(index)
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ sqlx = { default-features = false, path = ".." }
|
||||
env_logger = "0.11"
|
||||
dotenvy = "0.15.0"
|
||||
anyhow = "1.0.26"
|
||||
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
use sqlx::pool::PoolOptions;
|
||||
use sqlx::{Connection, Database, Error, Pool};
|
||||
use std::env;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub fn setup_if_needed() {
|
||||
let _ = dotenvy::dotenv();
|
||||
let _ = env_logger::builder().is_test(true).try_init();
|
||||
let _ = tracing_subscriber::fmt::Subscriber::builder()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_test_writer()
|
||||
.finish();
|
||||
}
|
||||
|
||||
// Make a new connection
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use sqlx::any::{AnyConnectOptions, AnyPoolOptions};
|
||||
use sqlx::Executor;
|
||||
use sqlx_core::connection::ConnectOptions;
|
||||
use sqlx_core::connection::{ConnectOptions, Connection};
|
||||
use sqlx_core::pool::PoolConnectMetadata;
|
||||
use sqlx_core::sql_str::AssertSqlSafe;
|
||||
use std::sync::{
|
||||
@ -9,6 +9,29 @@ use std::sync::{
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn pool_basic_functions() -> anyhow::Result<()> {
|
||||
sqlx::any::install_default_drivers();
|
||||
|
||||
let pool = AnyPoolOptions::new()
|
||||
.max_connections(2)
|
||||
.acquire_timeout(Duration::from_secs(3))
|
||||
.connect(&dotenvy::var("DATABASE_URL")?)
|
||||
.await?;
|
||||
|
||||
let mut conn = pool.acquire().await?;
|
||||
|
||||
conn.ping().await?;
|
||||
|
||||
drop(conn);
|
||||
|
||||
let b: bool = sqlx::query_scalar("SELECT true").fetch_one(&pool).await?;
|
||||
|
||||
assert!(b);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/launchbadge/sqlx/issues/527
|
||||
#[sqlx_macros::test]
|
||||
async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {
|
||||
@ -43,6 +66,7 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {
|
||||
#[sqlx_macros::test]
|
||||
async fn test_pool_callbacks() -> anyhow::Result<()> {
|
||||
sqlx::any::install_default_drivers();
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
|
||||
struct ConnStats {
|
||||
@ -131,7 +155,9 @@ async fn test_pool_callbacks() -> anyhow::Result<()> {
|
||||
id
|
||||
);
|
||||
|
||||
conn.execute(&statement[..]).await?;
|
||||
sqlx::raw_sql(AssertSqlSafe(statement))
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
Ok(conn)
|
||||
}
|
||||
});
|
||||
@ -154,6 +180,8 @@ async fn test_pool_callbacks() -> anyhow::Result<()> {
|
||||
];
|
||||
|
||||
for (id, before_acquire_calls, after_release_calls) in pattern {
|
||||
eprintln!("ID: {id}, before_acquire calls: {before_acquire_calls}, after_release calls: {after_release_calls}");
|
||||
|
||||
let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats")
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
@ -183,6 +211,7 @@ async fn test_connection_maintenance() -> anyhow::Result<()> {
|
||||
let last_meta = Arc::new(Mutex::new(None));
|
||||
let last_meta_ = last_meta.clone();
|
||||
let pool = AnyPoolOptions::new()
|
||||
.acquire_timeout(Duration::from_secs(1))
|
||||
.max_lifetime(Duration::from_millis(400))
|
||||
.min_connections(3)
|
||||
.before_acquire(move |_conn, _meta| {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user