task: add size check for user-supplied future (#6692)

This commit is contained in:
Motoyuki Kimura 2024-07-18 20:54:37 +09:00 committed by GitHub
parent f71bded943
commit da17c61464
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 96 additions and 18 deletions

View File

@ -6,7 +6,7 @@ use crate::runtime::blocking::schedule::BlockingSchedule;
use crate::runtime::blocking::{shutdown, BlockingTask};
use crate::runtime::builder::ThreadNameFn;
use crate::runtime::task::{self, JoinHandle};
use crate::runtime::{Builder, Callback, Handle};
use crate::runtime::{Builder, Callback, Handle, BOX_FUTURE_THRESHOLD};
use crate::util::metric_atomics::MetricAtomicUsize;
use std::collections::{HashMap, VecDeque};
@ -296,7 +296,7 @@ impl Spawner {
R: Send + 'static,
{
let (join_handle, spawn_result) =
if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.spawn_blocking_inner(Box::new(func), Mandatory::NonMandatory, None, rt)
} else {
self.spawn_blocking_inner(func, Mandatory::NonMandatory, None, rt)
@ -323,7 +323,7 @@ impl Spawner {
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
let (join_handle, spawn_result) = if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.spawn_blocking_inner(
Box::new(func),
Mandatory::Mandatory,

View File

@ -16,6 +16,7 @@ pub struct Handle {
}
use crate::runtime::task::JoinHandle;
use crate::runtime::BOX_FUTURE_THRESHOLD;
use crate::util::error::{CONTEXT_MISSING_ERROR, THREAD_LOCAL_DESTROYED_ERROR};
use std::future::Future;
@ -188,7 +189,11 @@ impl Handle {
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.spawn_named(future, None)
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.spawn_named(Box::pin(future), None)
} else {
self.spawn_named(future, None)
}
}
/// Runs the provided function on an executor dedicated to blocking
@ -291,6 +296,15 @@ impl Handle {
/// [`tokio::time`]: crate::time
#[track_caller]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.block_on_inner(Box::pin(future))
} else {
self.block_on_inner(future)
}
}
#[track_caller]
fn block_on_inner<F: Future>(&self, future: F) -> F::Output {
#[cfg(all(
tokio_unstable,
tokio_taskdump,

View File

@ -385,6 +385,10 @@ cfg_rt! {
mod runtime;
pub use runtime::{Runtime, RuntimeFlavor};
/// Boundary value to prevent stack overflow caused by a large-sized
/// Future being placed in the stack.
pub(crate) const BOX_FUTURE_THRESHOLD: usize = 2048;
mod thread_id;
pub(crate) use thread_id::ThreadId;

View File

@ -1,3 +1,4 @@
use super::BOX_FUTURE_THRESHOLD;
use crate::runtime::blocking::BlockingPool;
use crate::runtime::scheduler::CurrentThread;
use crate::runtime::{context, EnterGuard, Handle};
@ -240,7 +241,11 @@ impl Runtime {
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.handle.spawn(future)
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.handle.spawn_named(Box::pin(future), None)
} else {
self.handle.spawn_named(future, None)
}
}
/// Runs the provided function on an executor dedicated to blocking operations.
@ -324,6 +329,15 @@ impl Runtime {
/// [handle]: fn@Handle::block_on
#[track_caller]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.block_on_inner(Box::pin(future))
} else {
self.block_on_inner(future)
}
}
#[track_caller]
fn block_on_inner<F: Future>(&self, future: F) -> F::Output {
#[cfg(all(
tokio_unstable,
tokio_taskdump,

View File

@ -1,6 +1,6 @@
#![allow(unreachable_pub)]
use crate::{
runtime::Handle,
runtime::{Handle, BOX_FUTURE_THRESHOLD},
task::{JoinHandle, LocalSet},
};
use std::{future::Future, io};
@ -88,7 +88,13 @@ impl<'a> Builder<'a> {
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
Ok(super::spawn::spawn_inner(future, self.name))
Ok(
if cfg!(debug_assertions) && std::mem::size_of::<Fut>() > BOX_FUTURE_THRESHOLD {
super::spawn::spawn_inner(Box::pin(future), self.name)
} else {
super::spawn::spawn_inner(future, self.name)
},
)
}
/// Spawn a task with this builder's settings on the provided [runtime
@ -104,7 +110,13 @@ impl<'a> Builder<'a> {
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
Ok(handle.spawn_named(future, self.name))
Ok(
if cfg!(debug_assertions) && std::mem::size_of::<Fut>() > BOX_FUTURE_THRESHOLD {
handle.spawn_named(Box::pin(future), self.name)
} else {
handle.spawn_named(future, self.name)
},
)
}
/// Spawns `!Send` a task on the current [`LocalSet`] with this builder's
@ -127,7 +139,13 @@ impl<'a> Builder<'a> {
Fut: Future + 'static,
Fut::Output: 'static,
{
Ok(super::local::spawn_local_inner(future, self.name))
Ok(
if cfg!(debug_assertions) && std::mem::size_of::<Fut>() > BOX_FUTURE_THRESHOLD {
super::local::spawn_local_inner(Box::pin(future), self.name)
} else {
super::local::spawn_local_inner(future, self.name)
},
)
}
/// Spawns `!Send` a task on the provided [`LocalSet`] with this builder's
@ -188,12 +206,22 @@ impl<'a> Builder<'a> {
Output: Send + 'static,
{
use crate::runtime::Mandatory;
let (join_handle, spawn_result) = handle.inner.blocking_spawner().spawn_blocking_inner(
function,
Mandatory::NonMandatory,
self.name,
handle,
);
let (join_handle, spawn_result) =
if cfg!(debug_assertions) && std::mem::size_of::<Function>() > BOX_FUTURE_THRESHOLD {
handle.inner.blocking_spawner().spawn_blocking_inner(
Box::new(function),
Mandatory::NonMandatory,
self.name,
handle,
)
} else {
handle.inner.blocking_spawner().spawn_blocking_inner(
function,
Mandatory::NonMandatory,
self.name,
handle,
)
};
spawn_result?;
Ok(join_handle)

View File

@ -4,7 +4,7 @@ use crate::loom::sync::{Arc, Mutex};
#[cfg(tokio_unstable)]
use crate::runtime;
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::runtime::{context, ThreadId};
use crate::runtime::{context, ThreadId, BOX_FUTURE_THRESHOLD};
use crate::sync::AtomicWaker;
use crate::util::RcCell;
@ -367,7 +367,11 @@ cfg_rt! {
F: Future + 'static,
F::Output: 'static,
{
spawn_local_inner(future, None)
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
spawn_local_inner(Box::pin(future), None)
} else {
spawn_local_inner(future, None)
}
}
@ -641,6 +645,19 @@ impl LocalSet {
future: F,
name: Option<&str>,
) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
self.spawn_named_inner(Box::pin(future), name)
} else {
self.spawn_named_inner(future, name)
}
}
#[track_caller]
fn spawn_named_inner<F>(&self, future: F, name: Option<&str>) -> JoinHandle<F::Output>
where
F: Future + 'static,
F::Output: 'static,

View File

@ -1,3 +1,4 @@
use crate::runtime::BOX_FUTURE_THRESHOLD;
use crate::task::JoinHandle;
use std::future::Future;
@ -168,7 +169,7 @@ cfg_rt! {
{
// preventing stack overflows on debug mode, by quickly sending the
// task to the heap.
if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 {
if cfg!(debug_assertions) && std::mem::size_of::<F>() > BOX_FUTURE_THRESHOLD {
spawn_inner(Box::pin(future), None)
} else {
spawn_inner(future, None)