diff --git a/sqlx-core/src/ext/future.rs b/sqlx-core/src/ext/future.rs new file mode 100644 index 00000000..138f8001 --- /dev/null +++ b/sqlx-core/src/ext/future.rs @@ -0,0 +1,38 @@ +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[project = RaceProject] + pub struct Race { + #[pin] + left: L, + #[pin] + right: R, + } +} + +impl Future for Race +where + L: Future, + R: Future, +{ + type Output = Result; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(left) = this.left.as_mut().poll(cx) { + return Poll::Ready(Ok(left)); + } + + this.right.as_mut().poll(cx).map(Err) + } +} + +#[inline(always)] +pub fn race(left: L, right: R) -> Race { + Race { left, right } +} diff --git a/sqlx-core/src/ext/mod.rs b/sqlx-core/src/ext/mod.rs index 98059f8c..167c6856 100644 --- a/sqlx-core/src/ext/mod.rs +++ b/sqlx-core/src/ext/mod.rs @@ -2,3 +2,5 @@ pub mod ustr; #[macro_use] pub mod async_stream; + +pub mod future; diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 52920c6a..5adf82bf 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -14,7 +14,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use std::time::Instant; -use crate::pool::shard::DisconnectedSlot; +use crate::pool::connection_set::DisconnectedSlot; #[cfg(doc)] use crate::pool::PoolOptions; use crate::sync::{AsyncMutex, AsyncMutexGuard}; @@ -646,7 +646,7 @@ async fn connect_with_backoff( match res { ControlFlow::Break(Ok(conn)) => { - tracing::trace!( + tracing::debug!( target: "sqlx::pool::connect", %connection_id, attempt, @@ -654,18 +654,16 @@ async fn connect_with_backoff( "connection established", ); - return Ok(PoolConnection::new( - slot.put(ConnectionInner { - raw: conn, - id: connection_id, - created_at: now, - last_released_at: now, - }), - pool.0.clone(), - )); + return Ok(PoolConnection::new(slot.put(ConnectionInner { + pool: Arc::downgrade(&pool.0), + raw: conn, + id: connection_id, + created_at: now, + last_released_at: now, + }))); } ControlFlow::Break(Err(e)) => { - tracing::warn!( + tracing::error!( target: "sqlx::pool::connect", %connection_id, attempt, diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 8d115818..1103374c 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,33 +1,35 @@ use std::fmt::{self, Debug, Formatter}; use std::future::{self, Future}; +use std::io; use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use super::inner::{is_beyond_max_lifetime, PoolInner}; +use super::inner::PoolInner; use crate::pool::connect::{ConnectPermit, ConnectTaskShared, ConnectionId}; +use crate::pool::connection_set::{ConnectedSlot, DisconnectedSlot}; use crate::pool::options::PoolConnectionMetadata; -use crate::pool::shard::{ConnectedSlot, DisconnectedSlot}; -use crate::pool::Pool; +use crate::pool::{Pool, PoolOptions}; use crate::rt; const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); -const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); +const CLOSE_TIMEOUT: Duration = Duration::from_secs(5); /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. pub struct PoolConnection { conn: Option>>, - pub(crate) pool: Arc>, close_on_drop: bool, } pub(super) struct ConnectionInner { + // Note: must be `Weak` to prevent a reference cycle + pub(crate) pool: Weak>, pub(super) raw: DB::Connection, pub(super) id: ConnectionId, pub(super) created_at: Instant, @@ -72,11 +74,10 @@ impl AsMut for PoolConnection { } impl PoolConnection { - pub(super) fn new(live: ConnectedSlot>, pool: Arc>) -> Self { + pub(super) fn new(live: ConnectedSlot>) -> Self { Self { conn: Some(live), close_on_drop: false, - pool, } } @@ -140,13 +141,16 @@ impl PoolConnection { #[doc(hidden)] pub fn return_to_pool(&mut self) -> impl Future + Send + 'static { let conn = self.conn.take(); - let pool = self.pool.clone(); async move { let Some(conn) = conn else { return; }; + let Some(pool) = Weak::upgrade(&conn.pool) else { + return; + }; + rt::timeout(RETURN_TO_POOL_TIMEOUT, return_to_pool(conn, &pool)) .await // Dropping of the `slot` will check if the connection must be re-established @@ -161,7 +165,7 @@ impl PoolConnection { async move { if let Some(conn) = conn { // Don't hold the connection forever if it hangs while trying to close - rt::timeout(CLOSE_ON_DROP_TIMEOUT, close(conn)).await.ok(); + rt::timeout(CLOSE_TIMEOUT, close(conn)).await.ok(); } } } @@ -195,7 +199,7 @@ impl Drop for PoolConnection { } // We still need to spawn a task to maintain `min_connections`. - if self.conn.is_some() || self.pool.options.min_connections > 0 { + if self.conn.is_some() { crate::rt::spawn(self.return_to_pool()); } } @@ -220,6 +224,48 @@ impl ConnectionInner { idle_for: now.saturating_duration_since(self.last_released_at), } } + + pub fn is_beyond_max_lifetime(&self, options: &PoolOptions) -> bool { + if let Some(max_lifetime) = options.max_lifetime { + let age = self.created_at.elapsed(); + + if age > max_lifetime { + tracing::info!( + target: "sqlx::pool", + connection_id=%self.id, + ?age, + "connection is beyond `max_lifetime`, closing" + ); + + return true; + } + } + + false + } + + pub fn is_beyond_idle_timeout(&self, options: &PoolOptions) -> bool { + if let Some(idle_timeout) = options.idle_timeout { + let now = Instant::now(); + + let age = now.duration_since(self.created_at); + let idle_duration = now.duration_since(self.last_released_at); + + if idle_duration > idle_timeout { + tracing::info!( + target: "sqlx::pool", + connection_id=%self.id, + ?age, + ?idle_duration, + "connection is beyond `idle_timeout`, closing" + ); + + return true; + } + } + + false + } } pub(crate) async fn close( @@ -231,14 +277,19 @@ pub(crate) async fn close( let (conn, slot) = ConnectedSlot::take(conn); - let res = conn.raw.close().await.inspect_err(|error| { - tracing::debug!( - target: "sqlx::pool", - %connection_id, - %error, - "error occurred while closing the pool connection" - ); - }); + let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close()) + .await + .unwrap_or_else(|_| { + Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into()) + }) + .inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); (res, slot) } @@ -255,14 +306,19 @@ pub(crate) async fn close_hard( let (conn, slot) = ConnectedSlot::take(conn); - let res = conn.raw.close_hard().await.inspect_err(|error| { - tracing::debug!( - target: "sqlx::pool", - %connection_id, - %error, - "error occurred while closing the pool connection" - ); - }); + let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close_hard()) + .await + .unwrap_or_else(|_| { + Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into()) + }) + .inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); (res, slot) } @@ -282,7 +338,7 @@ async fn return_to_pool( // If the connection is beyond max lifetime, close the connection and // immediately create a new connection - if is_beyond_max_lifetime(&conn, &pool.options) { + if conn.is_beyond_max_lifetime(&pool.options) { let (_res, slot) = close(conn).await; return Err(slot); } @@ -314,6 +370,7 @@ async fn return_to_pool( // to recover from cancellations if let Err(error) = conn.raw.ping().await { tracing::warn!( + target: "sqlx::pool", %error, "error occurred while testing the connection on-release", ); diff --git a/sqlx-core/src/pool/connection_set.rs b/sqlx-core/src/pool/connection_set.rs new file mode 100644 index 00000000..8683f8a9 --- /dev/null +++ b/sqlx-core/src/pool/connection_set.rs @@ -0,0 +1,543 @@ +use crate::ext::future::race; +use crate::rt; +use crate::sync::{AsyncMutex, AsyncMutexGuardArc}; +use event_listener::{listener, Event, EventListener, IntoNotification}; +use futures_core::Stream; +use futures_util::stream::FuturesUnordered; +use futures_util::{FutureExt, StreamExt}; +use std::cmp; +use std::future::Future; +use std::ops::{Deref, DerefMut, RangeInclusive, RangeToInclusive}; +use std::pin::{pin, Pin}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; + +pub struct ConnectionSet { + global: Arc, + slots: Box<[Arc>]>, +} + +pub struct ConnectedSlot(SlotGuard); + +pub struct DisconnectedSlot(SlotGuard); + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum AcquirePreference { + Connected, + Disconnected, + Either, +} + +struct Global { + unlock_event: Event, + disconnect_event: Event, + num_connected: AtomicUsize, + min_connections: usize, + min_connections_event: Event<()>, +} + +struct SlotGuard { + slot: Arc>, + // `Option` allows us to take the guard in the drop handler. + locked: Option>>, +} + +struct Slot { + // By having each `Slot` hold its own reference to `Global`, we can avoid extra contended clones + // which would sap performance + global: Arc, + index: usize, + // I'd love to eliminate this redundant `Arc` but it's likely not possible without `unsafe` + connection: Arc>>, + unlock_event: Event, + disconnect_event: Event, + connected: AtomicBool, + locked: AtomicBool, + leaked: AtomicBool, +} + +impl ConnectionSet { + pub fn new(size: RangeInclusive) -> Self { + let global = Arc::new(Global { + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + num_connected: AtomicUsize::new(0), + min_connections: *size.start(), + min_connections_event: Event::with_tag(), + }); + + ConnectionSet { + // `vec![; size].into()` clones `` instead of repeating it, + // which is *no bueno* when wrapping something in `Arc` + slots: (0..*size.end()) + .map(|index| { + Arc::new(Slot { + global: global.clone(), + index, + connection: Arc::new(AsyncMutex::new(None)), + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + connected: AtomicBool::new(false), + locked: AtomicBool::new(false), + leaked: AtomicBool::new(false), + }) + }) + .collect(), + global, + } + } + + #[inline(always)] + pub fn num_connected(&self) -> usize { + self.global.num_connected() + } + + pub fn count_idle(&self) -> usize { + self.slots.iter().filter(|slot| slot.is_locked()).count() + } + + pub async fn acquire_connected(&self) -> ConnectedSlot { + self.acquire_inner(AcquirePreference::Connected) + .await + .assert_connected() + } + + pub async fn acquire_disconnected(&self) -> DisconnectedSlot { + self.acquire_inner(AcquirePreference::Disconnected) + .await + .assert_disconnected() + } + + /// Attempt to acquire the connection associated with the current thread. + pub async fn acquire_any(&self) -> Result, DisconnectedSlot> { + self.acquire_inner(AcquirePreference::Either) + .await + .try_connected() + } + + async fn acquire_inner(&self, pref: AcquirePreference) -> SlotGuard { + /// Smallest time-step supported by [`tokio::time::sleep()`]. + /// + /// `async-io` doesn't document a minimum time-step, instead deferring to the platform. + const STEP_INTERVAL: Duration = Duration::from_millis(1); + + const SEARCH_LIMIT: usize = 5; + + let preferred_slot = current_thread_id() % self.slots.len(); + + tracing::trace!(preferred_slot, ?pref, "acquire_inner"); + + // Always try to lock the connection associated with our thread ID + let mut acquire_preferred = pin!(self.slots[preferred_slot].acquire(pref)); + + let mut step_interval = pin!(rt::interval_after(STEP_INTERVAL)); + + let mut intervals_elapsed = 0usize; + + let mut search_slots = FuturesUnordered::new(); + + let mut listen_global = pin!(self.global.listen(pref)); + + let mut search_slot = self.next_slot(preferred_slot); + + std::future::poll_fn(|cx| loop { + if let Poll::Ready(locked) = acquire_preferred.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + // Don't push redundant futures for small sets. + let search_limit = cmp::min(SEARCH_LIMIT, self.slots.len()); + + if search_slots.len() < search_limit && step_interval.as_mut().poll_tick(cx).is_ready() + { + intervals_elapsed = intervals_elapsed.saturating_add(1); + + if search_slot != preferred_slot && self.slots[search_slot].matches_pref(pref) { + search_slots.push(self.slots[search_slot].lock()); + } + + search_slot = self.next_slot(search_slot); + } + + if let Poll::Ready(Some(locked)) = Pin::new(&mut search_slots).poll_next(cx) { + if locked.matches_pref(pref) { + return Poll::Ready(locked); + } + + continue; + } + + if intervals_elapsed > search_limit && search_slots.len() < search_limit { + if let Poll::Ready(slot) = listen_global.as_mut().poll(cx) { + if self.slots[slot].matches_pref(pref) { + search_slots.push(self.slots[slot].lock()); + } + + listen_global.as_mut().set(self.global.listen(pref)); + continue; + } + } + + return Poll::Pending; + }) + .await + } + + pub fn try_acquire_connected(&self) -> Option> { + Some( + self.try_acquire(AcquirePreference::Connected)? + .assert_connected(), + ) + } + + pub fn try_acquire_disconnected(&self) -> Option> { + Some( + self.try_acquire(AcquirePreference::Disconnected)? + .assert_disconnected(), + ) + } + + fn try_acquire(&self, pref: AcquirePreference) -> Option> { + let mut search_slot = current_thread_id() % self.slots.len(); + + for _ in 0..self.slots.len() { + if let Some(locked) = self.slots[search_slot].try_acquire(pref) { + return Some(locked); + } + + search_slot = self.next_slot(search_slot); + } + + None + } + + pub fn min_connections_listener(&self) -> EventListener { + self.global.min_connections_event.listen() + } + + pub fn iter_idle(&self) -> impl Iterator> + '_ { + self.slots.iter().filter_map(|slot| { + Some( + slot.try_acquire(AcquirePreference::Connected)? + .assert_connected(), + ) + }) + } + + pub async fn drain(&self, ref close: impl AsyncFn(ConnectedSlot) -> DisconnectedSlot) { + let mut closing = FuturesUnordered::new(); + + // We could try to be more efficient by only populating the `FuturesUnordered` for + // connected slots, but then we'd have to handle a disconnected slot becoming connected, + // which could happen concurrently. + // + // However, we don't *need* to be efficient when shutting down the pool. + for slot in &self.slots { + closing.push(async { + let locked = slot.lock().await; + + let slot = match locked.try_connected() { + Ok(connected) => close(connected).await, + Err(disconnected) => disconnected, + }; + + // The pool is shutting down; don't wake any tasks that might have been interested + slot.leak(); + }); + } + + while closing.next().await.is_some() {} + } + + #[inline(always)] + fn next_slot(&self, slot: usize) -> usize { + // By adding a number that is coprime to `slots.len()` before taking the modulo, + // we can visit each slot in a pseudo-random order, spreading the demand evenly. + // + // Interestingly, this pattern returns to the original slot after `slots.len()` iterations, + // because of congruence: https://en.wikipedia.org/wiki/Modular_arithmetic#Congruence + (slot + 547) % self.slots.len() + } +} + +impl AcquirePreference { + #[inline(always)] + fn wants_connected(&self, is_connected: bool) -> bool { + match (self, is_connected) { + (Self::Connected, true) => true, + (Self::Disconnected, false) => true, + (Self::Either, _) => true, + _ => false, + } + } +} + +impl Slot { + #[inline(always)] + fn matches_pref(&self, pref: AcquirePreference) -> bool { + !self.is_leaked() && pref.wants_connected(self.is_connected()) + } + + #[inline(always)] + fn is_connected(&self) -> bool { + self.connected.load(Ordering::Relaxed) + } + + #[inline(always)] + fn is_locked(&self) -> bool { + self.locked.load(Ordering::Relaxed) + } + + #[inline(always)] + fn is_leaked(&self) -> bool { + self.leaked.load(Ordering::Relaxed) + } + + #[inline(always)] + fn set_is_connected(&self, connected: bool) { + let was_connected = self.connected.swap(connected, Ordering::Acquire); + + match (connected, was_connected) { + (false, true) => { + // Ensure this is synchronized with `connected` + self.global.num_connected.fetch_add(1, Ordering::Release); + } + (true, false) => { + self.global.num_connected.fetch_sub(1, Ordering::Release); + } + _ => (), + } + } + + async fn acquire(self: &Arc, pref: AcquirePreference) -> SlotGuard { + loop { + if self.matches_pref(pref) { + tracing::trace!(slot_index=%self.index, "waiting for lock"); + + let locked = self.lock().await; + + if locked.matches_pref(pref) { + return locked; + } + } + + match pref { + AcquirePreference::Connected => { + listener!(self.unlock_event => listener); + listener.await; + } + AcquirePreference::Disconnected => { + listener!(self.disconnect_event => listener); + listener.await + } + AcquirePreference::Either => { + listener!(self.unlock_event => unlock_listener); + listener!(self.disconnect_event => disconnect_listener); + race(unlock_listener, disconnect_listener).await.ok(); + } + } + } + } + + fn try_acquire(self: &Arc, pref: AcquirePreference) -> Option> { + if self.matches_pref(pref) { + let locked = self.try_lock()?; + + if locked.matches_pref(pref) { + return Some(locked); + } + } + + None + } + + async fn lock(self: &Arc) -> SlotGuard { + let locked = crate::sync::lock_arc(&self.connection).await; + + self.locked.store(true, Ordering::Relaxed); + + SlotGuard { + slot: self.clone(), + locked: Some(locked), + } + } + + fn try_lock(self: &Arc) -> Option> { + let locked = crate::sync::try_lock_arc(&self.connection)?; + + self.locked.store(true, Ordering::Relaxed); + + Some(SlotGuard { + slot: self.clone(), + locked: Some(locked), + }) + } +} + +impl SlotGuard { + #[inline(always)] + fn get(&self) -> &Option { + self.locked.as_ref().expect(EXPECT_LOCKED) + } + + #[inline(always)] + fn get_mut(&mut self) -> &mut Option { + self.locked.as_mut().expect(EXPECT_LOCKED) + } + + #[inline(always)] + fn matches_pref(&self, pref: AcquirePreference) -> bool { + !self.slot.is_leaked() && pref.wants_connected(self.is_connected()) + } + + #[inline(always)] + fn is_connected(&self) -> bool { + self.get().is_some() + } + + fn try_connected(self) -> Result, DisconnectedSlot> { + if self.is_connected() { + Ok(ConnectedSlot(self)) + } else { + Err(DisconnectedSlot(self)) + } + } + + fn assert_connected(self) -> ConnectedSlot { + assert!(self.is_connected()); + ConnectedSlot(self) + } + + fn assert_disconnected(self) -> DisconnectedSlot { + assert!(!self.is_connected()); + + DisconnectedSlot(self) + } + + /// Updates `Slot::connected` without notifying the `ConnectionSet`. + /// + /// Returns `Some(connected)` or `None` if this guard was already dropped. + fn drop_without_notify(&mut self) -> Option { + self.locked.take().map(|locked| { + let connected = locked.is_some(); + self.slot.set_is_connected(connected); + self.slot.locked.store(false, Ordering::Release); + connected + }) + } +} + +const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation"; +const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`"; + +impl ConnectedSlot { + pub fn take(mut self) -> (C, DisconnectedSlot) { + let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED); + (conn, self.0.assert_disconnected()) + } +} + +impl Deref for ConnectedSlot { + type Target = C; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.0.get().as_ref().expect(EXPECT_CONNECTED) + } +} + +impl DerefMut for ConnectedSlot { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.get_mut().as_mut().expect(EXPECT_CONNECTED) + } +} + +impl DisconnectedSlot { + pub fn put(mut self, conn: C) -> ConnectedSlot { + *self.0.get_mut() = Some(conn); + ConnectedSlot(self.0) + } + + pub fn leak(mut self) { + self.0.slot.leaked.store(true, Ordering::Release); + self.0.drop_without_notify(); + } +} + +impl Drop for SlotGuard { + fn drop(&mut self) { + let Some(connected) = self.drop_without_notify() else { + return; + }; + + let event = if connected { + &self.slot.global.unlock_event + } else { + &self.slot.global.disconnect_event + }; + + if event.notify(1.tag(self.slot.index).additional()) != 0 { + return; + } + + if connected { + self.slot.unlock_event.notify(1); + return; + } + + if self.slot.disconnect_event.notify(1) != 0 { + return; + } + + if self.slot.global.num_connected() < self.slot.global.min_connections { + self.slot.global.min_connections_event.notify(1); + } + } +} + +impl Global { + #[inline(always)] + fn num_connected(&self) -> usize { + self.num_connected.load(Ordering::Relaxed) + } + + async fn listen(&self, pref: AcquirePreference) -> usize { + match pref { + AcquirePreference::Either => race(self.listen_unlocked(), self.listen_disconnected()) + .await + .unwrap_or_else(|slot| slot), + AcquirePreference::Connected => self.listen_unlocked().await, + AcquirePreference::Disconnected => self.listen_disconnected().await, + } + } + + async fn listen_unlocked(&self) -> usize { + listener!(self.unlock_event => listener); + listener.await + } + + async fn listen_disconnected(&self) -> usize { + listener!(self.disconnect_event => listener); + listener.await + } +} + +fn current_thread_id() -> usize { + // FIXME: this can be replaced when this is stabilized: + // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + // `SeqCst` is possibly too strong since we don't need synchronization with + // any other variable. I'm not confident enough in my understanding of atomics to be certain, + // especially with regards to weakly ordered architectures. + // + // However, this is literally only done once on each thread, so it doesn't really matter. + static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + } + + CURRENT_THREAD_ID.with(|i| *i) +} diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 046834a4..1ae687f1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -5,33 +5,30 @@ use crate::pool::{connection, CloseEvent, Pool, PoolConnection, PoolConnector, P use std::cmp; use std::future::Future; +use std::ops::ControlFlow; use std::pin::{pin, Pin}; -use std::rc::Weak; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::{ready, Poll}; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; use crate::connection::Connection; +use crate::ext::future::race; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::connect::{ - ConnectPermit, ConnectTask, ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector, -}; -use crate::pool::shard::{ConnectedSlot, DisconnectedSlot, Sharded}; -use crate::rt::JoinHandle; +use crate::pool::connect::{ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector}; +use crate::pool::connection_set::{ConnectedSlot, ConnectionSet, DisconnectedSlot}; use crate::{private_tracing_dynamic_event, rt}; -use either::Either; -use futures_core::FusedFuture; -use futures_util::future::{self, OptionFuture}; -use futures_util::{stream, FutureExt, TryStreamExt}; +use event_listener::listener; +use futures_util::future::{self}; use std::time::{Duration, Instant}; use tracing::Level; const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5); +const TEST_BEFORE_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(60); pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, - pub(super) sharded: Sharded>, + pub(super) connections: ConnectionSet>, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -44,39 +41,15 @@ impl PoolInner { options: PoolOptions, connector: impl PoolConnector, ) -> Arc { - let pool = Arc::::new_cyclic(|pool_weak| { - let pool_weak = pool_weak.clone(); - - let reconnect = move |slot| { - let Some(pool) = pool_weak.upgrade() else { - // Prevent an infinite loop on pool drop. - DisconnectedSlot::leak(slot); - return; - }; - - pool.connector.connect( - Pool(pool.clone()), - ConnectionId::next(), - slot, - ConnectTaskShared::new_arc(), - ); - }; - - Self { - connector: DynConnector::new(connector), - counter: ConnectionCounter::new(), - sharded: Sharded::new( - options.max_connections, - options.shards, - options.min_connections, - reconnect, - ), - is_closed: AtomicBool::new(false), - on_closed: event_listener::Event::new(), - acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), - acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), - options, - } + let pool = Arc::new(Self { + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + connections: ConnectionSet::new(options.min_connections..=options.max_connections), + is_closed: AtomicBool::new(false), + on_closed: event_listener::Event::new(), + acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), + acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), + options, }); spawn_maintenance_tasks(&pool); @@ -85,11 +58,11 @@ impl PoolInner { } pub(super) fn size(&self) -> usize { - self.sharded.count_connected() + self.connections.num_connected() } pub(super) fn num_idle(&self) -> usize { - self.sharded.count_unlocked(true) + self.connections.count_idle() } pub(super) fn is_closed(&self) -> bool { @@ -105,11 +78,8 @@ impl PoolInner { self.mark_closed(); // Keep clearing the idle queue as connections are released until the count reaches zero. - self.sharded.drain(|slot| async move { - let (conn, slot) = ConnectedSlot::take(slot); - - let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await; - + self.connections.drain(async |slot| { + let (_res, slot) = connection::close(slot).await; slot }) } @@ -130,7 +100,7 @@ impl PoolInner { return None; } - self.sharded.try_acquire_connected() + self.connections.try_acquire_connected() } pub(super) async fn acquire(self: &Arc) -> Result, Error> { @@ -140,74 +110,43 @@ impl PoolInner { let acquire_started_at = Instant::now(); - let mut close_event = pin!(self.close_event()); - let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); + // Lazily allocated `Arc` + let mut connect_shared = None; - let connect_shared = ConnectTaskShared::new_arc(); + let res = { + // Pinned to the stack without allocating + listener!(self.on_closed => close_listener); + let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); + let mut acquire_inner = pin!(self.acquire_inner(&mut connect_shared)); - let mut acquire_connected = pin!(self.acquire_connected().fuse()); - - let mut acquire_disconnected = pin!(self.sharded.acquire_disconnected().fuse()); - - let mut connect = future::Fuse::terminated(); - - let acquired = std::future::poll_fn(|cx| loop { - if let Poll::Ready(()) = close_event.as_mut().poll(cx) { - return Poll::Ready(Err(Error::PoolClosed)); - } - - if let Poll::Ready(res) = acquire_connected.as_mut().poll(cx) { - match res { - Ok(conn) => { - return Poll::Ready(Ok(conn)); - } - Err(slot) => { - if connect.is_terminated() { - connect = self - .connector - .connect( - Pool(self.clone()), - ConnectionId::next(), - slot, - connect_shared.clone(), - ) - .fuse(); - } - - // Try to acquire another connected connection. - acquire_connected.set(self.acquire_connected().fuse()); - continue; - } + std::future::poll_fn(|cx| { + if self.is_closed() { + return Poll::Ready(Err(Error::PoolClosed)); } - } - if let Poll::Ready(slot) = acquire_disconnected.as_mut().poll(cx) { - if connect.is_terminated() { - connect = self - .connector - .connect( - Pool(self.clone()), - ConnectionId::next(), - slot, - connect_shared.clone(), - ) - .fuse(); + // The result doesn't matter so much as the wakeup + let _ = Pin::new(&mut close_listener).poll(cx); + + if let Poll::Ready(()) = deadline.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: None, + })); } - } - if let Poll::Ready(res) = Pin::new(&mut connect).poll(cx) { - return Poll::Ready(res); - } + acquire_inner.as_mut().poll(cx) + }) + .await + }; - if let Poll::Ready(()) = deadline.as_mut().poll(cx) { - return Poll::Ready(Err(Error::PoolTimedOut { - last_connect_error: connect_shared.take_error().map(Box::new), - })); - } - - return Poll::Pending; - }) - .await?; + let acquired = res.map_err(|e| match e { + Error::PoolTimedOut { + last_connect_error: None, + } => Error::PoolTimedOut { + last_connect_error: connect_shared + .and_then(|shared| Some(shared.take_error()?.into())), + }, + e => e, + })?; let acquired_after = acquire_started_at.elapsed(); @@ -235,20 +174,36 @@ impl PoolInner { Ok(acquired) } - async fn acquire_connected( + async fn acquire_inner( self: &Arc, - ) -> Result, DisconnectedSlot>> { - let connected = self.sharded.acquire_connected().await; + connect_shared: &mut Option>, + ) -> Result, Error> { + tracing::trace!("waiting for any connection"); - tracing::debug!( - target: "sqlx::pool", - connection_id=%connected.id, - "acquired idle connection" + let disconnected = match self.connections.acquire_any().await { + Ok(conn) => match finish_acquire(self, conn).await { + Ok(conn) => return Ok(conn), + Err(slot) => slot, + }, + Err(slot) => slot, + }; + + let mut connect_task = self.connector.connect( + Pool(self.clone()), + ConnectionId::next(), + disconnected, + connect_shared.insert(ConnectTaskShared::new_arc()).clone(), ); - match finish_acquire(self, connected) { - Either::Left(task) => task.await, - Either::Right(conn) => Ok(conn), + loop { + match race(&mut connect_task, self.connections.acquire_connected()).await { + Ok(Ok(conn)) => return Ok(conn), + Ok(Err(e)) => return Err(e), + Err(conn) => match finish_acquire(self, conn).await { + Ok(conn) => return Ok(conn), + Err(_) => continue, + }, + } } } @@ -258,17 +213,20 @@ impl PoolInner { ) -> Result<(), Error> { let shared = ConnectTaskShared::new_arc(); - let connect_min_connections = - future::try_join_all(self.sharded.iter_min_connections().map(|slot| { - self.connector.connect( - Pool(self.clone()), - ConnectionId::next(), - slot, - shared.clone(), - ) - })); + let connect_min_connections = future::try_join_all( + (self.connections.num_connected()..self.options.min_connections) + .filter_map(|_| self.connections.try_acquire_disconnected()) + .map(|slot| { + self.connector.connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + shared.clone(), + ) + }), + ); - let mut conns = if let Some(deadline) = deadline { + let conns = if let Some(deadline) = deadline { match rt::timeout_at(deadline, connect_min_connections).await { Ok(Ok(conns)) => conns, Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => { @@ -297,144 +255,192 @@ impl Drop for PoolInner { } } -/// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise. -pub(super) fn is_beyond_max_lifetime( - live: &ConnectionInner, - options: &PoolOptions, -) -> bool { - options - .max_lifetime - .is_some_and(|max| live.created_at.elapsed() > max) -} - -/// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. -fn is_beyond_idle_timeout( - idle: &ConnectionInner, - options: &PoolOptions, -) -> bool { - options - .idle_timeout - .is_some_and(|timeout| idle.last_released_at.elapsed() > timeout) -} - /// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. /// /// Otherwise, immediately returns the connection. -fn finish_acquire( +async fn finish_acquire( pool: &Arc>, mut conn: ConnectedSlot>, -) -> Either< - JoinHandle, DisconnectedSlot>>>, - PoolConnection, -> { - if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { - let pool = pool.clone(); +) -> Result, DisconnectedSlot>> { + struct SpawnOnDrop(Option>>) + where + F::Output: Send + 'static; - // Spawn a task so the call may complete even if `acquire()` is cancelled. - return Either::Left(rt::spawn(async move { - // Check that the connection is still live + impl Future for SpawnOnDrop + where + F::Output: Send + 'static, + { + type Output = F::Output; + + #[inline(always)] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0 + .as_mut() + .expect("BUG: inner future taken") + .as_mut() + .poll(cx) + } + } + + impl Drop for SpawnOnDrop + where + F::Output: Send + 'static, + { + fn drop(&mut self) { + rt::try_spawn(self.0.take().expect("BUG: inner future taken")); + } + } + + async fn finish_inner( + conn: &mut ConnectedSlot>, + pool: &PoolInner, + ) -> ControlFlow<()> { + // Check that the connection is still live + if pool.options.test_before_acquire { if let Err(error) = conn.raw.ping().await { // an error here means the other end has hung up or we lost connectivity // either way we're fine to just discard the connection // the error itself here isn't necessarily unexpected so WARN is too strong tracing::info!(%error, connection_id=%conn.id, "ping on idle connection returned error"); - - // connection is broken so don't try to close nicely - let (_res, slot) = connection::close_hard(conn).await; - return Err(slot); + return ControlFlow::Break(()); } + } - if let Some(test) = &pool.options.before_acquire { - let meta = conn.idle_metadata(); - match test(&mut conn.raw, meta).await { - Ok(false) => { - // connection was rejected by user-defined hook, close nicely - let (_res, slot) = connection::close(conn).await; - return Err(slot); - } - - Err(error) => { - tracing::warn!(%error, "error from `before_acquire`"); - - // connection is broken so don't try to close nicely - let (_res, slot) = connection::close_hard(conn).await; - return Err(slot); - } - - Ok(true) => {} + if let Some(test) = &pool.options.before_acquire { + let meta = conn.idle_metadata(); + match test(&mut conn.raw, meta).await { + Ok(false) => { + // connection was rejected by user-defined hook, close nicely + tracing::debug!(connection_id=%conn.id, "connection rejected by `before_acquire`"); + return ControlFlow::Break(()); } - } - Ok(PoolConnection::new(conn, pool)) - })); + Err(error) => { + tracing::warn!(%error, "error from `before_acquire`"); + return ControlFlow::Break(()); + } + + Ok(true) => (), + } + } + + // Checks passed + ControlFlow::Continue(()) } - // No checks are configured, return immediately. - Either::Right(PoolConnection::new(conn, pool.clone())) + if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + let pool = pool.clone(); + + // Spawn a task on-drop so the call may complete even if `acquire()` is cancelled. + conn = SpawnOnDrop(Some(Box::pin(async move { + match rt::timeout(TEST_BEFORE_ACQUIRE_TIMEOUT, finish_inner(&mut conn, &pool)).await { + Ok(ControlFlow::Continue(())) => { + Ok(conn) + } + Ok(ControlFlow::Break(())) => { + // Connection rejected by user-defined hook, attempt to close nicely + let (_res, slot) = connection::close(conn).await; + Err(slot) + } + Err(_) => { + tracing::info!(connection_id=%conn.id, "`before_acquire` checks timed out, closing connection"); + let (_res, slot) = connection::close_hard(conn).await; + Err(slot) + } + } + }))).await?; + } + + tracing::debug!( + target: "sqlx::pool", + connection_id=%conn.id, + "acquired idle connection" + ); + + Ok(PoolConnection::new(conn)) } fn spawn_maintenance_tasks(pool: &Arc>) { - // NOTE: use `pool_weak` for the maintenance tasks - // so they don't keep `PoolInner` from being dropped. - let pool_weak = Arc::downgrade(pool); + if pool.options.min_connections > 0 { + // NOTE: use `pool_weak` for the maintenance tasks + // so they don't keep `PoolInner` from being dropped. + let pool_weak = Arc::downgrade(pool); + let mut close_event = pool.close_event(); - let period = match (pool.options.max_lifetime, pool.options.idle_timeout) { + rt::spawn(async move { + close_event + .do_until(check_min_connections(pool_weak)) + .await + .ok(); + }); + } + + let check_interval = match (pool.options.max_lifetime, pool.options.idle_timeout) { (Some(it), None) | (None, Some(it)) => it, - (Some(a), Some(b)) => cmp::min(a, b), - - (None, None) => { - if pool.options.min_connections > 0 { - rt::spawn(async move { - if let Some(pool) = pool_weak.upgrade() { - if let Err(error) = pool.try_min_connections(None).await { - tracing::error!( - target: "sqlx::pool", - ?error, - "error maintaining min_connections" - ); - } - } - }); - } - - return; - } + (None, None) => return, }; - // Immediately cancel this task if the pool is closed. + let pool_weak = Arc::downgrade(pool); let mut close_event = pool.close_event(); rt::spawn(async move { let _ = close_event - .do_until(async { - // If the last handle to the pool was dropped while we were sleeping - while let Some(pool) = pool_weak.upgrade() { - if pool.is_closed() { - return; - } - - let next_run = Instant::now() + period; - - // Go over all idle connections, check for idleness and lifetime, - // and if we have fewer than min_connections after reaping a connection, - // open a new one immediately. - for conn in pool.sharded.iter_idle() { - if is_beyond_idle_timeout(&conn, &pool.options) - || is_beyond_max_lifetime(&conn, &pool.options) - { - // Dropping the slot will check if the connection needs to be - // re-made. - let _ = connection::close(conn).await; - } - } - - // Don't hold a reference to the pool while sleeping. - drop(pool); - - rt::sleep_until(next_run).await; - } - }) + .do_until(check_idle_conns(pool_weak, check_interval)) .await; }); } + +async fn check_idle_conns(pool_weak: Weak>, check_interval: Duration) { + let mut interval = pin!(rt::interval_after(check_interval)); + + while let Some(pool) = pool_weak.upgrade() { + if pool.is_closed() { + return; + } + + // Go over all idle connections, check for idleness and lifetime, + // and if we have fewer than min_connections after reaping a connection, + // open a new one immediately. + for conn in pool.connections.iter_idle() { + if conn.is_beyond_idle_timeout(&pool.options) + || conn.is_beyond_max_lifetime(&pool.options) + { + // Dropping the slot will check if the connection needs to be re-made. + let _ = connection::close(conn).await; + } + } + + // Don't hold a reference to the pool while sleeping. + drop(pool); + + interval.as_mut().tick().await; + } +} + +async fn check_min_connections(pool_weak: Weak>) { + while let Some(pool) = pool_weak.upgrade() { + if pool.is_closed() { + return; + } + + match pool.try_min_connections(None).await { + Ok(()) => { + let listener = pool.connections.min_connections_listener(); + + // Important: don't hold a strong ref while sleeping + drop(pool); + + listener.await; + } + Err(e) => { + tracing::warn!( + target: "sqlx::pool::maintenance", + min_connections=pool.options.min_connections, + num_connected=pool.connections.num_connected(), + "unable to maintain `min_connections`: {e:?}", + ); + } + } + } +} diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 7d2b18ed..224ee8ff 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -56,20 +56,19 @@ use std::fmt; use std::future::Future; -use std::pin::{pin, Pin}; +use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use event_listener::EventListener; -use futures_core::FusedFuture; -use futures_util::FutureExt; - use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::ext::future::race; use crate::sql_str::SqlSafeStr; use crate::transaction::Transaction; - +use event_listener::EventListener; +use futures_core::FusedFuture; +use tracing::Instrument; pub use self::connect::{PoolConnectMetadata, PoolConnector}; pub use self::connection::PoolConnection; use self::inner::PoolInner; @@ -90,7 +89,9 @@ mod inner; // mod idle; mod options; -mod shard; +// mod shard; + +mod connection_set; /// An asynchronous pool of SQLx database connections. /// @@ -362,16 +363,21 @@ impl Pool { pub fn acquire(&self) -> impl Future, Error>> + 'static { let shared = self.0.clone(); async move { shared.acquire().await } + .instrument(tracing::error_span!(target: "sqlx::pool", "acquire")) } /// Attempts to retrieve a connection from the pool if there is one available. /// /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. + /// + /// # Note: Bypasses `before_acquire` + /// Since this function is not `async`, it cannot await the future returned by + /// [`before_acquire`][PoolOptions::before_acquire] without blocking. + /// + /// Instead, it simply returns the connection immediately. pub fn try_acquire(&self) -> Option> { - self.0 - .try_acquire() - .map(|conn| PoolConnection::new(conn, self.0.clone())) + self.0.try_acquire().map(|conn| PoolConnection::new(conn)) } /// Retrieves a connection and immediately begins a new transaction. @@ -577,42 +583,19 @@ impl CloseEvent { /// /// Cancels the future and returns `Err(PoolClosed)` if/when the pool is closed. /// If the pool was already closed, the future is never run. + #[inline(always)] pub async fn do_until(&mut self, fut: Fut) -> Result { - // Check that the pool wasn't closed already. - // - // We use `poll_immediate()` as it will use the correct waker instead of - // a no-op one like `.now_or_never()`, but it won't actually suspend execution here. - futures_util::future::poll_immediate(&mut *self) - .await - .map_or(Ok(()), |_| Err(Error::PoolClosed))?; - - let mut fut = pin!(fut); - - // I find that this is clearer in intent than `futures_util::future::select()` - // or `futures_util::select_biased!{}` (which isn't enabled anyway). - std::future::poll_fn(|cx| { - // Poll `fut` first as the wakeup event is more likely for it than `self`. - if let Poll::Ready(ret) = fut.as_mut().poll(cx) { - return Poll::Ready(Ok(ret)); - } - - // Can't really factor out mapping to `Err(Error::PoolClosed)` though it seems like - // we should because that results in a different `Ok` type each time. - // - // Ideally we'd map to something like `Result` but using `!` as a type - // is not allowed on stable Rust yet. - self.poll_unpin(cx).map(|_| Err(Error::PoolClosed)) - }) - .await + race(fut, self).await.map_err(|_| Error::PoolClosed) } } impl Future for CloseEvent { type Output = (); + #[inline(always)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(listener) = &mut self.listener { - ready!(listener.poll_unpin(cx)); + ready!(Pin::new(listener).poll(cx)); } // `EventListener` doesn't like being polled after it yields, and even if it did it diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 985d9bb6..862cf6ce 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -1,10 +1,13 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant}; use cfg_if::cfg_if; +use futures_core::Stream; +use futures_util::StreamExt; +use pin_project_lite::pin_project; #[cfg(feature = "_rt-async-io")] pub mod rt_async_io; @@ -59,19 +62,13 @@ pub async fn timeout_at(deadline: Instant, f: F) -> Result Interval { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return Interval::Tokio { + sleep: tokio::time::sleep(period), + period, + }; + } + + cfg_if! { + if #[cfg(feature = "_rt-async-io")] { + Interval::AsyncIo { timer: async_io::Timer::interval(period) } + } else { + missing_rt(period) + } + } +} + +impl Interval { + #[inline(always)] + pub fn tick(mut self: Pin<&mut Self>) -> impl Future + use<'_> { + std::future::poll_fn(move |cx| self.as_mut().poll_tick(cx)) + } + + #[inline(always)] + pub fn as_timeout(self: Pin<&mut Self>, fut: F) -> AsTimeout<'_, F> { + AsTimeout { + interval: self, + future: fut, + } + } + + #[inline(always)] + pub fn poll_tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + cfg_if! { + if #[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] { + match self.project() { + #[cfg(feature = "_rt-tokio")] + IntervalProjected::Tokio { mut sleep, period } => { + ready!(sleep.as_mut().poll(cx)); + let now = Instant::now(); + sleep.reset((now + *period).into()); + Poll::Ready(now) + } + #[cfg(feature = "_rt-async-io")] + IntervalProjected::AsyncIo { mut timer } => { + Poll::Ready(ready!(timer + .as_mut() + .poll_next(cx)) + .expect("BUG: `async_io::Timer::next()` should always yield")) + } + } + } else { + unreachable!() + } + } + } +} + +pin_project! { + pub struct AsTimeout<'i, F> { + interval: Pin<&'i mut Interval>, + #[pin] + future: F, + } +} + +impl Future for AsTimeout<'_, F> +where + F: Future, +{ + type Output = Option; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(out) = this.future.poll(cx) { + return Poll::Ready(Some(out)); + } + + this.interval.as_mut().poll_tick(cx).map(|_| None) + } +} + #[track_caller] pub fn spawn(fut: F) -> JoinHandle where @@ -128,6 +254,29 @@ where } } +pub fn try_spawn(fut: F) -> Option> +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + #[cfg(feature = "_rt-tokio")] + if let Ok(handle) = tokio::runtime::Handle::try_current() { + return Some(JoinHandle::Tokio(handle.spawn(fut))); + } + + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + Some(JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))) + } else if #[cfg(feature = "_rt-smol")] { + Some(JoinHandle::AsyncTask(Some(smol::spawn(fut)))) + } else if #[cfg(feature = "_rt-async-std")] { + Some(JoinHandle::AsyncStd(async_std::task::spawn(fut))) + } else { + None + } + } +} + #[track_caller] pub fn spawn_blocking(f: F) -> JoinHandle where diff --git a/sqlx-core/src/rt/rt_async_io/time.rs b/sqlx-core/src/rt/rt_async_io/time.rs index 039610b7..dbe1d8f7 100644 --- a/sqlx-core/src/rt/rt_async_io/time.rs +++ b/sqlx-core/src/rt/rt_async_io/time.rs @@ -1,13 +1,10 @@ +use crate::ext::future::race; +use crate::rt::TimeoutError; use std::{ future::Future, - pin::pin, time::{Duration, Instant}, }; -use futures_util::future::{select, Either}; - -use crate::rt::TimeoutError; - pub async fn sleep(duration: Duration) { async_io::Timer::after(duration).await; } @@ -17,8 +14,16 @@ pub async fn sleep_until(deadline: Instant) { } pub async fn timeout(duration: Duration, future: F) -> Result { - match select(pin!(future), pin!(sleep(duration))).await { - Either::Left((result, _)) => Ok(result), - Either::Right(_) => Err(TimeoutError), - } + race(future, sleep(duration)) + .await + .map_err(|_| TimeoutError) +} + +pub async fn timeout_at( + deadline: Instant, + future: F, +) -> Result { + race(future, sleep_until(deadline)) + .await + .map_err(|_| TimeoutError) } diff --git a/sqlx-core/src/rt/rt_tokio/mod.rs b/sqlx-core/src/rt/rt_tokio/mod.rs index ce699456..364ce3bf 100644 --- a/sqlx-core/src/rt/rt_tokio/mod.rs +++ b/sqlx-core/src/rt/rt_tokio/mod.rs @@ -1,5 +1,6 @@ mod socket; +#[inline(always)] pub fn available() -> bool { tokio::runtime::Handle::try_current().is_ok() } diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 2fd51445..bce8d60c 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -4,11 +4,40 @@ // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. +use std::sync::Arc; #[cfg(feature = "_rt-tokio")] -pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; +pub use tokio::sync::{ + Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, OwnedMutexGuard as AsyncMutexGuardArc, + RwLock as AsyncRwLock, +}; #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] -pub use async_lock::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; +pub use async_lock::{ + Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, MutexGuardArc as AsyncMutexGuardArc, + RwLock as AsyncRwLock, +}; + +pub async fn lock_arc(mutex: &Arc>) -> AsyncMutexGuardArc { + #[cfg(feature = "_rt-tokio")] + return mutex.clone().lock_owned().await; + + #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] + return mutex.lock_arc().await; + + #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] + return crate::rt::missing_rt(mutex); +} + +pub fn try_lock_arc(mutex: &Arc>) -> Option> { + #[cfg(feature = "_rt-tokio")] + return mutex.clone().try_lock_owned().ok(); + + #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] + return mutex.try_lock_arc(); + + #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] + return crate::rt::missing_rt(mutex); +} #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] pub use noop::*; @@ -18,6 +47,7 @@ mod noop { use crate::rt::missing_rt; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; + use std::sync::Arc; pub struct AsyncMutex { // `Sync` if `T: Send` @@ -28,6 +58,10 @@ mod noop { inner: &'a AsyncMutex, } + pub struct AsyncMutexGuardArc { + inner: Arc>, + } + impl AsyncMutex { pub fn new(val: T) -> Self { missing_rt(val) @@ -51,4 +85,18 @@ mod noop { missing_rt(self) } } + + impl Deref for AsyncMutexGuardArc { + type Target = T; + + fn deref(&self) -> &Self::Target { + missing_rt(self) + } + } + + impl DerefMut for AsyncMutexGuardArc { + fn deref_mut(&mut self) -> &mut Self::Target { + missing_rt(self) + } + } } diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 01cdc297..6a8b9d11 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -2,13 +2,15 @@ use sqlx::pool::PoolOptions; use sqlx::{Connection, Database, Error, Pool}; use std::env; use tracing_subscriber::EnvFilter; +use tracing_subscriber::fmt::format::FmtSpan; pub fn setup_if_needed() { let _ = dotenvy::dotenv(); let _ = tracing_subscriber::fmt::Subscriber::builder() .with_env_filter(EnvFilter::from_default_env()) - .with_test_writer() - .finish(); + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + // .with_test_writer() + .try_init(); } // Make a new connection diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 06adf0ca..d6bbebf2 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -255,6 +255,10 @@ async fn it_works_with_cache_disabled() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_executes_with_pool() -> anyhow::Result<()> { + setup_if_needed(); + + tracing::info!("starting test"); + let pool = sqlx_test::pool::().await?; let rows = pool.fetch_all("SELECT 1; SElECT 2").await?; @@ -1146,7 +1150,7 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { assert!(listener.next_buffered().is_none()); // Activate connection. - sqlx::query!("SELECT 1 AS one") + sqlx::query("SELECT 1 AS one") .fetch_all(&mut listener) .await?; @@ -2086,6 +2090,7 @@ async fn test_issue_3052() { } #[sqlx_macros::test] +#[cfg(feature = "chrono")] async fn test_bind_iter() -> anyhow::Result<()> { use sqlx::postgres::PgBindIterExt; use sqlx::types::chrono::{DateTime, Utc};