From 282b00cbe888a96669877ce70662fba87e8c0e3c Mon Sep 17 00:00:00 2001 From: Jon Gjengset Date: Mon, 20 Apr 2020 19:18:47 -0400 Subject: [PATCH] Be more principled about when blocking is ok (#2410) This enables `block_in_place` to be used in more contexts. Specifically, it allows you to block whenever you are off the tokio runtime (like if you are not using tokio, are in a `spawn_blocking` closure, etc.), and in the threaded scheduler's `block_on`. Blocking in `LocalSet` and the basic scheduler's` block_on` is still disallowed. Fixes #2327. Fixes #2393. --- tokio/src/runtime/basic_scheduler.rs | 2 +- tokio/src/runtime/blocking/shutdown.rs | 4 +- tokio/src/runtime/enter.rs | 103 ++++++++++++++++++++---- tokio/src/runtime/shell.rs | 2 +- tokio/src/runtime/thread_pool/mod.rs | 2 +- tokio/src/runtime/thread_pool/worker.rs | 32 +++++++- tokio/src/task/local.rs | 2 + tokio/tests/task_blocking.rs | 51 +++++++++++- 8 files changed, 174 insertions(+), 24 deletions(-) diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 301554280..7e1c257cc 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -121,7 +121,7 @@ where F: Future, { enter(self, |scheduler, context| { - let _enter = runtime::enter(); + let _enter = runtime::enter(false); let waker = waker_ref(&scheduler.spawner.shared); let mut cx = std::task::Context::from_waker(&waker); diff --git a/tokio/src/runtime/blocking/shutdown.rs b/tokio/src/runtime/blocking/shutdown.rs index 5ee8af0fb..f3c60ee30 100644 --- a/tokio/src/runtime/blocking/shutdown.rs +++ b/tokio/src/runtime/blocking/shutdown.rs @@ -36,12 +36,12 @@ impl Receiver { use crate::runtime::enter::{enter, try_enter}; let mut e = if std::thread::panicking() { - match try_enter() { + match try_enter(false) { Some(enter) => enter, _ => return, } } else { - enter() + enter(false) }; // The oneshot completes with an Err diff --git a/tokio/src/runtime/enter.rs b/tokio/src/runtime/enter.rs index 440941e14..9b3f2ad89 100644 --- a/tokio/src/runtime/enter.rs +++ b/tokio/src/runtime/enter.rs @@ -2,7 +2,26 @@ use std::cell::{Cell, RefCell}; use std::fmt; use std::marker::PhantomData; -thread_local!(static ENTERED: Cell = Cell::new(false)); +#[derive(Debug, Clone, Copy)] +pub(crate) enum EnterContext { + Entered { + #[allow(dead_code)] + allow_blocking: bool, + }, + NotEntered, +} + +impl EnterContext { + pub(crate) fn is_entered(self) -> bool { + if let EnterContext::Entered { .. } = self { + true + } else { + false + } + } +} + +thread_local!(static ENTERED: Cell = Cell::new(EnterContext::NotEntered)); /// Represents an executor context. pub(crate) struct Enter { @@ -11,8 +30,8 @@ pub(crate) struct Enter { /// Marks the current thread as being within the dynamic extent of an /// executor. -pub(crate) fn enter() -> Enter { - if let Some(enter) = try_enter() { +pub(crate) fn enter(allow_blocking: bool) -> Enter { + if let Some(enter) = try_enter(allow_blocking) { return enter; } @@ -26,12 +45,12 @@ pub(crate) fn enter() -> Enter { /// Tries to enter a runtime context, returns `None` if already in a runtime /// context. -pub(crate) fn try_enter() -> Option { +pub(crate) fn try_enter(allow_blocking: bool) -> Option { ENTERED.with(|c| { - if c.get() { + if c.get().is_entered() { None } else { - c.set(true); + c.set(EnterContext::Entered { allow_blocking }); Some(Enter { _p: PhantomData }) } }) @@ -47,26 +66,78 @@ pub(crate) fn try_enter() -> Option { #[cfg(all(feature = "rt-threaded", feature = "blocking"))] pub(crate) fn exit R, R>(f: F) -> R { // Reset in case the closure panics - struct Reset; + struct Reset(EnterContext); impl Drop for Reset { fn drop(&mut self) { ENTERED.with(|c| { - assert!(!c.get(), "closure claimed permanent executor"); - c.set(true); + assert!(!c.get().is_entered(), "closure claimed permanent executor"); + c.set(self.0); }); } } - ENTERED.with(|c| { - assert!(c.get(), "asked to exit when not entered"); - c.set(false); + let was = ENTERED.with(|c| { + let e = c.get(); + assert!(e.is_entered(), "asked to exit when not entered"); + c.set(EnterContext::NotEntered); + e }); - let _reset = Reset; - // dropping reset after f() will do c.set(true) + let _reset = Reset(was); + // dropping _reset after f() will reset ENTERED f() } +cfg_rt_core! { + cfg_rt_util! { + /// Disallow blocking in the current runtime context until the guard is dropped. + pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { + let reset = ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: true, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: false, + }); + true + } else { + false + } + }); + DisallowBlockingGuard(reset) + } + + pub(crate) struct DisallowBlockingGuard(bool); + impl Drop for DisallowBlockingGuard { + fn drop(&mut self) { + if self.0 { + // XXX: Do we want some kind of assertion here, or is "best effort" okay? + ENTERED.with(|c| { + if let EnterContext::Entered { + allow_blocking: false, + } = c.get() + { + c.set(EnterContext::Entered { + allow_blocking: true, + }); + } + }) + } + } + } + } +} + +cfg_rt_threaded! { + cfg_blocking! { + /// Returns true if in a runtime context. + pub(crate) fn context() -> EnterContext { + ENTERED.with(|c| c.get()) + } + } +} + cfg_blocking_impl! { use crate::park::ParkError; use std::time::Duration; @@ -149,8 +220,8 @@ impl fmt::Debug for Enter { impl Drop for Enter { fn drop(&mut self) { ENTERED.with(|c| { - assert!(c.get()); - c.set(false); + assert!(c.get().is_entered()); + c.set(EnterContext::NotEntered); }); } } diff --git a/tokio/src/runtime/shell.rs b/tokio/src/runtime/shell.rs index 294f2a16d..a65869d0d 100644 --- a/tokio/src/runtime/shell.rs +++ b/tokio/src/runtime/shell.rs @@ -32,7 +32,7 @@ impl Shell { where F: Future, { - let _e = enter(); + let _e = enter(true); pin!(f); diff --git a/tokio/src/runtime/thread_pool/mod.rs b/tokio/src/runtime/thread_pool/mod.rs index 82e82d5b3..ced9712d9 100644 --- a/tokio/src/runtime/thread_pool/mod.rs +++ b/tokio/src/runtime/thread_pool/mod.rs @@ -78,7 +78,7 @@ impl ThreadPool { where F: Future, { - let mut enter = crate::runtime::enter(); + let mut enter = crate::runtime::enter(true); enter.block_on(future).expect("failed to park thread") } } diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index 2213ec6c1..e31f237cc 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -172,6 +172,8 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc, Launch) { } cfg_blocking! { + use crate::runtime::enter::EnterContext; + pub(crate) fn block_in_place(f: F) -> R where F: FnOnce() -> R, @@ -203,7 +205,33 @@ cfg_blocking! { let mut had_core = false; CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("can call blocking only when running in a spawned task on the multi-threaded runtime"); + match (crate::runtime::enter::context(), maybe_cx.is_some()) { + (EnterContext::Entered { .. }, true) => { + // We are on a thread pool runtime thread, so we just need to set up blocking. + } + (EnterContext::Entered { allow_blocking }, false) => { + // We are on an executor, but _not_ on the thread pool. + // That is _only_ okay if we are in a thread pool runtime's block_on method: + if allow_blocking { + return; + } else { + // This probably means we are on the basic_scheduler or in a LocalSet, + // where it is _not_ okay to block. + panic!("can call blocking only when running on the multi-threaded runtime"); + } + } + (EnterContext::NotEntered, true) => { + // This is a nested call to block_in_place (we already exited). + // All the necessary setup has already been done. + return; + } + (EnterContext::NotEntered, false) => { + // We are outside of the tokio runtime, so blocking is fine. + // We can also skip all of the thread pool blocking setup steps. + return; + } + } + let cx = maybe_cx.expect("no .is_some() == false cases above should lead here"); // Get the worker core. If none is set, then blocking is fine! let core = match cx.core.borrow_mut().take() { @@ -273,7 +301,7 @@ fn run(worker: Arc) { core: RefCell::new(None), }; - let _enter = crate::runtime::enter(); + let _enter = crate::runtime::enter(true); CURRENT.set(&cx, || { // This should always be an error. It only returns a `Result` to support diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index edecb04b6..9af50cee4 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -520,6 +520,8 @@ impl Future for RunUntil<'_, T> { .waker .register_by_ref(cx.waker()); + let _no_blocking = crate::runtime::enter::disallow_blocking(); + if let Poll::Ready(output) = me.future.poll(cx) { return Poll::Ready(output); } diff --git a/tokio/tests/task_blocking.rs b/tokio/tests/task_blocking.rs index edcb005dc..72fed01e9 100644 --- a/tokio/tests/task_blocking.rs +++ b/tokio/tests/task_blocking.rs @@ -1,7 +1,7 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::task; +use tokio::{runtime, task}; use tokio_test::assert_ok; use std::thread; @@ -28,6 +28,29 @@ async fn basic_blocking() { } } +#[tokio::test(threaded_scheduler)] +async fn block_in_blocking() { + // Run a few times + for _ in 0..100 { + let out = assert_ok!( + tokio::spawn(async { + assert_ok!( + task::spawn_blocking(|| { + task::block_in_place(|| { + thread::sleep(Duration::from_millis(5)); + }); + "hello" + }) + .await + ) + }) + .await + ); + + assert_eq!(out, "hello"); + } +} + #[tokio::test(threaded_scheduler)] async fn block_in_block() { // Run a few times @@ -47,3 +70,29 @@ async fn block_in_block() { assert_eq!(out, "hello"); } } + +#[tokio::test(basic_scheduler)] +#[should_panic] +async fn no_block_in_basic_scheduler() { + task::block_in_place(|| {}); +} + +#[test] +fn yes_block_in_threaded_block_on() { + let mut rt = runtime::Builder::new() + .threaded_scheduler() + .build() + .unwrap(); + rt.block_on(async { + task::block_in_place(|| {}); + }); +} + +#[test] +#[should_panic] +fn no_block_in_basic_block_on() { + let mut rt = runtime::Builder::new().basic_scheduler().build().unwrap(); + rt.block_on(async { + task::block_in_place(|| {}); + }); +}