From 9a8a4f96c2f5af150bd129ad810fefad53500c3c Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 3 Sep 2025 02:50:35 -0700 Subject: [PATCH] WIP feat: integrate sharding into pool --- sqlx-core/Cargo.toml | 18 +++++++- sqlx-core/src/pool/connect.rs | 5 +++ sqlx-core/src/pool/inner.rs | 3 ++ sqlx-core/src/pool/options.rs | 81 +++++++++++++++++++++++++++++++++++ sqlx-core/src/pool/shard.rs | 60 +++++++++++++++++++++----- 5 files changed, 154 insertions(+), 13 deletions(-) diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 61c7387a..e9f70859 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -106,8 +106,22 @@ thiserror.workspace = true ease-off = { workspace = true, features = ["futures"] } pin-project-lite = "0.2.14" -[dependencies.parking_lot] -version = "0.12.4" +# N.B. we don't actually utilize spinlocks, we just need a `Mutex` type with a few requirements: +# * Guards that are `Send` (so `parking_lot` and `std::sync` are non-starters) +# * Guards that can use `Arc` and so don't borrow (which is provided by `lock_api`) +# +# Where we actually use this (in `sqlx-core/src/pool/shard.rs`), we don't rely on the mutex itself for anything but +# safe shared mutability. The `Shard` structure has its own synchronization, and only uses `Mutex::try_lock()`. +# +# We *could* use either `tokio::sync::Mutex` or `async_lock::Mutex` for this, but those have all the code for the +# async support, which we don't need. +[dependencies.spin] +version = "0.10.0" +default-features = false +features = ["mutex", "lock_api", "spin_mutex"] + +[dependencies.lock_api] +version = "0.4.13" features = ["arc_lock"] [dev-dependencies] diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index ee805914..63c87987 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -220,6 +220,11 @@ pub trait PoolConnector: Send + Sync + 'static { ) -> impl Future> + Send + '_; } +/// # Note: Future Changes (FIXME) +/// This could theoretically be replaced with an impl over `AsyncFn` to allow lending closures, +/// except we have no way to put the `Send` bound on the returned future. +/// +/// We need Return Type Notation for that: https://github.com/rust-lang/rust/pull/138424 impl PoolConnector for F where DB: Database, diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 5a37f17e..e3aee6a3 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -13,6 +13,7 @@ use std::task::ready; use crate::logger::private_level_filter_to_trace_level; use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector}; use crate::pool::idle::IdleQueue; +use crate::pool::shard::Sharded; use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; @@ -24,6 +25,7 @@ use tracing::Level; pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, + pub(super) sharded: Sharded, pub(super) idle: IdleQueue, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, @@ -40,6 +42,7 @@ impl PoolInner { let pool = Self { connector: DynConnector::new(connector), counter: ConnectionCounter::new(), + sharded: Sharded::new(options.max_connections, options.shards), idle: IdleQueue::new(options.fair, options.max_connections), is_closed: AtomicBool::new(false), on_closed: event_listener::Event::new(), diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 9775799f..0e8e05b4 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -7,6 +7,7 @@ use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; +use std::num::NonZero; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -68,6 +69,7 @@ pub struct PoolOptions { >, >, pub(crate) max_connections: usize, + pub(crate) shards: NonZero, pub(crate) acquire_time_level: LevelFilter, pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, @@ -91,6 +93,7 @@ impl Clone for PoolOptions { before_acquire: self.before_acquire.clone(), after_release: self.after_release.clone(), max_connections: self.max_connections, + shards: self.shards, acquire_time_level: self.acquire_time_level, acquire_slow_threshold: self.acquire_slow_threshold, acquire_slow_level: self.acquire_slow_level, @@ -143,6 +146,7 @@ impl PoolOptions { // A production application will want to set a higher limit than this. max_connections: 10, min_connections: 0, + shards: NonZero::::MIN, // Logging all acquires is opt-in acquire_time_level: LevelFilter::Off, // Default to warning, because an acquire timeout will be an error @@ -206,6 +210,58 @@ impl PoolOptions { self.min_connections } + /// Set the number of shards to split the internal structures into. + /// + /// The default value is dynamically determined based on the configured number of worker threads + /// in the current runtime (if that information is available), + /// or [`std::thread::available_parallelism()`], + /// or 1 otherwise. + /// + /// Each shard is assigned an equal share of [`max_connections`][Self::max_connections] + /// and its own queue of tasks waiting to acquire a connection. + /// + /// Then, when accessing the pool, each thread selects a "local" shard based on its + /// [thread ID][std::thread::Thread::id]1. + /// + /// If the number of shards equals the number of threads (which they do by default), + /// and worker threads are spawned sequentially (which they generally are), + /// each thread should access a different shard, which should significantly reduce + /// cache coherence overhead on multicore systems. + /// + /// If the number of shards does not evenly divide `max_connections`, + /// the implementation makes a best-effort to distribute them as evenly as possible + /// (if `remainder = max_connections % shards` and `remainder != 0`, + /// then `remainder` shards will get one additional connection each). + /// + /// The implementation then clamps the number of connections in a shard to the range `[1, 64]`. + /// + /// ### Details + /// When a task calls [`Pool::acquire()`] (or any other method that calls `acquire()`), + /// it will first attempt to acquire a connection from its thread-local shard, or lock an empty + /// slot to open a new connection (acquiring an idle connection and opening a new connection + /// happen concurrently to minimize acquire time). + /// + /// Failing that, it joins the wait list on the shard. Released connections are passed to + /// waiting tasks in a first-come, first-serve order per shard. + /// + /// If the task cannot acquire a connection after a short delay, + /// it tries to acquire a connection from another shard. + /// + /// If the task _still_ cannot acquire a connection after a longer delay, + /// it joins a global wait list. Tasks in the global wait list are the highest priority + /// for released connections, implementing a kind of eventual fairness. + /// + /// 1 because, as of writing, [`std::thread::ThreadId::as_u64`] is unstable, + /// the current implementation assigns each thread its own sequential ID in a `thread_local!()`. + pub fn shards(mut self, shards: NonZero) -> Self { + self.shards = shards; + self + } + + pub fn get_shards(&self) -> usize { + self.shards.get() + } + /// Enable logging of time taken to acquire a connection from the connection pool via /// [`Pool::acquire()`]. /// @@ -572,3 +628,28 @@ impl Debug for PoolOptions { .finish() } } + +fn default_shards() -> NonZero { + #[cfg(feature = "_rt-tokio")] + if let Ok(rt) = tokio::runtime::Handle::try_current() { + return rt + .metrics() + .num_workers() + .try_into() + .unwrap_or(NonZero::::MIN); + } + + #[cfg(feature = "_rt-async-std")] + if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT") + .ok() + .and_then(|s| s.parse()) + { + return val; + } + + if let Ok(val) = std::thread::available_parallelism() { + return val; + } + + NonZero::::MIN +} diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index 242635e1..a0bcee22 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -1,6 +1,6 @@ use event_listener::{Event, IntoNotification}; -use parking_lot::Mutex; use std::future::Future; +use std::num::NonZero; use std::pin::pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -8,6 +8,8 @@ use std::task::Poll; use std::time::Duration; use std::{array, iter}; +use spin::lock_api::Mutex; + type ShardId = usize; type ConnectionIndex = usize; @@ -15,7 +17,11 @@ type ConnectionIndex = usize; /// /// We want tasks to acquire from their local shards where possible, so they don't enter /// the global queue immediately. -const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(5); +const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(10); + +/// Delay before attempting to acquire from a non-local shard, +/// as well as the backoff when iterating through shards. +const NON_LOCAL_ACQUIRE_DELAY: Duration = Duration::from_micros(100); pub struct Sharded { shards: Box<[ArcShard]>, @@ -29,11 +35,10 @@ struct Global { disconnect_event: Event>, } -type ArcMutexGuard = parking_lot::ArcMutexGuard>; +type ArcMutexGuard = lock_api::ArcMutexGuard, Option>; pub struct LockGuard { - // `Option` allows us to drop the guard before sending the notification. - // Otherwise, if the receiver wakes too quickly, it might fail to lock the mutex. + // `Option` allows us to take the guard in the drop handler. locked: Option>, shard: ArcShard, index: ConnectionIndex, @@ -73,13 +78,13 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { }; impl Sharded { - pub fn new(connections: usize, shards: usize) -> Sharded { + pub fn new(connections: usize, shards: NonZero) -> Sharded { let global = Arc::new(Global { unlock_event: Event::with_tag(), disconnect_event: Event::with_tag(), }); - let shards = Params::calc(connections, shards) + let shards = Params::calc(connections, shards.get()) .shard_sizes() .enumerate() .map(|(shard_id, size)| Shard::new(shard_id, size, global.clone())) @@ -89,8 +94,28 @@ impl Sharded { } pub async fn acquire(&self, connected: bool) -> LockGuard { - let mut acquire_local = - pin!(self.shards[thread_id() % self.shards.len()].acquire(connected)); + if self.shards.len() == 1 { + return self.shards[0].acquire(connected).await; + } + + let thread_id = current_thread_id(); + + let mut acquire_local = pin!(self.shards[thread_id % self.shards.len()].acquire(connected)); + + let mut acquire_nonlocal = pin!(async { + let mut next_shard = thread_id; + + loop { + crate::rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; + + // Choose shards pseudorandomly by multiplying with a (relatively) large prime. + next_shard = (next_shard.wrapping_mul(547)) % self.shards.len(); + + if let Some(locked) = self.shards[next_shard].try_acquire(connected) { + return locked; + } + } + }); let mut acquire_global = pin!(async { crate::rt::sleep(GLOBAL_QUEUE_DELAY).await; @@ -113,6 +138,10 @@ impl Sharded { return Poll::Ready(locked); } + if let Poll::Ready(locked) = acquire_nonlocal.as_mut().poll(cx) { + return Poll::Ready(locked); + } + if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) { return Poll::Ready(locked); } @@ -125,6 +154,9 @@ impl Sharded { impl Shard>>]> { fn new(shard_id: ShardId, len: usize, global: Arc>) -> Arc { + // There's no way to create DSTs like this, in `std::sync::Arc`, on stable. + // + // Instead, we coerce from an array. macro_rules! make_array { ($($n:literal),+) => { match len { @@ -206,6 +238,8 @@ impl Shard>>]> { impl Params { fn calc(connections: usize, mut shards: usize) -> Params { + assert_ne!(shards, 0); + let mut shard_size = connections / shards; let mut remainder = connections % shards; @@ -217,7 +251,11 @@ impl Params { } else if shard_size >= MAX_SHARD_SIZE { let new_shards = connections.div_ceil(MAX_SHARD_SIZE); - tracing::debug!(connections, shards, "clamping shard count to {new_shards}"); + tracing::debug!( + connections, + shards, + "shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}" + ); shards = new_shards; shard_size = connections / shards; @@ -239,7 +277,7 @@ impl Params { } } -fn thread_id() -> usize { +fn current_thread_id() -> usize { // FIXME: this can be replaced when this is stabilized: // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 static THREAD_ID: AtomicUsize = AtomicUsize::new(0);