mirror of
https://github.com/tokio-rs/tokio.git
synced 2025-09-25 12:00:35 +00:00
task: improve LocalPoolHandle
(#4680)
This commit is contained in:
parent
c98be229ff
commit
d8fb721de2
@ -9,7 +9,44 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::{spawn_local, JoinHandle, LocalSet};
|
||||
|
||||
/// A handle to a local pool, used for spawning `!Send` tasks.
|
||||
/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
|
||||
///
|
||||
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
|
||||
/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
|
||||
/// execute on the same thread) inside the Future you supply to the various spawn methods
|
||||
/// of `LocalPoolHandle`,
|
||||
///
|
||||
/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
|
||||
/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::rc::Rc;
|
||||
/// use tokio::{self, task };
|
||||
/// use tokio_util::task::LocalPoolHandle;
|
||||
///
|
||||
/// #[tokio::main(flavor = "current_thread")]
|
||||
/// async fn main() {
|
||||
/// let pool = LocalPoolHandle::new(5);
|
||||
///
|
||||
/// let output = pool.spawn_pinned(|| {
|
||||
/// // `data` is !Send + !Sync
|
||||
/// let data = Rc::new("local data");
|
||||
/// let data_clone = data.clone();
|
||||
///
|
||||
/// async move {
|
||||
/// task::spawn_local(async move {
|
||||
/// println!("{}", data_clone);
|
||||
/// });
|
||||
///
|
||||
/// data.to_string()
|
||||
/// }
|
||||
/// }).await.unwrap();
|
||||
/// println!("output: {}", output);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
#[derive(Clone)]
|
||||
pub struct LocalPoolHandle {
|
||||
pool: Arc<LocalPool>,
|
||||
@ -33,6 +70,22 @@ impl LocalPoolHandle {
|
||||
LocalPoolHandle { pool }
|
||||
}
|
||||
|
||||
/// Returns the number of threads of the Pool.
|
||||
#[inline]
|
||||
pub fn num_threads(&self) -> usize {
|
||||
self.pool.workers.len()
|
||||
}
|
||||
|
||||
/// Returns the number of tasks scheduled on each worker. The indices of the
|
||||
/// worker threads correspond to the indices of the returned `Vec`.
|
||||
pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
|
||||
self.pool
|
||||
.workers
|
||||
.iter()
|
||||
.map(|worker| worker.task_count.load(Ordering::SeqCst))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
/// Spawn a task onto a worker thread and pin it there so it can't be moved
|
||||
/// off of the thread. Note that the future is not [`Send`], but the
|
||||
/// [`FnOnce`] which creates it is.
|
||||
@ -69,7 +122,60 @@ impl LocalPoolHandle {
|
||||
Fut: Future + 'static,
|
||||
Fut::Output: Send + 'static,
|
||||
{
|
||||
self.pool.spawn_pinned(create_task)
|
||||
self.pool
|
||||
.spawn_pinned(create_task, WorkerChoice::LeastBurdened)
|
||||
}
|
||||
|
||||
/// Differs from `spawn_pinned` only in that you can choose a specific worker thread
|
||||
/// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
|
||||
/// number of tasks scheduled.
|
||||
///
|
||||
/// A worker thread is chosen by index. Indices are 0 based and the largest index
|
||||
/// is given by `num_threads() - 1`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method panics if the index is out of bounds.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// This method can be used to spawn a task on all worker threads of the pool:
|
||||
///
|
||||
/// ```
|
||||
/// use tokio_util::task::LocalPoolHandle;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// const NUM_WORKERS: usize = 3;
|
||||
/// let pool = LocalPoolHandle::new(NUM_WORKERS);
|
||||
/// let handles = (0..pool.num_threads())
|
||||
/// .map(|worker_idx| {
|
||||
/// pool.spawn_pinned_by_idx(
|
||||
/// || {
|
||||
/// async {
|
||||
/// "test"
|
||||
/// }
|
||||
/// },
|
||||
/// worker_idx,
|
||||
/// )
|
||||
/// })
|
||||
/// .collect::<Vec<_>>();
|
||||
///
|
||||
/// for handle in handles {
|
||||
/// handle.await.unwrap();
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
F: Send + 'static,
|
||||
Fut: Future + 'static,
|
||||
Fut::Output: Send + 'static,
|
||||
{
|
||||
self.pool
|
||||
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,13 +185,22 @@ impl Debug for LocalPoolHandle {
|
||||
}
|
||||
}
|
||||
|
||||
enum WorkerChoice {
|
||||
LeastBurdened,
|
||||
ByIdx(usize),
|
||||
}
|
||||
|
||||
struct LocalPool {
|
||||
workers: Vec<LocalWorkerHandle>,
|
||||
}
|
||||
|
||||
impl LocalPool {
|
||||
/// Spawn a `?Send` future onto a worker
|
||||
fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
|
||||
fn spawn_pinned<F, Fut>(
|
||||
&self,
|
||||
create_task: F,
|
||||
worker_choice: WorkerChoice,
|
||||
) -> JoinHandle<Fut::Output>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
F: Send + 'static,
|
||||
@ -93,8 +208,10 @@ impl LocalPool {
|
||||
Fut::Output: Send + 'static,
|
||||
{
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
|
||||
let (worker, job_guard) = self.find_and_incr_least_burdened_worker();
|
||||
let (worker, job_guard) = match worker_choice {
|
||||
WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
|
||||
WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
|
||||
};
|
||||
let worker_spawner = worker.spawner.clone();
|
||||
|
||||
// Spawn a future onto the worker's runtime so we can immediately return
|
||||
@ -206,6 +323,13 @@ impl LocalPool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
|
||||
let worker = &self.workers[idx];
|
||||
worker.task_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
(worker, JobCountGuard(Arc::clone(&worker.task_count)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Automatically decrements a worker's job count when a job finishes (when
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio_util::task;
|
||||
|
||||
/// Simple test of running a !Send future via spawn_pinned
|
||||
@ -191,3 +192,45 @@ async fn tasks_are_balanced() {
|
||||
// be on separate workers/threads.
|
||||
assert_ne!(thread_id1, thread_id2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn spawn_by_idx() {
|
||||
let pool = task::LocalPoolHandle::new(3);
|
||||
let barrier = Arc::new(Barrier::new(4));
|
||||
let barrier1 = barrier.clone();
|
||||
let barrier2 = barrier.clone();
|
||||
let barrier3 = barrier.clone();
|
||||
|
||||
let handle1 = pool.spawn_pinned_by_idx(
|
||||
|| async move {
|
||||
barrier1.wait().await;
|
||||
std::thread::current().id()
|
||||
},
|
||||
0,
|
||||
);
|
||||
let _ = pool.spawn_pinned_by_idx(
|
||||
|| async move {
|
||||
barrier2.wait().await;
|
||||
std::thread::current().id()
|
||||
},
|
||||
0,
|
||||
);
|
||||
let handle2 = pool.spawn_pinned_by_idx(
|
||||
|| async move {
|
||||
barrier3.wait().await;
|
||||
std::thread::current().id()
|
||||
},
|
||||
1,
|
||||
);
|
||||
|
||||
let loads = pool.get_task_loads_for_each_worker();
|
||||
barrier.wait().await;
|
||||
assert_eq!(loads[0], 2);
|
||||
assert_eq!(loads[1], 1);
|
||||
assert_eq!(loads[2], 0);
|
||||
|
||||
let thread_id1 = handle1.await.unwrap();
|
||||
let thread_id2 = handle2.await.unwrap();
|
||||
|
||||
assert_ne!(thread_id1, thread_id2);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user