cs431_homework/hazard_pointer/
retire.rs

1use core::marker::PhantomData;
2#[cfg(not(feature = "check-loom"))]
3use core::sync::atomic::{Ordering, fence};
4
5#[cfg(feature = "check-loom")]
6use loom::sync::atomic::{Ordering, fence};
7
8use super::{HAZARDS, HazardBag};
9
10type Retired = (*mut (), unsafe fn(*mut ()));
11
12/// Thread-local list of retired pointers.
13#[derive(Debug)]
14pub struct RetiredSet<'s> {
15    hazards: &'s HazardBag,
16    /// The first element of the pair is the machine representation of the pointer and the second
17    /// is the function pointer to `free::<T>` where `T` is the type of the object.
18    inner: Vec<Retired>,
19    _marker: PhantomData<*const ()>, // !Send + !Sync
20}
21
22impl<'s> RetiredSet<'s> {
23    /// The max length of retired pointer list. `collect` is triggered when `THRESHOLD` pointers
24    /// are retired.
25    const THRESHOLD: usize = 64;
26
27    /// Create a new retired pointer list protected by the given `HazardBag`.
28    pub fn new(hazards: &'s HazardBag) -> Self {
29        Self {
30            hazards,
31            inner: Vec::new(),
32            _marker: PhantomData,
33        }
34    }
35
36    /// Retires a pointer.
37    ///
38    /// # Safety
39    ///
40    /// * `pointer` must be removed from shared memory before calling this function, and must be
41    ///   valid.
42    /// * The same `pointer` should only be retired once.
43    ///
44    /// # Note
45    ///
46    /// `T: Send` is not required because the retired pointers are not sent to other threads.
47    pub unsafe fn retire<T>(&mut self, pointer: *mut T) {
48        /// Frees a pointer. This function is defined here instead of `collect()` as we know about
49        /// the type of `pointer` only at the time of retiring it.
50        ///
51        /// # Safety
52        ///
53        /// * Subsumes the safety requirements of [`Box::from_raw`]. In particular, one must have
54        ///   unique ownership to `data`.
55        ///
56        /// [`Box::from_raw`]: https://doc.rust-lang.org/std/boxed/struct.Box.html#method.from_raw
57        unsafe fn free<T>(data: *mut ()) {
58            drop(unsafe { Box::from_raw(data.cast::<T>()) })
59        }
60
61        todo!()
62    }
63
64    /// Free the pointers that are `retire`d by the current thread and not `protect`ed by any other
65    /// threads.
66    pub fn collect(&mut self) {
67        todo!()
68    }
69}
70
71impl Default for RetiredSet<'static> {
72    fn default() -> Self {
73        Self::new(&HAZARDS)
74    }
75}
76
77// this triggers loom internal bug
78#[cfg(not(feature = "check-loom"))]
79impl Drop for RetiredSet<'_> {
80    fn drop(&mut self) {
81        // In a production-quality implementation of hazard pointers, the remaining local retired
82        // pointers will be moved to a global list of retired pointers, which are then reclaimed by
83        // the other threads. For pedagogical purposes, here we simply wait for all retired pointers
84        // are no longer protected.
85        while !self.inner.is_empty() {
86            self.collect();
87        }
88    }
89}
90
91#[cfg(all(test, not(feature = "check-loom")))]
92mod tests {
93    use std::cell::RefCell;
94    use std::collections::HashSet;
95    use std::rc::Rc;
96
97    use super::{HazardBag, RetiredSet};
98
99    // retire `THRESHOLD` pointers to trigger collection
100    #[test]
101    fn retire_threshold_collect() {
102        struct Tester(Rc<RefCell<HashSet<usize>>>, usize);
103        impl Drop for Tester {
104            fn drop(&mut self) {
105                let _ = self.0.borrow_mut().insert(self.1);
106            }
107        }
108        let hazards = HazardBag::new();
109        let mut retires = RetiredSet::new(&hazards);
110        let freed = Rc::new(RefCell::new(HashSet::new()));
111        for i in 0..RetiredSet::THRESHOLD {
112            unsafe { retires.retire(Box::leak(Box::new(Tester(freed.clone(), i)))) };
113        }
114        let freed = Rc::try_unwrap(freed).unwrap().into_inner();
115
116        assert_eq!(freed, (0..RetiredSet::THRESHOLD).collect())
117    }
118}