Skip to main content

miden_utils_sync/
rw_lock.rs

1#[cfg(not(loom))]
2use core::{
3    hint,
4    sync::atomic::{AtomicUsize, Ordering},
5};
6
7use lock_api::RawRwLock;
8#[cfg(loom)]
9use loom::{
10    hint,
11    sync::atomic::{AtomicUsize, Ordering},
12};
13
14/// An implementation of a reader-writer lock, based on a spinlock primitive, no-std compatible
15///
16/// See [lock_api::RwLock] for usage.
17pub type RwLock<T> = lock_api::RwLock<Spinlock, T>;
18
19/// See [lock_api::RwLockReadGuard] for usage.
20pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, Spinlock, T>;
21
22/// See [lock_api::RwLockWriteGuard] for usage.
23pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, Spinlock, T>;
24
25/// The underlying raw reader-writer primitive that implements [lock_api::RawRwLock]
26///
27/// This is fundamentally a spinlock, in that blocking operations on the lock will spin until
28/// they succeed in acquiring/releasing the lock.
29///
30/// To achieve the ability to share the underlying data with multiple readers, or hold
31/// exclusive access for one writer, the lock state is based on a "locked" count, where shared
32/// access increments the count by an even number, and acquiring exclusive access relies on the
33/// use of the lowest order bit to stop further shared acquisition, and indicate that the lock
34/// is exclusively held (the difference between the two is irrelevant from the perspective of
35/// a thread attempting to acquire the lock, but internally the state uses `usize::MAX` as the
36/// "exclusively locked" sentinel).
37///
38/// This mechanism gets us the following:
39///
40/// * Whether the lock has been acquired (shared or exclusive)
41/// * Whether the lock is being exclusively acquired
42/// * How many times the lock has been acquired
43/// * Whether the acquisition(s) are exclusive or shared
44///
45/// Further implementation details, such as how we manage draining readers once an attempt to
46/// exclusively acquire the lock occurs, are described below.
47///
48/// NOTE: This is a simple implementation, meant for use in no-std environments; there are much
49/// more robust/performant implementations available when OS primitives can be used.
50pub struct Spinlock {
51    /// The state of the lock, primarily representing the acquisition count, but relying on
52    /// the distinction between even and odd values to indicate whether or not exclusive access
53    /// is being acquired.
54    state: AtomicUsize,
55    /// A counter used to wake a parked writer once the last shared lock is released during
56    /// acquisition of an exclusive lock. The actual count is not acutally important, and
57    /// simply wraps around on overflow, but what is important is that when the value changes,
58    /// the writer will wake and resume attempting to acquire the exclusive lock.
59    writer_wake_counter: AtomicUsize,
60}
61
62impl Default for Spinlock {
63    #[inline(always)]
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl Spinlock {
70    #[cfg(not(loom))]
71    pub const fn new() -> Self {
72        Self {
73            state: AtomicUsize::new(0),
74            writer_wake_counter: AtomicUsize::new(0),
75        }
76    }
77
78    #[cfg(loom)]
79    pub fn new() -> Self {
80        Self {
81            state: AtomicUsize::new(0),
82            writer_wake_counter: AtomicUsize::new(0),
83        }
84    }
85}
86
87unsafe impl RawRwLock for Spinlock {
88    #[cfg(loom)]
89    const INIT: Spinlock = unimplemented!();
90
91    #[cfg(not(loom))]
92    // This is intentional on the part of the [RawRwLock] API, basically a hack to provide
93    // initial values as static items.
94    const INIT: Spinlock = Spinlock::new();
95
96    type GuardMarker = lock_api::GuardSend;
97
98    /// The operation invoked when calling `RwLock::read`, blocks the caller until acquired
99    fn lock_shared(&self) {
100        let mut s = self.state.load(Ordering::Relaxed);
101        loop {
102            // If the exclusive bit is unset, attempt to acquire a read lock
103            if s & 1 == 0 {
104                match self.state.compare_exchange_weak(
105                    s,
106                    s + 2,
107                    Ordering::Acquire,
108                    Ordering::Relaxed,
109                ) {
110                    Ok(_) => return,
111                    // Someone else beat us to the punch and acquired a lock
112                    Err(e) => s = e,
113                }
114            }
115            // If an exclusive lock is held/being acquired, loop until the lock state changes
116            // at which point, try to acquire the lock again
117            if s & 1 == 1 {
118                loop {
119                    let next = self.state.load(Ordering::Relaxed);
120                    if s == next {
121                        hint::spin_loop();
122                    } else {
123                        s = next;
124                        break;
125                    }
126                }
127            }
128        }
129    }
130
131    /// The operation invoked when calling `RwLock::try_read`, returns whether or not the
132    /// lock was acquired
133    fn try_lock_shared(&self) -> bool {
134        let s = self.state.load(Ordering::Relaxed);
135        if s & 1 == 0 {
136            self.state
137                .compare_exchange_weak(s, s + 2, Ordering::Acquire, Ordering::Relaxed)
138                .is_ok()
139        } else {
140            false
141        }
142    }
143
144    /// The operation invoked when dropping a `RwLockReadGuard`
145    unsafe fn unlock_shared(&self) {
146        if self.state.fetch_sub(2, Ordering::Release) == 3 {
147            // The lock is being exclusively acquired, and we're the last shared acquisition
148            // to be released, so wake the writer by incrementing the wake counter
149            self.writer_wake_counter.fetch_add(1, Ordering::Release);
150        }
151    }
152
153    /// The operation invoked when calling `RwLock::write`, blocks the caller until acquired
154    fn lock_exclusive(&self) {
155        let mut s = self.state.load(Ordering::Relaxed);
156        loop {
157            // Attempt to acquire the lock immediately, or complete acquistion of the lock
158            // if we're continuing the loop after acquiring the exclusive bit. If another
159            // thread acquired it first, we race to be the first thread to acquire it once
160            // released, by busy looping here.
161            if s <= 1 {
162                match self.state.compare_exchange(
163                    s,
164                    usize::MAX,
165                    Ordering::Acquire,
166                    Ordering::Relaxed,
167                ) {
168                    Ok(_) => return,
169                    Err(e) => {
170                        s = e;
171                        hint::spin_loop();
172                        continue;
173                    },
174                }
175            }
176
177            // Only shared locks have been acquired, attempt to acquire the exclusive bit,
178            // which will prevent further shared locks from being acquired. It does not
179            // in and of itself grant us exclusive access however.
180            if s & 1 == 0
181                && let Err(e) =
182                    self.state.compare_exchange(s, s + 1, Ordering::Relaxed, Ordering::Relaxed)
183            {
184                // The lock state has changed before we could acquire the exclusive bit,
185                // update our view of the lock state and try again
186                s = e;
187                continue;
188            }
189
190            // We've acquired the exclusive bit, now we need to busy wait until all shared
191            // acquisitions are released.
192            let w = self.writer_wake_counter.load(Ordering::Acquire);
193            s = self.state.load(Ordering::Relaxed);
194
195            // "Park" the thread here (by busy looping), until the release of the last shared
196            // lock, which is communicated to us by it incrementing the wake counter.
197            if s >= 2 {
198                while self.writer_wake_counter.load(Ordering::Acquire) == w {
199                    hint::spin_loop();
200                }
201                s = self.state.load(Ordering::Relaxed);
202            }
203
204            // All shared locks have been released, go back to the top and try to complete
205            // acquisition of exclusive access.
206        }
207    }
208
209    /// The operation invoked when calling `RwLock::try_write`, returns whether or not the
210    /// lock was acquired
211    fn try_lock_exclusive(&self) -> bool {
212        let s = self.state.load(Ordering::Relaxed);
213        if s <= 1 {
214            self.state
215                .compare_exchange(s, usize::MAX, Ordering::Acquire, Ordering::Relaxed)
216                .is_ok()
217        } else {
218            false
219        }
220    }
221
222    /// The operation invoked when dropping a `RwLockWriteGuard`
223    unsafe fn unlock_exclusive(&self) {
224        // Infallible, as we hold an exclusive lock
225        //
226        // Note the use of `Release` ordering here, which ensures any loads of the lock state
227        // by other threads, are ordered after this store.
228        self.state.store(0, Ordering::Release);
229        // This fetch_add isn't important for signaling purposes, however it serves a key
230        // purpose, in that it imposes a memory ordering on any loads of this field that
231        // have an `Acquire` ordering, i.e. they will read the value stored here. Without
232        // a `Release` store, loads/stores of this field could be reordered relative to
233        // each other.
234        self.writer_wake_counter.fetch_add(1, Ordering::Release);
235    }
236}
237
238#[cfg(all(loom, test))]
239mod test {
240    use alloc::vec::Vec;
241
242    use loom::{model::Builder, sync::Arc};
243
244    use super::{RwLock, Spinlock};
245
246    #[test]
247    fn test_rwlock_loom() {
248        let mut builder = Builder::default();
249        builder.max_duration = Some(std::time::Duration::from_secs(60));
250        builder.log = true;
251        builder.check(|| {
252            let raw_rwlock = Spinlock::new();
253            let n = Arc::new(RwLock::from_raw(raw_rwlock, 0usize));
254            let mut readers = Vec::new();
255            let mut writers = Vec::new();
256
257            let num_readers = 2;
258            let num_writers = 2;
259            let num_iterations = 2;
260
261            // Readers should never observe a non-zero value
262            for _ in 0..num_readers {
263                let n0 = n.clone();
264                let t = loom::thread::spawn(move || {
265                    for _ in 0..num_iterations {
266                        let guard = n0.read();
267                        assert_eq!(*guard, 0);
268                    }
269                });
270
271                readers.push(t);
272            }
273
274            // Writers should never observe a non-zero value once they've
275            // acquired the lock, and should never observe a value > 1
276            // while holding the lock
277            for _ in 0..num_writers {
278                let n0 = n.clone();
279                let t = loom::thread::spawn(move || {
280                    for _ in 0..num_iterations {
281                        let mut guard = n0.write();
282                        assert_eq!(*guard, 0);
283                        *guard += 1;
284                        assert_eq!(*guard, 1);
285                        *guard -= 1;
286                        assert_eq!(*guard, 0);
287                    }
288                });
289
290                writers.push(t);
291            }
292
293            for t in readers {
294                t.join().unwrap();
295            }
296
297            for t in writers {
298                t.join().unwrap();
299            }
300        })
301    }
302}