rt: switch enter to an RAII guard (#2954)

This commit is contained in:
Carl Lerche 2020-10-13 15:06:22 -07:00 committed by GitHub
parent a249421abc
commit 00b6127f2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 96 additions and 99 deletions

View File

@ -34,7 +34,8 @@ impl<F: Future> Future for TokioContext<'_, F> {
let handle = me.handle; let handle = me.handle;
let fut = me.inner; let fut = me.inner;
handle.enter(|| fut.poll(cx)) let _enter = handle.enter();
fut.poll(cx)
} }
} }

View File

@ -232,11 +232,9 @@ impl Spawner {
builder builder
.spawn(move || { .spawn(move || {
// Only the reference should be moved into the closure // Only the reference should be moved into the closure
let rt = &rt; let _enter = crate::runtime::context::enter(rt.clone());
rt.enter(move || { rt.blocking_spawner.inner.run(worker_id);
rt.blocking_spawner.inner.run(worker_id); drop(shutdown_tx);
drop(shutdown_tx);
})
}) })
.unwrap() .unwrap()
} }

View File

@ -508,7 +508,8 @@ cfg_rt_multi_thread! {
}; };
// Spawn the thread pool workers // Spawn the thread pool workers
handle.enter(|| launch.launch()); let _enter = crate::runtime::context::enter(handle.clone());
launch.launch();
Ok(Runtime { Ok(Runtime {
kind: Kind::ThreadPool(scheduler), kind: Kind::ThreadPool(scheduler),

View File

@ -60,24 +60,20 @@ cfg_rt! {
/// Set this [`Handle`] as the current active [`Handle`]. /// Set this [`Handle`] as the current active [`Handle`].
/// ///
/// [`Handle`]: Handle /// [`Handle`]: Handle
pub(crate) fn enter<F, R>(new: Handle, f: F) -> R pub(crate) fn enter(new: Handle) -> EnterGuard {
where CONTEXT.with(|ctx| {
F: FnOnce() -> R,
{
struct DropGuard(Option<Handle>);
impl Drop for DropGuard {
fn drop(&mut self) {
CONTEXT.with(|ctx| {
*ctx.borrow_mut() = self.0.take();
});
}
}
let _guard = CONTEXT.with(|ctx| {
let old = ctx.borrow_mut().replace(new); let old = ctx.borrow_mut().replace(new);
DropGuard(old) EnterGuard(old)
}); })
}
f()
#[derive(Debug)]
pub(crate) struct EnterGuard(Option<Handle>);
impl Drop for EnterGuard {
fn drop(&mut self) {
CONTEXT.with(|ctx| {
*ctx.borrow_mut() = self.0.take();
});
}
} }

View File

@ -1,4 +1,4 @@
use crate::runtime::{blocking, context, driver, Spawner}; use crate::runtime::{blocking, driver, Spawner};
/// Handle to the runtime. /// Handle to the runtime.
/// ///
@ -27,13 +27,13 @@ pub(crate) struct Handle {
} }
impl Handle { impl Handle {
/// Enter the runtime context. This allows you to construct types that must // /// Enter the runtime context. This allows you to construct types that must
/// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. // /// have an executor available on creation such as [`Sleep`] or [`TcpStream`].
/// It will also allow you to call methods such as [`tokio::spawn`]. // /// It will also allow you to call methods such as [`tokio::spawn`].
pub(crate) fn enter<F, R>(&self, f: F) -> R // pub(crate) fn enter<F, R>(&self, f: F) -> R
where // where
F: FnOnce() -> R, // F: FnOnce() -> R,
{ // {
context::enter(self.clone(), f) // context::enter(self.clone(), f)
} // }
} }

View File

@ -262,6 +262,16 @@ cfg_rt! {
blocking_pool: BlockingPool, blocking_pool: BlockingPool,
} }
/// Runtime context guard.
///
/// Returned by [`Runtime::enter`], the context guard exits the runtime
/// context on drop.
#[derive(Debug)]
pub struct EnterGuard<'a> {
rt: &'a Runtime,
guard: context::EnterGuard,
}
/// The runtime executor is either a thread-pool or a current-thread executor. /// The runtime executor is either a thread-pool or a current-thread executor.
#[derive(Debug)] #[derive(Debug)]
enum Kind { enum Kind {
@ -356,25 +366,26 @@ cfg_rt! {
} }
} }
/// Run a future to completion on the Tokio runtime. This is the runtime's /// Run a future to completion on the Tokio runtime. This is the
/// entry point. /// runtime's entry point.
/// ///
/// This runs the given future on the runtime, blocking until it is /// This runs the given future on the runtime, blocking until it is
/// complete, and yielding its resolved result. Any tasks or timers which /// complete, and yielding its resolved result. Any tasks or timers
/// the future spawns internally will be executed on the runtime. /// which the future spawns internally will be executed on the runtime.
/// ///
/// When this runtime is configured with `core_threads = 0`, only the first call /// When this runtime is configured with `core_threads = 0`, only the
/// to `block_on` will run the IO and timer drivers. Calls to other methods _before_ the first /// first call to `block_on` will run the IO and timer drivers. Calls to
/// `block_on` completes will just hook into the driver running on the thread /// other methods _before_ the first `block_on` completes will just hook
/// that first called `block_on`. This means that the driver may be passed /// into the driver running on the thread that first called `block_on`.
/// from thread to thread by the user between calls to `block_on`. /// This means that the driver may be passed from thread to thread by
/// the user between calls to `block_on`.
/// ///
/// This method may not be called from an asynchronous context. /// This method may not be called from an asynchronous context.
/// ///
/// # Panics /// # Panics
/// ///
/// This function panics if the provided future panics, or if called within an /// This function panics if the provided future panics, or if called
/// asynchronous execution context. /// within an asynchronous execution context.
/// ///
/// # Examples /// # Examples
/// ///
@ -392,17 +403,21 @@ cfg_rt! {
/// ///
/// [handle]: fn@Handle::block_on /// [handle]: fn@Handle::block_on
pub fn block_on<F: Future>(&self, future: F) -> F::Output { pub fn block_on<F: Future>(&self, future: F) -> F::Output {
self.handle.enter(|| match &self.kind { let _enter = self.enter();
match &self.kind {
#[cfg(feature = "rt")] #[cfg(feature = "rt")]
Kind::CurrentThread(exec) => exec.block_on(future), Kind::CurrentThread(exec) => exec.block_on(future),
#[cfg(feature = "rt-multi-thread")] #[cfg(feature = "rt-multi-thread")]
Kind::ThreadPool(exec) => exec.block_on(future), Kind::ThreadPool(exec) => exec.block_on(future),
}) }
} }
/// Enter the runtime context. This allows you to construct types that must /// Enter the runtime context.
/// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. ///
/// It will also allow you to call methods such as [`tokio::spawn`]. /// This allows you to construct types that must have an executor
/// available on creation such as [`Sleep`] or [`TcpStream`]. It will
/// also allow you to call methods such as [`tokio::spawn`].
/// ///
/// [`Sleep`]: struct@crate::time::Sleep /// [`Sleep`]: struct@crate::time::Sleep
/// [`TcpStream`]: struct@crate::net::TcpStream /// [`TcpStream`]: struct@crate::net::TcpStream
@ -426,14 +441,15 @@ cfg_rt! {
/// let s = "Hello World!".to_string(); /// let s = "Hello World!".to_string();
/// ///
/// // By entering the context, we tie `tokio::spawn` to this executor. /// // By entering the context, we tie `tokio::spawn` to this executor.
/// rt.enter(|| function_that_spawns(s)); /// let _guard = rt.enter();
/// function_that_spawns(s);
/// } /// }
/// ``` /// ```
pub fn enter<F, R>(&self, f: F) -> R pub fn enter(&self) -> EnterGuard<'_> {
where EnterGuard {
F: FnOnce() -> R, rt: self,
{ guard: context::enter(self.handle.clone()),
self.handle.enter(f) }
} }
/// Shutdown the runtime, waiting for at most `duration` for all spawned /// Shutdown the runtime, waiting for at most `duration` for all spawned

View File

@ -8,14 +8,15 @@ fn blocking_shutdown() {
let v = Arc::new(()); let v = Arc::new(());
let rt = mk_runtime(1); let rt = mk_runtime(1);
rt.enter(|| { {
let _enter = rt.enter();
for _ in 0..2 { for _ in 0..2 {
let v = v.clone(); let v = v.clone();
crate::task::spawn_blocking(move || { crate::task::spawn_blocking(move || {
assert!(1 < Arc::strong_count(&v)); assert!(1 < Arc::strong_count(&v));
}); });
} }
}); }
drop(rt); drop(rt);
assert_eq!(1, Arc::strong_count(&v)); assert_eq!(1, Arc::strong_count(&v));

View File

@ -253,21 +253,20 @@ mod tests {
#[test] #[test]
fn ctrl_c() { fn ctrl_c() {
let rt = rt(); let rt = rt();
let _enter = rt.enter();
rt.enter(|| { let mut ctrl_c = task::spawn(crate::signal::ctrl_c());
let mut ctrl_c = task::spawn(crate::signal::ctrl_c());
assert_pending!(ctrl_c.poll()); assert_pending!(ctrl_c.poll());
// Windows doesn't have a good programmatic way of sending events // Windows doesn't have a good programmatic way of sending events
// like sending signals on Unix, so we'll stub out the actual OS // like sending signals on Unix, so we'll stub out the actual OS
// integration and test that our handling works. // integration and test that our handling works.
unsafe { unsafe {
super::handler(CTRL_C_EVENT); super::handler(CTRL_C_EVENT);
} }
assert_ready_ok!(ctrl_c.poll()); assert_ready_ok!(ctrl_c.poll());
});
} }
#[test] #[test]

View File

@ -67,11 +67,10 @@ fn test_drop_on_notify() {
})); }));
{ {
rt.enter(|| { let _enter = rt.enter();
let waker = waker_ref(&task); let waker = waker_ref(&task);
let mut cx = Context::from_waker(&waker); let mut cx = Context::from_waker(&waker);
assert_pending!(task.future.lock().unwrap().as_mut().poll(&mut cx)); assert_pending!(task.future.lock().unwrap().as_mut().poll(&mut cx));
});
} }
// Get the address // Get the address

