diff --git a/Cargo.lock b/Cargo.lock index d4b7cdc2..7df78920 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -272,9 +272,9 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.4.0" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" dependencies = [ "event-listener 5.4.0", "event-listener-strategy", @@ -1392,7 +1392,7 @@ checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ "futures-core", "futures-sink", - "spin", + "spin 0.9.8", ] [[package]] @@ -2083,7 +2083,7 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin", + "spin 0.9.8", ] [[package]] @@ -3465,6 +3465,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3537,6 +3546,7 @@ dependencies = [ "async-fs", "async-global-executor 3.1.0", "async-io", + "async-lock", "async-std", "async-task", "base64 0.22.1", @@ -3559,13 +3569,14 @@ dependencies = [ "indexmap 2.10.0", "ipnet", "ipnetwork", + "lock_api", "log", "mac_address", "memchr", "native-tls", - "parking_lot", "percent-encoding", "pin-project-lite", + "rand", "rust_decimal", "rustls", "rustls-native-certs", @@ -3574,6 +3585,7 @@ dependencies = [ "sha2", "smallvec", "smol", + "spin 0.10.0", "sqlx", "thiserror 2.0.17", "time", diff --git a/Cargo.toml b/Cargo.toml index 25284eec..c3cb86c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -192,6 +192,7 @@ cfg-if = "1.0.0" thiserror = { version = "2.0.17", default-features = false, features = ["std"] } dotenvy = { version = "0.15.7", default-features = false } ease-off = "0.1.6" +rand = "0.8.5" # Runtimes [workspace.dependencies.async-global-executor] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index e9f70859..dc03c192 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -20,11 +20,12 @@ any = [] json = ["serde", "serde_json"] # for conditional compilation -_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"] -_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration -_rt-async-std = ["async-std", "_rt-async-io", "ease-off/async-io-2"] +_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-lock", "_rt-async-task"] +_rt-async-io = ["async-io", "async-fs", "ease-off/async-io-2"] # see note at async-fs declaration +_rt-async-lock = ["async-lock"] +_rt-async-std = ["async-std", "_rt-async-io", "_rt-async-lock"] _rt-async-task = ["async-task"] -_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"] +_rt-smol = ["smol", "_rt-async-io", "_rt-async-lock", "_rt-async-task"] _rt-tokio = ["tokio", "tokio-stream", "ease-off/tokio"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] @@ -72,6 +73,7 @@ uuid = { workspace = true, optional = true } # work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281 async-fs = { version = "2.1", optional = true } async-io = { version = "2.4.1", optional = true } +async-lock = { version = "3.4.1", optional = true } async-task = { version = "4.7.1", optional = true } base64 = { version = "0.22.0", default-features = false, features = ["std"] } @@ -101,9 +103,10 @@ indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.16.0" +rand.workspace = true thiserror.workspace = true -ease-off = { workspace = true, features = ["futures"] } +ease-off = { workspace = true, default-features = false } pin-project-lite = "0.2.14" # N.B. we don't actually utilize spinlocks, we just need a `Mutex` type with a few requirements: diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 00b1a640..8dfcc92a 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -1,12 +1,12 @@ //! Types for working with errors produced by SQLx. +use crate::database::Database; use std::any::type_name; use std::borrow::Cow; use std::error::Error as StdError; use std::fmt::Display; use std::io; - -use crate::database::Database; +use std::sync::Arc; use crate::type_info::TypeInfo; use crate::types::Type; @@ -104,7 +104,10 @@ pub enum Error { /// /// [`Pool::acquire`]: crate::pool::Pool::acquire #[error("pool timed out while waiting for an open connection")] - PoolTimedOut, + PoolTimedOut { + #[source] + last_connect_error: Option>, + }, /// [`Pool::close`] was called while we were waiting in [`Pool::acquire`]. /// diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 63c87987..52920c6a 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -1,20 +1,30 @@ use crate::connection::{ConnectOptions, Connection}; use crate::database::Database; -use crate::pool::connection::Floating; +use crate::pool::connection::ConnectionInner; use crate::pool::inner::PoolInner; -use crate::pool::PoolConnection; +use crate::pool::{Pool, PoolConnection}; use crate::rt::JoinHandle; -use crate::Error; +use crate::{rt, Error}; use ease_off::EaseOff; -use event_listener::{listener, Event}; +use event_listener::{listener, Event, EventListener}; use std::fmt::{Display, Formatter}; use std::future::Future; use std::ptr; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex, RwLock}; use std::time::Instant; +use crate::pool::shard::DisconnectedSlot; +#[cfg(doc)] +use crate::pool::PoolOptions; +use crate::sync::{AsyncMutex, AsyncMutexGuard}; +use ease_off::core::EaseOffCore; use std::io; +use std::ops::ControlFlow; +use std::pin::{pin, Pin}; +use std::task::{ready, Context, Poll}; + +const EASE_OFF: EaseOffCore = ease_off::Options::new().into_core(); /// Custom connect callback for [`Pool`][crate::pool::Pool]. /// @@ -197,7 +207,7 @@ pub trait PoolConnector: Send + Sync + 'static { /// If this method returns an error that is known to be retryable, it is called again /// in an exponential backoff loop. Retryable errors include, but are not limited to: /// - /// * [`io::ErrorKind::ConnectionRefused`] + /// * [`io::Error`] /// * Database errors for which /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] /// returns `true`. @@ -205,6 +215,8 @@ pub trait PoolConnector: Send + Sync + 'static { /// This error kind is not returned internally and is designed to allow this method to return /// arbitrary error types not otherwise supported. /// + /// This behavior may be customized by overriding [`Self::connect_with_control_flow()`]. + /// /// Manual implementations of this method may also use the signature: /// ```rust,ignore /// async fn connect( @@ -218,6 +230,54 @@ pub trait PoolConnector: Send + Sync + 'static { &self, meta: PoolConnectMetadata, ) -> impl Future> + Send + '_; + + /// Open a connection for the pool, or indicate what to do on an error. + /// + /// This method may return one of the following: + /// + /// * `ControlFlow::Break(Ok(_))` with a successfully established connection. + /// * `ControlFlow::Break(Err(_))` with an error to immediately return to the caller. + /// * `ControlFlow::Continue(_)` with a retryable error. + /// The pool will call this method again in an exponential backoff loop until it succeeds, + /// or the [connect timeout][PoolOptions::connect_timeout] + /// or [acquire timeout][PoolOptions::acquire_timeout] is reached. + /// + /// # Default Implementation + /// This method has a provided implementation by default which calls [`Self::connect()`] + /// and then returns `ControlFlow::Continue` if the error is any of the following: + /// + /// * [`io::Error`] + /// * Database errors for which + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. + /// * [`Error::PoolConnector`] with `retryable: true`. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. + /// + /// A custom backoff loop may be implemented by overriding this method and retrying internally, + /// only returning `ControlFlow::Break` if/when an error should be propagated out to the caller. + /// + /// If this method is overridden and does not call [`Self::connect()`], then the implementation + /// of the latter can be a stub. It is not called internally. + fn connect_with_control_flow( + &self, + meta: PoolConnectMetadata, + ) -> impl Future, Error>> + Send + '_ { + async { + match self.connect(meta).await { + Err(err @ Error::Io(_)) => ControlFlow::Continue(err), + Err(Error::Database(dbe)) if dbe.is_retryable_connect_error() => { + ControlFlow::Continue(Error::Database(dbe)) + } + Err( + err @ Error::PoolConnector { + retryable: true, .. + }, + ) => ControlFlow::Continue(err), + res => ControlFlow::Break(res), + } + } + } } /// # Note: Future Changes (FIXME) @@ -260,8 +320,12 @@ pub struct PoolConnectMetadata { /// /// May be used for reporting purposes, or to implement a custom backoff. pub start: Instant, + + /// The deadline (`start` plus the [connect timeout][PoolOptions::connect_timeout], if set). + pub deadline: Option, + /// The number of attempts that have occurred so far. - pub num_attempts: usize, + pub num_attempts: u32, /// The current size of the pool. pub pool_size: usize, /// The ID of the connection, unique for the pool. @@ -271,7 +335,12 @@ pub struct PoolConnectMetadata { pub struct DynConnector { // We want to spawn the connection attempt as a task anyway connect: Box< - dyn Fn(ConnectionId, ConnectPermit) -> JoinHandle>> + dyn Fn( + Pool, + ConnectionId, + DisconnectedSlot>, + Arc, + ) -> ConnectTask + Send + Sync + 'static, @@ -283,18 +352,90 @@ impl DynConnector { let connector = Arc::new(connector); Self { - connect: Box::new(move |id, permit| { - crate::rt::spawn(connect_with_backoff(id, permit, connector.clone())) + connect: Box::new(move |pool, id, guard, shared| { + ConnectTask::spawn(pool, id, guard, connector.clone(), shared) }), } } pub fn connect( &self, + pool: Pool, id: ConnectionId, - permit: ConnectPermit, - ) -> JoinHandle>> { - (self.connect)(id, permit) + slot: DisconnectedSlot>, + shared: Arc, + ) -> ConnectTask { + (self.connect)(pool, id, slot, shared) + } +} + +pub struct ConnectTask { + handle: JoinHandle>>, + shared: Arc, +} + +pub struct ConnectTaskShared { + cancel_event: Event, + // Using the normal `std::sync::Mutex` because the critical sections are very short; + // we only hold the lock long enough to insert or take the value. + last_error: Mutex>, +} + +impl Future for ConnectTask { + type Output = crate::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.handle).poll(cx) + } +} + +impl ConnectTask { + fn spawn( + pool: Pool, + id: ConnectionId, + guard: DisconnectedSlot>, + connector: Arc>, + shared: Arc, + ) -> Self { + let handle = crate::rt::spawn(connect_with_backoff( + pool, + id, + connector, + guard, + shared.clone(), + )); + + Self { handle, shared } + } + + pub fn cancel(&self) -> Option { + self.shared.cancel_event.notify(1); + + self.shared + .last_error + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + } +} + +impl ConnectTaskShared { + pub fn new_arc() -> Arc { + Arc::new(Self { + cancel_event: Event::new(), + last_error: Mutex::new(None), + }) + } + + pub fn take_error(&self) -> Option { + self.last_error + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + } + + fn put_error(&self, error: Error) { + *self.last_error.lock().unwrap_or_else(|e| e.into_inner()) = Some(error); } } @@ -308,6 +449,14 @@ pub struct ConnectionCounter { #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct ConnectionId(usize); +impl ConnectionId { + pub(super) fn next() -> ConnectionId { + static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + + ConnectionId(NEXT_ID.fetch_add(1, Ordering::AcqRel)) + } +} + impl ConnectionCounter { pub fn new() -> Self { Self { @@ -456,41 +605,131 @@ impl Display for ConnectionId { err )] async fn connect_with_backoff( + pool: Pool, connection_id: ConnectionId, - permit: ConnectPermit, connector: Arc>, + slot: DisconnectedSlot>, + shared: Arc, ) -> crate::Result> { - if permit.pool().is_closed() { - return Err(Error::PoolClosed); - } + listener!(pool.0.on_closed => closed); + listener!(shared.cancel_event => cancelled); - let mut ease_off = EaseOff::start_timeout(permit.pool().options.connect_timeout); + let start = Instant::now(); + let deadline = pool + .0 + .options + .connect_timeout + .and_then(|timeout| start.checked_add(timeout)); - for attempt in 1usize.. { + for attempt in 1u32.. { let meta = PoolConnectMetadata { - start: ease_off.started_at(), + start, + deadline, num_attempts: attempt, - pool_size: permit.pool().size(), + pool_size: pool.size(), connection_id, }; - let conn = ease_off - .try_async(connector.connect(meta)) - .await - .or_retry_if(|e| can_retry_error(e.inner()))?; + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds=start.elapsed().as_secs_f64(), + "beginning connection attempt" + ); - if let Some(conn) = conn { - return Ok(Floating::new_live(conn, connection_id, permit).reattach()); + let res = connector.connect_with_control_flow(meta).await; + + let now = Instant::now(); + let elapsed = now.duration_since(start); + let elapsed_seconds = elapsed.as_secs_f64(); + + match res { + ControlFlow::Break(Ok(conn)) => { + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + "connection established", + ); + + return Ok(PoolConnection::new( + slot.put(ConnectionInner { + raw: conn, + id: connection_id, + created_at: now, + last_released_at: now, + }), + pool.0.clone(), + )); + } + ControlFlow::Break(Err(e)) => { + tracing::warn!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + error=?e, + "error connecting to database", + ); + + return Err(e); + } + ControlFlow::Continue(e) => { + tracing::warn!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + error=?e, + "error connecting to database; retrying", + ); + + shared.put_error(e); + } + } + + let wait = EASE_OFF + .nth_retry_at(attempt, now, deadline, &mut rand::thread_rng()) + .map_err(|_| { + Error::PoolTimedOut { + // This should be populated by the caller + last_connect_error: None, + } + })?; + + if let Some(wait) = wait { + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + "waiting for {:?}", + wait.duration_since(now), + ); + + let mut sleep = pin!(rt::sleep_until(wait)); + + std::future::poll_fn(|cx| { + if let Poll::Ready(()) = Pin::new(&mut closed).poll(cx) { + return Poll::Ready(Err(Error::PoolClosed)); + } + + if let Poll::Ready(()) = Pin::new(&mut cancelled).poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: None, + })); + } + + ready!(sleep.as_mut().poll(cx)); + Poll::Ready(Ok(())) + }) + .await?; } } - Err(Error::PoolTimedOut) -} - -fn can_retry_error(e: &Error) -> bool { - match e { - Error::Io(e) if e.kind() == io::ErrorKind::ConnectionRefused => true, - Error::Database(e) => e.is_retryable_connect_error(), - _ => false, - } + Err(Error::PoolTimedOut { + last_connect_error: None, + }) } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 2ab315fa..8d115818 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -9,8 +9,10 @@ use crate::database::Database; use crate::error::Error; use super::inner::{is_beyond_max_lifetime, PoolInner}; -use crate::pool::connect::{ConnectPermit, ConnectionId}; +use crate::pool::connect::{ConnectPermit, ConnectTaskShared, ConnectionId}; use crate::pool::options::PoolConnectionMetadata; +use crate::pool::shard::{ConnectedSlot, DisconnectedSlot}; +use crate::pool::Pool; use crate::rt; const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); @@ -20,26 +22,16 @@ const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); /// /// Will be returned to the pool on-drop. pub struct PoolConnection { - live: Option>, - close_on_drop: bool, + conn: Option>>, pub(crate) pool: Arc>, + close_on_drop: bool, } -pub(super) struct Live { +pub(super) struct ConnectionInner { pub(super) raw: DB::Connection, pub(super) id: ConnectionId, pub(super) created_at: Instant, -} - -pub(super) struct Idle { - pub(super) live: Live, - pub(super) idle_since: Instant, -} - -/// RAII wrapper for connections being handled by functions that may drop them -pub(super) struct Floating { - pub(super) inner: C, - pub(super) permit: ConnectPermit, + pub(super) last_released_at: Instant, } const EXPECT_MSG: &str = "BUG: inner connection already taken!"; @@ -48,7 +40,7 @@ impl Debug for PoolConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("PoolConnection") .field("database", &DB::NAME) - .field("id", &self.live.as_ref().map(|live| live.id)) + .field("id", &self.conn.as_ref().map(|live| live.id)) .finish() } } @@ -57,13 +49,13 @@ impl Deref for PoolConnection { type Target = DB::Connection; fn deref(&self) -> &Self::Target { - &self.live.as_ref().expect(EXPECT_MSG).raw + &self.conn.as_ref().expect(EXPECT_MSG).raw } } impl DerefMut for PoolConnection { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live.as_mut().expect(EXPECT_MSG).raw + &mut self.conn.as_mut().expect(EXPECT_MSG).raw } } @@ -80,6 +72,14 @@ impl AsMut for PoolConnection { } impl PoolConnection { + pub(super) fn new(live: ConnectedSlot>, pool: Arc>) -> Self { + Self { + conn: Some(live), + close_on_drop: false, + pool, + } + } + /// Close this connection, allowing the pool to open a replacement. /// /// Equivalent to calling [`.detach()`] then [`.close()`], but the connection permit is retained @@ -88,8 +88,8 @@ impl PoolConnection { /// [`.detach()`]: PoolConnection::detach /// [`.close()`]: Connection::close pub async fn close(mut self) -> Result<(), Error> { - let floating = self.take_live().float(self.pool.clone()); - floating.inner.raw.close().await + let (res, _slot) = close(self.take_conn()).await; + res } /// Close this connection on-drop, instead of returning it to the pool. @@ -115,7 +115,8 @@ impl PoolConnection { /// [`max_connections`]: crate::pool::PoolOptions::max_connections /// [`min_connections`]: crate::pool::PoolOptions::min_connections pub fn detach(mut self) -> DB::Connection { - self.take_live().float(self.pool.clone()).detach() + let (conn, _slot) = ConnectedSlot::take(self.take_conn()); + conn.raw } /// Detach this connection from the pool, treating it as permanently checked-out. @@ -124,15 +125,13 @@ impl PoolConnection { /// /// If you don't want to impact the pool's capacity, use [`.detach()`][Self::detach] instead. pub fn leak(mut self) -> DB::Connection { - self.take_live().raw + let (conn, slot) = ConnectedSlot::take(self.take_conn()); + DisconnectedSlot::leak(slot); + conn.raw } - fn take_live(&mut self) -> Live { - self.live.take().expect(EXPECT_MSG) - } - - pub(super) fn into_floating(mut self) -> Floating> { - self.take_live().float(self.pool.clone()) + fn take_conn(&mut self) -> ConnectedSlot> { + self.conn.take().expect(EXPECT_MSG) } /// Test the connection to make sure it is still live before returning it to the pool. @@ -140,48 +139,30 @@ impl PoolConnection { /// This effectively runs the drop handler eagerly instead of spawning a task to do it. #[doc(hidden)] pub fn return_to_pool(&mut self) -> impl Future + Send + 'static { - // float the connection in the pool before we move into the task - // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime - // https://github.com/launchbadge/sqlx/issues/1396 - // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). - let floating: Option>> = - self.live.take().map(|live| live.float(self.pool.clone())); - + let conn = self.conn.take(); let pool = self.pool.clone(); async move { - let returned_to_pool = if let Some(floating) = floating { - rt::timeout(RETURN_TO_POOL_TIMEOUT, floating.return_to_pool()) - .await - .unwrap_or(false) - } else { - false + let Some(conn) = conn else { + return; }; - if !returned_to_pool { - pool.min_connections_maintenance(None).await; - } + rt::timeout(RETURN_TO_POOL_TIMEOUT, return_to_pool(conn, &pool)) + .await + // Dropping of the `slot` will check if the connection must be re-established + // but only after trying to pass it to a task that needs it. + .ok(); } } fn take_and_close(&mut self) -> impl Future + Send + 'static { - // float the connection in the pool before we move into the task - // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime - // https://github.com/launchbadge/sqlx/issues/1396 - // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). - let floating = self.live.take().map(|live| live.float(self.pool.clone())); - - let pool = self.pool.clone(); + let conn = self.conn.take(); async move { - if let Some(floating) = floating { + if let Some(conn) = conn { // Don't hold the connection forever if it hangs while trying to close - crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close()) - .await - .ok(); + rt::timeout(CLOSE_ON_DROP_TIMEOUT, close(conn)).await.ok(); } - - pool.min_connections_maintenance(None).await; } } } @@ -214,205 +195,21 @@ impl Drop for PoolConnection { } // We still need to spawn a task to maintain `min_connections`. - if self.live.is_some() || self.pool.options.min_connections > 0 { + if self.conn.is_some() || self.pool.options.min_connections > 0 { crate::rt::spawn(self.return_to_pool()); } } } -impl Live { - pub fn float(self, pool: Arc>) -> Floating { - Floating { - inner: self, - // create a new guard from a previously leaked permit - permit: ConnectPermit::float_existing(pool), - } - } - - pub fn into_idle(self) -> Idle { - Idle { - live: self, - idle_since: Instant::now(), - } - } -} - -impl Deref for Idle { - type Target = Live; - - fn deref(&self) -> &Self::Target { - &self.live - } -} - -impl DerefMut for Idle { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live - } -} - -impl Floating> { - pub fn new_live(conn: DB::Connection, id: ConnectionId, permit: ConnectPermit) -> Self { - Self { - inner: Live { - raw: conn, - id, - created_at: Instant::now(), - }, - permit, - } - } - - pub fn reattach(self) -> PoolConnection { - let Floating { inner, permit } = self; - - let pool = Arc::clone(permit.pool()); - - permit.consume(); - PoolConnection { - live: Some(inner), - close_on_drop: false, - pool, - } - } - - pub fn release(self) { - self.permit.pool().clone().release(self); - } - - /// Return the connection to the pool. - /// - /// Returns `true` if the connection was successfully returned, `false` if it was closed. - async fn return_to_pool(mut self) -> bool { - // Immediately close the connection. - if self.permit.pool().is_closed() { - self.close().await; - return false; - } - - // If the connection is beyond max lifetime, close the connection and - // immediately create a new connection - if is_beyond_max_lifetime(&self.inner, &self.permit.pool().options) { - self.close().await; - return false; - } - - if let Some(test) = &self.permit.pool().options.after_release { - let meta = self.metadata(); - match (test)(&mut self.inner.raw, meta).await { - Ok(true) => (), - Ok(false) => { - self.close().await; - return false; - } - Err(error) => { - tracing::warn!(%error, "error from `after_release`"); - // Connection is broken, don't try to gracefully close as - // something weird might happen. - self.close_hard().await; - return false; - } - } - } - - // test the connection on-release to ensure it is still viable, - // and flush anything time-sensitive like transaction rollbacks - // if an Executor future/stream is dropped during an `.await` call, the connection - // is likely to be left in an inconsistent state, in which case it should not be - // returned to the pool; also of course, if it was dropped due to an error - // this is simply a band-aid as SQLx-next connections should be able - // to recover from cancellations - if let Err(error) = self.raw.ping().await { - tracing::warn!( - %error, - "error occurred while testing the connection on-release", - ); - - // Connection is broken, don't try to gracefully close. - self.close_hard().await; - false - } else { - // if the connection is still viable, release it to the pool - self.release(); - true - } - } - - pub async fn close(self) { - // This isn't used anywhere that we care about the return value - let _ = self.inner.raw.close().await; - - // `guard` is dropped as intended - } - - pub async fn close_hard(self) { - let _ = self.inner.raw.close_hard().await; - } - - pub fn detach(self) -> DB::Connection { - self.inner.raw - } - - pub fn into_idle(self) -> Floating> { - Floating { - inner: self.inner.into_idle(), - permit: self.permit, - } - } - +impl ConnectionInner { pub fn metadata(&self) -> PoolConnectionMetadata { PoolConnectionMetadata { age: self.created_at.elapsed(), idle_for: Duration::ZERO, } } -} -impl Floating> { - pub fn from_idle(idle: Idle, pool: Arc>) -> Self { - Self { - inner: idle, - permit: ConnectPermit::float_existing(pool), - } - } - - pub async fn ping(&mut self) -> Result<(), Error> { - self.live.raw.ping().await - } - - pub fn into_live(self) -> Floating> { - Floating { - inner: self.inner.live, - permit: self.permit, - } - } - - pub async fn close(self) -> (ConnectionId, ConnectPermit) { - let connection_id = self.inner.live.id; - - tracing::debug!(%connection_id, "closing connection (gracefully)"); - - if let Err(error) = self.inner.live.raw.close().await { - tracing::debug!( - %connection_id, - %error, - "error occurred while closing the pool connection" - ); - } - (connection_id, self.permit) - } - - pub async fn close_hard(self) -> (ConnectionId, ConnectPermit) { - let connection_id = self.inner.live.id; - - tracing::debug!(%connection_id, "closing connection (hard)"); - - let _ = self.inner.live.raw.close_hard().await; - - (connection_id, self.permit) - } - - pub fn metadata(&self) -> PoolConnectionMetadata { + pub fn idle_metadata(&self) -> PoolConnectionMetadata { // Use a single `now` value for consistency. let now = Instant::now(); @@ -420,21 +217,113 @@ impl Floating> { // NOTE: the receiver is the later `Instant` and the arg is the earlier // https://github.com/launchbadge/sqlx/issues/1912 age: now.saturating_duration_since(self.created_at), - idle_for: now.saturating_duration_since(self.idle_since), + idle_for: now.saturating_duration_since(self.last_released_at), } } } -impl Deref for Floating { - type Target = C; +pub(crate) async fn close( + conn: ConnectedSlot>, +) -> (Result<(), Error>, DisconnectedSlot>) { + let connection_id = conn.id; - fn deref(&self) -> &Self::Target { - &self.inner - } + tracing::debug!(target: "sqlx::pool", %connection_id, "closing connection (gracefully)"); + + let (conn, slot) = ConnectedSlot::take(conn); + + let res = conn.raw.close().await.inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); + + (res, slot) +} +pub(crate) async fn close_hard( + conn: ConnectedSlot>, +) -> (Result<(), Error>, DisconnectedSlot>) { + let connection_id = conn.id; + + tracing::debug!( + target: "sqlx::pool", + %connection_id, + "closing connection (forcefully)" + ); + + let (conn, slot) = ConnectedSlot::take(conn); + + let res = conn.raw.close_hard().await.inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); + + (res, slot) } -impl DerefMut for Floating { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner +/// Return the connection to the pool. +/// +/// Returns `true` if the connection was successfully returned, `false` if it was closed. +async fn return_to_pool( + mut conn: ConnectedSlot>, + pool: &PoolInner, +) -> Result<(), DisconnectedSlot>> { + // Immediately close the connection. + if pool.is_closed() { + let (_res, slot) = close(conn).await; + return Err(slot); + } + + // If the connection is beyond max lifetime, close the connection and + // immediately create a new connection + if is_beyond_max_lifetime(&conn, &pool.options) { + let (_res, slot) = close(conn).await; + return Err(slot); + } + + if let Some(test) = &pool.options.after_release { + let meta = conn.metadata(); + match (test)(&mut conn.raw, meta).await { + Ok(true) => (), + Ok(false) => { + let (_res, slot) = close(conn).await; + return Err(slot); + } + Err(error) => { + tracing::warn!(%error, "error from `after_release`"); + // Connection is broken, don't try to gracefully close as + // something weird might happen. + let (_res, slot) = close_hard(conn).await; + return Err(slot); + } + } + } + + // test the connection on-release to ensure it is still viable, + // and flush anything time-sensitive like transaction rollbacks + // if an Executor future/stream is dropped during an `.await` call, the connection + // is likely to be left in an inconsistent state, in which case it should not be + // returned to the pool; also of course, if it was dropped due to an error + // this is simply a band-aid as SQLx-next connections should be able + // to recover from cancellations + if let Err(error) = conn.raw.ping().await { + tracing::warn!( + %error, + "error occurred while testing the connection on-release", + ); + + // Connection is broken, don't try to gracefully close. + let (_res, slot) = close_hard(conn).await; + Err(slot) + } else { + // if the connection is still viable, release it to the pool + drop(conn); + Ok(()) } } diff --git a/sqlx-core/src/pool/idle.rs b/sqlx-core/src/pool/idle.rs index 8b07b8e7..602ed3c5 100644 --- a/sqlx-core/src/pool/idle.rs +++ b/sqlx-core/src/pool/idle.rs @@ -1,6 +1,6 @@ use crate::connection::Connection; use crate::database::Database; -use crate::pool::connection::{Floating, Idle, Live}; +use crate::pool::connection::{Floating, Idle, ConnectionInner}; use crate::pool::inner::PoolInner; use crossbeam_queue::ArrayQueue; use event_listener::Event; @@ -71,7 +71,7 @@ impl IdleQueue { }) } - pub fn release(&self, conn: Floating>) { + pub fn release(&self, conn: Floating>) { let Floating { inner: conn, permit, diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index e3aee6a3..af9229d4 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -1,32 +1,35 @@ -use super::connection::{Floating, Idle, Live}; +use super::connection::ConnectionInner; use crate::database::Database; use crate::error::Error; -use crate::pool::{CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; +use crate::pool::{connection, CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; use std::cmp; use std::future::Future; -use std::pin::pin; +use std::pin::{pin, Pin}; +use std::rc::Weak; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::task::ready; +use std::task::{ready, Poll}; +use crate::connection::Connection; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector}; -use crate::pool::idle::IdleQueue; -use crate::pool::shard::Sharded; +use crate::pool::connect::{ + ConnectPermit, ConnectTask, ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector, +}; +use crate::pool::shard::{ConnectedSlot, DisconnectedSlot, Sharded}; use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; +use futures_core::FusedFuture; use futures_util::future::{self, OptionFuture}; -use futures_util::FutureExt; +use futures_util::{stream, FutureExt, TryStreamExt}; use std::time::{Duration, Instant}; use tracing::Level; pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, - pub(super) sharded: Sharded, - pub(super) idle: IdleQueue, + pub(super) sharded: Sharded>, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -39,19 +42,38 @@ impl PoolInner { options: PoolOptions, connector: impl PoolConnector, ) -> Arc { - let pool = Self { - connector: DynConnector::new(connector), - counter: ConnectionCounter::new(), - sharded: Sharded::new(options.max_connections, options.shards), - idle: IdleQueue::new(options.fair, options.max_connections), - is_closed: AtomicBool::new(false), - on_closed: event_listener::Event::new(), - acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), - acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), - options, - }; + let pool = Arc::::new_cyclic(|pool_weak| { + let pool_weak = pool_weak.clone(); - let pool = Arc::new(pool); + let reconnect = move |slot| { + let Some(pool) = pool_weak.upgrade() else { + return; + }; + + pool.connector.connect( + Pool(pool.clone()), + ConnectionId::next(), + slot, + ConnectTaskShared::new_arc(), + ); + }; + + Self { + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + sharded: Sharded::new( + options.max_connections, + options.shards, + options.min_connections, + reconnect, + ), + is_closed: AtomicBool::new(false), + on_closed: event_listener::Event::new(), + acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), + acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), + options, + } + }); spawn_maintenance_tasks(&pool); @@ -59,11 +81,11 @@ impl PoolInner { } pub(super) fn size(&self) -> usize { - self.counter.connections() + self.sharded.count_connected() } pub(super) fn num_idle(&self) -> usize { - self.idle.len() + self.sharded.count_unlocked(true) } pub(super) fn is_closed(&self) -> bool { @@ -79,23 +101,13 @@ impl PoolInner { self.mark_closed(); // Keep clearing the idle queue as connections are released until the count reaches zero. - async move { - let mut drained = pin!(self.counter.drain()); + self.sharded.drain(|slot| async move { + let (conn, slot) = ConnectedSlot::take(slot); - loop { - let mut acquire_idle = pin!(self.idle.acquire(self)); + let _ = conn.raw.close().await; - // Not using `futures::select!{}` here because it requires a proc-macro dep, - // and frankly it's a little broken. - match future::select(drained.as_mut(), acquire_idle.as_mut()).await { - // *not* `either::Either`; they rolled their own - future::Either::Left(_) => break, - future::Either::Right((idle, _)) => { - idle.close().await; - } - } - } - } + slot + }) } pub(crate) fn close_event(&self) -> CloseEvent { @@ -109,17 +121,12 @@ impl PoolInner { } #[inline] - pub(super) fn try_acquire(self: &Arc) -> Option>> { + pub(super) fn try_acquire(self: &Arc) -> Option>> { if self.is_closed() { return None; } - self.idle.try_acquire(self) - } - - pub(super) fn release(&self, floating: Floating>) { - // `options.after_release` and other checks are in `PoolConnection::return_to_pool()`. - self.idle.release(floating); + self.sharded.try_acquire_connected() } pub(super) async fn acquire(self: &Arc) -> Result, Error> { @@ -131,91 +138,70 @@ impl PoolInner { let mut close_event = pin!(self.close_event()); let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); - let mut acquire_idle = pin!(self.idle.acquire(self).fuse()); - let mut before_acquire = OptionFuture::from(None); - let mut acquire_connect_permit = pin!(OptionFuture::from(Some( - self.counter.acquire_permit(self).fuse() - ))); - let mut connect = OptionFuture::from(None); - // The internal state machine of `acquire()`. - // - // * The initial state is racing to acquire either an idle connection or a new `ConnectPermit`. - // * If we acquire a `ConnectPermit`, we begin the connection loop (with backoff) - // as implemented by `DynConnector`. - // * If we acquire an idle connection, we then start polling `check_idle_conn()`. - // - // This doesn't quite fit into `select!{}` because the set of futures that may be polled - // at a given time is dynamic, so it's actually simpler to hand-roll it. - let acquired = future::poll_fn(|cx| { - use std::task::Poll::*; + let connect_shared = ConnectTaskShared::new_arc(); - // First check if the pool is already closed, - // or register for a wakeup if it gets closed. - if let Ready(()) = close_event.poll_unpin(cx) { - return Ready(Err(Error::PoolClosed)); + let mut acquire_connected = pin!(self.acquire_connected().fuse()); + + let mut acquire_disconnected = pin!(self.sharded.acquire_disconnected().fuse()); + + let mut connect = future::Fuse::terminated(); + + let acquired = std::future::poll_fn(|cx| loop { + if let Poll::Ready(()) = close_event.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolClosed)); } - // Then check if our deadline has elapsed, or schedule a wakeup for when that happens. - if let Ready(()) = deadline.poll_unpin(cx) { - return Ready(Err(Error::PoolTimedOut)); + if let Poll::Ready(()) = deadline.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: connect_shared.take_error().map(Box::new), + })); } - // Attempt to acquire a connection from the idle queue. - if let Ready(idle) = acquire_idle.poll_unpin(cx) { - // If we acquired an idle connection, run any checks that need to be done. - // - // Includes `test_on_acquire` and the `before_acquire` callback, if set. - match finish_acquire(idle) { - // There are checks needed to be done, so they're spawned as a task - // to be cancellation-safe. - Either::Left(check_task) => { - before_acquire = Some(check_task).into(); + if let Poll::Ready(res) = acquire_connected.as_mut().poll(cx) { + match res { + Ok(conn) => { + return Poll::Ready(Ok(conn)); } - // The connection is ready to go. - Either::Right(conn) => { - return Ready(Ok(conn)); - } - } - } - - // Poll the task returned by `finish_acquire` - match ready!(before_acquire.poll_unpin(cx)) { - Some(Ok(conn)) => return Ready(Ok(conn)), - Some(Err((id, permit))) => { - // We don't strictly need to poll `connect` here; all we really want to do - // is to check if it is `None`. But since currently there's no getter for that, - // it doesn't really hurt to just poll it here. - match connect.poll_unpin(cx) { - Ready(None) => { - // If we're not already attempting to connect, - // take the permit returned from closing the connection and - // attempt to open a new one. - connect = Some(self.connector.connect(id, permit)).into(); + Err(slot) => { + if connect.is_terminated() { + connect = self + .connector + .connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + connect_shared.clone(), + ) + .fuse(); } - // `permit` is dropped in these branches, allowing another task to use it - Ready(Some(res)) => return Ready(res), - Pending => (), + + // Try to acquire another connected connection. + acquire_connected.set(self.acquire_connected().fuse()); + continue; } - - // Attempt to acquire another idle connection concurrently to opening a new one. - acquire_idle.set(self.idle.acquire(self).fuse()); - // Annoyingly, `OptionFuture` doesn't fuse to `None` on its own - before_acquire = None.into(); } - None => (), } - if let Ready(Some((id, permit))) = acquire_connect_permit.poll_unpin(cx) { - connect = Some(self.connector.connect(id, permit)).into(); + if let Poll::Ready(slot) = acquire_disconnected.as_mut().poll(cx) { + if connect.is_terminated() { + connect = self + .connector + .connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + connect_shared.clone(), + ) + .fuse(); + } } - if let Ready(Some(res)) = connect.poll_unpin(cx) { - // RFC: suppress errors here? - return Ready(res); + if let Poll::Ready(res) = Pin::new(&mut connect).poll(cx) { + return Poll::Ready(res); } - Pending + return Poll::Pending; }) .await?; @@ -245,59 +231,66 @@ impl PoolInner { Ok(acquired) } - /// Try to maintain `min_connections`, returning any errors (including `PoolTimedOut`). - pub async fn try_min_connections(self: &Arc, deadline: Instant) -> Result<(), Error> { - rt::timeout_at(deadline, async { - while self.size() < self.options.min_connections { - // Don't wait for a connect permit. - // - // If no extra permits are available then we shouldn't be trying to spin up - // connections anyway. - let Some((id, permit)) = self.counter.try_acquire_permit(self) else { - return Ok(()); - }; + async fn acquire_connected( + self: &Arc, + ) -> Result, DisconnectedSlot>> { + let connected = self.sharded.acquire_connected().await; - let conn = self.connector.connect(id, permit).await?; + tracing::debug!( + target: "sqlx::pool", + connection_id=%connected.id, + "acquired idle connection" + ); - // We skip `after_release` since the connection was never provided to user code - // besides inside `PollConnector::connect()`, if they override it. - self.release(conn.into_floating()); - } - - Ok(()) - }) - .await - .unwrap_or_else(|_| Err(Error::PoolTimedOut)) + match finish_acquire(self, connected) { + Either::Left(task) => task.await, + Either::Right(conn) => Ok(conn), + } } - /// Attempt to maintain `min_connections`, logging if unable. - pub async fn min_connections_maintenance(self: &Arc, deadline: Option) { - let deadline = deadline.unwrap_or_else(|| { - // Arbitrary default deadline if the caller doesn't care. - Instant::now() + Duration::from_secs(300) - }); + pub(crate) async fn try_min_connections(self: &Arc) -> Result<(), Error> { + stream::iter( + self.sharded + .iter_min_connections() + .map(Result::<_, Error>::Ok), + ) + .try_for_each_concurrent(None, |slot| async move { + let shared = ConnectTaskShared::new_arc(); - match self.try_min_connections(deadline).await { - Ok(()) => (), - Err(Error::PoolClosed) => (), - Err(Error::PoolTimedOut) => { - tracing::debug!("unable to complete `min_connections` maintenance before deadline") + let res = self + .connector + .connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + shared.clone(), + ) + .await; + + match res { + Ok(conn) => { + drop(conn); + Ok(()) + } + Err(Error::PoolTimedOut { .. }) => Err(Error::PoolTimedOut { + last_connect_error: shared.take_error().map(Box::new), + }), + Err(other) => Err(other), } - Err(error) => tracing::debug!(%error, "error while maintaining min_connections"), - } + }) + .await } } impl Drop for PoolInner { fn drop(&mut self) { self.mark_closed(); - self.idle.drain(self); } } /// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise. pub(super) fn is_beyond_max_lifetime( - live: &Live, + live: &ConnectionInner, options: &PoolOptions, ) -> bool { options @@ -306,60 +299,69 @@ pub(super) fn is_beyond_max_lifetime( } /// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. -fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions) -> bool { +fn is_beyond_idle_timeout( + idle: &ConnectionInner, + options: &PoolOptions, +) -> bool { options .idle_timeout - .is_some_and(|timeout| idle.idle_since.elapsed() > timeout) + .is_some_and(|timeout| idle.last_released_at.elapsed() > timeout) } /// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. /// /// Otherwise, immediately returns the connection. fn finish_acquire( - mut conn: Floating>, + pool: &Arc>, + mut conn: ConnectedSlot>, ) -> Either< - JoinHandle, (ConnectionId, ConnectPermit)>>, + JoinHandle, DisconnectedSlot>>>, PoolConnection, > { - let pool = conn.permit.pool(); - if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + let pool = pool.clone(); + // Spawn a task so the call may complete even if `acquire()` is cancelled. return Either::Left(rt::spawn(async move { // Check that the connection is still live - if let Err(error) = conn.ping().await { + if let Err(error) = conn.raw.ping().await { // an error here means the other end has hung up or we lost connectivity // either way we're fine to just discard the connection // the error itself here isn't necessarily unexpected so WARN is too strong - tracing::info!(%error, "ping on idle connection returned error"); + tracing::info!(%error, connection_id=%conn.id, "ping on idle connection returned error"); + // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); + let (_res, slot) = connection::close_hard(conn).await; + return Err(slot); } - if let Some(test) = &conn.permit.pool().options.before_acquire { - let meta = conn.metadata(); - match test(&mut conn.inner.live.raw, meta).await { + if let Some(test) = &pool.options.before_acquire { + let meta = conn.idle_metadata(); + match test(&mut conn.raw, meta).await { Ok(false) => { // connection was rejected by user-defined hook, close nicely - return Err(conn.close().await); + let (_res, slot) = connection::close(conn).await; + return Err(slot); } Err(error) => { tracing::warn!(%error, "error from `before_acquire`"); + // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); + let (_res, slot) = connection::close_hard(conn).await; + return Err(slot); } Ok(true) => {} } } - Ok(conn.into_live().reattach()) + Ok(PoolConnection::new(conn, pool)) })); } // No checks are configured, return immediately. - Either::Right(conn.into_live().reattach()) + Either::Right(PoolConnection::new(conn, pool.clone())) } fn spawn_maintenance_tasks(pool: &Arc>) { @@ -376,7 +378,13 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if pool.options.min_connections > 0 { rt::spawn(async move { if let Some(pool) = pool_weak.upgrade() { - pool.min_connections_maintenance(None).await; + if let Err(error) = pool.try_min_connections().await { + tracing::error!( + target: "sqlx::pool", + ?error, + "error maintaining min_connections" + ); + } } }); } @@ -401,31 +409,21 @@ fn spawn_maintenance_tasks(pool: &Arc>) { // Go over all idle connections, check for idleness and lifetime, // and if we have fewer than min_connections after reaping a connection, - // open a new one immediately. Note that other connections may be popped from - // the queue in the meantime - that's fine, there is no harm in checking more - for _ in 0..pool.num_idle() { - if let Some(conn) = pool.try_acquire() { - if is_beyond_idle_timeout(&conn, &pool.options) - || is_beyond_max_lifetime(&conn, &pool.options) - { - let _ = conn.close().await; - pool.min_connections_maintenance(Some(next_run)).await; - } else { - pool.release(conn.into_live()); - } + // open a new one immediately. + for conn in pool.sharded.iter_idle() { + if is_beyond_idle_timeout(&conn, &pool.options) + || is_beyond_max_lifetime(&conn, &pool.options) + { + // Dropping the slot will check if the connection needs to be + // re-made. + let _ = connection::close(conn).await; } } // Don't hold a reference to the pool while sleeping. drop(pool); - if let Some(duration) = next_run.checked_duration_since(Instant::now()) { - // `async-std` doesn't have a `sleep_until()` - rt::sleep(duration).await; - } else { - // `next_run` is in the past, just yield. - rt::yield_now().await; - } + rt::sleep_until(next_run).await; } }) .await; diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 84776d0e..0b8d9452 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -87,7 +87,7 @@ mod connect; mod connection; mod inner; -mod idle; +// mod idle; mod options; mod shard; @@ -369,7 +369,7 @@ impl Pool { /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| conn.into_live().reattach()) + self.0.try_acquire().map(|conn| PoolConnection::new(conn, self.0.clone())) } /// Retrieves a connection and immediately begins a new transaction. diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 0e8e05b4..e3469561 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -1,10 +1,11 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use crate::pool::connect::DefaultConnector; +use crate::pool::connect::{ConnectTaskShared, ConnectionId, DefaultConnector}; use crate::pool::inner::PoolInner; use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; +use futures_util::{stream, TryStreamExt}; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; use std::num::NonZero; @@ -74,7 +75,7 @@ pub struct PoolOptions { pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, pub(crate) acquire_timeout: Duration, - pub(crate) connect_timeout: Duration, + pub(crate) connect_timeout: Option, pub(crate) min_connections: usize, pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, @@ -155,7 +156,7 @@ impl PoolOptions { // to not flag typical time to add a new connection to a pool. acquire_slow_threshold: Duration::from_secs(2), acquire_timeout: Duration::from_secs(30), - connect_timeout: Duration::from_secs(2 * 60), + connect_timeout: None, idle_timeout: Some(Duration::from_secs(10 * 60)), max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, @@ -323,15 +324,15 @@ impl PoolOptions { /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. /// /// If shorter than `acquire_timeout`, this will cause the last connec - pub fn connect_timeout(mut self, timeout: Duration) -> Self { - self.connect_timeout = timeout; + pub fn connect_timeout(mut self, timeout: impl Into>) -> Self { + self.connect_timeout = timeout.into(); self } /// Get the maximum amount of time to spend attempting to open a connection. /// /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. - pub fn get_connect_timeout(&self) -> Duration { + pub fn get_connect_timeout(&self) -> Option { self.connect_timeout } @@ -573,17 +574,6 @@ impl PoolOptions { let inner = PoolInner::new_arc(self, connector); - if inner.options.min_connections > 0 { - // If the idle reaper is spawned then this will race with the call from that task - // and may not report any connection errors. - inner.try_min_connections(deadline).await?; - } - - // If `min_connections` is nonzero then we'll likely just pull a connection - // from the idle queue here, but it should at least get tested first. - let conn = inner.acquire().await?; - inner.release(conn.into_floating()); - Ok(Pool(inner)) } @@ -642,7 +632,7 @@ fn default_shards() -> NonZero { #[cfg(feature = "_rt-async-std")] if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT") .ok() - .and_then(|s| s.parse()) + .and_then(|s| s.parse().ok()) { return val; } diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index a0bcee22..24750e0a 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -1,15 +1,17 @@ -use event_listener::{Event, IntoNotification}; +use crate::rt; +use event_listener::{listener, Event, IntoNotification}; +use futures_util::{future, stream, StreamExt}; +use spin::lock_api::Mutex; use std::future::Future; use std::num::NonZero; +use std::ops::{Deref, DerefMut}; use std::pin::pin; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::Poll; +use std::sync::{atomic, Arc}; +use std::task::{ready, Poll}; use std::time::Duration; use std::{array, iter}; -use spin::lock_api::Mutex; - type ShardId = usize; type ConnectionIndex = usize; @@ -17,7 +19,7 @@ type ConnectionIndex = usize; /// /// We want tasks to acquire from their local shards where possible, so they don't enter /// the global queue immediately. -const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(10); +const GLOBAL_ACQUIRE_DELAY: Duration = Duration::from_millis(10); /// Delay before attempting to acquire from a non-local shard, /// as well as the backoff when iterating through shards. @@ -30,20 +32,27 @@ pub struct Sharded { type ArcShard = Arc>>]>>; -struct Global { - unlock_event: Event>, - disconnect_event: Event>, +struct Global) + Send + Sync + 'static> { + unlock_event: Event>, + disconnect_event: Event>, + min_connections: usize, + num_shards: usize, + do_reconnect: F, } type ArcMutexGuard = lock_api::ArcMutexGuard, Option>; -pub struct LockGuard { +struct SlotGuard { // `Option` allows us to take the guard in the drop handler. locked: Option>, shard: ArcShard, index: ConnectionIndex, } +pub struct ConnectedSlot(SlotGuard); + +pub struct DisconnectedSlot(SlotGuard); + // Align to cache lines. // Simplified from https://docs.rs/crossbeam-utils/0.8.21/src/crossbeam_utils/cache_padded.rs.html#80 // @@ -54,12 +63,15 @@ pub struct LockGuard { #[cfg_attr(not(target_pointer_width = "64"), repr(align(64)))] struct Shard { shard_id: ShardId, - /// Bitset for all connection indexes that are currently in-use. + /// Bitset for all connection indices that are currently in-use. locked_set: AtomicUsize, - /// Bitset for all connection indexes that are currently connected. + /// Bitset for all connection indices that are currently connected. connected_set: AtomicUsize, - unlock_event: Event>, - disconnect_event: Event>, + /// Bitset for all connection indices that have been explicitly leaked. + leaked_set: AtomicUsize, + unlock_event: Event>, + disconnect_event: Event>, + leak_event: Event, global: Arc>, connections: Ts, } @@ -78,13 +90,23 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { }; impl Sharded { - pub fn new(connections: usize, shards: NonZero) -> Sharded { + pub fn new( + connections: usize, + shards: NonZero, + min_connections: usize, + do_reconnect: impl Fn(DisconnectedSlot) + Send + Sync + 'static, + ) -> Sharded { + let params = Params::calc(connections, shards.get()); + let global = Arc::new(Global { unlock_event: Event::with_tag(), disconnect_event: Event::with_tag(), + num_shards: params.shards, + min_connections, + do_reconnect, }); - let shards = Params::calc(connections, shards.get()) + let shards = params .shard_sizes() .enumerate() .map(|(shard_id, size)| Shard::new(shard_id, size, global.clone())) @@ -93,7 +115,60 @@ impl Sharded { Sharded { shards, global } } - pub async fn acquire(&self, connected: bool) -> LockGuard { + #[inline] + pub fn num_shards(&self) -> usize { + self.shards.len() + } + + #[allow(clippy::cast_possible_truncation)] // This is only informational + pub fn count_connected(&self) -> usize { + atomic::fence(Ordering::Acquire); + + self.shards + .iter() + .map(|shard| shard.connected_set.load(Ordering::Relaxed).count_ones() as usize) + .sum() + } + + #[allow(clippy::cast_possible_truncation)] // This is only informational + pub fn count_unlocked(&self, connected: bool) -> usize { + self.shards + .iter() + .map(|shard| shard.unlocked_mask(connected).count_ones() as usize) + .sum() + } + + pub async fn acquire_connected(&self) -> ConnectedSlot { + let guard = self.acquire(true).await; + + assert!( + guard.get().is_some(), + "BUG: expected slot {}/{} to be connected but it wasn't", + guard.shard.shard_id, + guard.index + ); + + ConnectedSlot(guard) + } + + pub fn try_acquire_connected(&self) -> Option> { + todo!() + } + + pub async fn acquire_disconnected(&self) -> DisconnectedSlot { + let guard = self.acquire(true).await; + + assert!( + guard.get().is_some(), + "BUG: expected slot {}/{} NOT to be connected but it WAS", + guard.shard.shard_id, + guard.index + ); + + DisconnectedSlot(guard) + } + + async fn acquire(&self, connected: bool) -> SlotGuard { if self.shards.len() == 1 { return self.shards[0].acquire(connected).await; } @@ -106,7 +181,7 @@ impl Sharded { let mut next_shard = thread_id; loop { - crate::rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; + rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; // Choose shards pseudorandomly by multiplying with a (relatively) large prime. next_shard = (next_shard.wrapping_mul(547)) % self.shards.len(); @@ -118,7 +193,7 @@ impl Sharded { }); let mut acquire_global = pin!(async { - crate::rt::sleep(GLOBAL_QUEUE_DELAY).await; + rt::sleep(GLOBAL_ACQUIRE_DELAY).await; let event_to_listen = if connected { &self.global.unlock_event @@ -150,6 +225,36 @@ impl Sharded { }) .await } + + pub fn iter_min_connections(&self) -> impl Iterator> + '_ { + self.shards + .iter() + .flat_map(|shard| shard.iter_min_connections()) + } + + pub fn iter_idle(&self) -> impl Iterator> + '_ { + self.shards.iter().flat_map(|shard| shard.iter_idle()) + } + + pub async fn drain(&self, close: F) + where + F: Fn(ConnectedSlot) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + T: Send + 'static, + { + let close = Arc::new(close); + + stream::iter(self.shards.iter()) + .for_each_concurrent(None, |shard| { + let shard = shard.clone(); + let close = close.clone(); + + rt::spawn(async move { + shard.drain(&*close).await; + }) + }) + .await; + } } impl Shard>>]> { @@ -163,9 +268,11 @@ impl Shard>>]> { $($n => Arc::new(Shard { shard_id, locked_set: AtomicUsize::new(0), - unlock_event: Event::with_tag(), connected_set: AtomicUsize::new(0), + leaked_set: AtomicUsize::new(0), + unlock_event: Event::with_tag(), disconnect_event: Event::with_tag(), + leak_event: Event::with_tag(), global, connections: array::from_fn::<_, $n, _>(|_| Arc::new(Mutex::new(None))) }),)* @@ -181,7 +288,27 @@ impl Shard>>]> { ) } - async fn acquire(self: &Arc, connected: bool) -> LockGuard { + #[inline] + fn unlocked_mask(&self, connected: bool) -> Mask { + let locked_set = self.locked_set.load(Ordering::Acquire); + let connected_set = self.connected_set.load(Ordering::Relaxed); + + let connected_mask = if connected { + connected_set + } else { + !connected_set + }; + + Mask(!locked_set & connected_mask) + } + + /// Choose the first index that is unlocked with bit `connected` + #[inline] + fn next_unlocked(&self, connected: bool) -> Option { + self.unlocked_mask(connected).next() + } + + async fn acquire(self: &Arc, connected: bool) -> SlotGuard { // Attempt an unfair acquire first, before we modify the waitlist. if let Some(locked) = self.try_acquire(connected) { return locked; @@ -205,35 +332,286 @@ impl Shard>>]> { listener.await } - fn try_acquire(self: &Arc, connected: bool) -> Option> { - let locked_set = self.locked_set.load(Ordering::Acquire); - let connected_set = self.connected_set.load(Ordering::Relaxed); - - let connected_mask = if connected { - connected_set - } else { - !connected_set - }; - - // Choose the first index that is unlocked with bit `connected` - let index = (!locked_set & connected_mask).leading_zeros() as usize; - - self.try_lock(index) + fn try_acquire(self: &Arc, connected: bool) -> Option> { + self.try_lock(self.next_unlocked(connected)?) } - fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { - let locked = self.connections[index].try_lock_arc()?; + fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { + let locked = self.connections.get(index)?.try_lock_arc()?; // The locking of the connection itself must use an `Acquire` fence, // so additional synchronization is unnecessary. atomic_set(&self.locked_set, index, true, Ordering::Relaxed); - Some(LockGuard { + Some(SlotGuard { locked: Some(locked), shard: self.clone(), index, }) } + + fn iter_min_connections(self: &Arc) -> impl Iterator> + '_ { + (0..self.connections.len()) + .filter_map(|index| { + let slot = self.try_lock(index)?; + + // Guard against some weird bug causing this to already be connected + slot.get().is_none().then_some(DisconnectedSlot(slot)) + }) + .take(self.global.shard_min_connections(self.shard_id)) + } + + fn iter_idle(self: &Arc) -> impl Iterator> + '_ { + self.unlocked_mask(true).filter_map(|index| { + let slot = self.try_lock(index)?; + + // Guard against some weird bug causing this to already be connected + slot.get().is_some().then_some(ConnectedSlot(slot)) + }) + } + + async fn drain(self: &Arc, close: F) + where + F: Fn(ConnectedSlot) -> Fut, + Fut: Future>, + { + let mut drain_connected = pin!(async { + loop { + let connected = self.acquire(true).await; + DisconnectedSlot::leak(close(ConnectedSlot(connected)).await); + } + }); + + let mut drain_disconnected = pin!(async { + loop { + let disconnected = DisconnectedSlot(self.acquire(false).await); + DisconnectedSlot::leak(disconnected); + } + }); + + let mut drain_leaked = pin!(async { + loop { + listener!(self.leak_event => leaked); + leaked.await; + } + }); + + let finished_mask = (1usize << self.connections.len()) - 1; + + std::future::poll_fn(|cx| { + // The connection set is drained once all slots are leaked. + if self.leaked_set.load(Ordering::Acquire) == finished_mask { + return Poll::Ready(()); + } + + // These futures shouldn't return `Ready` + let _ = drain_connected.as_mut().poll(cx); + let _ = drain_disconnected.as_mut().poll(cx); + let _ = drain_leaked.as_mut().poll(cx); + + Poll::Pending + }) + .await; + } +} + +impl Deref for ConnectedSlot { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0 + .get() + .as_ref() + .expect("BUG: expected slot to be populated, but it wasn't") + } +} + +impl DerefMut for ConnectedSlot { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + .get_mut() + .as_mut() + .expect("BUG: expected slot to be populated, but it wasn't") + } +} + +impl ConnectedSlot { + pub fn take(mut this: Self) -> (T, DisconnectedSlot) { + let conn = this + .0 + .get_mut() + .take() + .expect("BUG: expected slot to be populated, but it wasn't"); + + (conn, DisconnectedSlot(this.0)) + } +} + +impl DisconnectedSlot { + pub fn put(mut self, connection: T) -> ConnectedSlot { + *self.0.get_mut() = Some(connection); + ConnectedSlot(self.0) + } + + pub fn leak(mut self: Self) { + self.0.locked = None; + + atomic_set( + &self.0.shard.connected_set, + self.0.index, + false, + Ordering::Relaxed, + ); + atomic_set( + &self.0.shard.leaked_set, + self.0.index, + true, + Ordering::Release, + ); + + self.0.shard.leak_event.notify(usize::MAX.tag(self.0.index)); + } + + pub fn should_reconnect(&self) -> bool { + self.0.should_reconnect() + } +} + +impl SlotGuard { + fn get(&self) -> &Option { + self.locked + .as_deref() + .expect("BUG: `SlotGuard.locked` taken") + } + + fn get_mut(&mut self) -> &mut Option { + self.locked + .as_deref_mut() + .expect("BUG: `SlotGuard.locked` taken") + } + + fn should_reconnect(&self) -> bool { + let min_connections = self.shard.global.shard_min_connections(self.shard.shard_id); + + let num_connected = self + .shard + .connected_set + .load(Ordering::Acquire) + .count_ones() as usize; + + num_connected < min_connections + } +} + +impl Drop for SlotGuard { + fn drop(&mut self) { + let Some(locked) = self.locked.take() else { + return; + }; + + let connected = locked.is_some(); + + // Updating the connected flag shouldn't require a fence. + atomic_set( + &self.shard.connected_set, + self.index, + connected, + Ordering::Relaxed, + ); + + // We don't actually unlock the connection unless there's no receivers to accept it. + // If another receiver is waiting for a connection, we can directly pass them the lock. + // + // This prevents drive-by tasks from acquiring connections before waiting tasks + // at high contention, while requiring little synchronization otherwise. + // + // We *could* just pass them the shard ID and/or index, but then we have to handle + // the situation when a receiver was passed a connection that was still marked as locked, + // but was cancelled before it could complete the acquisition. Otherwise, the connection + // would be marked as locked forever, effectively being leaked. + let mut locked = Some(locked); + + // This is a code smell, but it's necessary because `event-listener` has no way to specify + // that a message should *only* be sent once. This means tags either need to be `Clone` + // or provided by a `FnMut()` closure. + // + // Note that there's no guarantee that this closure won't be called more than once by the + // implementation, but the code as of writing should not. + let mut self_as_tag = || { + let locked = locked + .take() + .expect("BUG: notification sent more than once"); + + SlotGuard { + locked: Some(locked), + shard: self.shard.clone(), + index: self.index, + } + }; + + if connected { + // Check for global waiters first. + if self + .shard + .global + .unlock_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 { + return; + } + } else { + if self + .shard + .global + .disconnect_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self + .shard + .disconnect_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + // If this connection is required to satisfy `min_connections` + if self.should_reconnect() { + (self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag())); + return; + } + } + + // Be sure to drop the lock guard if it's still held, + // *before* we semantically release the lock in the bitset. + // + // Otherwise, another task could check and see the connection is free, + // but then fail to lock the mutex for it. + drop(locked); + + atomic_set(&self.shard.locked_set, self.index, false, Ordering::Release); + } +} + +impl Global { + fn shard_min_connections(&self, shard_id: ShardId) -> usize { + let min_connections_per_shard = self.min_connections / self.num_shards; + + if (self.min_connections % self.num_shards) < shard_id { + min_connections_per_shard + 1 + } else { + min_connections_per_shard + } + } } impl Params { @@ -277,6 +655,16 @@ impl Params { } } +fn atomic_set(atomic: &AtomicUsize, index: usize, value: bool, ordering: Ordering) { + if value { + let bit = 1 << index; + atomic.fetch_or(bit, ordering); + } else { + let bit = !(1 << index); + atomic.fetch_and(bit, ordering); + } +} + fn current_thread_id() -> usize { // FIXME: this can be replaced when this is stabilized: // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 @@ -289,106 +677,32 @@ fn current_thread_id() -> usize { CURRENT_THREAD_ID.with(|i| *i) } -impl Drop for LockGuard { - fn drop(&mut self) { - let Some(locked) = self.locked.take() else { - return; - }; +#[derive(Clone, Debug, PartialEq, Eq)] +struct Mask(usize); - let connected = locked.is_some(); - - // Updating the connected flag shouldn't require a fence. - atomic_set( - &self.shard.connected_set, - self.index, - connected, - Ordering::Relaxed, - ); - - // If another receiver is waiting for a connection, we can directly pass them the lock. - // - // This prevents drive-by tasks from acquiring connections before waiting tasks - // at high contention, while requiring little synchronization otherwise. - // - // We *could* just pass them the shard ID and/or index, but then we have to handle - // the situation when a receiver was passed a connection that was still marked as locked, - // but was cancelled before it could complete the acquisition. Otherwise, the connection - // would be marked as locked forever, effectively being leaked. - - let mut locked = Some(locked); - - // This is a code smell, but it's necessary because `event-listener` has no way to specify - // that a message should *only* be sent once. This means tags either need to be `Clone` - // or provided by a `FnMut()` closure. - // - // Note that there's no guarantee that this closure won't be called more than once by the - // implementation, but the code as of writing should not. - let mut self_as_tag = || { - let locked = locked - .take() - .expect("BUG: notification sent more than once"); - - LockGuard { - locked: Some(locked), - shard: self.shard.clone(), - index: self.index, - } - }; - - if connected { - // Check for global waiters first. - if self - .shard - .global - .unlock_event - .notify(1.tag_with(&mut self_as_tag)) - > 0 - { - return; - } - - if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 { - return; - } - } else { - if self - .shard - .global - .disconnect_event - .notify(1.tag_with(&mut self_as_tag)) - > 0 - { - return; - } - - if self - .shard - .disconnect_event - .notify(1.tag_with(&mut self_as_tag)) - > 0 - { - return; - } - } - - // Be sure to drop the lock guard if it's still held, - // *before* we semantically release the lock in the bitset. - // - // Otherwise, another task could check and see the connection is free, - // but then fail to lock the mutex for it. - drop(locked); - - atomic_set(&self.shard.locked_set, self.index, false, Ordering::Release); +impl Mask { + pub fn count_ones(&self) -> usize { + self.0.count_ones() as usize } } -fn atomic_set(atomic: &AtomicUsize, index: usize, value: bool, ordering: Ordering) { - if value { - let bit = 1 >> index; - atomic.fetch_or(bit, ordering); - } else { - let bit = !(1 >> index); - atomic.fetch_and(bit, ordering); +impl Iterator for Mask { + type Item = usize; + + fn next(&mut self) -> Option { + if self.0 == 0 { + return None; + } + + let index = self.0.trailing_zeros() as usize; + self.0 &= 1 << index; + + Some(index) + } + + fn size_hint(&self) -> (usize, Option) { + let count = self.0.count_ones() as usize; + (count, Some(count)) } } diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 0044139f..985d9bb6 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -56,18 +56,18 @@ pub async fn timeout_at(deadline: Instant, f: F) -> Result(fut: F) -> JoinHandle where @@ -186,7 +201,7 @@ pub fn test_block_on(f: F) -> F::Output { #[track_caller] pub const fn missing_rt(_unused: T) -> ! { if cfg!(feature = "_rt-tokio") { - panic!("this functionality requires a Tokio context") + panic!("this functionality requires an active Tokio runtime") } panic!("one of the `runtime` features of SQLx must be enabled") diff --git a/sqlx-core/src/rt/rt_async_io/mod.rs b/sqlx-core/src/rt/rt_async_io/mod.rs index 5e4d7074..70d01fbe 100644 --- a/sqlx-core/src/rt/rt_async_io/mod.rs +++ b/sqlx-core/src/rt/rt_async_io/mod.rs @@ -1,4 +1,4 @@ mod socket; -mod timeout; -pub use timeout::*; +mod time; +pub use time::*; diff --git a/sqlx-core/src/rt/rt_async_io/timeout.rs b/sqlx-core/src/rt/rt_async_io/time.rs similarity index 54% rename from sqlx-core/src/rt/rt_async_io/timeout.rs rename to sqlx-core/src/rt/rt_async_io/time.rs index b4a77907..039610b7 100644 --- a/sqlx-core/src/rt/rt_async_io/timeout.rs +++ b/sqlx-core/src/rt/rt_async_io/time.rs @@ -1,20 +1,24 @@ -use std::{future::Future, pin::pin, time::Duration}; +use std::{ + future::Future, + pin::pin, + time::{Duration, Instant}, +}; use futures_util::future::{select, Either}; use crate::rt::TimeoutError; pub async fn sleep(duration: Duration) { - timeout_future(duration).await; + async_io::Timer::after(duration).await; +} + +pub async fn sleep_until(deadline: Instant) { + async_io::Timer::at(deadline).await; } pub async fn timeout(duration: Duration, future: F) -> Result { - match select(pin!(future), timeout_future(duration)).await { + match select(pin!(future), pin!(sleep(duration))).await { Either::Left((result, _)) => Ok(result), Either::Right(_) => Err(TimeoutError), } } - -fn timeout_future(duration: Duration) -> impl Future { - async_io::Timer::after(duration) -} diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 971752f8..2fd51445 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -4,8 +4,51 @@ // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; - #[cfg(feature = "_rt-tokio")] -pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; +pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; + +#[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] +pub use async_lock::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; + +#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] +pub use noop::*; + +#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] +mod noop { + use crate::rt::missing_rt; + use std::marker::PhantomData; + use std::ops::{Deref, DerefMut}; + + pub struct AsyncMutex { + // `Sync` if `T: Send` + _marker: PhantomData>, + } + + pub struct AsyncMutexGuard<'a, T> { + inner: &'a AsyncMutex, + } + + impl AsyncMutex { + pub fn new(val: T) -> Self { + missing_rt(val) + } + + pub fn lock(&self) -> AsyncMutexGuard { + missing_rt(self) + } + } + + impl Deref for AsyncMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + missing_rt(self) + } + } + + impl DerefMut for AsyncMutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + missing_rt(self) + } + } +}