add actions

This commit is contained in:
noah 2025-05-04 10:55:03 -05:00
parent 77eb27a9fb
commit 200cd93a9b
4 changed files with 182 additions and 75 deletions

View File

@ -462,10 +462,12 @@ impl Handle {
parent
.map(|parent| {
if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| {
parent.on_child_spawn(&mut OnChildTaskSpawnContext {
id,
_phantom: Default::default(),
})
parent
.on_child_spawn(&mut OnChildTaskSpawnContext {
id,
_phantom: Default::default(),
})
.hooks
})) {
r
} else {
@ -479,10 +481,12 @@ impl Handle {
.or_else(|| {
if let Some(hooks) = me.hooks_factory_ref() {
if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| {
hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext {
id,
_phantom: Default::default(),
})
hooks
.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext {
id,
_phantom: Default::default(),
})
.hooks
})) {
r
} else {
@ -533,10 +537,12 @@ impl Handle {
parent
.map(|parent| {
if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| {
parent.on_child_spawn(&mut OnChildTaskSpawnContext {
id,
_phantom: Default::default(),
})
parent
.on_child_spawn(&mut OnChildTaskSpawnContext {
id,
_phantom: Default::default(),
})
.hooks
})) {
r
} else {
@ -550,10 +556,12 @@ impl Handle {
.or_else(|| {
if let Some(hooks) = me.hooks_factory_ref() {
if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| {
hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext {
id,
_phantom: Default::default(),
})
hooks
.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext {
id,
_phantom: Default::default(),
})
.hooks
})) {
r
} else {

View File

@ -80,10 +80,12 @@ impl Handle {
parent
.map(|parent| {
if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| {
parent.on_child_spawn(&mut OnChildTaskSpawnContext {
id,
_phantom: Default::default(),
})
parent
.on_child_spawn(&mut OnChildTaskSpawnContext {
id,
_phantom: Default::default(),
})
.hooks
})) {
r
} else {
@ -97,10 +99,12 @@ impl Handle {
.or_else(|| {
if let Some(hooks) = me.hooks_factory_ref() {
if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| {
hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext {
id,
_phantom: Default::default(),
})
hooks
.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext {
id,
_phantom: Default::default(),
})
.hooks
})) {
r
} else {

View File

@ -8,11 +8,9 @@ use std::sync::Arc;
/// spawned in "detached mode" via [`crate::task::spawn_with_hooks`], or which were spawned from outside the runtime or
/// from another context where no [`TaskHookHarness`] was present.
pub trait TaskHookHarnessFactory {
/// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task.
fn on_top_level_spawn(
&self,
ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>>;
/// Runs a hook which may produce a new [`TaskHookHarness`] object which the runtime will attach to a given task.
fn on_top_level_spawn(&self, ctx: &mut OnTopLevelTaskSpawnContext<'_>)
-> OnTopLevelSpawnAction;
}
/// Trait for user-provided "harness" objects which are attached to tasks and provide hook
@ -20,24 +18,27 @@ pub trait TaskHookHarnessFactory {
#[allow(unused_variables)]
pub trait TaskHookHarness {
/// Pre-poll task hook which runs arbitrary user logic.
fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) {}
fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction {
BeforeTaskPollAction::default()
}
/// Post-poll task hook which runs arbitrary user logic.
fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) {}
fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction {
AfterTaskPollAction::default()
}
/// Task hook which runs when this task spawns a child, unless that child is explicitly spawned
/// detached from the parent.
///
/// This hook creates a harness for the child, or detaches the child from any instrumentation.
fn on_child_spawn(
&mut self,
ctx: &mut OnChildTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
None
fn on_child_spawn(&mut self, ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction {
OnChildSpawnAction::default()
}
/// Task hook which runs on task termination.
fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) {}
fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) -> OnTaskTerminateAction {
OnTaskTerminateAction::default()
}
}
pub(crate) type OptionalTaskHooksFactory =
@ -97,3 +98,59 @@ pub struct BeforeTaskPollContext<'a> {
pub struct AfterTaskPollContext<'a> {
pub(crate) _phantom: PhantomData<&'a ()>,
}
#[derive(Default)]
#[allow(missing_debug_implementations, missing_docs)]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[non_exhaustive]
pub struct OnTopLevelSpawnAction {
pub(crate) hooks: Option<Box<dyn TaskHookHarness + Send + Sync + 'static>>,
}
impl OnTopLevelSpawnAction {
/// Pass in a set of task hooks for the task.
pub fn set_hooks<T>(&mut self, hooks: T) -> &mut Self
where
T: TaskHookHarness + Send + Sync + 'static,
{
self.hooks = Some(Box::new(hooks));
self
}
}
#[derive(Default)]
#[allow(missing_debug_implementations, missing_docs)]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[non_exhaustive]
pub struct OnChildSpawnAction {
pub(crate) hooks: Option<Box<dyn TaskHookHarness + Send + Sync + 'static>>,
}
impl OnChildSpawnAction {
/// Pass in a set of task hooks for the child task.
pub fn set_hooks<T>(&mut self, hooks: T) -> &mut Self
where
T: TaskHookHarness + Send + Sync + 'static,
{
self.hooks = Some(Box::new(hooks));
self
}
}
#[derive(Default)]
#[allow(missing_debug_implementations, missing_docs)]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[non_exhaustive]
pub struct OnTaskTerminateAction {}
#[derive(Default)]
#[allow(missing_debug_implementations, missing_docs)]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[non_exhaustive]
pub struct BeforeTaskPollAction {}
#[derive(Default)]
#[allow(missing_debug_implementations, missing_docs)]
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
#[non_exhaustive]
pub struct AfterTaskPollAction {}

View File

@ -9,8 +9,9 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::runtime;
use tokio::runtime::{
AfterTaskPollContext, BeforeTaskPollContext, OnChildTaskSpawnContext, OnTaskTerminateContext,
OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory,
AfterTaskPollAction, AfterTaskPollContext, BeforeTaskPollAction, BeforeTaskPollContext,
OnChildSpawnAction, OnChildTaskSpawnContext, OnTaskTerminateAction, OnTaskTerminateContext,
OnTopLevelSpawnAction, OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory,
};
#[test]
@ -83,9 +84,10 @@ fn run_runtime_default_factory(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
) -> OnTopLevelSpawnAction {
self.counter.fetch_add(1, Ordering::SeqCst);
None
Default::default()
}
}
@ -138,25 +140,30 @@ fn run_parent_child_chaining(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
) -> OnTopLevelSpawnAction {
self.parent_spawns.fetch_add(1, Ordering::SeqCst);
Some(Box::new(TestHooks {
let mut a = OnTopLevelSpawnAction::default();
a.set_hooks(TestHooks {
spawns: self.child_spawns.clone(),
}))
});
a
}
}
impl TaskHookHarness for TestHooks {
fn on_child_spawn(
&mut self,
_ctx: &mut OnChildTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction {
self.spawns.fetch_add(1, Ordering::SeqCst);
Some(Box::new(Self {
let mut a = OnChildSpawnAction::default();
a.set_hooks(Self {
spawns: self.spawns.clone(),
}))
});
a
}
}
@ -195,16 +202,22 @@ fn run_before_poll(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
Some(Box::new(TestHooks {
) -> OnTopLevelSpawnAction {
let mut a = OnTopLevelSpawnAction::default();
a.set_hooks(TestHooks {
polls: self.polls.clone(),
}))
});
a
}
}
impl TaskHookHarness for TestHooks {
fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) {
fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction {
self.polls.fetch_add(1, Ordering::SeqCst);
Default::default()
}
}
@ -240,16 +253,22 @@ fn run_after_poll(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
Some(Box::new(TestHooks {
) -> OnTopLevelSpawnAction {
let mut a = OnTopLevelSpawnAction::default();
a.set_hooks(TestHooks {
polls: self.polls.clone(),
}))
});
a
}
}
impl TaskHookHarness for TestHooks {
fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) {
fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction {
self.polls.fetch_add(1, Ordering::SeqCst);
Default::default()
}
}
@ -285,16 +304,25 @@ fn run_terminate(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
Some(Box::new(TestHooks {
) -> OnTopLevelSpawnAction {
let mut a = OnTopLevelSpawnAction::default();
a.set_hooks(TestHooks {
terminations: self.terminations.clone(),
}))
});
a
}
}
impl TaskHookHarness for TestHooks {
fn on_task_terminate(&mut self, _ctx: &mut OnTaskTerminateContext<'_>) {
fn on_task_terminate(
&mut self,
_ctx: &mut OnTaskTerminateContext<'_>,
) -> OnTaskTerminateAction {
self.terminations.fetch_add(1, Ordering::SeqCst);
Default::default()
}
}
@ -327,17 +355,23 @@ fn run_hook_switching(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
Some(Box::new(TestHooks {
) -> OnTopLevelSpawnAction {
let mut a = OnTopLevelSpawnAction::default();
a.set_hooks(TestHooks {
id: self.next_id.fetch_add(1, Ordering::SeqCst),
flag: self.flag.clone(),
}))
});
a
}
}
impl TaskHookHarness for TestHooks {
fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) {
fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction {
self.flag.store(self.id, Ordering::SeqCst);
Default::default()
}
}
@ -371,17 +405,20 @@ fn run_override(mut builder: runtime::Builder) {
}
impl TaskHookHarness for TestHooks {
fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) {
fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction {
self.counter.fetch_add(1, Ordering::SeqCst);
Default::default()
}
fn on_child_spawn(
&mut self,
_ctx: &mut OnChildTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
Some(Box::new(Self {
fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction {
let mut a = OnChildSpawnAction::default();
a.set_hooks(Self {
counter: self.counter.clone(),
}))
});
a
}
}
@ -389,9 +426,10 @@ fn run_override(mut builder: runtime::Builder) {
fn on_top_level_spawn(
&self,
_ctx: &mut OnTopLevelTaskSpawnContext<'_>,
) -> Option<Box<dyn TaskHookHarness + Send + Sync + 'static>> {
) -> OnTopLevelSpawnAction {
self.counter.fetch_add(1, Ordering::SeqCst);
None
Default::default()
}
}