mirror of
https://github.com/launchbadge/sqlx.git
synced 2026-05-01 06:24:31 +00:00
WIP feat: integrate sharding into pool
This commit is contained in:
@@ -106,8 +106,22 @@ thiserror.workspace = true
|
|||||||
ease-off = { workspace = true, features = ["futures"] }
|
ease-off = { workspace = true, features = ["futures"] }
|
||||||
pin-project-lite = "0.2.14"
|
pin-project-lite = "0.2.14"
|
||||||
|
|
||||||
[dependencies.parking_lot]
|
# N.B. we don't actually utilize spinlocks, we just need a `Mutex` type with a few requirements:
|
||||||
version = "0.12.4"
|
# * Guards that are `Send` (so `parking_lot` and `std::sync` are non-starters)
|
||||||
|
# * Guards that can use `Arc` and so don't borrow (which is provided by `lock_api`)
|
||||||
|
#
|
||||||
|
# Where we actually use this (in `sqlx-core/src/pool/shard.rs`), we don't rely on the mutex itself for anything but
|
||||||
|
# safe shared mutability. The `Shard` structure has its own synchronization, and only uses `Mutex::try_lock()`.
|
||||||
|
#
|
||||||
|
# We *could* use either `tokio::sync::Mutex` or `async_lock::Mutex` for this, but those have all the code for the
|
||||||
|
# async support, which we don't need.
|
||||||
|
[dependencies.spin]
|
||||||
|
version = "0.10.0"
|
||||||
|
default-features = false
|
||||||
|
features = ["mutex", "lock_api", "spin_mutex"]
|
||||||
|
|
||||||
|
[dependencies.lock_api]
|
||||||
|
version = "0.4.13"
|
||||||
features = ["arc_lock"]
|
features = ["arc_lock"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|||||||
@@ -220,6 +220,11 @@ pub trait PoolConnector<DB: Database>: Send + Sync + 'static {
|
|||||||
) -> impl Future<Output = crate::Result<DB::Connection>> + Send + '_;
|
) -> impl Future<Output = crate::Result<DB::Connection>> + Send + '_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # Note: Future Changes (FIXME)
|
||||||
|
/// This could theoretically be replaced with an impl over `AsyncFn` to allow lending closures,
|
||||||
|
/// except we have no way to put the `Send` bound on the returned future.
|
||||||
|
///
|
||||||
|
/// We need Return Type Notation for that: https://github.com/rust-lang/rust/pull/138424
|
||||||
impl<DB, F, Fut> PoolConnector<DB> for F
|
impl<DB, F, Fut> PoolConnector<DB> for F
|
||||||
where
|
where
|
||||||
DB: Database,
|
DB: Database,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use std::task::ready;
|
|||||||
use crate::logger::private_level_filter_to_trace_level;
|
use crate::logger::private_level_filter_to_trace_level;
|
||||||
use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector};
|
use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector};
|
||||||
use crate::pool::idle::IdleQueue;
|
use crate::pool::idle::IdleQueue;
|
||||||
|
use crate::pool::shard::Sharded;
|
||||||
use crate::rt::JoinHandle;
|
use crate::rt::JoinHandle;
|
||||||
use crate::{private_tracing_dynamic_event, rt};
|
use crate::{private_tracing_dynamic_event, rt};
|
||||||
use either::Either;
|
use either::Either;
|
||||||
@@ -24,6 +25,7 @@ use tracing::Level;
|
|||||||
pub(crate) struct PoolInner<DB: Database> {
|
pub(crate) struct PoolInner<DB: Database> {
|
||||||
pub(super) connector: DynConnector<DB>,
|
pub(super) connector: DynConnector<DB>,
|
||||||
pub(super) counter: ConnectionCounter,
|
pub(super) counter: ConnectionCounter,
|
||||||
|
pub(super) sharded: Sharded<DB::Connection>,
|
||||||
pub(super) idle: IdleQueue<DB>,
|
pub(super) idle: IdleQueue<DB>,
|
||||||
is_closed: AtomicBool,
|
is_closed: AtomicBool,
|
||||||
pub(super) on_closed: event_listener::Event,
|
pub(super) on_closed: event_listener::Event,
|
||||||
@@ -40,6 +42,7 @@ impl<DB: Database> PoolInner<DB> {
|
|||||||
let pool = Self {
|
let pool = Self {
|
||||||
connector: DynConnector::new(connector),
|
connector: DynConnector::new(connector),
|
||||||
counter: ConnectionCounter::new(),
|
counter: ConnectionCounter::new(),
|
||||||
|
sharded: Sharded::new(options.max_connections, options.shards),
|
||||||
idle: IdleQueue::new(options.fair, options.max_connections),
|
idle: IdleQueue::new(options.fair, options.max_connections),
|
||||||
is_closed: AtomicBool::new(false),
|
is_closed: AtomicBool::new(false),
|
||||||
on_closed: event_listener::Event::new(),
|
on_closed: event_listener::Event::new(),
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use crate::pool::{Pool, PoolConnector};
|
|||||||
use futures_core::future::BoxFuture;
|
use futures_core::future::BoxFuture;
|
||||||
use log::LevelFilter;
|
use log::LevelFilter;
|
||||||
use std::fmt::{self, Debug, Formatter};
|
use std::fmt::{self, Debug, Formatter};
|
||||||
|
use std::num::NonZero;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
@@ -68,6 +69,7 @@ pub struct PoolOptions<DB: Database> {
|
|||||||
>,
|
>,
|
||||||
>,
|
>,
|
||||||
pub(crate) max_connections: usize,
|
pub(crate) max_connections: usize,
|
||||||
|
pub(crate) shards: NonZero<usize>,
|
||||||
pub(crate) acquire_time_level: LevelFilter,
|
pub(crate) acquire_time_level: LevelFilter,
|
||||||
pub(crate) acquire_slow_level: LevelFilter,
|
pub(crate) acquire_slow_level: LevelFilter,
|
||||||
pub(crate) acquire_slow_threshold: Duration,
|
pub(crate) acquire_slow_threshold: Duration,
|
||||||
@@ -91,6 +93,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
|
|||||||
before_acquire: self.before_acquire.clone(),
|
before_acquire: self.before_acquire.clone(),
|
||||||
after_release: self.after_release.clone(),
|
after_release: self.after_release.clone(),
|
||||||
max_connections: self.max_connections,
|
max_connections: self.max_connections,
|
||||||
|
shards: self.shards,
|
||||||
acquire_time_level: self.acquire_time_level,
|
acquire_time_level: self.acquire_time_level,
|
||||||
acquire_slow_threshold: self.acquire_slow_threshold,
|
acquire_slow_threshold: self.acquire_slow_threshold,
|
||||||
acquire_slow_level: self.acquire_slow_level,
|
acquire_slow_level: self.acquire_slow_level,
|
||||||
@@ -143,6 +146,7 @@ impl<DB: Database> PoolOptions<DB> {
|
|||||||
// A production application will want to set a higher limit than this.
|
// A production application will want to set a higher limit than this.
|
||||||
max_connections: 10,
|
max_connections: 10,
|
||||||
min_connections: 0,
|
min_connections: 0,
|
||||||
|
shards: NonZero::<usize>::MIN,
|
||||||
// Logging all acquires is opt-in
|
// Logging all acquires is opt-in
|
||||||
acquire_time_level: LevelFilter::Off,
|
acquire_time_level: LevelFilter::Off,
|
||||||
// Default to warning, because an acquire timeout will be an error
|
// Default to warning, because an acquire timeout will be an error
|
||||||
@@ -206,6 +210,58 @@ impl<DB: Database> PoolOptions<DB> {
|
|||||||
self.min_connections
|
self.min_connections
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the number of shards to split the internal structures into.
|
||||||
|
///
|
||||||
|
/// The default value is dynamically determined based on the configured number of worker threads
|
||||||
|
/// in the current runtime (if that information is available),
|
||||||
|
/// or [`std::thread::available_parallelism()`],
|
||||||
|
/// or 1 otherwise.
|
||||||
|
///
|
||||||
|
/// Each shard is assigned an equal share of [`max_connections`][Self::max_connections]
|
||||||
|
/// and its own queue of tasks waiting to acquire a connection.
|
||||||
|
///
|
||||||
|
/// Then, when accessing the pool, each thread selects a "local" shard based on its
|
||||||
|
/// [thread ID][std::thread::Thread::id]<sup>1</sup>.
|
||||||
|
///
|
||||||
|
/// If the number of shards equals the number of threads (which they do by default),
|
||||||
|
/// and worker threads are spawned sequentially (which they generally are),
|
||||||
|
/// each thread should access a different shard, which should significantly reduce
|
||||||
|
/// cache coherence overhead on multicore systems.
|
||||||
|
///
|
||||||
|
/// If the number of shards does not evenly divide `max_connections`,
|
||||||
|
/// the implementation makes a best-effort to distribute them as evenly as possible
|
||||||
|
/// (if `remainder = max_connections % shards` and `remainder != 0`,
|
||||||
|
/// then `remainder` shards will get one additional connection each).
|
||||||
|
///
|
||||||
|
/// The implementation then clamps the number of connections in a shard to the range `[1, 64]`.
|
||||||
|
///
|
||||||
|
/// ### Details
|
||||||
|
/// When a task calls [`Pool::acquire()`] (or any other method that calls `acquire()`),
|
||||||
|
/// it will first attempt to acquire a connection from its thread-local shard, or lock an empty
|
||||||
|
/// slot to open a new connection (acquiring an idle connection and opening a new connection
|
||||||
|
/// happen concurrently to minimize acquire time).
|
||||||
|
///
|
||||||
|
/// Failing that, it joins the wait list on the shard. Released connections are passed to
|
||||||
|
/// waiting tasks in a first-come, first-serve order per shard.
|
||||||
|
///
|
||||||
|
/// If the task cannot acquire a connection after a short delay,
|
||||||
|
/// it tries to acquire a connection from another shard.
|
||||||
|
///
|
||||||
|
/// If the task _still_ cannot acquire a connection after a longer delay,
|
||||||
|
/// it joins a global wait list. Tasks in the global wait list are the highest priority
|
||||||
|
/// for released connections, implementing a kind of eventual fairness.
|
||||||
|
///
|
||||||
|
/// <sup>1</sup> because, as of writing, [`std::thread::ThreadId::as_u64`] is unstable,
|
||||||
|
/// the current implementation assigns each thread its own sequential ID in a `thread_local!()`.
|
||||||
|
pub fn shards(mut self, shards: NonZero<usize>) -> Self {
|
||||||
|
self.shards = shards;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_shards(&self) -> usize {
|
||||||
|
self.shards.get()
|
||||||
|
}
|
||||||
|
|
||||||
/// Enable logging of time taken to acquire a connection from the connection pool via
|
/// Enable logging of time taken to acquire a connection from the connection pool via
|
||||||
/// [`Pool::acquire()`].
|
/// [`Pool::acquire()`].
|
||||||
///
|
///
|
||||||
@@ -572,3 +628,28 @@ impl<DB: Database> Debug for PoolOptions<DB> {
|
|||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_shards() -> NonZero<usize> {
|
||||||
|
#[cfg(feature = "_rt-tokio")]
|
||||||
|
if let Ok(rt) = tokio::runtime::Handle::try_current() {
|
||||||
|
return rt
|
||||||
|
.metrics()
|
||||||
|
.num_workers()
|
||||||
|
.try_into()
|
||||||
|
.unwrap_or(NonZero::<usize>::MIN);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "_rt-async-std")]
|
||||||
|
if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse())
|
||||||
|
{
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(val) = std::thread::available_parallelism() {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
NonZero::<usize>::MIN
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use event_listener::{Event, IntoNotification};
|
use event_listener::{Event, IntoNotification};
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
use std::num::NonZero;
|
||||||
use std::pin::pin;
|
use std::pin::pin;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -8,6 +8,8 @@ use std::task::Poll;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::{array, iter};
|
use std::{array, iter};
|
||||||
|
|
||||||
|
use spin::lock_api::Mutex;
|
||||||
|
|
||||||
type ShardId = usize;
|
type ShardId = usize;
|
||||||
type ConnectionIndex = usize;
|
type ConnectionIndex = usize;
|
||||||
|
|
||||||
@@ -15,7 +17,11 @@ type ConnectionIndex = usize;
|
|||||||
///
|
///
|
||||||
/// We want tasks to acquire from their local shards where possible, so they don't enter
|
/// We want tasks to acquire from their local shards where possible, so they don't enter
|
||||||
/// the global queue immediately.
|
/// the global queue immediately.
|
||||||
const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(5);
|
const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(10);
|
||||||
|
|
||||||
|
/// Delay before attempting to acquire from a non-local shard,
|
||||||
|
/// as well as the backoff when iterating through shards.
|
||||||
|
const NON_LOCAL_ACQUIRE_DELAY: Duration = Duration::from_micros(100);
|
||||||
|
|
||||||
pub struct Sharded<T> {
|
pub struct Sharded<T> {
|
||||||
shards: Box<[ArcShard<T>]>,
|
shards: Box<[ArcShard<T>]>,
|
||||||
@@ -29,11 +35,10 @@ struct Global<T> {
|
|||||||
disconnect_event: Event<LockGuard<T>>,
|
disconnect_event: Event<LockGuard<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
type ArcMutexGuard<T> = parking_lot::ArcMutexGuard<parking_lot::RawMutex, Option<T>>;
|
type ArcMutexGuard<T> = lock_api::ArcMutexGuard<spin::Mutex<()>, Option<T>>;
|
||||||
|
|
||||||
pub struct LockGuard<T> {
|
pub struct LockGuard<T> {
|
||||||
// `Option` allows us to drop the guard before sending the notification.
|
// `Option` allows us to take the guard in the drop handler.
|
||||||
// Otherwise, if the receiver wakes too quickly, it might fail to lock the mutex.
|
|
||||||
locked: Option<ArcMutexGuard<T>>,
|
locked: Option<ArcMutexGuard<T>>,
|
||||||
shard: ArcShard<T>,
|
shard: ArcShard<T>,
|
||||||
index: ConnectionIndex,
|
index: ConnectionIndex,
|
||||||
@@ -73,13 +78,13 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
impl<T> Sharded<T> {
|
impl<T> Sharded<T> {
|
||||||
pub fn new(connections: usize, shards: usize) -> Sharded<T> {
|
pub fn new(connections: usize, shards: NonZero<usize>) -> Sharded<T> {
|
||||||
let global = Arc::new(Global {
|
let global = Arc::new(Global {
|
||||||
unlock_event: Event::with_tag(),
|
unlock_event: Event::with_tag(),
|
||||||
disconnect_event: Event::with_tag(),
|
disconnect_event: Event::with_tag(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let shards = Params::calc(connections, shards)
|
let shards = Params::calc(connections, shards.get())
|
||||||
.shard_sizes()
|
.shard_sizes()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(shard_id, size)| Shard::new(shard_id, size, global.clone()))
|
.map(|(shard_id, size)| Shard::new(shard_id, size, global.clone()))
|
||||||
@@ -89,8 +94,28 @@ impl<T> Sharded<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn acquire(&self, connected: bool) -> LockGuard<T> {
|
pub async fn acquire(&self, connected: bool) -> LockGuard<T> {
|
||||||
let mut acquire_local =
|
if self.shards.len() == 1 {
|
||||||
pin!(self.shards[thread_id() % self.shards.len()].acquire(connected));
|
return self.shards[0].acquire(connected).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let thread_id = current_thread_id();
|
||||||
|
|
||||||
|
let mut acquire_local = pin!(self.shards[thread_id % self.shards.len()].acquire(connected));
|
||||||
|
|
||||||
|
let mut acquire_nonlocal = pin!(async {
|
||||||
|
let mut next_shard = thread_id;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
crate::rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await;
|
||||||
|
|
||||||
|
// Choose shards pseudorandomly by multiplying with a (relatively) large prime.
|
||||||
|
next_shard = (next_shard.wrapping_mul(547)) % self.shards.len();
|
||||||
|
|
||||||
|
if let Some(locked) = self.shards[next_shard].try_acquire(connected) {
|
||||||
|
return locked;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut acquire_global = pin!(async {
|
let mut acquire_global = pin!(async {
|
||||||
crate::rt::sleep(GLOBAL_QUEUE_DELAY).await;
|
crate::rt::sleep(GLOBAL_QUEUE_DELAY).await;
|
||||||
@@ -113,6 +138,10 @@ impl<T> Sharded<T> {
|
|||||||
return Poll::Ready(locked);
|
return Poll::Ready(locked);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Poll::Ready(locked) = acquire_nonlocal.as_mut().poll(cx) {
|
||||||
|
return Poll::Ready(locked);
|
||||||
|
}
|
||||||
|
|
||||||
if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) {
|
if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) {
|
||||||
return Poll::Ready(locked);
|
return Poll::Ready(locked);
|
||||||
}
|
}
|
||||||
@@ -125,6 +154,9 @@ impl<T> Sharded<T> {
|
|||||||
|
|
||||||
impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
||||||
fn new(shard_id: ShardId, len: usize, global: Arc<Global<T>>) -> Arc<Self> {
|
fn new(shard_id: ShardId, len: usize, global: Arc<Global<T>>) -> Arc<Self> {
|
||||||
|
// There's no way to create DSTs like this, in `std::sync::Arc`, on stable.
|
||||||
|
//
|
||||||
|
// Instead, we coerce from an array.
|
||||||
macro_rules! make_array {
|
macro_rules! make_array {
|
||||||
($($n:literal),+) => {
|
($($n:literal),+) => {
|
||||||
match len {
|
match len {
|
||||||
@@ -206,6 +238,8 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
|
|||||||
|
|
||||||
impl Params {
|
impl Params {
|
||||||
fn calc(connections: usize, mut shards: usize) -> Params {
|
fn calc(connections: usize, mut shards: usize) -> Params {
|
||||||
|
assert_ne!(shards, 0);
|
||||||
|
|
||||||
let mut shard_size = connections / shards;
|
let mut shard_size = connections / shards;
|
||||||
let mut remainder = connections % shards;
|
let mut remainder = connections % shards;
|
||||||
|
|
||||||
@@ -217,7 +251,11 @@ impl Params {
|
|||||||
} else if shard_size >= MAX_SHARD_SIZE {
|
} else if shard_size >= MAX_SHARD_SIZE {
|
||||||
let new_shards = connections.div_ceil(MAX_SHARD_SIZE);
|
let new_shards = connections.div_ceil(MAX_SHARD_SIZE);
|
||||||
|
|
||||||
tracing::debug!(connections, shards, "clamping shard count to {new_shards}");
|
tracing::debug!(
|
||||||
|
connections,
|
||||||
|
shards,
|
||||||
|
"shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}"
|
||||||
|
);
|
||||||
|
|
||||||
shards = new_shards;
|
shards = new_shards;
|
||||||
shard_size = connections / shards;
|
shard_size = connections / shards;
|
||||||
@@ -239,7 +277,7 @@ impl Params {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn thread_id() -> usize {
|
fn current_thread_id() -> usize {
|
||||||
// FIXME: this can be replaced when this is stabilized:
|
// FIXME: this can be replaced when this is stabilized:
|
||||||
// https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64
|
// https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64
|
||||||
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
|
static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
|||||||
Reference in New Issue
Block a user