cs431_homework/elim_stack/
treiber_stack.rs

1use core::mem::{self, ManuallyDrop};
2use core::ops::Deref;
3use core::ptr;
4use core::sync::atomic::Ordering;
5
6use crossbeam_epoch::{Atomic, Guard, Owned, Shared};
7
8use super::base::Stack;
9
10#[derive(Debug)]
11pub struct Node<T> {
12    data: ManuallyDrop<T>,
13    next: *const Node<T>,
14}
15
16// Any particular `T` should never be accessed concurrently, so no need for `Sync`.
17unsafe impl<T: Send> Send for Node<T> {}
18unsafe impl<T: Send> Sync for Node<T> {}
19
20/// Treiber's lock-free stack.
21///
22/// Usable with any number of producers and consumers.
23#[derive(Debug)]
24pub struct TreiberStack<T> {
25    head: Atomic<Node<T>>,
26}
27
28impl<T> From<T> for Node<T> {
29    fn from(t: T) -> Self {
30        Self {
31            data: ManuallyDrop::new(t),
32            next: ptr::null(),
33        }
34    }
35}
36
37impl<T> Deref for Node<T> {
38    type Target = ManuallyDrop<T>;
39
40    fn deref(&self) -> &Self::Target {
41        &self.data
42    }
43}
44
45impl<T> Default for TreiberStack<T> {
46    fn default() -> Self {
47        TreiberStack {
48            head: Atomic::null(),
49        }
50    }
51}
52
53impl<T> Stack<T> for TreiberStack<T> {
54    type PushReq = Node<T>;
55
56    fn try_push(
57        &self,
58        req: Owned<Self::PushReq>,
59        guard: &Guard,
60    ) -> Result<(), Owned<Self::PushReq>> {
61        let mut req = req;
62        let head = self.head.load(Ordering::Relaxed, guard);
63        req.next = head.as_raw();
64
65        match self
66            .head
67            .compare_exchange(head, req, Ordering::Release, Ordering::Relaxed, guard)
68        {
69            Ok(_) => Ok(()),
70            Err(e) => Err(e.new),
71        }
72    }
73
74    fn try_pop(&self, guard: &Guard) -> Result<Option<T>, ()> {
75        let head = self.head.load(Ordering::Acquire, guard);
76        let Some(head_ref) = (unsafe { head.as_ref() }) else {
77            return Ok(None);
78        };
79        let next = Shared::from(head_ref.next);
80
81        let _ = self
82            .head
83            .compare_exchange(head, next, Ordering::Relaxed, Ordering::Relaxed, guard)
84            .map_err(|_| ())?;
85
86        let data = ManuallyDrop::into_inner(unsafe { ptr::read(&head_ref.data) });
87        unsafe { guard.defer_destroy(head) };
88        Ok(Some(data))
89    }
90
91    fn is_empty(&self, guard: &Guard) -> bool {
92        self.head.load(Ordering::Acquire, guard).is_null()
93    }
94}
95
96impl<T> Drop for TreiberStack<T> {
97    fn drop(&mut self) {
98        let mut o_curr = mem::take(&mut self.head);
99        while let Some(curr) = unsafe { o_curr.try_into_owned() }.map(Owned::into_box) {
100            drop(ManuallyDrop::into_inner(curr.data));
101            o_curr = curr.next.into();
102        }
103    }
104}
105
106#[cfg(test)]
107mod test {
108    use std::thread::scope;
109
110    use super::*;
111
112    #[test]
113    fn push() {
114        let stack = TreiberStack::default();
115
116        scope(|scope| {
117            let mut handles = Vec::new();
118            for _ in 0..10 {
119                let handle = scope.spawn(|| {
120                    for i in 0..10_000 {
121                        stack.push(i);
122                        assert!(stack.pop().is_some());
123                    }
124                });
125                handles.push(handle);
126            }
127        });
128
129        assert!(stack.pop().is_none());
130    }
131}