From 200cd93a9b2df51ec996bddbe0cdde2cd745cb72 Mon Sep 17 00:00:00 2001 From: noah Date: Sun, 4 May 2025 10:55:03 -0500 Subject: [PATCH] add actions --- .../runtime/scheduler/current_thread/mod.rs | 40 +++--- .../runtime/scheduler/multi_thread/handle.rs | 20 +-- tokio/src/runtime/task_hooks/mod.rs | 83 +++++++++++-- tokio/tests/task_hooks.rs | 114 ++++++++++++------ 4 files changed, 182 insertions(+), 75 deletions(-) diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index a42528ba7..0f018ada1 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -462,10 +462,12 @@ impl Handle { parent .map(|parent| { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - parent.on_child_spawn(&mut OnChildTaskSpawnContext { - id, - _phantom: Default::default(), - }) + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -479,10 +481,12 @@ impl Handle { .or_else(|| { if let Some(hooks) = me.hooks_factory_ref() { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { - id, - _phantom: Default::default(), - }) + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -533,10 +537,12 @@ impl Handle { parent .map(|parent| { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - parent.on_child_spawn(&mut OnChildTaskSpawnContext { - id, - _phantom: Default::default(), - }) + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -550,10 +556,12 @@ impl Handle { .or_else(|| { if let Some(hooks) = me.hooks_factory_ref() { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { - id, - _phantom: Default::default(), - }) + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index fa8973a8f..030910d30 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -80,10 +80,12 @@ impl Handle { parent .map(|parent| { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - parent.on_child_spawn(&mut OnChildTaskSpawnContext { - id, - _phantom: Default::default(), - }) + parent + .on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { @@ -97,10 +99,12 @@ impl Handle { .or_else(|| { if let Some(hooks) = me.hooks_factory_ref() { if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { - hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { - id, - _phantom: Default::default(), - }) + hooks + .on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + .hooks })) { r } else { diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs index c28a0b9dd..51b70cf5a 100644 --- a/tokio/src/runtime/task_hooks/mod.rs +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -8,11 +8,9 @@ use std::sync::Arc; /// spawned in "detached mode" via [`crate::task::spawn_with_hooks`], or which were spawned from outside the runtime or /// from another context where no [`TaskHookHarness`] was present. pub trait TaskHookHarnessFactory { - /// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task. - fn on_top_level_spawn( - &self, - ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option>; + /// Runs a hook which may produce a new [`TaskHookHarness`] object which the runtime will attach to a given task. + fn on_top_level_spawn(&self, ctx: &mut OnTopLevelTaskSpawnContext<'_>) + -> OnTopLevelSpawnAction; } /// Trait for user-provided "harness" objects which are attached to tasks and provide hook @@ -20,24 +18,27 @@ pub trait TaskHookHarnessFactory { #[allow(unused_variables)] pub trait TaskHookHarness { /// Pre-poll task hook which runs arbitrary user logic. - fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) {} + fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { + BeforeTaskPollAction::default() + } /// Post-poll task hook which runs arbitrary user logic. - fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) {} + fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction { + AfterTaskPollAction::default() + } /// Task hook which runs when this task spawns a child, unless that child is explicitly spawned /// detached from the parent. /// /// This hook creates a harness for the child, or detaches the child from any instrumentation. - fn on_child_spawn( - &mut self, - ctx: &mut OnChildTaskSpawnContext<'_>, - ) -> Option> { - None + fn on_child_spawn(&mut self, ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + OnChildSpawnAction::default() } /// Task hook which runs on task termination. - fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) {} + fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) -> OnTaskTerminateAction { + OnTaskTerminateAction::default() + } } pub(crate) type OptionalTaskHooksFactory = @@ -97,3 +98,59 @@ pub struct BeforeTaskPollContext<'a> { pub struct AfterTaskPollContext<'a> { pub(crate) _phantom: PhantomData<&'a ()>, } + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnTopLevelSpawnAction { + pub(crate) hooks: Option>, +} + +impl OnTopLevelSpawnAction { + /// Pass in a set of task hooks for the task. + pub fn set_hooks(&mut self, hooks: T) -> &mut Self + where + T: TaskHookHarness + Send + Sync + 'static, + { + self.hooks = Some(Box::new(hooks)); + self + } +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnChildSpawnAction { + pub(crate) hooks: Option>, +} + +impl OnChildSpawnAction { + /// Pass in a set of task hooks for the child task. + pub fn set_hooks(&mut self, hooks: T) -> &mut Self + where + T: TaskHookHarness + Send + Sync + 'static, + { + self.hooks = Some(Box::new(hooks)); + self + } +} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct OnTaskTerminateAction {} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct BeforeTaskPollAction {} + +#[derive(Default)] +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +#[non_exhaustive] +pub struct AfterTaskPollAction {} diff --git a/tokio/tests/task_hooks.rs b/tokio/tests/task_hooks.rs index 2af3cc594..e2127388e 100644 --- a/tokio/tests/task_hooks.rs +++ b/tokio/tests/task_hooks.rs @@ -9,8 +9,9 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::runtime; use tokio::runtime::{ - AfterTaskPollContext, BeforeTaskPollContext, OnChildTaskSpawnContext, OnTaskTerminateContext, - OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, + AfterTaskPollAction, AfterTaskPollContext, BeforeTaskPollAction, BeforeTaskPollContext, + OnChildSpawnAction, OnChildTaskSpawnContext, OnTaskTerminateAction, OnTaskTerminateContext, + OnTopLevelSpawnAction, OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, }; #[test] @@ -83,9 +84,10 @@ fn run_runtime_default_factory(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { + ) -> OnTopLevelSpawnAction { self.counter.fetch_add(1, Ordering::SeqCst); - None + + Default::default() } } @@ -138,25 +140,30 @@ fn run_parent_child_chaining(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { + ) -> OnTopLevelSpawnAction { self.parent_spawns.fetch_add(1, Ordering::SeqCst); - Some(Box::new(TestHooks { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { spawns: self.child_spawns.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn on_child_spawn( - &mut self, - _ctx: &mut OnChildTaskSpawnContext<'_>, - ) -> Option> { + fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { self.spawns.fetch_add(1, Ordering::SeqCst); - Some(Box::new(Self { + let mut a = OnChildSpawnAction::default(); + + a.set_hooks(Self { spawns: self.spawns.clone(), - })) + }); + + a } } @@ -195,16 +202,22 @@ fn run_before_poll(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { polls: self.polls.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { self.polls.fetch_add(1, Ordering::SeqCst); + + Default::default() } } @@ -240,16 +253,22 @@ fn run_after_poll(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { polls: self.polls.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) { + fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) -> AfterTaskPollAction { self.polls.fetch_add(1, Ordering::SeqCst); + + Default::default() } } @@ -285,16 +304,25 @@ fn run_terminate(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { terminations: self.terminations.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn on_task_terminate(&mut self, _ctx: &mut OnTaskTerminateContext<'_>) { + fn on_task_terminate( + &mut self, + _ctx: &mut OnTaskTerminateContext<'_>, + ) -> OnTaskTerminateAction { self.terminations.fetch_add(1, Ordering::SeqCst); + + Default::default() } } @@ -327,17 +355,23 @@ fn run_hook_switching(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(TestHooks { + ) -> OnTopLevelSpawnAction { + let mut a = OnTopLevelSpawnAction::default(); + + a.set_hooks(TestHooks { id: self.next_id.fetch_add(1, Ordering::SeqCst), flag: self.flag.clone(), - })) + }); + + a } } impl TaskHookHarness for TestHooks { - fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { self.flag.store(self.id, Ordering::SeqCst); + + Default::default() } } @@ -371,17 +405,20 @@ fn run_override(mut builder: runtime::Builder) { } impl TaskHookHarness for TestHooks { - fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) -> BeforeTaskPollAction { self.counter.fetch_add(1, Ordering::SeqCst); + + Default::default() } - fn on_child_spawn( - &mut self, - _ctx: &mut OnChildTaskSpawnContext<'_>, - ) -> Option> { - Some(Box::new(Self { + fn on_child_spawn(&mut self, _ctx: &mut OnChildTaskSpawnContext<'_>) -> OnChildSpawnAction { + let mut a = OnChildSpawnAction::default(); + + a.set_hooks(Self { counter: self.counter.clone(), - })) + }); + + a } } @@ -389,9 +426,10 @@ fn run_override(mut builder: runtime::Builder) { fn on_top_level_spawn( &self, _ctx: &mut OnTopLevelTaskSpawnContext<'_>, - ) -> Option> { + ) -> OnTopLevelSpawnAction { self.counter.fetch_add(1, Ordering::SeqCst); - None + + Default::default() } }