task: improve LocalPoolHandle (#4680)

This commit is contained in:
b-naber 2022-06-17 18:37:00 +02:00 committed by GitHub
parent c98be229ff
commit d8fb721de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 172 additions and 5 deletions

View File

@ -9,7 +9,44 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};
/// A handle to a local pool, used for spawning `!Send` tasks.
/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
///
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
/// in the pool. Consequently you can also use [`tokio::task::spawn_local`] (which will
/// execute on the same thread) inside the Future you supply to the various spawn methods
/// of `LocalPoolHandle`,
///
/// [`tokio::task::LocalSet`]: tokio::task::LocalSet
/// [`tokio::task::spawn_local`]: tokio::task::spawn_local
///
/// # Examples
///
/// ```
/// use std::rc::Rc;
/// use tokio::{self, task };
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main(flavor = "current_thread")]
/// async fn main() {
/// let pool = LocalPoolHandle::new(5);
///
/// let output = pool.spawn_pinned(|| {
/// // `data` is !Send + !Sync
/// let data = Rc::new("local data");
/// let data_clone = data.clone();
///
/// async move {
/// task::spawn_local(async move {
/// println!("{}", data_clone);
/// });
///
/// data.to_string()
/// }
/// }).await.unwrap();
/// println!("output: {}", output);
/// }
/// ```
///
#[derive(Clone)]
pub struct LocalPoolHandle {
pool: Arc<LocalPool>,
@ -33,6 +70,22 @@ impl LocalPoolHandle {
LocalPoolHandle { pool }
}
/// Returns the number of threads of the Pool.
#[inline]
pub fn num_threads(&self) -> usize {
self.pool.workers.len()
}
/// Returns the number of tasks scheduled on each worker. The indices of the
/// worker threads correspond to the indices of the returned `Vec`.
pub fn get_task_loads_for_each_worker(&self) -> Vec<usize> {
self.pool
.workers
.iter()
.map(|worker| worker.task_count.load(Ordering::SeqCst))
.collect::<Vec<_>>()
}
/// Spawn a task onto a worker thread and pin it there so it can't be moved
/// off of the thread. Note that the future is not [`Send`], but the
/// [`FnOnce`] which creates it is.
@ -69,7 +122,60 @@ impl LocalPoolHandle {
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool.spawn_pinned(create_task)
self.pool
.spawn_pinned(create_task, WorkerChoice::LeastBurdened)
}
/// Differs from `spawn_pinned` only in that you can choose a specific worker thread
/// of the pool, whereas `spawn_pinned` chooses the worker with the smallest
/// number of tasks scheduled.
///
/// A worker thread is chosen by index. Indices are 0 based and the largest index
/// is given by `num_threads() - 1`
///
/// # Panics
///
/// This method panics if the index is out of bounds.
///
/// # Examples
///
/// This method can be used to spawn a task on all worker threads of the pool:
///
/// ```
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main]
/// async fn main() {
/// const NUM_WORKERS: usize = 3;
/// let pool = LocalPoolHandle::new(NUM_WORKERS);
/// let handles = (0..pool.num_threads())
/// .map(|worker_idx| {
/// pool.spawn_pinned_by_idx(
/// || {
/// async {
/// "test"
/// }
/// },
/// worker_idx,
/// )
/// })
/// .collect::<Vec<_>>();
///
/// for handle in handles {
/// handle.await.unwrap();
/// }
/// }
/// ```
///
pub fn spawn_pinned_by_idx<F, Fut>(&self, create_task: F, idx: usize) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
{
self.pool
.spawn_pinned(create_task, WorkerChoice::ByIdx(idx))
}
}
@ -79,13 +185,22 @@ impl Debug for LocalPoolHandle {
}
}
enum WorkerChoice {
LeastBurdened,
ByIdx(usize),
}
struct LocalPool {
workers: Vec<LocalWorkerHandle>,
}
impl LocalPool {
/// Spawn a `?Send` future onto a worker
fn spawn_pinned<F, Fut>(&self, create_task: F) -> JoinHandle<Fut::Output>
fn spawn_pinned<F, Fut>(
&self,
create_task: F,
worker_choice: WorkerChoice,
) -> JoinHandle<Fut::Output>
where
F: FnOnce() -> Fut,
F: Send + 'static,
@ -93,8 +208,10 @@ impl LocalPool {
Fut::Output: Send + 'static,
{
let (sender, receiver) = oneshot::channel();
let (worker, job_guard) = self.find_and_incr_least_burdened_worker();
let (worker, job_guard) = match worker_choice {
WorkerChoice::LeastBurdened => self.find_and_incr_least_burdened_worker(),
WorkerChoice::ByIdx(idx) => self.find_worker_by_idx(idx),
};
let worker_spawner = worker.spawner.clone();
// Spawn a future onto the worker's runtime so we can immediately return
@ -206,6 +323,13 @@ impl LocalPool {
}
}
}
fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
let worker = &self.workers[idx];
worker.task_count.fetch_add(1, Ordering::SeqCst);
(worker, JobCountGuard(Arc::clone(&worker.task_count)))
}
}
/// Automatically decrements a worker's job count when a job finishes (when

View File

@ -2,6 +2,7 @@
use std::rc::Rc;
use std::sync::Arc;
use tokio::sync::Barrier;
use tokio_util::task;
/// Simple test of running a !Send future via spawn_pinned
@ -191,3 +192,45 @@ async fn tasks_are_balanced() {
// be on separate workers/threads.
assert_ne!(thread_id1, thread_id2);
}
#[tokio::test]
async fn spawn_by_idx() {
let pool = task::LocalPoolHandle::new(3);
let barrier = Arc::new(Barrier::new(4));
let barrier1 = barrier.clone();
let barrier2 = barrier.clone();
let barrier3 = barrier.clone();
let handle1 = pool.spawn_pinned_by_idx(
|| async move {
barrier1.wait().await;
std::thread::current().id()
},
0,
);
let _ = pool.spawn_pinned_by_idx(
|| async move {
barrier2.wait().await;
std::thread::current().id()
},
0,
);
let handle2 = pool.spawn_pinned_by_idx(
|| async move {
barrier3.wait().await;
std::thread::current().id()
},
1,
);
let loads = pool.get_task_loads_for_each_worker();
barrier.wait().await;
assert_eq!(loads[0], 2);
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 0);
let thread_id1 = handle1.await.unwrap();
let thread_id2 = handle2.await.unwrap();
assert_ne!(thread_id1, thread_id2);
}