sync: fix mark_changed when version overflows (#6017)

This commit is contained in:
Uwe Klotz 2023-09-19 17:44:08 +02:00 committed by GitHub
parent 9d51b76d01
commit ad7f988da3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -359,7 +359,10 @@ mod state {
use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::SeqCst; use crate::loom::sync::atomic::Ordering::SeqCst;
const CLOSED: usize = 1; const CLOSED_BIT: usize = 1;
// Using 2 as the step size preserves the `CLOSED_BIT`.
const STEP_SIZE: usize = 2;
/// The version part of the state. The lowest bit is always zero. /// The version part of the state. The lowest bit is always zero.
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Copy, Clone, Debug, Eq, PartialEq)]
@ -378,31 +381,26 @@ mod state {
pub(super) struct AtomicState(AtomicUsize); pub(super) struct AtomicState(AtomicUsize);
impl Version { impl Version {
/// Get the initial version when creating the channel.
pub(super) fn initial() -> Self {
// The initial version is 1 so that `mark_changed` can decrement by one.
// (The value is 2 due to the closed bit.)
Version(2)
}
/// Decrements the version. /// Decrements the version.
pub(super) fn decrement(&mut self) { pub(super) fn decrement(&mut self) {
// Decrement by two to avoid touching the CLOSED bit. // Using a wrapping decrement here is required to ensure that the
if self.0 >= 2 { // operation is consistent with `std::sync::atomic::AtomicUsize::fetch_add()`
self.0 -= 2; // which wraps on overflow.
} self.0 = self.0.wrapping_sub(STEP_SIZE);
} }
pub(super) const INITIAL: Self = Version(0);
} }
impl StateSnapshot { impl StateSnapshot {
/// Extract the version from the state. /// Extract the version from the state.
pub(super) fn version(self) -> Version { pub(super) fn version(self) -> Version {
Version(self.0 & !CLOSED) Version(self.0 & !CLOSED_BIT)
} }
/// Is the closed bit set? /// Is the closed bit set?
pub(super) fn is_closed(self) -> bool { pub(super) fn is_closed(self) -> bool {
(self.0 & CLOSED) == CLOSED (self.0 & CLOSED_BIT) == CLOSED_BIT
} }
} }
@ -410,7 +408,7 @@ mod state {
/// Create a new `AtomicState` that is not closed and which has the /// Create a new `AtomicState` that is not closed and which has the
/// version set to `Version::initial()`. /// version set to `Version::initial()`.
pub(super) fn new() -> Self { pub(super) fn new() -> Self {
AtomicState(AtomicUsize::new(2)) AtomicState(AtomicUsize::new(Version::INITIAL.0))
} }
/// Load the current value of the state. /// Load the current value of the state.
@ -420,13 +418,12 @@ mod state {
/// Increment the version counter. /// Increment the version counter.
pub(super) fn increment_version(&self) { pub(super) fn increment_version(&self) {
// Increment by two to avoid touching the CLOSED bit. self.0.fetch_add(STEP_SIZE, SeqCst);
self.0.fetch_add(2, SeqCst);
} }
/// Set the closed bit in the state. /// Set the closed bit in the state.
pub(super) fn set_closed(&self) { pub(super) fn set_closed(&self) {
self.0.fetch_or(CLOSED, SeqCst); self.0.fetch_or(CLOSED_BIT, SeqCst);
} }
} }
} }
@ -482,7 +479,7 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let rx = Receiver { let rx = Receiver {
shared, shared,
version: Version::initial(), version: Version::INITIAL,
}; };
(tx, rx) (tx, rx)