chore: try to avoid noalias attributes on intrusive linked list (#3654)

This commit is contained in:
Alice Ryhl 2021-03-29 22:38:29 +02:00 committed by GitHub
parent 1a80d6eee5
commit fee76ea7d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,10 +6,11 @@
//! structure's APIs are `unsafe` as they require the caller to ensure the
//! specified node is actually contained by the list.
use core::cell::UnsafeCell;
use core::fmt;
use core::marker::PhantomData;
use core::marker::{PhantomData, PhantomPinned};
use core::mem::ManuallyDrop;
use core::ptr::NonNull;
use core::ptr::{self, NonNull};
/// An intrusive linked list.
///
@ -60,11 +61,40 @@ pub(crate) unsafe trait Link {
/// Previous / next pointers
pub(crate) struct Pointers<T> {
inner: UnsafeCell<PointersInner<T>>,
}
/// We do not want the compiler to put the `noalias` attribute on mutable
/// references to this type, so the type has been made `!Unpin` with a
/// `PhantomPinned` field.
///
/// Additionally, we never access the `prev` or `next` fields directly, as any
/// such access would implicitly involve the creation of a reference to the
/// field, which we want to avoid since the fields are not `!Unpin`, and would
/// hence be given the `noalias` attribute if we were to do such an access.
/// As an alternative to accessing the fields directly, the `Pointers` type
/// provides getters and setters for the two fields, and those are implemented
/// using raw pointer casts and offsets, which is valid since the struct is
/// #[repr(C)].
///
/// See this link for more information:
/// https://github.com/rust-lang/rust/pull/82834
#[repr(C)]
struct PointersInner<T> {
/// The previous node in the list. null if there is no previous node.
///
/// This field is accessed through pointer manipulation, so it is not dead code.
#[allow(dead_code)]
prev: Option<NonNull<T>>,
/// The next node in the list. null if there is no previous node.
///
/// This field is accessed through pointer manipulation, so it is not dead code.
#[allow(dead_code)]
next: Option<NonNull<T>>,
/// This type is !Unpin due to the heuristic from:
/// https://github.com/rust-lang/rust/pull/82834
_pin: PhantomPinned,
}
unsafe impl<T: Send> Send for Pointers<T> {}
@ -91,11 +121,11 @@ impl<L: Link> LinkedList<L, L::Target> {
let ptr = L::as_raw(&*val);
assert_ne!(self.head, Some(ptr));
unsafe {
L::pointers(ptr).as_mut().next = self.head;
L::pointers(ptr).as_mut().prev = None;
L::pointers(ptr).as_mut().set_next(self.head);
L::pointers(ptr).as_mut().set_prev(None);
if let Some(head) = self.head {
L::pointers(head).as_mut().prev = Some(ptr);
L::pointers(head).as_mut().set_prev(Some(ptr));
}
self.head = Some(ptr);
@ -111,16 +141,16 @@ impl<L: Link> LinkedList<L, L::Target> {
pub(crate) fn pop_back(&mut self) -> Option<L::Handle> {
unsafe {
let last = self.tail?;
self.tail = L::pointers(last).as_ref().prev;
self.tail = L::pointers(last).as_ref().get_prev();
if let Some(prev) = L::pointers(last).as_ref().prev {
L::pointers(prev).as_mut().next = None;
if let Some(prev) = L::pointers(last).as_ref().get_prev() {
L::pointers(prev).as_mut().set_next(None);
} else {
self.head = None
}
L::pointers(last).as_mut().prev = None;
L::pointers(last).as_mut().next = None;
L::pointers(last).as_mut().set_prev(None);
L::pointers(last).as_mut().set_next(None);
Some(L::from_raw(last))
}
@ -143,31 +173,35 @@ impl<L: Link> LinkedList<L, L::Target> {
/// The caller **must** ensure that `node` is currently contained by
/// `self` or not contained by any other list.
pub(crate) unsafe fn remove(&mut self, node: NonNull<L::Target>) -> Option<L::Handle> {
if let Some(prev) = L::pointers(node).as_ref().prev {
debug_assert_eq!(L::pointers(prev).as_ref().next, Some(node));
L::pointers(prev).as_mut().next = L::pointers(node).as_ref().next;
if let Some(prev) = L::pointers(node).as_ref().get_prev() {
debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node));
L::pointers(prev)
.as_mut()
.set_next(L::pointers(node).as_ref().get_next());
} else {
if self.head != Some(node) {
return None;
}
self.head = L::pointers(node).as_ref().next;
self.head = L::pointers(node).as_ref().get_next();
}
if let Some(next) = L::pointers(node).as_ref().next {
debug_assert_eq!(L::pointers(next).as_ref().prev, Some(node));
L::pointers(next).as_mut().prev = L::pointers(node).as_ref().prev;
if let Some(next) = L::pointers(node).as_ref().get_next() {
debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node));
L::pointers(next)
.as_mut()
.set_prev(L::pointers(node).as_ref().get_prev());
} else {
// This might be the last item in the list
if self.tail != Some(node) {
return None;
}
self.tail = L::pointers(node).as_ref().prev;
self.tail = L::pointers(node).as_ref().get_prev();
}
L::pointers(node).as_mut().next = None;
L::pointers(node).as_mut().prev = None;
L::pointers(node).as_mut().set_next(None);
L::pointers(node).as_mut().set_prev(None);
Some(L::from_raw(node))
}
@ -224,7 +258,7 @@ cfg_rt_multi_thread! {
fn next(&mut self) -> Option<&'a T::Target> {
let curr = self.curr?;
// safety: the pointer references data contained by the list
self.curr = unsafe { T::pointers(curr).as_ref() }.next;
self.curr = unsafe { T::pointers(curr).as_ref() }.get_next();
// safety: the value is still owned by the linked list.
Some(unsafe { &*curr.as_ptr() })
@ -265,7 +299,7 @@ cfg_io_readiness! {
fn next(&mut self) -> Option<Self::Item> {
while let Some(curr) = self.curr {
// safety: the pointer references data contained by the list
self.curr = unsafe { T::pointers(curr).as_ref() }.next;
self.curr = unsafe { T::pointers(curr).as_ref() }.get_next();
// safety: the value is still owned by the linked list.
if (self.filter)(unsafe { &mut *curr.as_ptr() }) {
@ -284,17 +318,58 @@ impl<T> Pointers<T> {
/// Create a new set of empty pointers
pub(crate) fn new() -> Pointers<T> {
Pointers {
prev: None,
next: None,
inner: UnsafeCell::new(PointersInner {
prev: None,
next: None,
_pin: PhantomPinned,
}),
}
}
fn get_prev(&self) -> Option<NonNull<T>> {
// SAFETY: prev is the first field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
let prev = inner as *const Option<NonNull<T>>;
ptr::read(prev)
}
}
fn get_next(&self) -> Option<NonNull<T>> {
// SAFETY: next is the second field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
let prev = inner as *const Option<NonNull<T>>;
let next = prev.add(1);
ptr::read(next)
}
}
fn set_prev(&mut self, value: Option<NonNull<T>>) {
// SAFETY: prev is the first field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
let prev = inner as *mut Option<NonNull<T>>;
ptr::write(prev, value);
}
}
fn set_next(&mut self, value: Option<NonNull<T>>) {
// SAFETY: next is the second field in PointersInner, which is #[repr(C)].
unsafe {
let inner = self.inner.get();
let prev = inner as *mut Option<NonNull<T>>;
let next = prev.add(1);
ptr::write(next, value);
}
}
}
impl<T> fmt::Debug for Pointers<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let prev = self.get_prev();
let next = self.get_next();
f.debug_struct("Pointers")
.field("prev", &self.prev)
.field("next", &self.next)
.field("prev", &prev)
.field("next", &next)
.finish()
}
}
@ -321,7 +396,7 @@ mod tests {
}
unsafe fn from_raw(ptr: NonNull<Entry>) -> Pin<&'a Entry> {
Pin::new(&*ptr.as_ptr())
Pin::new_unchecked(&*ptr.as_ptr())
}
unsafe fn pointers(mut target: NonNull<Entry>) -> NonNull<Pointers<Entry>> {
@ -361,8 +436,8 @@ mod tests {
macro_rules! assert_clean {
($e:ident) => {{
assert!($e.pointers.next.is_none());
assert!($e.pointers.prev.is_none());
assert!($e.pointers.get_next().is_none());
assert!($e.pointers.get_prev().is_none());
}};
}
@ -460,8 +535,8 @@ mod tests {
assert_clean!(a);
assert_ptr_eq!(b, list.head);
assert_ptr_eq!(c, b.pointers.next);
assert_ptr_eq!(b, c.pointers.prev);
assert_ptr_eq!(c, b.pointers.get_next());
assert_ptr_eq!(b, c.pointers.get_prev());
let items = collect_list(&mut list);
assert_eq!([31, 7].to_vec(), items);
@ -476,8 +551,8 @@ mod tests {
assert!(list.remove(ptr(&b)).is_some());
assert_clean!(b);
assert_ptr_eq!(c, a.pointers.next);
assert_ptr_eq!(a, c.pointers.prev);
assert_ptr_eq!(c, a.pointers.get_next());
assert_ptr_eq!(a, c.pointers.get_prev());
let items = collect_list(&mut list);
assert_eq!([31, 5].to_vec(), items);
@ -493,7 +568,7 @@ mod tests {
assert!(list.remove(ptr(&c)).is_some());
assert_clean!(c);
assert!(b.pointers.next.is_none());
assert!(b.pointers.get_next().is_none());
assert_ptr_eq!(b, list.tail);
let items = collect_list(&mut list);
@ -516,8 +591,8 @@ mod tests {
assert_ptr_eq!(b, list.head);
assert_ptr_eq!(b, list.tail);
assert!(b.pointers.next.is_none());
assert!(b.pointers.prev.is_none());
assert!(b.pointers.get_next().is_none());
assert!(b.pointers.get_prev().is_none());
let items = collect_list(&mut list);
assert_eq!([7].to_vec(), items);
@ -536,8 +611,8 @@ mod tests {
assert_ptr_eq!(a, list.head);
assert_ptr_eq!(a, list.tail);
assert!(a.pointers.next.is_none());
assert!(a.pointers.prev.is_none());
assert!(a.pointers.get_next().is_none());
assert!(a.pointers.get_prev().is_none());
let items = collect_list(&mut list);
assert_eq!([5].to_vec(), items);