Make task::Builder::spawn* methods fallible (#4823)

This commit is contained in:
Ivan Petkov 2022-07-12 15:56:33 -07:00 committed by GitHub
parent de686b5355
commit 3b6c74a40a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 38 deletions

View File

@ -4,7 +4,7 @@
//! compilation. //! compilation.
mod pool; mod pool;
pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, Spawner, Task}; pub(crate) use pool::{spawn_blocking, BlockingPool, Mandatory, SpawnError, Spawner, Task};
cfg_fs! { cfg_fs! {
pub(crate) use pool::spawn_mandatory_blocking; pub(crate) use pool::spawn_mandatory_blocking;

View File

@ -11,6 +11,7 @@ use crate::runtime::{Builder, Callback, ToHandle};
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::fmt; use std::fmt;
use std::io;
use std::time::Duration; use std::time::Duration;
pub(crate) struct BlockingPool { pub(crate) struct BlockingPool {
@ -82,6 +83,25 @@ pub(crate) enum Mandatory {
NonMandatory, NonMandatory,
} }
pub(crate) enum SpawnError {
/// Pool is shutting down and the task was not scheduled
ShuttingDown,
/// There are no worker threads available to take the task
/// and the OS failed to spawn a new one
NoThreads(io::Error),
}
impl From<SpawnError> for io::Error {
fn from(e: SpawnError) -> Self {
match e {
SpawnError::ShuttingDown => {
io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
}
SpawnError::NoThreads(e) => e,
}
}
}
impl Task { impl Task {
pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task { pub(crate) fn new(task: task::UnownedTask<NoopSchedule>, mandatory: Mandatory) -> Task {
Task { task, mandatory } Task { task, mandatory }
@ -221,7 +241,7 @@ impl fmt::Debug for BlockingPool {
// ===== impl Spawner ===== // ===== impl Spawner =====
impl Spawner { impl Spawner {
pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), ()> { pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), SpawnError> {
let mut shared = self.inner.shared.lock(); let mut shared = self.inner.shared.lock();
if shared.shutdown { if shared.shutdown {
@ -231,7 +251,7 @@ impl Spawner {
task.task.shutdown(); task.task.shutdown();
// no need to even push this task; it would never get picked up // no need to even push this task; it would never get picked up
return Err(()); return Err(SpawnError::ShuttingDown);
} }
shared.queue.push_back(task); shared.queue.push_back(task);
@ -262,7 +282,7 @@ impl Spawner {
Err(e) => { Err(e) => {
// The OS refused to spawn the thread and there is no thread // The OS refused to spawn the thread and there is no thread
// to pick up the task that has just been pushed to the queue. // to pick up the task that has just been pushed to the queue.
panic!("OS can't spawn worker thread: {}", e) return Err(SpawnError::NoThreads(e));
} }
} }
} }

View File

@ -341,7 +341,7 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static, F: FnOnce() -> R + Send + 'static,
R: Send + 'static, R: Send + 'static,
{ {
let (join_handle, _was_spawned) = if cfg!(debug_assertions) let (join_handle, spawn_result) = if cfg!(debug_assertions)
&& std::mem::size_of::<F>() > 2048 && std::mem::size_of::<F>() > 2048
{ {
self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None, rt) self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None, rt)
@ -349,7 +349,14 @@ impl HandleInner {
self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None, rt) self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None, rt)
}; };
join_handle match spawn_result {
Ok(()) => join_handle,
// Compat: do not panic here, return the join_handle even though it will never resolve
Err(blocking::SpawnError::ShuttingDown) => join_handle,
Err(blocking::SpawnError::NoThreads(e)) => {
panic!("OS can't spawn worker thread: {}", e)
}
}
} }
cfg_fs! { cfg_fs! {
@ -363,7 +370,7 @@ impl HandleInner {
F: FnOnce() -> R + Send + 'static, F: FnOnce() -> R + Send + 'static,
R: Send + 'static, R: Send + 'static,
{ {
let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 { let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
self.spawn_blocking_inner( self.spawn_blocking_inner(
Box::new(func), Box::new(func),
blocking::Mandatory::Mandatory, blocking::Mandatory::Mandatory,
@ -379,7 +386,7 @@ impl HandleInner {
) )
}; };
if was_spawned { if spawn_result.is_ok() {
Some(join_handle) Some(join_handle)
} else { } else {
None None
@ -394,7 +401,7 @@ impl HandleInner {
is_mandatory: blocking::Mandatory, is_mandatory: blocking::Mandatory,
name: Option<&str>, name: Option<&str>,
rt: &dyn ToHandle, rt: &dyn ToHandle,
) -> (JoinHandle<R>, bool) ) -> (JoinHandle<R>, Result<(), blocking::SpawnError>)
where where
F: FnOnce() -> R + Send + 'static, F: FnOnce() -> R + Send + 'static,
R: Send + 'static, R: Send + 'static,
@ -424,7 +431,7 @@ impl HandleInner {
let spawned = self let spawned = self
.blocking_spawner .blocking_spawner
.spawn(blocking::Task::new(task, is_mandatory), rt); .spawn(blocking::Task::new(task, is_mandatory), rt);
(handle, spawned.is_ok()) (handle, spawned)
} }
} }

View File

@ -3,7 +3,7 @@ use crate::{
runtime::{context, Handle}, runtime::{context, Handle},
task::{JoinHandle, LocalSet}, task::{JoinHandle, LocalSet},
}; };
use std::future::Future; use std::{future::Future, io};
/// Factory which is used to configure the properties of a new task. /// Factory which is used to configure the properties of a new task.
/// ///
@ -48,7 +48,7 @@ use std::future::Future;
/// .spawn(async move { /// .spawn(async move {
/// // Process each socket concurrently. /// // Process each socket concurrently.
/// process(socket).await /// process(socket).await
/// }); /// })?;
/// } /// }
/// } /// }
/// ``` /// ```
@ -83,12 +83,12 @@ impl<'a> Builder<'a> {
/// See [`task::spawn`](crate::task::spawn) for /// See [`task::spawn`](crate::task::spawn) for
/// more details. /// more details.
#[track_caller] #[track_caller]
pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> pub fn spawn<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where where
Fut: Future + Send + 'static, Fut: Future + Send + 'static,
Fut::Output: Send + 'static, Fut::Output: Send + 'static,
{ {
super::spawn::spawn_inner(future, self.name) Ok(super::spawn::spawn_inner(future, self.name))
} }
/// Spawn a task with this builder's settings on the provided [runtime /// Spawn a task with this builder's settings on the provided [runtime
@ -99,12 +99,16 @@ impl<'a> Builder<'a> {
/// [runtime handle]: crate::runtime::Handle /// [runtime handle]: crate::runtime::Handle
/// [`Handle::spawn`]: crate::runtime::Handle::spawn /// [`Handle::spawn`]: crate::runtime::Handle::spawn
#[track_caller] #[track_caller]
pub fn spawn_on<Fut>(&mut self, future: Fut, handle: &Handle) -> JoinHandle<Fut::Output> pub fn spawn_on<Fut>(
&mut self,
future: Fut,
handle: &Handle,
) -> io::Result<JoinHandle<Fut::Output>>
where where
Fut: Future + Send + 'static, Fut: Future + Send + 'static,
Fut::Output: Send + 'static, Fut::Output: Send + 'static,
{ {
handle.spawn_named(future, self.name) Ok(handle.spawn_named(future, self.name))
} }
/// Spawns `!Send` a task on the current [`LocalSet`] with this builder's /// Spawns `!Send` a task on the current [`LocalSet`] with this builder's
@ -122,12 +126,12 @@ impl<'a> Builder<'a> {
/// [`task::spawn_local`]: crate::task::spawn_local /// [`task::spawn_local`]: crate::task::spawn_local
/// [`LocalSet`]: crate::task::LocalSet /// [`LocalSet`]: crate::task::LocalSet
#[track_caller] #[track_caller]
pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> pub fn spawn_local<Fut>(self, future: Fut) -> io::Result<JoinHandle<Fut::Output>>
where where
Fut: Future + 'static, Fut: Future + 'static,
Fut::Output: 'static, Fut::Output: 'static,
{ {
super::local::spawn_local_inner(future, self.name) Ok(super::local::spawn_local_inner(future, self.name))
} }
/// Spawns `!Send` a task on the provided [`LocalSet`] with this builder's /// Spawns `!Send` a task on the provided [`LocalSet`] with this builder's
@ -138,12 +142,16 @@ impl<'a> Builder<'a> {
/// [`LocalSet::spawn_local`]: crate::task::LocalSet::spawn_local /// [`LocalSet::spawn_local`]: crate::task::LocalSet::spawn_local
/// [`LocalSet`]: crate::task::LocalSet /// [`LocalSet`]: crate::task::LocalSet
#[track_caller] #[track_caller]
pub fn spawn_local_on<Fut>(self, future: Fut, local_set: &LocalSet) -> JoinHandle<Fut::Output> pub fn spawn_local_on<Fut>(
self,
future: Fut,
local_set: &LocalSet,
) -> io::Result<JoinHandle<Fut::Output>>
where where
Fut: Future + 'static, Fut: Future + 'static,
Fut::Output: 'static, Fut::Output: 'static,
{ {
local_set.spawn_named(future, self.name) Ok(local_set.spawn_named(future, self.name))
} }
/// Spawns blocking code on the blocking threadpool. /// Spawns blocking code on the blocking threadpool.
@ -155,7 +163,10 @@ impl<'a> Builder<'a> {
/// See [`task::spawn_blocking`](crate::task::spawn_blocking) /// See [`task::spawn_blocking`](crate::task::spawn_blocking)
/// for more details. /// for more details.
#[track_caller] #[track_caller]
pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output> pub fn spawn_blocking<Function, Output>(
self,
function: Function,
) -> io::Result<JoinHandle<Output>>
where where
Function: FnOnce() -> Output + Send + 'static, Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static, Output: Send + 'static,
@ -174,18 +185,20 @@ impl<'a> Builder<'a> {
self, self,
function: Function, function: Function,
handle: &Handle, handle: &Handle,
) -> JoinHandle<Output> ) -> io::Result<JoinHandle<Output>>
where where
Function: FnOnce() -> Output + Send + 'static, Function: FnOnce() -> Output + Send + 'static,
Output: Send + 'static, Output: Send + 'static,
{ {
use crate::runtime::Mandatory; use crate::runtime::Mandatory;
let (join_handle, _was_spawned) = handle.as_inner().spawn_blocking_inner( let (join_handle, spawn_result) = handle.as_inner().spawn_blocking_inner(
function, function,
Mandatory::NonMandatory, Mandatory::NonMandatory,
self.name, self.name,
handle, handle,
); );
join_handle
spawn_result?;
Ok(join_handle)
} }
} }

View File

@ -101,13 +101,15 @@ impl<T: 'static> JoinSet<T> {
/// use tokio::task::JoinSet; /// use tokio::task::JoinSet;
/// ///
/// #[tokio::main] /// #[tokio::main]
/// async fn main() { /// async fn main() -> std::io::Result<()> {
/// let mut set = JoinSet::new(); /// let mut set = JoinSet::new();
/// ///
/// // Use the builder to configure a task's name before spawning it. /// // Use the builder to configure a task's name before spawning it.
/// set.build_task() /// set.build_task()
/// .name("my_task") /// .name("my_task")
/// .spawn(async { /* ... */ }); /// .spawn(async { /* ... */ })?;
///
/// Ok(())
/// } /// }
/// ``` /// ```
#[cfg(all(tokio_unstable, feature = "tracing"))] #[cfg(all(tokio_unstable, feature = "tracing"))]
@ -377,13 +379,13 @@ impl<'a, T: 'static> Builder<'a, T> {
/// ///
/// [`AbortHandle`]: crate::task::AbortHandle /// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller] #[track_caller]
pub fn spawn<F>(self, future: F) -> AbortHandle pub fn spawn<F>(self, future: F) -> std::io::Result<AbortHandle>
where where
F: Future<Output = T>, F: Future<Output = T>,
F: Send + 'static, F: Send + 'static,
T: Send, T: Send,
{ {
self.joinset.insert(self.builder.spawn(future)) Ok(self.joinset.insert(self.builder.spawn(future)?))
} }
/// Spawn the provided task on the provided [runtime handle] with this /// Spawn the provided task on the provided [runtime handle] with this
@ -397,13 +399,13 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`AbortHandle`]: crate::task::AbortHandle /// [`AbortHandle`]: crate::task::AbortHandle
/// [runtime handle]: crate::runtime::Handle /// [runtime handle]: crate::runtime::Handle
#[track_caller] #[track_caller]
pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> AbortHandle pub fn spawn_on<F>(mut self, future: F, handle: &Handle) -> std::io::Result<AbortHandle>
where where
F: Future<Output = T>, F: Future<Output = T>,
F: Send + 'static, F: Send + 'static,
T: Send, T: Send,
{ {
self.joinset.insert(self.builder.spawn_on(future, handle)) Ok(self.joinset.insert(self.builder.spawn_on(future, handle)?))
} }
/// Spawn the provided task on the current [`LocalSet`] with this builder's /// Spawn the provided task on the current [`LocalSet`] with this builder's
@ -420,12 +422,12 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet /// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle /// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller] #[track_caller]
pub fn spawn_local<F>(self, future: F) -> AbortHandle pub fn spawn_local<F>(self, future: F) -> std::io::Result<AbortHandle>
where where
F: Future<Output = T>, F: Future<Output = T>,
F: 'static, F: 'static,
{ {
self.joinset.insert(self.builder.spawn_local(future)) Ok(self.joinset.insert(self.builder.spawn_local(future)?))
} }
/// Spawn the provided task on the provided [`LocalSet`] with this builder's /// Spawn the provided task on the provided [`LocalSet`] with this builder's
@ -438,13 +440,14 @@ impl<'a, T: 'static> Builder<'a, T> {
/// [`LocalSet`]: crate::task::LocalSet /// [`LocalSet`]: crate::task::LocalSet
/// [`AbortHandle`]: crate::task::AbortHandle /// [`AbortHandle`]: crate::task::AbortHandle
#[track_caller] #[track_caller]
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> AbortHandle pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> std::io::Result<AbortHandle>
where where
F: Future<Output = T>, F: Future<Output = T>,
F: 'static, F: 'static,
{ {
self.joinset Ok(self
.insert(self.builder.spawn_local_on(future, local_set)) .joinset
.insert(self.builder.spawn_local_on(future, local_set)?))
} }
} }

View File

@ -11,6 +11,7 @@ mod tests {
let result = Builder::new() let result = Builder::new()
.name("name") .name("name")
.spawn(async { "task executed" }) .spawn(async { "task executed" })
.unwrap()
.await; .await;
assert_eq!(result.unwrap(), "task executed"); assert_eq!(result.unwrap(), "task executed");
@ -21,6 +22,7 @@ mod tests {
let result = Builder::new() let result = Builder::new()
.name("name") .name("name")
.spawn_blocking(|| "task executed") .spawn_blocking(|| "task executed")
.unwrap()
.await; .await;
assert_eq!(result.unwrap(), "task executed"); assert_eq!(result.unwrap(), "task executed");
@ -34,6 +36,7 @@ mod tests {
Builder::new() Builder::new()
.name("name") .name("name")
.spawn_local(async move { unsend_data }) .spawn_local(async move { unsend_data })
.unwrap()
.await .await
}) })
.await; .await;
@ -43,14 +46,20 @@ mod tests {
#[test] #[test]
async fn spawn_without_name() { async fn spawn_without_name() {
let result = Builder::new().spawn(async { "task executed" }).await; let result = Builder::new()
.spawn(async { "task executed" })
.unwrap()
.await;
assert_eq!(result.unwrap(), "task executed"); assert_eq!(result.unwrap(), "task executed");
} }
#[test] #[test]
async fn spawn_blocking_without_name() { async fn spawn_blocking_without_name() {
let result = Builder::new().spawn_blocking(|| "task executed").await; let result = Builder::new()
.spawn_blocking(|| "task executed")
.unwrap()
.await;
assert_eq!(result.unwrap(), "task executed"); assert_eq!(result.unwrap(), "task executed");
} }
@ -59,7 +68,12 @@ mod tests {
async fn spawn_local_without_name() { async fn spawn_local_without_name() {
let unsend_data = Rc::new("task executed"); let unsend_data = Rc::new("task executed");
let result = LocalSet::new() let result = LocalSet::new()
.run_until(async move { Builder::new().spawn_local(async move { unsend_data }).await }) .run_until(async move {
Builder::new()
.spawn_local(async move { unsend_data })
.unwrap()
.await
})
.await; .await;
assert_eq!(*result.unwrap(), "task executed"); assert_eq!(*result.unwrap(), "task executed");