WIP refactor: replace sharding with single connection set (5)

This commit is contained in:
Austin Bonander 2026-01-04 19:24:44 -08:00
parent 54e842376e
commit 824e27b506

View File

@ -2,7 +2,7 @@ use crate::ext::future::race;
use crate::rt;
use crate::sync::{AsyncMutex, AsyncMutexGuardArc};
use event_listener::{listener, Event, EventListener, IntoNotification};
use futures_core::Stream;
use futures_util::future::{Fuse, FusedFuture};
use futures_util::stream::FuturesUnordered;
use futures_util::{FutureExt, StreamExt};
use std::cmp;
@ -11,7 +11,7 @@ use std::ops::{Deref, DerefMut, RangeInclusive, RangeToInclusive};
use std::pin::{pin, Pin};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
use std::task::{ready, Poll};
use std::time::Duration;
use tracing::Instrument;
@ -121,97 +121,84 @@ impl<C> ConnectionSet<C> {
}
async fn acquire_inner(&self, pref: AcquirePreference) -> SlotGuard<C> {
let span = tracing::trace_span!(
target: "sqlx::pool::connection_set",
"acquire_inner",
preferred_slot = tracing::field::Empty,
?pref,
);
if self.slots.len() == 1 {
span.record("alternate_slot", 0usize);
return self.slots[0].acquire(pref).instrument(span).await;
}
// Always try to lock the connection associated with our thread ID first
let preferred_slot = current_thread_id() % self.slots.len();
span.record("preferred_slot", preferred_slot);
// Always try to lock the connection associated with our thread ID
let mut acquire_preferred = pin!(self.slots[preferred_slot].acquire(pref));
// The number of tasks currently interested in this slot. Always at least 1.
let search_offset = Arc::strong_count(&self.slots[preferred_slot].connection);
let alternate_slot = (preferred_slot + 547usize.wrapping_mul(
Arc::strong_count(&self.slots[preferred_slot].connection)
)) % self.slots.len();
let mut acquire_alternate = pin!(self.slots[alternate_slot].acquire(pref));
let mut listen_global = pin!(self.global.listen(pref));
let mut yielded_1 = false;
let mut yielded_2 = false;
std::future::poll_fn(|cx| {
if let Poll::Ready(locked) = acquire_preferred.as_mut().poll(cx) {
return Poll::Ready(locked);
let acquire_global = pin!(async {
if let Some(locked) = self.try_acquire(pref, preferred_slot.wrapping_add(search_offset))
{
return locked;
}
if let Poll::Ready(locked) = acquire_alternate.as_mut().poll(cx) {
return Poll::Ready(locked);
}
loop {
let slot = self.global.listen(pref).await;
// if !yielded_1 {
// cx.waker().wake_by_ref();
// yielded_1 = true;
// return Poll::Pending;
// }
if let Poll::Ready(slot) = listen_global.as_mut().poll(cx) {
if let Some(locked) = self.slots[slot].try_acquire(pref) {
return Poll::Ready(locked);
if let Some(locked) = self.try_acquire(pref, slot) {
return locked;
}
listen_global.as_mut().set(self.global.listen(pref));
}
});
if !yielded_2 {
cx.waker().wake_by_ref();
yielded_2 = true;
return Poll::Pending;
let res = race(self.slots[preferred_slot].acquire(pref), acquire_global)
.instrument(span.clone())
.await;
let _span = span.enter();
match res {
Ok(preferred) => {
tracing::trace!("acquired from preferred_slot");
preferred
}
if let Some(locked) = self.try_acquire(pref) {
return Poll::Ready(locked);
Err(global) => {
tracing::trace!(slot = global.slot.index, "acquired from acquire_global");
global
}
Poll::Pending
})
.instrument(tracing::trace_span!(
target: "sqlx::pool::connection_set",
"acquire_inner",
preferred_slot,
?pref,
))
.await
}
}
pub fn try_acquire_connected(&self) -> Option<ConnectedSlot<C>> {
Some(
self.try_acquire(AcquirePreference::Connected)?
self.try_acquire(AcquirePreference::Connected, current_thread_id())?
.assert_connected(),
)
}
pub fn try_acquire_disconnected(&self) -> Option<DisconnectedSlot<C>> {
Some(
self.try_acquire(AcquirePreference::Disconnected)?
self.try_acquire(AcquirePreference::Disconnected, current_thread_id())?
.assert_disconnected(),
)
}
fn try_acquire(&self, pref: AcquirePreference) -> Option<SlotGuard<C>> {
let preferred_slot = current_thread_id() % self.slots.len();
fn try_acquire(&self, pref: AcquirePreference, starting_slot: usize) -> Option<SlotGuard<C>> {
let starting_slot = starting_slot % self.slots.len();
let (slots_before, slots_after) = self.slots.split_at(preferred_slot);
let (slots_before, slots_after) = self.global.locked_set.split_at(starting_slot);
let (preferred_slot, slots_after) = slots_after.split_first().unwrap();
if let Some(locked) = preferred_slot.try_acquire(pref) {
return Some(locked);
}
for slot in slots_before.iter().chain(slots_after).rev() {
if self.global.locked_set[slot.index].load(Ordering::Relaxed) {
for (index, locked) in slots_after.iter().chain(slots_before).enumerate() {
if locked.load(Ordering::Relaxed) {
continue;
}
if let Some(locked) = slot.try_acquire(pref) {
let slot = (starting_slot + index) % self.slots.len();
if let Some(locked) = self.slots[slot].try_acquire(pref) {
return Some(locked);
}
}
@ -363,6 +350,7 @@ impl<C> Slot<C> {
let locked = crate::sync::lock_arc(&self.connection).await;
self.locked.store(true, Ordering::Relaxed);
self.global.locked_set[self.index].store(true, Ordering::Relaxed);
SlotGuard {
slot: self.clone(),
@ -374,6 +362,7 @@ impl<C> Slot<C> {
let locked = crate::sync::try_lock_arc(&self.connection)?;
self.locked.store(true, Ordering::Relaxed);
self.global.locked_set[self.index].store(true, Ordering::Relaxed);
Some(SlotGuard {
slot: self.clone(),