rt: fix bug in work-stealing queue (#2387)

Fixes a couple bugs in the work-stealing queue introduced as
part of #2315. First, the cursor needs to be able to represent more
values than the size of the buffer. This is to be able to track if
`tail` is ahead of `head` or if they are identical. This bug resulted in
the "overflow" path being taken before the buffer was full.

The second bug can happen when a queue is being stolen from concurrently
with stealing into. In this case, it is possible for buffer slots to be
overwritten before they are released by the stealer. This is harder to
happen in practice due to the first bug preventing the queue from
filling up 100%, but could still happen. It triggered an assertion in
`steal_into`. This bug slipped through due to a bug in loom not
correctly catching the case. The loom bug is fixed as part of
tokio-rs/loom#119.

Fixes: #2382
This commit is contained in:
Carl Lerche 2020-04-09 11:35:16 -07:00 committed by GitHub
parent de8326a5a4
commit 58ba45a38c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 387 additions and 44 deletions

View File

@ -129,7 +129,7 @@ tempfile = "3.1.0"
# loom is currently not compiling on windows.
# See: https://github.com/Xudong-Huang/generator-rs/issues/19
[target.'cfg(not(windows))'.dev-dependencies]
loom = { version = "0.3.0", features = ["futures", "checkpoint"] }
loom = { version = "0.3.1", features = ["futures", "checkpoint"] }
[package.metadata.docs.rs]
all-features = true

View File

@ -0,0 +1,44 @@
use std::cell::UnsafeCell;
use std::fmt;
use std::ops::Deref;
/// `AtomicU16` providing an additional `load_unsync` function.
pub(crate) struct AtomicU16 {
inner: UnsafeCell<std::sync::atomic::AtomicU16>,
}
unsafe impl Send for AtomicU16 {}
unsafe impl Sync for AtomicU16 {}
impl AtomicU16 {
pub(crate) fn new(val: u16) -> AtomicU16 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU16::new(val));
AtomicU16 { inner }
}
/// Performs an unsynchronized load.
///
/// # Safety
///
/// All mutations must have happened before the unsynchronized load.
/// Additionally, there must be no concurrent mutations.
pub(crate) unsafe fn unsync_load(&self) -> u16 {
*(*self.inner.get()).get_mut()
}
}
impl Deref for AtomicU16 {
type Target = std::sync::atomic::AtomicU16;
fn deref(&self) -> &Self::Target {
// safety: it is always safe to access `&self` fns on the inner value as
// we never perform unsafe mutations.
unsafe { &*self.inner.get() }
}
}
impl fmt::Debug for AtomicU16 {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(fmt)
}
}

View File

