chore: replace manual vtable definitions with Wake (#7342)

This commit is contained in:
Tim Vilgot Mikael Fredenberg 2025-05-27 19:28:21 +02:00 committed by GitHub
parent 98f527f42d
commit 4380de9fe9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 106 deletions

View File

@ -26,11 +26,10 @@
//! ```
use std::future::Future;
use std::mem;
use std::ops;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::task::{Context, Poll, Wake, Waker};
use tokio_stream::Stream;
@ -171,7 +170,7 @@ impl MockTask {
F: FnOnce(&mut Context<'_>) -> R,
{
self.waker.clear();
let waker = self.waker();
let waker = self.clone().into_waker();
let mut cx = Context::from_waker(&waker);
f(&mut cx)
@ -190,11 +189,8 @@ impl MockTask {
Arc::strong_count(&self.waker)
}
fn waker(&self) -> Waker {
unsafe {
let raw = to_raw(self.waker.clone());
Waker::from_raw(raw)
}
fn into_waker(self) -> Waker {
self.waker.into()
}
}
@ -226,8 +222,14 @@ impl ThreadWaker {
_ => unreachable!(),
}
}
}
fn wake(&self) {
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
let mut state = self.state.lock().unwrap();
let prev = *state;
@ -247,39 +249,3 @@ impl ThreadWaker {
self.condvar.notify_one();
}
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}
unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
Arc::from_raw(raw as *const ThreadWaker)
}
unsafe fn clone(raw: *const ()) -> RawWaker {
let waker = from_raw(raw);
// Increment the ref count
mem::forget(waker.clone());
to_raw(waker)
}
unsafe fn wake(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
}
unsafe fn wake_by_ref(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
// We don't actually own a reference to the unparker
mem::forget(waker);
}
unsafe fn drop_waker(raw: *const ()) {
let _ = from_raw(raw);
}

View File

@ -2,6 +2,7 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Condvar, Mutex};
use crate::util::{waker, Wake};
use std::sync::atomic::Ordering::SeqCst;
use std::time::Duration;
@ -226,7 +227,7 @@ use crate::loom::thread::AccessError;
use std::future::Future;
use std::marker::PhantomData;
use std::rc::Rc;
use std::task::{RawWaker, RawWakerVTable, Waker};
use std::task::Waker;
/// Blocks the current thread using a condition variable.
#[derive(Debug)]
@ -292,50 +293,20 @@ impl CachedParkThread {
impl UnparkThread {
pub(crate) fn into_waker(self) -> Waker {
unsafe {
let raw = unparker_to_raw_waker(self.inner);
Waker::from_raw(raw)
}
waker(self.inner)
}
}
impl Inner {
#[allow(clippy::wrong_self_convention)]
fn into_raw(this: Arc<Inner>) -> *const () {
Arc::into_raw(this) as *const ()
impl Wake for Inner {
fn wake(arc_self: Arc<Self>) {
arc_self.unpark();
}
unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
Arc::from_raw(ptr as *const Inner)
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.unpark();
}
}
unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
RawWaker::new(
Inner::into_raw(unparker),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
)
}
unsafe fn clone(raw: *const ()) -> RawWaker {
Arc::increment_strong_count(raw as *const Inner);
unparker_to_raw_waker(Inner::from_raw(raw))
}
unsafe fn drop_waker(raw: *const ()) {
drop(Inner::from_raw(raw));
}
unsafe fn wake(raw: *const ()) {
let unparker = Inner::from_raw(raw);
unparker.unpark();
}
unsafe fn wake_by_ref(raw: *const ()) {
let raw = raw as *const Inner;
(*raw).unpark();
}
#[cfg(loom)]
pub(crate) fn current_thread_park_count() -> usize {
CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst))

View File

@ -16,6 +16,9 @@ pub(crate) use blocking_check::check_socket_for_blocking;
pub(crate) mod metric_atomics;
mod wake;
pub(crate) use wake::{waker, Wake};
#[cfg(any(
// io driver uses `WakeList` directly
feature = "net",
@ -66,9 +69,7 @@ cfg_rt! {
pub(crate) use self::rand::RngSeedGenerator;
mod wake;
pub(crate) use wake::WakerRef;
pub(crate) use wake::{waker_ref, Wake};
pub(crate) use wake::{waker_ref, WakerRef};
mod sync_wrapper;
pub(crate) use sync_wrapper::SyncWrapper;

View File

@ -1,8 +1,6 @@
use crate::loom::sync::Arc;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::Deref;
use std::task::{RawWaker, RawWakerVTable, Waker};
/// Simplified waking interface based on Arcs.
@ -14,30 +12,45 @@ pub(crate) trait Wake: Send + Sync + Sized + 'static {
fn wake_by_ref(arc_self: &Arc<Self>);
}
/// A `Waker` that is only valid for a given lifetime.
#[derive(Debug)]
pub(crate) struct WakerRef<'a> {
waker: ManuallyDrop<Waker>,
_p: PhantomData<&'a ()>,
}
cfg_rt! {
use std::marker::PhantomData;
use std::ops::Deref;
impl Deref for WakerRef<'_> {
type Target = Waker;
/// A `Waker` that is only valid for a given lifetime.
#[derive(Debug)]
pub(crate) struct WakerRef<'a> {
waker: ManuallyDrop<Waker>,
_p: PhantomData<&'a ()>,
}
fn deref(&self) -> &Waker {
&self.waker
impl Deref for WakerRef<'_> {
type Target = Waker;
fn deref(&self) -> &Waker {
&self.waker
}
}
/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
let ptr = Arc::as_ptr(wake).cast::<()>();
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };
WakerRef {
waker: ManuallyDrop::new(waker),
_p: PhantomData,
}
}
}
/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
let ptr = Arc::as_ptr(wake).cast::<()>();
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };
WakerRef {
waker: ManuallyDrop::new(waker),
_p: PhantomData,
/// Creates a waker from a `Arc<impl Wake>`.
pub(crate) fn waker<W: Wake>(wake: Arc<W>) -> Waker {
unsafe {
Waker::from_raw(RawWaker::new(
Arc::into_raw(wake).cast(),
waker_vtable::<W>(),
))
}
}