mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-28 12:10:37 +00:00
fs: guarantee that File::write
will attempt the write even if the runtime shuts down (#4316)
This commit is contained in:
parent
9e38ebcaa9
commit
7aad428994
@ -1,5 +1,11 @@
|
||||
cfg_rt! {
|
||||
pub(crate) use crate::runtime::spawn_blocking;
|
||||
|
||||
cfg_fs! {
|
||||
#[allow(unused_imports)]
|
||||
pub(crate) use crate::runtime::spawn_mandatory_blocking;
|
||||
}
|
||||
|
||||
pub(crate) use crate::task::JoinHandle;
|
||||
}
|
||||
|
||||
@ -16,7 +22,16 @@ cfg_not_rt! {
|
||||
{
|
||||
assert_send_sync::<JoinHandle<std::cell::Cell<()>>>();
|
||||
panic!("requires the `rt` Tokio feature flag")
|
||||
}
|
||||
|
||||
cfg_fs! {
|
||||
pub(crate) fn spawn_mandatory_blocking<F, R>(_f: F) -> Option<JoinHandle<R>>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
panic!("requires the `rt` Tokio feature flag")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct JoinHandle<R> {
|
||||
|
@ -19,17 +19,17 @@ use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::task::Poll::*;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::mocks::spawn_blocking;
|
||||
#[cfg(test)]
|
||||
use super::mocks::JoinHandle;
|
||||
#[cfg(test)]
|
||||
use super::mocks::MockFile as StdFile;
|
||||
#[cfg(not(test))]
|
||||
use crate::blocking::spawn_blocking;
|
||||
#[cfg(test)]
|
||||
use super::mocks::{spawn_blocking, spawn_mandatory_blocking};
|
||||
#[cfg(not(test))]
|
||||
use crate::blocking::JoinHandle;
|
||||
#[cfg(not(test))]
|
||||
use crate::blocking::{spawn_blocking, spawn_mandatory_blocking};
|
||||
#[cfg(not(test))]
|
||||
use std::fs::File as StdFile;
|
||||
|
||||
/// A reference to an open file on the filesystem.
|
||||
@ -649,7 +649,7 @@ impl AsyncWrite for File {
|
||||
let n = buf.copy_from(src);
|
||||
let std = me.std.clone();
|
||||
|
||||
inner.state = Busy(spawn_blocking(move || {
|
||||
let blocking_task_join_handle = spawn_mandatory_blocking(move || {
|
||||
let res = if let Some(seek) = seek {
|
||||
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
|
||||
} else {
|
||||
@ -657,7 +657,12 @@ impl AsyncWrite for File {
|
||||
};
|
||||
|
||||
(Operation::Write(res), buf)
|
||||
}));
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, "background task failed")
|
||||
})?;
|
||||
|
||||
inner.state = Busy(blocking_task_join_handle);
|
||||
|
||||
return Ready(Ok(n));
|
||||
}
|
||||
|
@ -105,6 +105,21 @@ where
|
||||
JoinHandle { rx }
|
||||
}
|
||||
|
||||
pub(super) fn spawn_mandatory_blocking<F, R>(f: F) -> Option<JoinHandle<R>>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let task = Box::new(move || {
|
||||
let _ = tx.send(f());
|
||||
});
|
||||
|
||||
QUEUE.with(|cell| cell.borrow_mut().push_back(task));
|
||||
|
||||
Some(JoinHandle { rx })
|
||||
}
|
||||
|
||||
impl<T> Future for JoinHandle<T> {
|
||||
type Output = Result<T, io::Error>;
|
||||
|
||||
|
@ -4,7 +4,11 @@
|
||||
//! compilation.
|
||||
|
||||
mod pool;
|
||||
pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner};
|
||||
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task};
|
||||
|
||||
cfg_fs! {
|
||||
pub(crate) use pool::spawn_mandatory_blocking;
|
||||
}
|
||||
|
||||
mod schedule;
|
||||
mod shutdown;
|
||||
|
@ -70,11 +70,40 @@ struct Shared {
|
||||
worker_thread_index: usize,
|
||||
}
|
||||
|
||||
type Task = task::UnownedTask<NoopSchedule>;
|
||||
pub(crate) struct Task {
|
||||
task: task::UnownedTask<NoopSchedule>,
|
||||
mandatory: Mandatory,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub(crate) enum Mandatory {
|
||||
#[cfg_attr(not(fs), allow(dead_code))]
|
||||
Mandatory,
|
||||
NonMandatory,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task {
|
||||
Task { task, mandatory }
|
||||
}
|
||||
|
||||
fn run(self) {
|
||||
self.task.run();
|
||||
}
|
||||
|
||||
fn shutdown_or_run_if_mandatory(self) {
|
||||
match self.mandatory {
|
||||
Mandatory::NonMandatory => self.task.shutdown(),
|
||||
Mandatory::Mandatory => self.task.run(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const KEEP_ALIVE: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Runs the provided function on an executor dedicated to blocking operations.
|
||||
/// Tasks will be scheduled as non-mandatory, meaning they may not get executed
|
||||
/// in case of runtime shutdown.
|
||||
pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
@ -84,6 +113,25 @@ where
|
||||
rt.spawn_blocking(func)
|
||||
}
|
||||
|
||||
cfg_fs! {
|
||||
#[cfg_attr(any(
|
||||
all(loom, not(test)), // the function is covered by loom tests
|
||||
test
|
||||
), allow(dead_code))]
|
||||
/// Runs the provided function on an executor dedicated to blocking
|
||||
/// operations. Tasks will be scheduled as mandatory, meaning they are
|
||||
/// guaranteed to run unless a shutdown is already taking place. In case a
|
||||
/// shutdown is already taking place, `None` will be returned.
|
||||
pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
let rt = context::current();
|
||||
rt.spawn_mandatory_blocking(func)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl BlockingPool =====
|
||||
|
||||
impl BlockingPool {
|
||||
@ -176,8 +224,10 @@ impl Spawner {
|
||||
let mut shared = self.inner.shared.lock();
|
||||
|
||||
if shared.shutdown {
|
||||
// Shutdown the task
|
||||
task.shutdown();
|
||||
// Shutdown the task: it's fine to shutdown this task (even if
|
||||
// mandatory) because it was scheduled after the shutdown of the
|
||||
// runtime began.
|
||||
task.task.shutdown();
|
||||
|
||||
// no need to even push this task; it would never get picked up
|
||||
return Err(());
|
||||
@ -302,7 +352,8 @@ impl Inner {
|
||||
// Drain the queue
|
||||
while let Some(task) = shared.queue.pop_front() {
|
||||
drop(shared);
|
||||
task.shutdown();
|
||||
|
||||
task.shutdown_or_run_if_mandatory();
|
||||
|
||||
shared = self.shared.lock();
|
||||
}
|
||||
|
@ -189,15 +189,56 @@ impl Handle {
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
|
||||
self.spawn_blocking_inner(Box::new(func), None)
|
||||
} else {
|
||||
self.spawn_blocking_inner(func, None)
|
||||
let (join_handle, _was_spawned) =
|
||||
if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
|
||||
self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None)
|
||||
} else {
|
||||
self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None)
|
||||
};
|
||||
|
||||
join_handle
|
||||
}
|
||||
|
||||
cfg_fs! {
|
||||
#[track_caller]
|
||||
#[cfg_attr(any(
|
||||
all(loom, not(test)), // the function is covered by loom tests
|
||||
test
|
||||
), allow(dead_code))]
|
||||
pub(crate) fn spawn_mandatory_blocking<F, R>(&self, func: F) -> Option<JoinHandle<R>>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
|
||||
self.spawn_blocking_inner(
|
||||
Box::new(func),
|
||||
blocking::Mandatory::Mandatory,
|
||||
None
|
||||
)
|
||||
} else {
|
||||
self.spawn_blocking_inner(
|
||||
func,
|
||||
blocking::Mandatory::Mandatory,
|
||||
None
|
||||
)
|
||||
};
|
||||
|
||||
if was_spawned {
|
||||
Some(join_handle)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub(crate) fn spawn_blocking_inner<F, R>(&self, func: F, name: Option<&str>) -> JoinHandle<R>
|
||||
pub(crate) fn spawn_blocking_inner<F, R>(
|
||||
&self,
|
||||
func: F,
|
||||
is_mandatory: blocking::Mandatory,
|
||||
name: Option<&str>,
|
||||
) -> (JoinHandle<R>, bool)
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
@ -223,8 +264,10 @@ impl Handle {
|
||||
let _ = name;
|
||||
|
||||
let (task, handle) = task::unowned(fut, NoopSchedule);
|
||||
let _ = self.blocking_spawner.spawn(task, self);
|
||||
handle
|
||||
let spawned = self
|
||||
.blocking_spawner
|
||||
.spawn(blocking::Task::new(task, is_mandatory), self);
|
||||
(handle, spawned.is_ok())
|
||||
}
|
||||
|
||||
/// Runs a future to completion on this `Handle`'s associated `Runtime`.
|
||||
|
@ -201,6 +201,14 @@ cfg_rt! {
|
||||
use blocking::BlockingPool;
|
||||
pub(crate) use blocking::spawn_blocking;
|
||||
|
||||
cfg_trace! {
|
||||
pub(crate) use blocking::Mandatory;
|
||||
}
|
||||
|
||||
cfg_fs! {
|
||||
pub(crate) use blocking::spawn_mandatory_blocking;
|
||||
}
|
||||
|
||||
mod builder;
|
||||
pub use self::builder::Builder;
|
||||
|
||||
|
@ -23,6 +23,56 @@ fn blocking_shutdown() {
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_mandatory_blocking_should_always_run() {
|
||||
use crate::runtime::tests::loom_oneshot;
|
||||
loom::model(|| {
|
||||
let rt = runtime::Builder::new_current_thread().build().unwrap();
|
||||
|
||||
let (tx, rx) = loom_oneshot::channel();
|
||||
let _enter = rt.enter();
|
||||
runtime::spawn_blocking(|| {});
|
||||
runtime::spawn_mandatory_blocking(move || {
|
||||
let _ = tx.send(());
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
drop(rt);
|
||||
|
||||
// This call will deadlock if `spawn_mandatory_blocking` doesn't run.
|
||||
let () = rx.recv();
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spawn_mandatory_blocking_should_run_even_when_shutting_down_from_other_thread() {
|
||||
use crate::runtime::tests::loom_oneshot;
|
||||
loom::model(|| {
|
||||
let rt = runtime::Builder::new_current_thread().build().unwrap();
|
||||
let handle = rt.handle().clone();
|
||||
|
||||
// Drop the runtime in a different thread
|
||||
{
|
||||
loom::thread::spawn(move || {
|
||||
drop(rt);
|
||||
});
|
||||
}
|
||||
|
||||
let _enter = handle.enter();
|
||||
let (tx, rx) = loom_oneshot::channel();
|
||||
let handle = runtime::spawn_mandatory_blocking(move || {
|
||||
let _ = tx.send(());
|
||||
});
|
||||
|
||||
// handle.is_some() means that `spawn_mandatory_blocking`
|
||||
// promised us to run the blocking task
|
||||
if handle.is_some() {
|
||||
// This call will deadlock if `spawn_mandatory_blocking` doesn't run.
|
||||
let () = rx.recv();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn mk_runtime(num_threads: usize) -> Runtime {
|
||||
runtime::Builder::new_multi_thread()
|
||||
.worker_threads(num_threads)
|
||||
|
@ -107,6 +107,9 @@ impl<'a> Builder<'a> {
|
||||
Function: FnOnce() -> Output + Send + 'static,
|
||||
Output: Send + 'static,
|
||||
{
|
||||
context::current().spawn_blocking_inner(function, self.name)
|
||||
use crate::runtime::Mandatory;
|
||||
let (join_handle, _was_spawned) =
|
||||
context::current().spawn_blocking_inner(function, Mandatory::NonMandatory, self.name);
|
||||
join_handle
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user