task: disallow blocking in LocalSet::{poll,drop} (#7372)

This commit is contained in:
Alice Ryhl 2025-06-08 20:56:55 +02:00 committed by GitHub
parent 38d88c6799
commit 8259133ca0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 118 additions and 22 deletions

View File

@ -32,7 +32,7 @@ pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> {
/// Disallows blocking in the current runtime context until the guard is dropped. /// Disallows blocking in the current runtime context until the guard is dropped.
pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard { pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard {
let reset = CONTEXT.with(|c| { let reset = CONTEXT.try_with(|c| {
if let EnterRuntime::Entered { if let EnterRuntime::Entered {
allow_block_in_place: true, allow_block_in_place: true,
} = c.runtime.get() } = c.runtime.get()
@ -46,7 +46,7 @@ pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard {
} }
}); });
DisallowBlockInPlaceGuard(reset) DisallowBlockInPlaceGuard(reset.unwrap_or(false))
} }
impl BlockingRegionGuard { impl BlockingRegionGuard {

View File

@ -916,6 +916,8 @@ impl Future for LocalSet {
type Output = (); type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let _no_blocking = crate::runtime::context::disallow_block_in_place();
// Register the waker before starting to work // Register the waker before starting to work
self.context.shared.waker.register_by_ref(cx.waker()); self.context.shared.waker.register_by_ref(cx.waker());
@ -948,6 +950,8 @@ impl Default for LocalSet {
impl Drop for LocalSet { impl Drop for LocalSet {
fn drop(&mut self) { fn drop(&mut self) {
self.with_if_possible(|| { self.with_if_possible(|| {
let _no_blocking = crate::runtime::context::disallow_block_in_place();
// Shut down all tasks in the LocalOwnedTasks and close it to // Shut down all tasks in the LocalOwnedTasks and close it to
// prevent new tasks from ever being added. // prevent new tasks from ever being added.
unsafe { unsafe {

View File

@ -151,31 +151,123 @@ fn enter_guard_spawn() {
}); });
} }
#[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery #[cfg(not(target_os = "wasi"))]
#[test] mod block_in_place_cases {
// This will panic, since the thread that calls `block_on` cannot use use super::*;
// in-place blocking inside of `block_on`. use std::future::Future;
#[should_panic] use std::pin::Pin;
fn local_threadpool_blocking_in_place() { use std::task::{Context, Poll};
thread_local! {
static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) }; struct BlockInPlaceOnDrop;
impl Future for BlockInPlaceOnDrop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Pending
}
}
impl Drop for BlockInPlaceOnDrop {
fn drop(&mut self) {
tokio::task::block_in_place(|| {});
}
} }
ON_RT_THREAD.with(|cell| cell.set(true)); async fn complete(jh: tokio::task::JoinHandle<()>) {
match jh.await {
Ok(()) => {}
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
Err(err) if err.is_cancelled() => panic!("task cancelled"),
Err(err) => panic!("{:?}", err),
}
}
let rt = runtime::Builder::new_current_thread() #[test]
.enable_all() #[should_panic = "can call blocking only when running on the multi-threaded runtime"]
.build() fn local_threadpool_blocking_in_place() {
.unwrap(); thread_local! {
LocalSet::new().block_on(&rt, async { static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
assert!(ON_RT_THREAD.with(|cell| cell.get())); }
let join = task::spawn_local(async move {
assert!(ON_RT_THREAD.with(|cell| cell.get())); ON_RT_THREAD.with(|cell| cell.set(true));
task::block_in_place(|| {});
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
LocalSet::new().block_on(&rt, async {
assert!(ON_RT_THREAD.with(|cell| cell.get())); assert!(ON_RT_THREAD.with(|cell| cell.get()));
let join = task::spawn_local(async move {
assert!(ON_RT_THREAD.with(|cell| cell.get()));
task::block_in_place(|| {});
assert!(ON_RT_THREAD.with(|cell| cell.get()));
});
complete(join).await;
}); });
join.await.unwrap(); }
});
#[tokio::test(flavor = "multi_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_run_until_mt() {
let local_set = LocalSet::new();
local_set
.run_until(async {
tokio::task::block_in_place(|| {});
})
.await;
}
#[tokio::test(flavor = "multi_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_mt() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(async {
tokio::task::block_in_place(|| {});
});
local_set.await;
complete(jh).await;
}
#[tokio::test(flavor = "multi_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_drop_mt() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(BlockInPlaceOnDrop);
local_set.run_until(tokio::task::yield_now()).await;
drop(local_set);
complete(jh).await;
}
#[tokio::test(flavor = "current_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_run_until_ct() {
let local_set = LocalSet::new();
local_set
.run_until(async {
tokio::task::block_in_place(|| {});
})
.await;
}
#[tokio::test(flavor = "current_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_ct() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(async {
tokio::task::block_in_place(|| {});
});
local_set.await;
complete(jh).await;
}
#[tokio::test(flavor = "current_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_drop_ct() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(BlockInPlaceOnDrop);
local_set.run_until(tokio::task::yield_now()).await;
drop(local_set);
complete(jh).await;
}
} }
#[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads #[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads