Make task::Builder::spawn* methods fallible (#4823)

This commit is contained in:
Ivan Petkov 2022-07-12 15:56:33 -07:00 committed by GitHub
parent de686b5355
commit 3b6c74a40a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 38 deletions

View File

@ -4,7 +4,7 @@
//! compilation.
mod pool;
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task};
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, SpawnError, Spawner, Task};
cfg_fs! {
pub(crate) use pool::spawn_mandatory_blocking;

View File

@ -11,6 +11,7 @@ use crate::runtime::{Builder, Callback, ToHandle};
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::io;
use std::time::Duration;
pub(crate) struct BlockingPool {
@ -82,6 +83,25 @@ pub(crate) enum Mandatory {
NonMandatory,
}
pub(crate) enum SpawnError {
/// Pool is shutting down and the task was not scheduled
ShuttingDown,
/// There are no worker threads available to take the task
/// and the OS failed to spawn a new one
NoThreads(io::Error),
}
impl From<SpawnError> for io::Error {
fn from(e: SpawnError) -> Self {
match e {
SpawnError::ShuttingDown => {
io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
}
SpawnError::NoThreads(e) => e,
}
}
}
impl Task {
pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task {
Task { task, mandatory }
@ -221,7 +241,7 @@ impl fmt::Debug for BlockingPool {
// ===== impl Spawner =====
impl Spawner {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), ()> {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), SpawnError> {
let mut shared = self.inner.shared.lock();
if shared.shutdown {
@ -231,7 +251,7 @@ impl Spawner {
task.task.shutdown();
// no need to even push this task; it would never get picked up
return Err(());
return Err(SpawnError::ShuttingDown);
}
shared.queue.push_back(task);
@ -262,7 +282,7 @@ impl Spawner {
Err(e) => {
// The OS refused to spawn the thread and there is no thread
// to pick up the task that has just been pushed to the queue.
panic!("OS can't spawn worker thread: {}", e)
return Err(SpawnError::NoThreads(e));
}
}
}

View File

@ -341,7 +341,7 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, _was_spawned) = if cfg!(debug_assertions)
let (join_handle, spawn_result) = if cfg!(debug_assertions)
&& std::mem::size_of::<F>() > 2048
{
self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None, rt)
@ -349,7 +349,14 @@ impl HandleInner {
self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None, rt)
};
join_handle
match spawn_result {
Ok(()) => join_handle,
// Compat: do not panic here, return the join_handle even though it will never resolve
Err(blocking::SpawnError::ShuttingDown) => join_handle,
Err(blocking::SpawnError::NoThreads(e)) => {
panic!("OS can't spawn worker thread: {}", e)
}
}
}
cfg_fs! {
@ -363,7 +370,7 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner(
Box::new(func),
blocking::Mandatory::Mandatory,
@ -379,7 +386,7 @@ impl HandleInner {
)
};
if was_spawned {
if spawn_result.is_ok() {
Some(join_handle)
} else {
None
@ -394,7 +401,7 @@ impl HandleInner {
is_mandatory: blocking::Mandatory,
name: Option<&str>,
rt: &dyn ToHandle,
) -> (JoinHandle<R>, bool)
) -> (JoinHandle<R>, Result<(), blocking::SpawnError>)
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
@ -424,7 +431,7 @@ impl HandleInner {
let spawned = self
.blocking_spawner
.spawn(blocking::Task::new(task, is_mandatory), rt);
(handle, spawned.is_ok())
(handle, spawned)
}
}

View File

@ -3,7 +3,7 @@ use crate::{
runtime::{context, Handle},
task::{JoinHandle, LocalSet},
};
use std::future::Future;
use std::{future::Future, io};
/// Factory which is used to configure the properties of a new task.
///
@ -48,7 +48,7 @@ use std::future::Future;
/// .spawn(async move {
/// // Process each socket concurrently.
/// process(socket).await
/// });
/// })?;
/// }
/// }
/// ```
@ -83,12 +83,12 @@ impl<'a> Builder<'a> {
/// See [`task::spawn`](crate::task::spawn) for
/// more details.
#[track_caller]
pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output>
pub fn spawn<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
super::spawn::spawn_inner(future, self.name)
Ok(super::spawn::spawn_inner(future, self.name))
}
/// Spawn a task with this builder's settings on the provided [runtime
@ -99,12 +99,16 @@ impl<'a> Builder<'a> {
/// [runtime handle]: crate::runtime::Handle
/// [`Handle::spawn`]: crate::runtime::Handle::spawn
#[track_caller]
pub fn spawn_on<Fut>(&mut self, future: Fut, handle: &Handle) -> JoinHandle<Fut::Output>
pub fn spawn_on<Fut>(
&mut self,
future: Fut,
handle: &Handle,
) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
handle.spawn_named(future, self.name)
Ok(handle.spawn_named(future, self.name))
}
/// Spawns `!Send` a task on the current [`LocalSet`] with this builder's
@ -122,12 +126,12 @@ impl<'a> Builder<'a> {
/// [`task::spawn_local`]: crate::task::spawn_local
/// [`LocalSet`]: crate::task::LocalSet
#[track_caller]
pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output>
pub fn spawn_local<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + 'static,
Fut::Output: 'static,
{
super::local::spawn_local_inner(future, self.name)
Ok(super::local::spawn_local_inner(future, self.name))
}
/// Spawns `!Send` a task on the provided [`LocalSet`] with this builder's
@ -138,12 +142,16 @@ impl<'a> Builder<'a> {
/// [`LocalSet::spawn_local`]: crate::task::LocalSet::spawn_local
/// [`LocalSet`]: crate::task::LocalSet
#[track_caller]
pub fn spawn_local_on<Fut>(self, future: Fut, local_set: &LocalSet) -> JoinHandle<Fut::Output>
pub fn spawn_local_on<Fut>(
self,
future: Fut,
local_set: &LocalSet,
) -> io::Result<JoinHandle<Fut::Output>>
where
Fut: Future + 'static,
Fut::Output: 'static,
{
local_set.spawn_named(future, self.name)
Ok(local_set.spawn_named(future, self.name))
}
/// Spawns blocking code on the blocking threadpool.
@ -155,7 +163,10 @@ impl<'a> Builder<'a> {
/// See [`task::spawn_blocking`](crate::task::spawn_blocking)
/// for more details.
#[track_caller]
pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output>
pub fn spawn_blocking<Function, Output>(
self,
function: Function,
) -> io::Result<JoinHandle<Output>>
where
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
@ -174,18 +185,20 @@ impl<'a> Builder<'a> {
self,
function: Function,
handle: &Handle,
) -> JoinHandle<Output>
) -> io::Result<JoinHandle<Output>>
where
Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static,
{
use crate::runtime::Mandatory;
let (join_handle, _was_spawned) = handle.as_inner().spawn_blocking_inner(
let (join_handle, spawn_result) = handle.as_inner().spawn_blocking_inner(
function,
Mandatory::NonMandatory,
self.name,
handle,
);
join_handle
spawn_result?;
Ok(join_handle)
}
}

View File

@ -101,13 +101,15 @@ impl<T: 'static> JoinSet<T> {
/// use tokio::task::JoinSet;
///
/// #[tokio::main]
/// async fn main() {
/// async fn main() -> std::io::Result<()> {
/// let mut set = JoinSet::new();
///
/// // Use the builder to configure a task's name before spawning it.
/// set.build_task()
/// .name("my_task")
/// .spawn(async { /* ... */ });
/// .spawn(async { /* ... */ })?;
///
/// Ok(())
/// }
/// ```
#[cfg(all(tokio_unstable, feature = "tracing"))]
@ -377,13 +379,13 @@ impl<'a, T: 'static> Builder<'a, T> {
///
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn<F>(self, future: F) -> AbortHandle
pub fn spawn<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.joinset.insert(self.builder.spawn(future))
Ok(self.joinset.insert(self.builder.spawn(future)?))
}
/// Spawn the provided task on the provided [runtime handle] with this
@ -397,13 +399,13 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`AbortHandle`]: crate::task::AbortHandle
/// [runtime handle]: crate::runtime::Handle
#[track_caller]
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> AbortHandle
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.joinset.insert(self.builder.spawn_on(future, handle))
Ok(self.joinset.insert(self.builder.spawn_on(future, handle)?))
}
/// Spawn the provided task on the current [`LocalSet`] with this builder's
@ -420,12 +422,12 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn_local<F>(self, future: F) -> AbortHandle
pub fn spawn_local<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
self.joinset.insert(self.builder.spawn_local(future))
Ok(self.joinset.insert(self.builder.spawn_local(future)?))
}
/// Spawn the provided task on the provided [`LocalSet`] with this builder's
@ -438,13 +440,14 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller]
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> AbortHandle
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
self.joinset
.insert(self.builder.spawn_local_on(future, local_set))
Ok(self
.joinset
.insert(self.builder.spawn_local_on(future, local_set)?))
}
}

View File

@ -11,6 +11,7 @@ mod tests {
let result = Builder::new()
.name("name")
.spawn(async { "task executed" })
.unwrap()
.await;
assert_eq!(result.unwrap(), "task executed");
@ -21,6 +22,7 @@ mod tests {
let result = Builder::new()
.name("name")
.spawn_blocking(|| "task executed")
.unwrap()
.await;
assert_eq!(result.unwrap(), "task executed");
@ -34,6 +36,7 @@ mod tests {
Builder::new()
.name("name")
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await;
@ -43,14 +46,20 @@ mod tests {
#[test]
async fn spawn_without_name() {
let result = Builder::new().spawn(async { "task executed" }).await;
let result = Builder::new()
.spawn(async { "task executed" })
.unwrap()
.await;
assert_eq!(result.unwrap(), "task executed");
}
#[test]
async fn spawn_blocking_without_name() {
let result = Builder::new().spawn_blocking(|| "task executed").await;
let result = Builder::new()
.spawn_blocking(|| "task executed")
.unwrap()
.await;
assert_eq!(result.unwrap(), "task executed");
}
@ -59,7 +68,12 @@ mod tests {
async fn spawn_local_without_name() {
let unsend_data = Rc::new("task executed");
let result = LocalSet::new()
.run_until(async move { Builder::new().spawn_local(async move { unsend_data }).await })
.run_until(async move {
Builder::new()
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await;
assert_eq!(*result.unwrap(), "task executed");