diff --git a/src/ring_buffer/spsc.rs b/src/ring_buffer/spsc.rs index 98f18b01..3026c51f 100644 --- a/src/ring_buffer/spsc.rs +++ b/src/ring_buffer/spsc.rs @@ -72,21 +72,37 @@ macro_rules! impl_ { { /// Returns the item in the front of the queue, or `None` if the queue is empty pub fn dequeue(&mut self) -> Option { - let rb = unsafe { self.rb.as_ref() }; + let tail = unsafe { self.rb.as_ref().tail.load_acquire() }; + let head = unsafe { self.rb.as_ref().head.load_relaxed() }; - let n = rb.capacity() + 1; - let buffer: &[T] = unsafe { rb.buffer.as_ref() }; - - let tail = rb.tail.load_acquire(); - let head = rb.head.load_relaxed(); if head != tail { - let item = unsafe { ptr::read(buffer.get_unchecked(usize::from(head))) }; - rb.head.store_release((head + 1) % n); - Some(item) + Some(unsafe { self._dequeue(head) }) } else { None } } + + /// Returns the item in the front of the queue, without checking if it's empty + /// + /// # Unsafety + /// + /// If the queue is empty this is equivalent to calling `mem::uninitialized` + pub unsafe fn dequeue_unchecked(&mut self) -> T { + let head = self.rb.as_ref().head.load_relaxed(); + debug_assert_ne!(head, self.rb.as_ref().tail.load_acquire()); + self._dequeue(head) + } + + unsafe fn _dequeue(&mut self, head: $uxx) -> T { + let rb = self.rb.as_ref(); + + let n = rb.capacity() + 1; + let buffer: &[T] = rb.buffer.as_ref(); + + let item = ptr::read(buffer.get_unchecked(usize::from(head))); + rb.head.store_release((head + 1) % n); + item + } } impl<'a, T, A> Producer<'a, T, A, $uxx> @@ -97,30 +113,52 @@ macro_rules! impl_ { /// /// Returns `BufferFullError` if the queue is full pub fn enqueue(&mut self, item: T) -> Result<(), BufferFullError> { - let rb = unsafe { self.rb.as_mut() }; - - let n = rb.capacity() + 1; - let buffer: &mut [T] = unsafe { rb.buffer.as_mut() }; - - let tail = rb.tail.load_relaxed(); + let n = unsafe { self.rb.as_ref().capacity() + 1 }; + let tail = unsafe { self.rb.as_ref().tail.load_relaxed() }; // NOTE we could replace this `load_acquire` with a `load_relaxed` and this method // would be sound on most architectures but that change would result in UB according // to the C++ memory model, which is what Rust currently uses, so we err on the side // of caution and stick to `load_acquire`. Check issue google#sanitizers#882 for // more details. - let head = rb.head.load_acquire(); + let head = unsafe { self.rb.as_ref().head.load_acquire() }; let next_tail = (tail + 1) % n; if next_tail != head { - // NOTE(ptr::write) the memory slot that we are about to write to is - // uninitialized. We use `ptr::write` to avoid running `T`'s destructor on the - // uninitialized memory - unsafe { ptr::write(buffer.get_unchecked_mut(usize::from(tail)), item) } - rb.tail.store_release(next_tail); + unsafe { self._enqueue(tail, item) }; Ok(()) } else { Err(BufferFullError) } } + + /// Adds an `item` to the end of the queue without checking if it's full + /// + /// **WARNING** If the queue is full this operation will make the queue appear empty to + /// the `Consumer`, thus *leaking* (destructors won't run) all the elements that were in + /// the queue. + pub fn enqueue_unchecked(&mut self, item: T) { + unsafe { + let tail = self.rb.as_ref().tail.load_relaxed(); + debug_assert_ne!( + (tail + 1) % (self.rb.as_ref().capacity() + 1), + self.rb.as_ref().head.load_acquire() + ); + self._enqueue(tail, item); + } + } + + unsafe fn _enqueue(&mut self, tail: $uxx, item: T) { + let rb = self.rb.as_mut(); + + let n = rb.capacity() + 1; + let buffer: &mut [T] = rb.buffer.as_mut(); + + let next_tail = (tail + 1) % n; + // NOTE(ptr::write) the memory slot that we are about to write to is + // uninitialized. We use `ptr::write` to avoid running `T`'s destructor on the + // uninitialized memory + ptr::write(buffer.get_unchecked_mut(usize::from(tail)), item); + rb.tail.store_release(next_tail); + } } }; } diff --git a/tests/tsan.rs b/tests/tsan.rs index 0449e230..e9d09ad0 100644 --- a/tests/tsan.rs +++ b/tests/tsan.rs @@ -118,3 +118,38 @@ fn contention() { assert!(rb.is_empty()); } + +#[test] +fn unchecked() { + const N: usize = 1024; + + let mut rb: RingBuffer = RingBuffer::new(); + + for _ in 0..N { + rb.enqueue(1).unwrap(); + } + + { + let (mut p, mut c) = rb.split(); + + Pool::new(2).scoped(move |scope| { + scope.execute(move || { + for _ in 0..N { + p.enqueue_unchecked(2); + } + }); + + scope.execute(move || { + let mut sum: usize = 0; + + for _ in 0..N { + sum = sum.wrapping_add(usize::from(unsafe { c.dequeue_unchecked() })); + } + + assert_eq!(sum, N); + }); + }); + } + + assert_eq!(rb.len(), N); +}