WIP refactor: replace sharding with single connection set

This commit is contained in:
Austin Bonander 2025-11-30 16:43:48 -08:00
parent d905016923
commit 0dd92b4594
13 changed files with 1185 additions and 348 deletions

View File

@ -0,0 +1,38 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
pin_project! {
#[project = RaceProject]
pub struct Race<L, R> {
#[pin]
left: L,
#[pin]
right: R,
}
}
impl<L, R> Future for Race<L, R>
where
L: Future,
R: Future,
{
type Output = Result<L::Output, R::Output>;
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Poll::Ready(left) = this.left.as_mut().poll(cx) {
return Poll::Ready(Ok(left));
}
this.right.as_mut().poll(cx).map(Err)
}
}
#[inline(always)]
pub fn race<L, R>(left: L, right: R) -> Race<L, R> {
Race { left, right }
}

View File

@ -2,3 +2,5 @@ pub mod ustr;
#[macro_use]
pub mod async_stream;
pub mod future;

View File

@ -14,7 +14,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Instant;
use crate::pool::shard::DisconnectedSlot;
use crate::pool::connection_set::DisconnectedSlot;
#[cfg(doc)]
use crate::pool::PoolOptions;
use crate::sync::{AsyncMutex, AsyncMutexGuard};
@ -646,7 +646,7 @@ async fn connect_with_backoff<DB: Database>(
match res {
ControlFlow::Break(Ok(conn)) => {
tracing::trace!(
tracing::debug!(
target: "sqlx::pool::connect",
%connection_id,
attempt,
@ -654,18 +654,16 @@ async fn connect_with_backoff<DB: Database>(
"connection established",
);
return Ok(PoolConnection::new(
slot.put(ConnectionInner {
raw: conn,
id: connection_id,
created_at: now,
last_released_at: now,
}),
pool.0.clone(),
));
return Ok(PoolConnection::new(slot.put(ConnectionInner {
pool: Arc::downgrade(&pool.0),
raw: conn,
id: connection_id,
created_at: now,
last_released_at: now,
})));
}
ControlFlow::Break(Err(e)) => {
tracing::warn!(
tracing::error!(
target: "sqlx::pool::connect",
%connection_id,
attempt,

View File

@ -1,33 +1,35 @@
use std::fmt::{self, Debug, Formatter};
use std::future::{self, Future};
use std::io;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use crate::connection::Connection;
use crate::database::Database;
use crate::error::Error;
use super::inner::{is_beyond_max_lifetime, PoolInner};
use super::inner::PoolInner;
use crate::pool::connect::{ConnectPermit, ConnectTaskShared, ConnectionId};
use crate::pool::connection_set::{ConnectedSlot, DisconnectedSlot};
use crate::pool::options::PoolConnectionMetadata;
use crate::pool::shard::{ConnectedSlot, DisconnectedSlot};
use crate::pool::Pool;
use crate::pool::{Pool, PoolOptions};
use crate::rt;
const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5);
const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5);
const CLOSE_TIMEOUT: Duration = Duration::from_secs(5);
/// A connection managed by a [`Pool`][crate::pool::Pool].
///
/// Will be returned to the pool on-drop.
pub struct PoolConnection<DB: Database> {
conn: Option<ConnectedSlot<ConnectionInner<DB>>>,
pub(crate) pool: Arc<PoolInner<DB>>,
close_on_drop: bool,
}
pub(super) struct ConnectionInner<DB: Database> {
// Note: must be `Weak` to prevent a reference cycle
pub(crate) pool: Weak<PoolInner<DB>>,
pub(super) raw: DB::Connection,
pub(super) id: ConnectionId,
pub(super) created_at: Instant,
@ -72,11 +74,10 @@ impl<DB: Database> AsMut<DB::Connection> for PoolConnection<DB> {
}
impl<DB: Database> PoolConnection<DB> {
pub(super) fn new(live: ConnectedSlot<ConnectionInner<DB>>, pool: Arc<PoolInner<DB>>) -> Self {
pub(super) fn new(live: ConnectedSlot<ConnectionInner<DB>>) -> Self {
Self {
conn: Some(live),
close_on_drop: false,
pool,
}
}
@ -140,13 +141,16 @@ impl<DB: Database> PoolConnection<DB> {
#[doc(hidden)]
pub fn return_to_pool(&mut self) -> impl Future<Output = ()> + Send + 'static {
let conn = self.conn.take();
let pool = self.pool.clone();
async move {
let Some(conn) = conn else {
return;
};
let Some(pool) = Weak::upgrade(&conn.pool) else {
return;
};
rt::timeout(RETURN_TO_POOL_TIMEOUT, return_to_pool(conn, &pool))
.await
// Dropping of the `slot` will check if the connection must be re-established
@ -161,7 +165,7 @@ impl<DB: Database> PoolConnection<DB> {
async move {
if let Some(conn) = conn {
// Don't hold the connection forever if it hangs while trying to close
rt::timeout(CLOSE_ON_DROP_TIMEOUT, close(conn)).await.ok();
rt::timeout(CLOSE_TIMEOUT, close(conn)).await.ok();
}
}
}
@ -195,7 +199,7 @@ impl<DB: Database> Drop for PoolConnection<DB> {
}
// We still need to spawn a task to maintain `min_connections`.
if self.conn.is_some() || self.pool.options.min_connections > 0 {
if self.conn.is_some() {
crate::rt::spawn(self.return_to_pool());
}
}
@ -220,6 +224,48 @@ impl<DB: Database> ConnectionInner<DB> {
idle_for: now.saturating_duration_since(self.last_released_at),
}
}
pub fn is_beyond_max_lifetime(&self, options: &PoolOptions<DB>) -> bool {
if let Some(max_lifetime) = options.max_lifetime {
let age = self.created_at.elapsed();
if age > max_lifetime {
tracing::info!(
target: "sqlx::pool",
connection_id=%self.id,
?age,
"connection is beyond `max_lifetime`, closing"
);
return true;
}
}
false
}
pub fn is_beyond_idle_timeout(&self, options: &PoolOptions<DB>) -> bool {
if let Some(idle_timeout) = options.idle_timeout {
let now = Instant::now();
let age = now.duration_since(self.created_at);
let idle_duration = now.duration_since(self.last_released_at);
if idle_duration > idle_timeout {
tracing::info!(
target: "sqlx::pool",
connection_id=%self.id,
?age,
?idle_duration,
"connection is beyond `idle_timeout`, closing"
);
return true;
}
}
false
}
}
pub(crate) async fn close<DB: Database>(
@ -231,14 +277,19 @@ pub(crate) async fn close<DB: Database>(
let (conn, slot) = ConnectedSlot::take(conn);
let res = conn.raw.close().await.inspect_err(|error| {
tracing::debug!(
target: "sqlx::pool",
%connection_id,
%error,
"error occurred while closing the pool connection"
);
});
let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close())
.await
.unwrap_or_else(|_| {
Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into())
})
.inspect_err(|error| {
tracing::debug!(
target: "sqlx::pool",
%connection_id,
%error,
"error occurred while closing the pool connection"
);
});
(res, slot)
}
@ -255,14 +306,19 @@ pub(crate) async fn close_hard<DB: Database>(
let (conn, slot) = ConnectedSlot::take(conn);
let res = conn.raw.close_hard().await.inspect_err(|error| {
tracing::debug!(
target: "sqlx::pool",
%connection_id,
%error,
"error occurred while closing the pool connection"
);
});
let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close_hard())
.await
.unwrap_or_else(|_| {
Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into())
})
.inspect_err(|error| {
tracing::debug!(
target: "sqlx::pool",
%connection_id,
%error,
"error occurred while closing the pool connection"
);
});
(res, slot)
}
@ -282,7 +338,7 @@ async fn return_to_pool<DB: Database>(
// If the connection is beyond max lifetime, close the connection and
// immediately create a new connection
if is_beyond_max_lifetime(&conn, &pool.options) {
if conn.is_beyond_max_lifetime(&pool.options) {
let (_res, slot) = close(conn).await;
return Err(slot);
}
@ -314,6 +370,7 @@ async fn return_to_pool<DB: Database>(
// to recover from cancellations
if let Err(error) = conn.raw.ping().await {
tracing::warn!(
target: "sqlx::pool",
%error,
"error occurred while testing the connection on-release",
);

View File

@ -0,0 +1,543 @@
use crate::ext::future::race;
use crate::rt;
use crate::sync::{AsyncMutex, AsyncMutexGuardArc};
use event_listener::{listener, Event, EventListener, IntoNotification};
use futures_core::Stream;
use futures_util::stream::FuturesUnordered;
use futures_util::{FutureExt, StreamExt};
use std::cmp;
use std::future::Future;
use std::ops::{Deref, DerefMut, RangeInclusive, RangeToInclusive};
use std::pin::{pin, Pin};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
pub struct ConnectionSet<C> {
global: Arc<Global>,
slots: Box<[Arc<Slot<C>>]>,
}
pub struct ConnectedSlot<C>(SlotGuard<C>);
pub struct DisconnectedSlot<C>(SlotGuard<C>);
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum AcquirePreference {
Connected,
Disconnected,
Either,
}
struct Global {
unlock_event: Event<usize>,
disconnect_event: Event<usize>,
num_connected: AtomicUsize,
min_connections: usize,
min_connections_event: Event<()>,
}
struct SlotGuard<C> {
slot: Arc<Slot<C>>,
// `Option` allows us to take the guard in the drop handler.
locked: Option<AsyncMutexGuardArc<Option<C>>>,
}
struct Slot<C> {
// By having each `Slot` hold its own reference to `Global`, we can avoid extra contended clones
// which would sap performance
global: Arc<Global>,
index: usize,
// I'd love to eliminate this redundant `Arc` but it's likely not possible without `unsafe`
connection: Arc<AsyncMutex<Option<C>>>,
unlock_event: Event,
disconnect_event: Event,
connected: AtomicBool,
locked: AtomicBool,
leaked: AtomicBool,
}
impl<C> ConnectionSet<C> {
pub fn new(size: RangeInclusive<usize>) -> Self {
let global = Arc::new(Global {
unlock_event: Event::with_tag(),
disconnect_event: Event::with_tag(),
num_connected: AtomicUsize::new(0),
min_connections: *size.start(),
min_connections_event: Event::with_tag(),
});
ConnectionSet {
// `vec![<expr>; size].into()` clones `<expr>` instead of repeating it,
// which is *no bueno* when wrapping something in `Arc`
slots: (0..*size.end())
.map(|index| {
Arc::new(Slot {
global: global.clone(),
index,
connection: Arc::new(AsyncMutex::new(None)),
unlock_event: Event::with_tag(),
disconnect_event: Event::with_tag(),
connected: AtomicBool::new(false),
locked: AtomicBool::new(false),
leaked: AtomicBool::new(false),
})
})
.collect(),
global,
}
}
#[inline(always)]
pub fn num_connected(&self) -> usize {
self.global.num_connected()
}
pub fn count_idle(&self) -> usize {
self.slots.iter().filter(|slot| slot.is_locked()).count()
}
pub async fn acquire_connected(&self) -> ConnectedSlot<C> {
self.acquire_inner(AcquirePreference::Connected)
.await
.assert_connected()
}
pub async fn acquire_disconnected(&self) -> DisconnectedSlot<C> {
self.acquire_inner(AcquirePreference::Disconnected)
.await
.assert_disconnected()
}
/// Attempt to acquire the connection associated with the current thread.
pub async fn acquire_any(&self) -> Result<ConnectedSlot<C>, DisconnectedSlot<C>> {
self.acquire_inner(AcquirePreference::Either)
.await
.try_connected()
}
async fn acquire_inner(&self, pref: AcquirePreference) -> SlotGuard<C> {
/// Smallest time-step supported by [`tokio::time::sleep()`].
///
/// `async-io` doesn't document a minimum time-step, instead deferring to the platform.
const STEP_INTERVAL: Duration = Duration::from_millis(1);
const SEARCH_LIMIT: usize = 5;
let preferred_slot = current_thread_id() % self.slots.len();
tracing::trace!(preferred_slot, ?pref, "acquire_inner");
// Always try to lock the connection associated with our thread ID
let mut acquire_preferred = pin!(self.slots[preferred_slot].acquire(pref));
let mut step_interval = pin!(rt::interval_after(STEP_INTERVAL));
let mut intervals_elapsed = 0usize;
let mut search_slots = FuturesUnordered::new();
let mut listen_global = pin!(self.global.listen(pref));
let mut search_slot = self.next_slot(preferred_slot);
std::future::poll_fn(|cx| loop {
if let Poll::Ready(locked) = acquire_preferred.as_mut().poll(cx) {
return Poll::Ready(locked);
}
// Don't push redundant futures for small sets.
let search_limit = cmp::min(SEARCH_LIMIT, self.slots.len());
if search_slots.len() < search_limit && step_interval.as_mut().poll_tick(cx).is_ready()
{
intervals_elapsed = intervals_elapsed.saturating_add(1);
if search_slot != preferred_slot && self.slots[search_slot].matches_pref(pref) {
search_slots.push(self.slots[search_slot].lock());
}
search_slot = self.next_slot(search_slot);
}
if let Poll::Ready(Some(locked)) = Pin::new(&mut search_slots).poll_next(cx) {
if locked.matches_pref(pref) {
return Poll::Ready(locked);
}
continue;
}
if intervals_elapsed > search_limit && search_slots.len() < search_limit {
if let Poll::Ready(slot) = listen_global.as_mut().poll(cx) {
if self.slots[slot].matches_pref(pref) {
search_slots.push(self.slots[slot].lock());
}
listen_global.as_mut().set(self.global.listen(pref));
continue;
}
}
return Poll::Pending;
})
.await
}
pub fn try_acquire_connected(&self) -> Option<ConnectedSlot<C>> {
Some(
self.try_acquire(AcquirePreference::Connected)?
.assert_connected(),
)
}
pub fn try_acquire_disconnected(&self) -> Option<DisconnectedSlot<C>> {
Some(
self.try_acquire(AcquirePreference::Disconnected)?
.assert_disconnected(),
)
}
fn try_acquire(&self, pref: AcquirePreference) -> Option<SlotGuard<C>> {
let mut search_slot = current_thread_id() % self.slots.len();
for _ in 0..self.slots.len() {
if let Some(locked) = self.slots[search_slot].try_acquire(pref) {
return Some(locked);
}
search_slot = self.next_slot(search_slot);
}
None
}
pub fn min_connections_listener(&self) -> EventListener {
self.global.min_connections_event.listen()
}
pub fn iter_idle(&self) -> impl Iterator<Item = ConnectedSlot<C>> + '_ {
self.slots.iter().filter_map(|slot| {
Some(
slot.try_acquire(AcquirePreference::Connected)?
.assert_connected(),
)
})
}
pub async fn drain(&self, ref close: impl AsyncFn(ConnectedSlot<C>) -> DisconnectedSlot<C>) {
let mut closing = FuturesUnordered::new();
// We could try to be more efficient by only populating the `FuturesUnordered` for
// connected slots, but then we'd have to handle a disconnected slot becoming connected,
// which could happen concurrently.
//
// However, we don't *need* to be efficient when shutting down the pool.
for slot in &self.slots {
closing.push(async {
let locked = slot.lock().await;
let slot = match locked.try_connected() {
Ok(connected) => close(connected).await,
Err(disconnected) => disconnected,
};
// The pool is shutting down; don't wake any tasks that might have been interested
slot.leak();
});
}
while closing.next().await.is_some() {}
}
#[inline(always)]
fn next_slot(&self, slot: usize) -> usize {
// By adding a number that is coprime to `slots.len()` before taking the modulo,
// we can visit each slot in a pseudo-random order, spreading the demand evenly.
//
// Interestingly, this pattern returns to the original slot after `slots.len()` iterations,
// because of congruence: https://en.wikipedia.org/wiki/Modular_arithmetic#Congruence
(slot + 547) % self.slots.len()
}
}
impl AcquirePreference {
#[inline(always)]
fn wants_connected(&self, is_connected: bool) -> bool {
match (self, is_connected) {
(Self::Connected, true) => true,
(Self::Disconnected, false) => true,
(Self::Either, _) => true,
_ => false,
}
}
}
impl<C> Slot<C> {
#[inline(always)]
fn matches_pref(&self, pref: AcquirePreference) -> bool {
!self.is_leaked() && pref.wants_connected(self.is_connected())
}
#[inline(always)]
fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed)
}
#[inline(always)]
fn is_locked(&self) -> bool {
self.locked.load(Ordering::Relaxed)
}
#[inline(always)]
fn is_leaked(&self) -> bool {
self.leaked.load(Ordering::Relaxed)
}
#[inline(always)]
fn set_is_connected(&self, connected: bool) {
let was_connected = self.connected.swap(connected, Ordering::Acquire);
match (connected, was_connected) {
(false, true) => {
// Ensure this is synchronized with `connected`
self.global.num_connected.fetch_add(1, Ordering::Release);
}
(true, false) => {
self.global.num_connected.fetch_sub(1, Ordering::Release);
}
_ => (),
}
}
async fn acquire(self: &Arc<Self>, pref: AcquirePreference) -> SlotGuard<C> {
loop {
if self.matches_pref(pref) {
tracing::trace!(slot_index=%self.index, "waiting for lock");
let locked = self.lock().await;
if locked.matches_pref(pref) {
return locked;
}
}
match pref {
AcquirePreference::Connected => {
listener!(self.unlock_event => listener);
listener.await;
}
AcquirePreference::Disconnected => {
listener!(self.disconnect_event => listener);
listener.await
}
AcquirePreference::Either => {
listener!(self.unlock_event => unlock_listener);
listener!(self.disconnect_event => disconnect_listener);
race(unlock_listener, disconnect_listener).await.ok();
}
}
}
}
fn try_acquire(self: &Arc<Self>, pref: AcquirePreference) -> Option<SlotGuard<C>> {
if self.matches_pref(pref) {
let locked = self.try_lock()?;
if locked.matches_pref(pref) {
return Some(locked);
}
}
None
}
async fn lock(self: &Arc<Self>) -> SlotGuard<C> {
let locked = crate::sync::lock_arc(&self.connection).await;
self.locked.store(true, Ordering::Relaxed);
SlotGuard {
slot: self.clone(),
locked: Some(locked),
}
}
fn try_lock(self: &Arc<Self>) -> Option<SlotGuard<C>> {
let locked = crate::sync::try_lock_arc(&self.connection)?;
self.locked.store(true, Ordering::Relaxed);
Some(SlotGuard {
slot: self.clone(),
locked: Some(locked),
})
}
}
impl<C> SlotGuard<C> {
#[inline(always)]
fn get(&self) -> &Option<C> {
self.locked.as_ref().expect(EXPECT_LOCKED)
}
#[inline(always)]
fn get_mut(&mut self) -> &mut Option<C> {
self.locked.as_mut().expect(EXPECT_LOCKED)
}
#[inline(always)]
fn matches_pref(&self, pref: AcquirePreference) -> bool {
!self.slot.is_leaked() && pref.wants_connected(self.is_connected())
}
#[inline(always)]
fn is_connected(&self) -> bool {
self.get().is_some()
}
fn try_connected(self) -> Result<ConnectedSlot<C>, DisconnectedSlot<C>> {
if self.is_connected() {
Ok(ConnectedSlot(self))
} else {
Err(DisconnectedSlot(self))
}
}
fn assert_connected(self) -> ConnectedSlot<C> {
assert!(self.is_connected());
ConnectedSlot(self)
}
fn assert_disconnected(self) -> DisconnectedSlot<C> {
assert!(!self.is_connected());
DisconnectedSlot(self)
}
/// Updates `Slot::connected` without notifying the `ConnectionSet`.
///
/// Returns `Some(connected)` or `None` if this guard was already dropped.
fn drop_without_notify(&mut self) -> Option<bool> {
self.locked.take().map(|locked| {
let connected = locked.is_some();
self.slot.set_is_connected(connected);
self.slot.locked.store(false, Ordering::Release);
connected
})
}
}
const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation";
const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`";
impl<C> ConnectedSlot<C> {
pub fn take(mut self) -> (C, DisconnectedSlot<C>) {
let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED);
(conn, self.0.assert_disconnected())
}
}
impl<C> Deref for ConnectedSlot<C> {
type Target = C;
#[inline(always)]
fn deref(&self) -> &Self::Target {
self.0.get().as_ref().expect(EXPECT_CONNECTED)
}
}
impl<C> DerefMut for ConnectedSlot<C> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.get_mut().as_mut().expect(EXPECT_CONNECTED)
}
}
impl<C> DisconnectedSlot<C> {
pub fn put(mut self, conn: C) -> ConnectedSlot<C> {
*self.0.get_mut() = Some(conn);
ConnectedSlot(self.0)
}
pub fn leak(mut self) {
self.0.slot.leaked.store(true, Ordering::Release);
self.0.drop_without_notify();
}
}
impl<C> Drop for SlotGuard<C> {
fn drop(&mut self) {
let Some(connected) = self.drop_without_notify() else {
return;
};
let event = if connected {
&self.slot.global.unlock_event
} else {
&self.slot.global.disconnect_event
};
if event.notify(1.tag(self.slot.index).additional()) != 0 {
return;
}
if connected {
self.slot.unlock_event.notify(1);
return;
}
if self.slot.disconnect_event.notify(1) != 0 {
return;
}
if self.slot.global.num_connected() < self.slot.global.min_connections {
self.slot.global.min_connections_event.notify(1);
}
}
}
impl Global {
#[inline(always)]
fn num_connected(&self) -> usize {
self.num_connected.load(Ordering::Relaxed)
}
async fn listen(&self, pref: AcquirePreference) -> usize {
match pref {
AcquirePreference::Either => race(self.listen_unlocked(), self.listen_disconnected())
.await
.unwrap_or_else(|slot| slot),
AcquirePreference::Connected => self.listen_unlocked().await,
AcquirePreference::Disconnected => self.listen_disconnected().await,
}
}
async fn listen_unlocked(&self) -> usize {
listener!(self.unlock_event => listener);
listener.await
}
async fn listen_disconnected(&self) -> usize {
listener!(self.disconnect_event => listener);
listener.await
}
}
fn current_thread_id() -> usize {
// FIXME: this can be replaced when this is stabilized:
// https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
thread_local! {
// `SeqCst` is possibly too strong since we don't need synchronization with
// any other variable. I'm not confident enough in my understanding of atomics to be certain,
// especially with regards to weakly ordered architectures.
//
// However, this is literally only done once on each thread, so it doesn't really matter.
static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst);
}
CURRENT_THREAD_ID.with(|i| *i)
}

View File

@ -5,33 +5,30 @@ use crate::pool::{connection, CloseEvent, Pool, PoolConnection, PoolConnector, P
use std::cmp;
use std::future::Future;
use std::ops::ControlFlow;
use std::pin::{pin, Pin};
use std::rc::Weak;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{ready, Poll};
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use crate::connection::Connection;
use crate::ext::future::race;
use crate::logger::private_level_filter_to_trace_level;
use crate::pool::connect::{
ConnectPermit, ConnectTask, ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector,
};
use crate::pool::shard::{ConnectedSlot, DisconnectedSlot, Sharded};
use crate::rt::JoinHandle;
use crate::pool::connect::{ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector};
use crate::pool::connection_set::{ConnectedSlot, ConnectionSet, DisconnectedSlot};
use crate::{private_tracing_dynamic_event, rt};
use either::Either;
use futures_core::FusedFuture;
use futures_util::future::{self, OptionFuture};
use futures_util::{stream, FutureExt, TryStreamExt};
use event_listener::listener;
use futures_util::future::{self};
use std::time::{Duration, Instant};
use tracing::Level;
const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5);
const TEST_BEFORE_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(60);
pub(crate) struct PoolInner<DB: Database> {
pub(super) connector: DynConnector<DB>,
pub(super) counter: ConnectionCounter,
pub(super) sharded: Sharded<ConnectionInner<DB>>,
pub(super) connections: ConnectionSet<ConnectionInner<DB>>,
is_closed: AtomicBool,
pub(super) on_closed: event_listener::Event,
pub(super) options: PoolOptions<DB>,
@ -44,39 +41,15 @@ impl<DB: Database> PoolInner<DB> {
options: PoolOptions<DB>,
connector: impl PoolConnector<DB>,
) -> Arc<Self> {
let pool = Arc::<Self>::new_cyclic(|pool_weak| {
let pool_weak = pool_weak.clone();
let reconnect = move |slot| {
let Some(pool) = pool_weak.upgrade() else {
// Prevent an infinite loop on pool drop.
DisconnectedSlot::leak(slot);
return;
};
pool.connector.connect(
Pool(pool.clone()),
ConnectionId::next(),
slot,
ConnectTaskShared::new_arc(),
);
};
Self {
connector: DynConnector::new(connector),
counter: ConnectionCounter::new(),
sharded: Sharded::new(
options.max_connections,
options.shards,
options.min_connections,
reconnect,
),
is_closed: AtomicBool::new(false),
on_closed: event_listener::Event::new(),
acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level),
acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level),
options,
}
let pool = Arc::new(Self {
connector: DynConnector::new(connector),
counter: ConnectionCounter::new(),
connections: ConnectionSet::new(options.min_connections..=options.max_connections),
is_closed: AtomicBool::new(false),
on_closed: event_listener::Event::new(),
acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level),
acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level),
options,
});
spawn_maintenance_tasks(&pool);
@ -85,11 +58,11 @@ impl<DB: Database> PoolInner<DB> {
}
pub(super) fn size(&self) -> usize {
self.sharded.count_connected()
self.connections.num_connected()
}
pub(super) fn num_idle(&self) -> usize {
self.sharded.count_unlocked(true)
self.connections.count_idle()
}
pub(super) fn is_closed(&self) -> bool {
@ -105,11 +78,8 @@ impl<DB: Database> PoolInner<DB> {
self.mark_closed();
// Keep clearing the idle queue as connections are released until the count reaches zero.
self.sharded.drain(|slot| async move {
let (conn, slot) = ConnectedSlot::take(slot);
let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await;
self.connections.drain(async |slot| {
let (_res, slot) = connection::close(slot).await;
slot
})
}
@ -130,7 +100,7 @@ impl<DB: Database> PoolInner<DB> {
return None;
}
self.sharded.try_acquire_connected()
self.connections.try_acquire_connected()
}
pub(super) async fn acquire(self: &Arc<Self>) -> Result<PoolConnection<DB>, Error> {
@ -140,74 +110,43 @@ impl<DB: Database> PoolInner<DB> {
let acquire_started_at = Instant::now();
let mut close_event = pin!(self.close_event());
let mut deadline = pin!(rt::sleep(self.options.acquire_timeout));
// Lazily allocated `Arc<ConnectTaskShared>`
let mut connect_shared = None;
let connect_shared = ConnectTaskShared::new_arc();
let res = {
// Pinned to the stack without allocating
listener!(self.on_closed => close_listener);
let mut deadline = pin!(rt::sleep(self.options.acquire_timeout));
let mut acquire_inner = pin!(self.acquire_inner(&mut connect_shared));
let mut acquire_connected = pin!(self.acquire_connected().fuse());
let mut acquire_disconnected = pin!(self.sharded.acquire_disconnected().fuse());
let mut connect = future::Fuse::terminated();
let acquired = std::future::poll_fn(|cx| loop {
if let Poll::Ready(()) = close_event.as_mut().poll(cx) {
return Poll::Ready(Err(Error::PoolClosed));
}
if let Poll::Ready(res) = acquire_connected.as_mut().poll(cx) {
match res {
Ok(conn) => {
return Poll::Ready(Ok(conn));
}
Err(slot) => {
if connect.is_terminated() {
connect = self
.connector
.connect(
Pool(self.clone()),
ConnectionId::next(),
slot,
connect_shared.clone(),
)
.fuse();
}
// Try to acquire another connected connection.
acquire_connected.set(self.acquire_connected().fuse());
continue;
}
std::future::poll_fn(|cx| {
if self.is_closed() {
return Poll::Ready(Err(Error::PoolClosed));
}
}
if let Poll::Ready(slot) = acquire_disconnected.as_mut().poll(cx) {
if connect.is_terminated() {
connect = self
.connector
.connect(
Pool(self.clone()),
ConnectionId::next(),
slot,
connect_shared.clone(),
)
.fuse();
// The result doesn't matter so much as the wakeup
let _ = Pin::new(&mut close_listener).poll(cx);
if let Poll::Ready(()) = deadline.as_mut().poll(cx) {
return Poll::Ready(Err(Error::PoolTimedOut {
last_connect_error: None,
}));
}
}
if let Poll::Ready(res) = Pin::new(&mut connect).poll(cx) {
return Poll::Ready(res);
}
acquire_inner.as_mut().poll(cx)
})
.await
};
if let Poll::Ready(()) = deadline.as_mut().poll(cx) {
return Poll::Ready(Err(Error::PoolTimedOut {
last_connect_error: connect_shared.take_error().map(Box::new),
}));
}
return Poll::Pending;
})
.await?;
let acquired = res.map_err(|e| match e {
Error::PoolTimedOut {
last_connect_error: None,
} => Error::PoolTimedOut {
last_connect_error: connect_shared
.and_then(|shared| Some(shared.take_error()?.into())),
},
e => e,
})?;
let acquired_after = acquire_started_at.elapsed();
@ -235,20 +174,36 @@ impl<DB: Database> PoolInner<DB> {
Ok(acquired)
}
async fn acquire_connected(
async fn acquire_inner(
self: &Arc<Self>,
) -> Result<PoolConnection<DB>, DisconnectedSlot<ConnectionInner<DB>>> {
let connected = self.sharded.acquire_connected().await;
connect_shared: &mut Option<Arc<ConnectTaskShared>>,
) -> Result<PoolConnection<DB>, Error> {
tracing::trace!("waiting for any connection");
tracing::debug!(
target: "sqlx::pool",
connection_id=%connected.id,
"acquired idle connection"
let disconnected = match self.connections.acquire_any().await {
Ok(conn) => match finish_acquire(self, conn).await {
Ok(conn) => return Ok(conn),
Err(slot) => slot,
},
Err(slot) => slot,
};
let mut connect_task = self.connector.connect(
Pool(self.clone()),
ConnectionId::next(),
disconnected,
connect_shared.insert(ConnectTaskShared::new_arc()).clone(),
);
match finish_acquire(self, connected) {
Either::Left(task) => task.await,
Either::Right(conn) => Ok(conn),
loop {
match race(&mut connect_task, self.connections.acquire_connected()).await {
Ok(Ok(conn)) => return Ok(conn),
Ok(Err(e)) => return Err(e),
Err(conn) => match finish_acquire(self, conn).await {
Ok(conn) => return Ok(conn),
Err(_) => continue,
},
}
}
}
@ -258,17 +213,20 @@ impl<DB: Database> PoolInner<DB> {
) -> Result<(), Error> {
let shared = ConnectTaskShared::new_arc();
let connect_min_connections =
future::try_join_all(self.sharded.iter_min_connections().map(|slot| {
self.connector.connect(
Pool(self.clone()),
ConnectionId::next(),
slot,
shared.clone(),
)
}));
let connect_min_connections = future::try_join_all(
(self.connections.num_connected()..self.options.min_connections)
.filter_map(|_| self.connections.try_acquire_disconnected())
.map(|slot| {
self.connector.connect(
Pool(self.clone()),
ConnectionId::next(),
slot,
shared.clone(),
)
}),
);
let mut conns = if let Some(deadline) = deadline {
let conns = if let Some(deadline) = deadline {
match rt::timeout_at(deadline, connect_min_connections).await {
Ok(Ok(conns)) => conns,
Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => {
@ -297,144 +255,192 @@ impl<DB: Database> Drop for PoolInner<DB> {
}
}
/// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise.
pub(super) fn is_beyond_max_lifetime<DB: Database>(
live: &ConnectionInner<DB>,
options: &PoolOptions<DB>,
) -> bool {
options
.max_lifetime
.is_some_and(|max| live.created_at.elapsed() > max)
}
/// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise.
fn is_beyond_idle_timeout<DB: Database>(
idle: &ConnectionInner<DB>,
options: &PoolOptions<DB>,
) -> bool {
options
.idle_timeout
.is_some_and(|timeout| idle.last_released_at.elapsed() > timeout)
}
/// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable.
///
/// Otherwise, immediately returns the connection.
fn finish_acquire<DB: Database>(
async fn finish_acquire<DB: Database>(
pool: &Arc<PoolInner<DB>>,
mut conn: ConnectedSlot<ConnectionInner<DB>>,
) -> Either<
JoinHandle<Result<PoolConnection<DB>, DisconnectedSlot<ConnectionInner<DB>>>>,
PoolConnection<DB>,
> {
if pool.options.test_before_acquire || pool.options.before_acquire.is_some() {
let pool = pool.clone();
) -> Result<PoolConnection<DB>, DisconnectedSlot<ConnectionInner<DB>>> {
struct SpawnOnDrop<F: Future + Send + 'static>(Option<Pin<Box<F>>>)
where
F::Output: Send + 'static;
// Spawn a task so the call may complete even if `acquire()` is cancelled.
return Either::Left(rt::spawn(async move {
// Check that the connection is still live
impl<F: Future + Send + 'static> Future for SpawnOnDrop<F>
where
F::Output: Send + 'static,
{
type Output = F::Output;
#[inline(always)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0
.as_mut()
.expect("BUG: inner future taken")
.as_mut()
.poll(cx)
}
}
impl<F: Future + Send + 'static> Drop for SpawnOnDrop<F>
where
F::Output: Send + 'static,
{
fn drop(&mut self) {
rt::try_spawn(self.0.take().expect("BUG: inner future taken"));
}
}
async fn finish_inner<DB: Database>(
conn: &mut ConnectedSlot<ConnectionInner<DB>>,
pool: &PoolInner<DB>,
) -> ControlFlow<()> {
// Check that the connection is still live
if pool.options.test_before_acquire {
if let Err(error) = conn.raw.ping().await {
// an error here means the other end has hung up or we lost connectivity
// either way we're fine to just discard the connection
// the error itself here isn't necessarily unexpected so WARN is too strong
tracing::info!(%error, connection_id=%conn.id, "ping on idle connection returned error");
// connection is broken so don't try to close nicely
let (_res, slot) = connection::close_hard(conn).await;
return Err(slot);
return ControlFlow::Break(());
}
}
if let Some(test) = &pool.options.before_acquire {
let meta = conn.idle_metadata();
match test(&mut conn.raw, meta).await {
Ok(false) => {
// connection was rejected by user-defined hook, close nicely
let (_res, slot) = connection::close(conn).await;
return Err(slot);
}
Err(error) => {
tracing::warn!(%error, "error from `before_acquire`");
// connection is broken so don't try to close nicely
let (_res, slot) = connection::close_hard(conn).await;
return Err(slot);
}
Ok(true) => {}
if let Some(test) = &pool.options.before_acquire {
let meta = conn.idle_metadata();
match test(&mut conn.raw, meta).await {
Ok(false) => {
// connection was rejected by user-defined hook, close nicely
tracing::debug!(connection_id=%conn.id, "connection rejected by `before_acquire`");
return ControlFlow::Break(());
}
}
Ok(PoolConnection::new(conn, pool))
}));
Err(error) => {
tracing::warn!(%error, "error from `before_acquire`");
return ControlFlow::Break(());
}
Ok(true) => (),
}
}
// Checks passed
ControlFlow::Continue(())
}
// No checks are configured, return immediately.
Either::Right(PoolConnection::new(conn, pool.clone()))
if pool.options.test_before_acquire || pool.options.before_acquire.is_some() {
let pool = pool.clone();
// Spawn a task on-drop so the call may complete even if `acquire()` is cancelled.
conn = SpawnOnDrop(Some(Box::pin(async move {
match rt::timeout(TEST_BEFORE_ACQUIRE_TIMEOUT, finish_inner(&mut conn, &pool)).await {
Ok(ControlFlow::Continue(())) => {
Ok(conn)
}
Ok(ControlFlow::Break(())) => {
// Connection rejected by user-defined hook, attempt to close nicely
let (_res, slot) = connection::close(conn).await;
Err(slot)
}
Err(_) => {
tracing::info!(connection_id=%conn.id, "`before_acquire` checks timed out, closing connection");
let (_res, slot) = connection::close_hard(conn).await;
Err(slot)
}
}
}))).await?;
}
tracing::debug!(
target: "sqlx::pool",
connection_id=%conn.id,
"acquired idle connection"
);
Ok(PoolConnection::new(conn))
}
fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
// NOTE: use `pool_weak` for the maintenance tasks
// so they don't keep `PoolInner` from being dropped.
let pool_weak = Arc::downgrade(pool);
if pool.options.min_connections > 0 {
// NOTE: use `pool_weak` for the maintenance tasks
// so they don't keep `PoolInner` from being dropped.
let pool_weak = Arc::downgrade(pool);
let mut close_event = pool.close_event();
let period = match (pool.options.max_lifetime, pool.options.idle_timeout) {
rt::spawn(async move {
close_event
.do_until(check_min_connections(pool_weak))
.await
.ok();
});
}
let check_interval = match (pool.options.max_lifetime, pool.options.idle_timeout) {
(Some(it), None) | (None, Some(it)) => it,
(Some(a), Some(b)) => cmp::min(a, b),
(None, None) => {
if pool.options.min_connections > 0 {
rt::spawn(async move {
if let Some(pool) = pool_weak.upgrade() {
if let Err(error) = pool.try_min_connections(None).await {
tracing::error!(
target: "sqlx::pool",
?error,
"error maintaining min_connections"
);
}
}
});
}
return;
}
(None, None) => return,
};
// Immediately cancel this task if the pool is closed.
let pool_weak = Arc::downgrade(pool);
let mut close_event = pool.close_event();
rt::spawn(async move {
let _ = close_event
.do_until(async {
// If the last handle to the pool was dropped while we were sleeping
while let Some(pool) = pool_weak.upgrade() {
if pool.is_closed() {
return;
}
let next_run = Instant::now() + period;
// Go over all idle connections, check for idleness and lifetime,
// and if we have fewer than min_connections after reaping a connection,
// open a new one immediately.
for conn in pool.sharded.iter_idle() {
if is_beyond_idle_timeout(&conn, &pool.options)
|| is_beyond_max_lifetime(&conn, &pool.options)
{
// Dropping the slot will check if the connection needs to be
// re-made.
let _ = connection::close(conn).await;
}
}
// Don't hold a reference to the pool while sleeping.
drop(pool);
rt::sleep_until(next_run).await;
}
})
.do_until(check_idle_conns(pool_weak, check_interval))
.await;
});
}
async fn check_idle_conns<DB: Database>(pool_weak: Weak<PoolInner<DB>>, check_interval: Duration) {
let mut interval = pin!(rt::interval_after(check_interval));
while let Some(pool) = pool_weak.upgrade() {
if pool.is_closed() {
return;
}
// Go over all idle connections, check for idleness and lifetime,
// and if we have fewer than min_connections after reaping a connection,
// open a new one immediately.
for conn in pool.connections.iter_idle() {
if conn.is_beyond_idle_timeout(&pool.options)
|| conn.is_beyond_max_lifetime(&pool.options)
{
// Dropping the slot will check if the connection needs to be re-made.
let _ = connection::close(conn).await;
}
}
// Don't hold a reference to the pool while sleeping.
drop(pool);
interval.as_mut().tick().await;
}
}
async fn check_min_connections<DB: Database>(pool_weak: Weak<PoolInner<DB>>) {
while let Some(pool) = pool_weak.upgrade() {
if pool.is_closed() {
return;
}
match pool.try_min_connections(None).await {
Ok(()) => {
let listener = pool.connections.min_connections_listener();
// Important: don't hold a strong ref while sleeping
drop(pool);
listener.await;
}
Err(e) => {
tracing::warn!(
target: "sqlx::pool::maintenance",
min_connections=pool.options.min_connections,
num_connected=pool.connections.num_connected(),
"unable to maintain `min_connections`: {e:?}",
);
}
}
}
}

View File

@ -56,20 +56,19 @@
use std::fmt;
use std::future::Future;
use std::pin::{pin, Pin};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use event_listener::EventListener;
use futures_core::FusedFuture;
use futures_util::FutureExt;
use crate::connection::Connection;
use crate::database::Database;
use crate::error::Error;
use crate::ext::future::race;
use crate::sql_str::SqlSafeStr;
use crate::transaction::Transaction;
use event_listener::EventListener;
use futures_core::FusedFuture;
use tracing::Instrument;
pub use self::connect::{PoolConnectMetadata, PoolConnector};
pub use self::connection::PoolConnection;
use self::inner::PoolInner;
@ -90,7 +89,9 @@ mod inner;
// mod idle;
mod options;
mod shard;
// mod shard;
mod connection_set;
/// An asynchronous pool of SQLx database connections.
///
@ -362,16 +363,21 @@ impl<DB: Database> Pool<DB> {
pub fn acquire(&self) -> impl Future<Output = Result<PoolConnection<DB>, Error>> + 'static {
let shared = self.0.clone();
async move { shared.acquire().await }
.instrument(tracing::error_span!(target: "sqlx::pool", "acquire"))
}
/// Attempts to retrieve a connection from the pool if there is one available.
///
/// Returns `None` immediately if there are no idle connections available in the pool
/// or there are tasks waiting for a connection which have yet to wake.
///
/// # Note: Bypasses `before_acquire`
/// Since this function is not `async`, it cannot await the future returned by
/// [`before_acquire`][PoolOptions::before_acquire] without blocking.
///
/// Instead, it simply returns the connection immediately.
pub fn try_acquire(&self) -> Option<PoolConnection<DB>> {
self.0
.try_acquire()
.map(|conn| PoolConnection::new(conn, self.0.clone()))
self.0.try_acquire().map(|conn| PoolConnection::new(conn))
}
/// Retrieves a connection and immediately begins a new transaction.
@ -577,42 +583,19 @@ impl CloseEvent {
///
/// Cancels the future and returns `Err(PoolClosed)` if/when the pool is closed.
/// If the pool was already closed, the future is never run.
#[inline(always)]
pub async fn do_until<Fut: Future>(&mut self, fut: Fut) -> Result<Fut::Output, Error> {
// Check that the pool wasn't closed already.
//
// We use `poll_immediate()` as it will use the correct waker instead of
// a no-op one like `.now_or_never()`, but it won't actually suspend execution here.
futures_util::future::poll_immediate(&mut *self)
.await
.map_or(Ok(()), |_| Err(Error::PoolClosed))?;
let mut fut = pin!(fut);
// I find that this is clearer in intent than `futures_util::future::select()`
// or `futures_util::select_biased!{}` (which isn't enabled anyway).
std::future::poll_fn(|cx| {
// Poll `fut` first as the wakeup event is more likely for it than `self`.
if let Poll::Ready(ret) = fut.as_mut().poll(cx) {
return Poll::Ready(Ok(ret));
}
// Can't really factor out mapping to `Err(Error::PoolClosed)` though it seems like
// we should because that results in a different `Ok` type each time.
//
// Ideally we'd map to something like `Result<!, Error>` but using `!` as a type
// is not allowed on stable Rust yet.
self.poll_unpin(cx).map(|_| Err(Error::PoolClosed))
})
.await
race(fut, self).await.map_err(|_| Error::PoolClosed)
}
}
impl Future for CloseEvent {
type Output = ();
#[inline(always)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(listener) = &mut self.listener {
ready!(listener.poll_unpin(cx));
ready!(Pin::new(listener).poll(cx));
}
// `EventListener` doesn't like being polled after it yields, and even if it did it

View File

@ -1,10 +1,13 @@
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::task::{ready, Context, Poll};
use std::time::{Duration, Instant};
use cfg_if::cfg_if;
use futures_core::Stream;
use futures_util::StreamExt;
use pin_project_lite::pin_project;
#[cfg(feature = "_rt-async-io")]
pub mod rt_async_io;
@ -59,19 +62,13 @@ pub async fn timeout_at<F: Future>(deadline: Instant, f: F) -> Result<F::Output,
.map_err(|_| TimeoutError);
}
#[cfg(feature = "_rt-async-std")]
{
let Some(duration) = deadline.checked_duration_since(Instant::now()) else {
return Err(TimeoutError);
};
async_std::future::timeout(duration, f)
.await
.map_err(|_| TimeoutError)
cfg_if! {
if #[cfg(feature = "_rt-async-io")] {
rt_async_io::timeout_at(deadline, f).await
} else {
missing_rt((deadline, f))
}
}
#[cfg(not(feature = "_rt-async-std"))]
missing_rt((deadline, f))
}
pub async fn sleep(duration: Duration) {
@ -104,6 +101,135 @@ pub async fn sleep_until(instant: Instant) {
}
}
// https://github.com/taiki-e/pin-project-lite/issues/3
#[cfg(all(feature = "_rt-tokio", feature = "_rt-async-io"))]
pin_project! {
#[project = IntervalProjected]
pub enum Interval {
Tokio {
// Bespoke impl because `tokio::time::Interval` allocates when we could just pin instead
#[pin]
sleep: tokio::time::Sleep,
period: Duration,
},
AsyncIo {
#[pin]
timer: async_io::Timer,
},
}
}
#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-io")))]
pin_project! {
#[project = IntervalProjected]
pub enum Interval {
Tokio {
#[pin]
sleep: tokio::time::Sleep,
period: Duration,
},
}
}
#[cfg(all(not(feature = "_rt-tokio"), feature = "_rt-async-io"))]
pin_project! {
#[project = IntervalProjected]
pub enum Interval {
AsyncIo {
#[pin]
timer: async_io::Timer,
},
}
}
#[cfg(not(any(feature = "_rt-tokio", feature = "_rt-async-io")))]
pub enum Interval {}
pub fn interval_after(period: Duration) -> Interval {
#[cfg(feature = "_rt-tokio")]
if rt_tokio::available() {
return Interval::Tokio {
sleep: tokio::time::sleep(period),
period,
};
}
cfg_if! {
if #[cfg(feature = "_rt-async-io")] {
Interval::AsyncIo { timer: async_io::Timer::interval(period) }
} else {
missing_rt(period)
}
}
}
impl Interval {
#[inline(always)]
pub fn tick(mut self: Pin<&mut Self>) -> impl Future<Output = Instant> + use<'_> {
std::future::poll_fn(move |cx| self.as_mut().poll_tick(cx))
}
#[inline(always)]
pub fn as_timeout<F: Future>(self: Pin<&mut Self>, fut: F) -> AsTimeout<'_, F> {
AsTimeout {
interval: self,
future: fut,
}
}
#[inline(always)]
pub fn poll_tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Instant> {
cfg_if! {
if #[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] {
match self.project() {
#[cfg(feature = "_rt-tokio")]
IntervalProjected::Tokio { mut sleep, period } => {
ready!(sleep.as_mut().poll(cx));
let now = Instant::now();
sleep.reset((now + *period).into());
Poll::Ready(now)
}
#[cfg(feature = "_rt-async-io")]
IntervalProjected::AsyncIo { mut timer } => {
Poll::Ready(ready!(timer
.as_mut()
.poll_next(cx))
.expect("BUG: `async_io::Timer::next()` should always yield"))
}
}
} else {
unreachable!()
}
}
}
}
pin_project! {
pub struct AsTimeout<'i, F> {
interval: Pin<&'i mut Interval>,
#[pin]
future: F,
}
}
impl<F> Future for AsTimeout<'_, F>
where
F: Future,
{
type Output = Option<F::Output>;
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Poll::Ready(out) = this.future.poll(cx) {
return Poll::Ready(Some(out));
}
this.interval.as_mut().poll_tick(cx).map(|_| None)
}
}
#[track_caller]
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
where
@ -128,6 +254,29 @@ where
}
}
pub fn try_spawn<F>(fut: F) -> Option<JoinHandle<F::Output>>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
#[cfg(feature = "_rt-tokio")]
if let Ok(handle) = tokio::runtime::Handle::try_current() {
return Some(JoinHandle::Tokio(handle.spawn(fut)));
}
cfg_if! {
if #[cfg(feature = "_rt-async-global-executor")] {
Some(JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut))))
} else if #[cfg(feature = "_rt-smol")] {
Some(JoinHandle::AsyncTask(Some(smol::spawn(fut))))
} else if #[cfg(feature = "_rt-async-std")] {
Some(JoinHandle::AsyncStd(async_std::task::spawn(fut)))
} else {
None
}
}
}
#[track_caller]
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
where

View File

@ -1,13 +1,10 @@
use crate::ext::future::race;
use crate::rt::TimeoutError;
use std::{
future::Future,
pin::pin,
time::{Duration, Instant},
};
use futures_util::future::{select, Either};
use crate::rt::TimeoutError;
pub async fn sleep(duration: Duration) {
async_io::Timer::after(duration).await;
}
@ -17,8 +14,16 @@ pub async fn sleep_until(deadline: Instant) {
}
pub async fn timeout<F: Future>(duration: Duration, future: F) -> Result<F::Output, TimeoutError> {
match select(pin!(future), pin!(sleep(duration))).await {
Either::Left((result, _)) => Ok(result),
Either::Right(_) => Err(TimeoutError),
}
race(future, sleep(duration))
.await
.map_err(|_| TimeoutError)
}
pub async fn timeout_at<F: Future>(
deadline: Instant,
future: F,
) -> Result<F::Output, TimeoutError> {
race(future, sleep_until(deadline))
.await
.map_err(|_| TimeoutError)
}

View File

@ -1,5 +1,6 @@
mod socket;
#[inline(always)]
pub fn available() -> bool {
tokio::runtime::Handle::try_current().is_ok()
}

View File

@ -4,11 +4,40 @@
// We'll generally lean towards Tokio's types as those are more featureful
// (including `tokio-console` support) and more widely deployed.
use std::sync::Arc;
#[cfg(feature = "_rt-tokio")]
pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock};
pub use tokio::sync::{
Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, OwnedMutexGuard as AsyncMutexGuardArc,
RwLock as AsyncRwLock,
};
#[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))]
pub use async_lock::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock};
pub use async_lock::{
Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, MutexGuardArc as AsyncMutexGuardArc,
RwLock as AsyncRwLock,
};
pub async fn lock_arc<T>(mutex: &Arc<AsyncMutex<T>>) -> AsyncMutexGuardArc<T> {
#[cfg(feature = "_rt-tokio")]
return mutex.clone().lock_owned().await;
#[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))]
return mutex.lock_arc().await;
#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))]
return crate::rt::missing_rt(mutex);
}
pub fn try_lock_arc<T>(mutex: &Arc<AsyncMutex<T>>) -> Option<AsyncMutexGuardArc<T>> {
#[cfg(feature = "_rt-tokio")]
return mutex.clone().try_lock_owned().ok();
#[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))]
return mutex.try_lock_arc();
#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))]
return crate::rt::missing_rt(mutex);
}
#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))]
pub use noop::*;
@ -18,6 +47,7 @@ mod noop {
use crate::rt::missing_rt;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
pub struct AsyncMutex<T> {
// `Sync` if `T: Send`
@ -28,6 +58,10 @@ mod noop {
inner: &'a AsyncMutex<T>,
}
pub struct AsyncMutexGuardArc<T> {
inner: Arc<AsyncMutex<T>>,
}
impl<T> AsyncMutex<T> {
pub fn new(val: T) -> Self {
missing_rt(val)
@ -51,4 +85,18 @@ mod noop {
missing_rt(self)
}
}
impl<T> Deref for AsyncMutexGuardArc<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
missing_rt(self)
}
}
impl<T> DerefMut for AsyncMutexGuardArc<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
missing_rt(self)
}
}
}

View File

@ -2,13 +2,15 @@ use sqlx::pool::PoolOptions;
use sqlx::{Connection, Database, Error, Pool};
use std::env;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::format::FmtSpan;
pub fn setup_if_needed() {
let _ = dotenvy::dotenv();
let _ = tracing_subscriber::fmt::Subscriber::builder()
.with_env_filter(EnvFilter::from_default_env())
.with_test_writer()
.finish();
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
// .with_test_writer()
.try_init();
}
// Make a new connection

View File

@ -255,6 +255,10 @@ async fn it_works_with_cache_disabled() -> anyhow::Result<()> {
#[sqlx_macros::test]
async fn it_executes_with_pool() -> anyhow::Result<()> {
setup_if_needed();
tracing::info!("starting test");
let pool = sqlx_test::pool::<Postgres>().await?;
let rows = pool.fetch_all("SELECT 1; SElECT 2").await?;
@ -1146,7 +1150,7 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> {
assert!(listener.next_buffered().is_none());
// Activate connection.
sqlx::query!("SELECT 1 AS one")
sqlx::query("SELECT 1 AS one")
.fetch_all(&mut listener)
.await?;
@ -2086,6 +2090,7 @@ async fn test_issue_3052() {
}
#[sqlx_macros::test]
#[cfg(feature = "chrono")]
async fn test_bind_iter() -> anyhow::Result<()> {
use sqlx::postgres::PgBindIterExt;
use sqlx::types::chrono::{DateTime, Utc};