@ -15,16 +15,6 @@ impl AtomicU32 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU32::new(val));
AtomicU32 { inner }
}
/// Performs an unsynchronized load.
///
/// # Safety
///
/// All mutations must have happened before the unsynchronized load.
/// Additionally, there must be no concurrent mutations.
pub(crate) unsafe fn unsync_load(&self) -> u32 {
*(*self.inner.get()).get_mut()
}
}
impl Deref for AtomicU32 {

View File

@ -15,16 +15,6 @@ impl AtomicU8 {
let inner = UnsafeCell::new(std::sync::atomic::AtomicU8::new(val));
AtomicU8 { inner }
}
/// Performs an unsynchronized load.
///
/// # Safety
///
/// All mutations must have happened before the unsynchronized load.
/// Additionally, there must be no concurrent mutations.
pub(crate) unsafe fn unsync_load(&self) -> u8 {
*(*self.inner.get()).get_mut()
}
}
impl Deref for AtomicU8 {

View File

@ -1,6 +1,8 @@
#![cfg_attr(any(not(feature = "full"), loom), allow(unused_imports, dead_code))]
mod atomic_ptr;
mod atomic_u16;
mod atomic_u32;
mod atomic_u64;
mod atomic_u8;
mod atomic_usize;
@ -60,11 +62,12 @@ pub(crate) mod sync {
pub(crate) mod atomic {
pub(crate) use crate::loom::std::atomic_ptr::AtomicPtr;
pub(crate) use crate::loom::std::atomic_u16::AtomicU16;
pub(crate) use crate::loom::std::atomic_u32::AtomicU32;
pub(crate) use crate::loom::std::atomic_u64::AtomicU64;
pub(crate) use crate::loom::std::atomic_u8::AtomicU8;
pub(crate) use crate::loom::std::atomic_usize::AtomicUsize;
pub(crate) use std::sync::atomic::AtomicU16;
pub(crate) use std::sync::atomic::{spin_loop_hint, AtomicBool};
}
}

View File

@ -1,7 +1,7 @@
//! Run-queue structures to support a work-stealing scheduler
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{AtomicU16, AtomicU8, AtomicUsize};
use crate::loom::sync::atomic::{AtomicU16, AtomicU32, AtomicUsize};
use crate::loom::sync::{Arc, Mutex};
use crate::runtime::task;
@ -34,17 +34,19 @@ pub(super) struct Inject<T: 'static> {
pub(super) struct Inner<T: 'static> {
/// Concurrently updated by many threads.
///
/// Contains two `u8` values. The LSB byte is the "real" head of the queue.
/// The `u8` in the MSB is set by a stealer in process of stealing values.
/// It represents the first value being stolen in the batch.
/// Contains two `u16` values. The LSB byte is the "real" head of the queue.
/// The `u16` in the MSB is set by a stealer in process of stealing values.
/// It represents the first value being stolen in the batch. `u16` is used
/// in order to distinguish between `head == tail` and `head == tail -
/// capacity`.
///
/// When both `u8` values are the same, there is no active stealer.
/// When both `u16` values are the same, there is no active stealer.
///
/// Tracking an in-progress stealer prevents a wrapping scenario.
head: AtomicU16,
head: AtomicU32,
/// Only updated by producer thread but read by many threads.
tail: AtomicU8,
tail: AtomicU16,
/// Elements
buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>]>,
@ -86,8 +88,8 @@ pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
}
let inner = Arc::new(Inner {
head: AtomicU16::new(0),
tail: AtomicU8::new(0),
head: AtomicU32::new(0),
tail: AtomicU16::new(0),
buffer: buffer.into(),
});
@ -115,7 +117,7 @@ impl<T> Local<T> {
// safety: this is the **only** thread that updates this cell.
let tail = unsafe { self.inner.tail.unsync_load() };
if steal as usize & MASK != tail.wrapping_add(1) as usize & MASK {
if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 {
// There is capacity for the task
break tail;
} else if steal != real {
@ -165,16 +167,16 @@ impl<T> Local<T> {
fn push_overflow(
&mut self,
task: task::Notified<T>,
head: u8,
tail: u8,
head: u16,
tail: u16,
inject: &Inject<T>,
) -> Result<(), task::Notified<T>> {
const BATCH_LEN: usize = LOCAL_QUEUE_CAPACITY / 2 + 1;
let n = (LOCAL_QUEUE_CAPACITY / 2) as u8;
let n = (LOCAL_QUEUE_CAPACITY / 2) as u16;
assert_eq!(
tail.wrapping_sub(head) as usize,
LOCAL_QUEUE_CAPACITY - 1,
LOCAL_QUEUE_CAPACITY,
"queue is not full; tail = {}; head = {}",
tail,
head
@ -261,10 +263,12 @@ impl<T> Local<T> {
let next_real = real.wrapping_add(1);
// Only update `steal` component if it differs from `real`.
// If `steal == real` there are no concurrent stealers. Both `steal`
// and `real` are updated.
let next = if steal == real {
pack(next_real, next_real)
} else {
assert_ne!(steal, next_real);
pack(steal, next_real)
};
@ -295,6 +299,17 @@ impl<T> Steal<T> {
// holds a mutable reference.
let dst_tail = unsafe { dst.inner.tail.unsync_load() };
// To the caller, `dst` may **look** empty but still have values
// contained in the buffer. If another thread is concurrently stealing
// from `dst` there may not be enough capacity to steal.
let (steal, _) = unpack(dst.inner.head.load(Acquire));
if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 {
// we *could* try to steal less here, but for simplicity, we're just
// going to abort.
return None;
}
// Steal the tasks into `dst`'s buffer. This does not yet expose the
// tasks in `dst`.
let mut n = self.steal_into2(dst, dst_tail);
@ -327,7 +342,7 @@ impl<T> Steal<T> {
// Steal tasks from `self`, placing them into `dst`. Returns the number of
// tasks that were stolen.
fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u8) -> u8 {
fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 {
let mut prev_packed = self.0.head.load(Acquire);
let mut next_packed;
@ -352,6 +367,7 @@ impl<T> Steal<T> {
// Update the real head index to acquire the tasks.
let steal_to = src_head_real.wrapping_add(n);
assert_ne!(src_head_steal, steal_to);
next_packed = pack(src_head_steal, steal_to);
// Claim all those tasks. This is done by incrementing the "real"
@ -368,6 +384,8 @@ impl<T> Steal<T> {
}
};
assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n);
let (first, _) = unpack(next_packed);
// Take all the tasks
@ -594,16 +612,16 @@ fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) {
/// Split the head value into the real head and the index a stealer is working
/// on.
fn unpack(n: u16) -> (u8, u8) {
let real = n & u8::max_value() as u16;
let steal = n >> 8;
fn unpack(n: u32) -> (u16, u16) {
let real = n & u16::max_value() as u32;
let steal = n >> 16;
(steal as u8, real as u8)
(steal as u16, real as u16)
}
/// Join the two head values
fn pack(steal: u8, real: u8) -> u16 {
(real as u16) | ((steal as u16) << 8)
fn pack(steal: u16, real: u16) -> u32 {
(real as u32) | ((steal as u32) << 16)
}
#[test]

View File

@ -3,6 +3,110 @@ use crate::runtime::task::{self, Schedule, Task};
use loom::thread;
#[test]
fn basic() {
loom::model(|| {
let (steal, mut local) = queue::local();
let inject = queue::Inject::new();
let th = thread::spawn(move || {
let (_, mut local) = queue::local();
let mut n = 0;
for _ in 0..3 {
if steal.steal_into(&mut local).is_some() {
n += 1;
}
while local.pop().is_some() {
n += 1;
}
}
n
});
let mut n = 0;
for _ in 0..2 {
for _ in 0..2 {
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
}
if local.pop().is_some() {
n += 1;
}
// Push another task
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
while local.pop().is_some() {
n += 1;
}
}
while inject.pop().is_some() {
n += 1;
}
n += th.join().unwrap();
assert_eq!(6, n);
});
}
#[test]
fn steal_overflow() {
loom::model(|| {
let (steal, mut local) = queue::local();
let inject = queue::Inject::new();
let th = thread::spawn(move || {
let (_, mut local) = queue::local();
let mut n = 0;
if steal.steal_into(&mut local).is_some() {
n += 1;
}
while local.pop().is_some() {
n += 1;
}
n
});
let mut n = 0;
// push a task, pop a task
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
if local.pop().is_some() {
n += 1;
}
for _ in 0..6 {
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
}
n += th.join().unwrap();
while local.pop().is_some() {
n += 1;
}
while inject.pop().is_some() {
n += 1;
}
assert_eq!(7, n);
});
}
#[test]
fn multi_stealer() {
const NUM_TASKS: usize = 5;
@ -57,6 +161,43 @@ fn multi_stealer() {
});
}
#[test]
fn chained_steal() {
loom::model(|| {
let (s1, mut l1) = queue::local();
let (s2, mut l2) = queue::local();
let inject = queue::Inject::new();
// Load up some tasks
for _ in 0..4 {
let (task, _) = task::joinable::<_, Runtime>(async {});
l1.push_back(task, &inject);
let (task, _) = task::joinable::<_, Runtime>(async {});
l2.push_back(task, &inject);
}
// Spawn a task to steal from **our** queue
let th = thread::spawn(move || {
let (_, mut local) = queue::local();
s1.steal_into(&mut local);
while local.pop().is_some() {}
});
// Drain our tasks, then attempt to steal
while l1.pop().is_some() {}
s2.steal_into(&mut l1);
th.join().unwrap();
while l1.pop().is_some() {}
while l2.pop().is_some() {}
while inject.pop().is_some() {}
});
}
struct Runtime;
impl Schedule for Runtime {

View File

@ -1,6 +1,47 @@
use crate::runtime::queue;
use crate::runtime::task::{self, Schedule, Task};
use std::thread;
use std::time::Duration;
#[test]
fn fits_256() {
let (_, mut local) = queue::local();
let inject = queue::Inject::new();
for _ in 0..256 {
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
}
assert!(inject.pop().is_none());
while local.pop().is_some() {}
}
#[test]
fn overflow() {
let (_, mut local) = queue::local();
let inject = queue::Inject::new();
for _ in 0..257 {
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
}
let mut n = 0;
while inject.pop().is_some() {
n += 1;
}
while local.pop().is_some() {
n += 1;
}
assert_eq!(n, 257);
}
#[test]
fn steal_batch() {
let (steal1, mut local1) = queue::local();
@ -27,6 +68,122 @@ fn steal_batch() {
assert!(local1.pop().is_none());
}
#[test]
fn stress1() {
const NUM_ITER: usize = 1;
const NUM_STEAL: usize = 1_000;
const NUM_LOCAL: usize = 1_000;
const NUM_PUSH: usize = 500;
const NUM_POP: usize = 250;
for _ in 0..NUM_ITER {
let (steal, mut local) = queue::local();
let inject = queue::Inject::new();
let th = thread::spawn(move || {
let (_, mut local) = queue::local();
let mut n = 0;
for _ in 0..NUM_STEAL {
if steal.steal_into(&mut local).is_some() {
n += 1;
}
while local.pop().is_some() {
n += 1;
}
thread::yield_now();
}
n
});
let mut n = 0;
for _ in 0..NUM_LOCAL {
for _ in 0..NUM_PUSH {
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
}
for _ in 0..NUM_POP {
if local.pop().is_some() {
n += 1;
} else {
break;
}
}
}
while inject.pop().is_some() {
n += 1;
}
n += th.join().unwrap();
assert_eq!(n, NUM_LOCAL * NUM_PUSH);
}
}
#[test]
fn stress2() {
const NUM_ITER: usize = 1;
const NUM_TASKS: usize = 1_000_000;
const NUM_STEAL: usize = 1_000;
for _ in 0..NUM_ITER {
let (steal, mut local) = queue::local();
let inject = queue::Inject::new();
let th = thread::spawn(move || {
let (_, mut local) = queue::local();
let mut n = 0;
for _ in 0..NUM_STEAL {
if steal.steal_into(&mut local).is_some() {
n += 1;
}
while local.pop().is_some() {
n += 1;
}
thread::sleep(Duration::from_micros(10));
}
n
});
let mut num_pop = 0;
for i in 0..NUM_TASKS {
let (task, _) = task::joinable::<_, Runtime>(async {});
local.push_back(task, &inject);
if i % 128 == 0 && local.pop().is_some() {
num_pop += 1;
}
while inject.pop().is_some() {
num_pop += 1;
}
}
num_pop += th.join().unwrap();
while local.pop().is_some() {
num_pop += 1;
}
while inject.pop().is_some() {
num_pop += 1;
}
assert_eq!(num_pop, NUM_TASKS);
}
}
struct Runtime;
impl Schedule for Runtime {