View File

@ -9,10 +9,11 @@ use tokio_test::{assert_err, assert_pending, assert_ready, task};
fn tcp_doesnt_block() { fn tcp_doesnt_block() {
let rt = rt(); let rt = rt();
let listener = rt.enter(|| { let listener = {
let _enter = rt.enter();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap() TcpListener::from_std(listener).unwrap()
}); };
drop(rt); drop(rt);
@ -27,10 +28,11 @@ fn tcp_doesnt_block() {
fn drop_wakes() { fn drop_wakes() {
let rt = rt(); let rt = rt();
let listener = rt.enter(|| { let listener = {
let _enter = rt.enter();
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
TcpListener::from_std(listener).unwrap() TcpListener::from_std(listener).unwrap()
}); };
let mut task = task::spawn(async move { let mut task = task::spawn(async move {
assert_err!(listener.accept().await); assert_err!(listener.accept().await);

View File

@ -554,23 +554,6 @@ rt_test! {
}); });
} }
#[test]
fn spawn_blocking_after_shutdown() {
let rt = rt();
let handle = rt.clone();
// Shutdown
drop(rt);
handle.enter(|| {
let res = task::spawn_blocking(|| unreachable!());
// Avoid using a tokio runtime
let out = futures::executor::block_on(res);
assert!(out.is_err());
});
}
#[test] #[test]
fn always_active_parker() { fn always_active_parker() {
// This test it to show that we will always have // This test it to show that we will always have
@ -713,9 +696,10 @@ rt_test! {
#[test] #[test]
fn enter_and_spawn() { fn enter_and_spawn() {
let rt = rt(); let rt = rt();
let handle = rt.enter(|| { let handle = {
let _enter = rt.enter();
tokio::spawn(async {}) tokio::spawn(async {})
}); };
assert_ok!(rt.block_on(handle)); assert_ok!(rt.block_on(handle));
} }