task: disallow blocking in LocalSet::{poll,drop} (#7372)

This commit is contained in:
Alice Ryhl 2025-06-08 20:56:55 +02:00 committed by GitHub
parent 38d88c6799
commit 8259133ca0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 118 additions and 22 deletions

View File

@ -32,7 +32,7 @@ pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> {
/// Disallows blocking in the current runtime context until the guard is dropped.
pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard {
let reset = CONTEXT.with(|c| {
let reset = CONTEXT.try_with(|c| {
if let EnterRuntime::Entered {
allow_block_in_place: true,
} = c.runtime.get()
@ -46,7 +46,7 @@ pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard {
}
});
DisallowBlockInPlaceGuard(reset)
DisallowBlockInPlaceGuard(reset.unwrap_or(false))
}
impl BlockingRegionGuard {

View File

@ -916,6 +916,8 @@ impl Future for LocalSet {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let _no_blocking = crate::runtime::context::disallow_block_in_place();
// Register the waker before starting to work
self.context.shared.waker.register_by_ref(cx.waker());
@ -948,6 +950,8 @@ impl Default for LocalSet {
impl Drop for LocalSet {
fn drop(&mut self) {
self.with_if_possible(|| {
let _no_blocking = crate::runtime::context::disallow_block_in_place();
// Shut down all tasks in the LocalOwnedTasks and close it to
// prevent new tasks from ever being added.
unsafe {

View File

@ -151,31 +151,123 @@ fn enter_guard_spawn() {
});
}
#[cfg(not(target_os = "wasi"))] // Wasi doesn't support panic recovery
#[test]
// This will panic, since the thread that calls `block_on` cannot use
// in-place blocking inside of `block_on`.
#[should_panic]
fn local_threadpool_blocking_in_place() {
thread_local! {
static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
#[cfg(not(target_os = "wasi"))]
mod block_in_place_cases {
use super::*;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
struct BlockInPlaceOnDrop;
impl Future for BlockInPlaceOnDrop {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Pending
}
}
impl Drop for BlockInPlaceOnDrop {
fn drop(&mut self) {
tokio::task::block_in_place(|| {});
}
}
ON_RT_THREAD.with(|cell| cell.set(true));
async fn complete(jh: tokio::task::JoinHandle<()>) {
match jh.await {
Ok(()) => {}
Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
Err(err) if err.is_cancelled() => panic!("task cancelled"),
Err(err) => panic!("{:?}", err),
}
}
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
LocalSet::new().block_on(&rt, async {
assert!(ON_RT_THREAD.with(|cell| cell.get()));
let join = task::spawn_local(async move {
assert!(ON_RT_THREAD.with(|cell| cell.get()));
task::block_in_place(|| {});
#[test]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
fn local_threadpool_blocking_in_place() {
thread_local! {
static ON_RT_THREAD: Cell<bool> = const { Cell::new(false) };
}
ON_RT_THREAD.with(|cell| cell.set(true));
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
LocalSet::new().block_on(&rt, async {
assert!(ON_RT_THREAD.with(|cell| cell.get()));
let join = task::spawn_local(async move {
assert!(ON_RT_THREAD.with(|cell| cell.get()));
task::block_in_place(|| {});
assert!(ON_RT_THREAD.with(|cell| cell.get()));
});
complete(join).await;
});
join.await.unwrap();
});
}
#[tokio::test(flavor = "multi_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_run_until_mt() {
let local_set = LocalSet::new();
local_set
.run_until(async {
tokio::task::block_in_place(|| {});
})
.await;
}
#[tokio::test(flavor = "multi_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_mt() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(async {
tokio::task::block_in_place(|| {});
});
local_set.await;
complete(jh).await;
}
#[tokio::test(flavor = "multi_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_drop_mt() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(BlockInPlaceOnDrop);
local_set.run_until(tokio::task::yield_now()).await;
drop(local_set);
complete(jh).await;
}
#[tokio::test(flavor = "current_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_run_until_ct() {
let local_set = LocalSet::new();
local_set
.run_until(async {
tokio::task::block_in_place(|| {});
})
.await;
}
#[tokio::test(flavor = "current_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_ct() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(async {
tokio::task::block_in_place(|| {});
});
local_set.await;
complete(jh).await;
}
#[tokio::test(flavor = "current_thread")]
#[should_panic = "can call blocking only when running on the multi-threaded runtime"]
async fn block_in_place_in_spawn_local_drop_ct() {
let local_set = LocalSet::new();
let jh = local_set.spawn_local(BlockInPlaceOnDrop);
local_set.run_until(tokio::task::yield_now()).await;
drop(local_set);
complete(jh).await;
}
}
#[cfg(not(target_os = "wasi"))] // Wasi doesn't support threads