WIP refactor: replace sharding with single connection set (6)

This commit is contained in:
Austin Bonander 2026-01-23 13:58:53 -08:00
parent 824e27b506
commit cd8e35546e

View File

@ -16,7 +16,7 @@ use std::time::Duration;
use tracing::Instrument;
pub struct ConnectionSet<C> {
global: Arc<Global>,
global: Arc<Global<C>>,
slots: Box<[Arc<Slot<C>>]>,
}
@ -31,9 +31,9 @@ enum AcquirePreference {
Either,
}
struct Global {
unlock_event: Event<usize>,
disconnect_event: Event<usize>,
struct Global<C> {
unlock_event: Event<ReleaseWithoutNotify<C>>,
disconnect_event: Event<ReleaseWithoutNotify<C>>,
locked_set: Box<[AtomicBool]>,
num_connected: AtomicUsize,
min_connections: usize,
@ -46,10 +46,12 @@ struct SlotGuard<C> {
locked: Option<AsyncMutexGuardArc<Option<C>>>,
}
struct ReleaseWithoutNotify<C>(SlotGuard<C>);
struct Slot<C> {
// By having each `Slot` hold its own reference to `Global`, we can avoid extra contended clones
// which would sap performance
global: Arc<Global>,
global: Arc<Global<C>>,
index: usize,
// I'd love to eliminate this redundant `Arc` but it's likely not possible without `unsafe`
connection: Arc<AsyncMutex<Option<C>>>,
@ -129,40 +131,49 @@ impl<C> ConnectionSet<C> {
);
if self.slots.len() == 1 {
span.record("alternate_slot", 0usize);
span.record("preferred_slot", 0usize);
return self.slots[0].acquire(pref).instrument(span).await;
}
// Always try to lock the connection associated with our thread ID first
let preferred_slot = current_thread_id() % self.slots.len();
let preferred_slot = self.choose_preferred_slot();
span.record("preferred_slot", preferred_slot);
// The number of tasks currently interested in this slot. Always at least 1.
let search_offset = Arc::strong_count(&self.slots[preferred_slot].connection);
let acquire_preferred = self.slots[preferred_slot].acquire(pref);
let acquire_global = async {
// Yielding actually improves performance here.
rt::yield_now().await;
// Since we know `preferred_slot` is locked, we offset our search by the number
// of tasks interested in this slot, which is always at least 1.
let search_offset = Arc::strong_count(&self.slots[preferred_slot]);
let acquire_global = pin!(async {
if let Some(locked) = self.try_acquire(pref, preferred_slot.wrapping_add(search_offset))
{
tracing::trace!(
search_offset,
slot = locked.slot.index,
"acquired from try_acquire"
);
return locked;
}
loop {
let slot = self.global.listen(pref).await;
// Since `acquire_global` is fair, we wait
//rt::sleep(Duration::from_millis(50)).await;
if let Some(locked) = self.try_acquire(pref, slot) {
return locked;
}
}
});
rt::yield_now().await;
let res = race(self.slots[preferred_slot].acquire(pref), acquire_global)
self.global.listen(pref).await
};
let res = race(acquire_preferred, acquire_global)
.instrument(span.clone())
.await;
let _span = span.enter();
match res {
Ok(preferred) => {
tracing::trace!("acquired from preferred_slot");
tracing::trace!(slot = preferred_slot, "acquired from acquire_preferred");
preferred
}
Err(global) => {
@ -186,6 +197,26 @@ impl<C> ConnectionSet<C> {
)
}
/// Find a non-leaked slot starting with the one associated with [`current_thread_id()`].
fn choose_preferred_slot(&self) -> usize {
// Always try to lock the connection associated with our thread ID first
let starting_slot = current_thread_id() % self.slots.len();
let search_slots = (starting_slot..self.slots.len()).chain(0..starting_slot);
for slot in search_slots {
if !self.slots[slot].is_leaked() {
return slot;
}
}
tracing::warn!(
num_slots = self.slots.len(),
"all slots have been leaked! all acquires will time out"
);
starting_slot
}
fn try_acquire(&self, pref: AcquirePreference, starting_slot: usize) -> Option<SlotGuard<C>> {
let starting_slot = starting_slot % self.slots.len();
@ -255,6 +286,46 @@ impl<C> ConnectionSet<C> {
}
}
const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation";
const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`";
impl<C> ConnectedSlot<C> {
pub fn take(mut self) -> (C, DisconnectedSlot<C>) {
let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED);
(conn, self.0.assert_disconnected())
}
}
impl<C> Deref for ConnectedSlot<C> {
type Target = C;
#[inline(always)]
fn deref(&self) -> &Self::Target {
self.0.get().as_ref().expect(EXPECT_CONNECTED)
}
}
impl<C> DerefMut for ConnectedSlot<C> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.get_mut().as_mut().expect(EXPECT_CONNECTED)
}
}
impl<C> DisconnectedSlot<C> {
pub fn put(mut self, conn: C) -> ConnectedSlot<C> {
*self.0.get_mut() = Some(conn);
ConnectedSlot(self.0)
}
pub fn leak(mut self) {
self.0.slot.connected.store(false, Ordering::Relaxed);
self.0.slot.leaked.store(true, Ordering::Release);
// Drop the guard without marking the connection as unlocked
self.0.locked = None;
}
}
impl AcquirePreference {
#[inline(always)]
fn wants_connected(&self, is_connected: bool) -> bool {
@ -411,71 +482,35 @@ impl<C> SlotGuard<C> {
DisconnectedSlot(self)
}
/// Updates `Slot::connected` without notifying the `ConnectionSet`.
///
/// Returns `Some(connected)` or `None` if this guard was already dropped.
fn drop_without_notify(&mut self) -> Option<bool> {
fn release_without_notify(&mut self) -> Option<ReleaseWithoutNotify<C>> {
self.locked.take().map(|locked| {
let connected = locked.is_some();
self.slot.set_is_connected(connected);
self.slot.locked.store(false, Ordering::Release);
self.slot.global.locked_set[self.slot.index].store(false, Ordering::Relaxed);
connected
ReleaseWithoutNotify(SlotGuard {
slot: self.slot.clone(),
locked: Some(locked),
})
})
}
}
const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation";
const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`";
impl<C> ConnectedSlot<C> {
pub fn take(mut self) -> (C, DisconnectedSlot<C>) {
let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED);
(conn, self.0.assert_disconnected())
}
}
impl<C> Deref for ConnectedSlot<C> {
type Target = C;
#[inline(always)]
fn deref(&self) -> &Self::Target {
self.0.get().as_ref().expect(EXPECT_CONNECTED)
}
}
impl<C> DerefMut for ConnectedSlot<C> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.get_mut().as_mut().expect(EXPECT_CONNECTED)
}
}
impl<C> DisconnectedSlot<C> {
pub fn put(mut self, conn: C) -> ConnectedSlot<C> {
*self.0.get_mut() = Some(conn);
ConnectedSlot(self.0)
}
pub fn leak(mut self) {
self.0.slot.leaked.store(true, Ordering::Release);
self.0.drop_without_notify();
}
}
impl<C> Drop for SlotGuard<C> {
fn drop(&mut self) {
let Some(connected) = self.drop_without_notify() else {
let Some(mut guard) = self.release_without_notify() else {
return;
};
let event = if connected {
let connected = guard.is_connected();
let event = if guard.is_connected() {
&self.slot.global.unlock_event
} else {
&self.slot.global.disconnect_event
};
if event.notify(1.tag(self.slot.index).additional()) != 0 {
if event.notify(
1.tag_with(|| ReleaseWithoutNotify(guard.take()))
.additional(),
) != 0
{
return;
}
@ -494,13 +529,47 @@ impl<C> Drop for SlotGuard<C> {
}
}
impl Global {
impl<C> ReleaseWithoutNotify<C> {
fn take(&mut self) -> SlotGuard<C> {
SlotGuard {
slot: self.0.slot.clone(),
locked: Some(
self.0
.locked
.take()
.expect("BUG: `SlotGuard.locked` should not be `None` here"),
),
}
}
fn is_connected(&self) -> bool {
self.0
.locked
.as_ref()
.expect("BUG: `SlotGuard.locked` should not be `None` here")
.is_some()
}
}
impl<C> Drop for ReleaseWithoutNotify<C> {
fn drop(&mut self) {
let Some(locked) = self.0.locked.take() else {
return;
};
self.0.slot.set_is_connected(locked.is_some());
self.0.slot.locked.store(false, Ordering::Release);
self.0.slot.global.locked_set[self.0.slot.index].store(false, Ordering::Relaxed);
}
}
impl<C> Global<C> {
#[inline(always)]
fn num_connected(&self) -> usize {
self.num_connected.load(Ordering::Relaxed)
}
async fn listen(&self, pref: AcquirePreference) -> usize {
async fn listen(&self, pref: AcquirePreference) -> SlotGuard<C> {
match pref {
AcquirePreference::Either => race(self.listen_unlocked(), self.listen_disconnected())
.await
@ -510,14 +579,14 @@ impl Global {
}
}
async fn listen_unlocked(&self) -> usize {
async fn listen_unlocked(&self) -> SlotGuard<C> {
listener!(self.unlock_event => listener);
listener.await
listener.await.take()
}
async fn listen_disconnected(&self) -> usize {
async fn listen_disconnected(&self) -> SlotGuard<C> {
listener!(self.disconnect_event => listener);
listener.await
listener.await.take()
}
}