cs431_homework/elim_stack/
treiber_stack.rs1use core::mem::{self, ManuallyDrop};
2use core::ops::Deref;
3use core::ptr;
4use core::sync::atomic::Ordering;
5
6use crossbeam_epoch::{Atomic, Guard, Owned, Shared};
7
8use super::base::Stack;
9
10#[derive(Debug)]
11pub struct Node<T> {
12 data: ManuallyDrop<T>,
13 next: *const Node<T>,
14}
15
16unsafe impl<T: Send> Send for Node<T> {}
18unsafe impl<T: Send> Sync for Node<T> {}
19
20#[derive(Debug)]
24pub struct TreiberStack<T> {
25 head: Atomic<Node<T>>,
26}
27
28impl<T> From<T> for Node<T> {
29 fn from(t: T) -> Self {
30 Self {
31 data: ManuallyDrop::new(t),
32 next: ptr::null(),
33 }
34 }
35}
36
37impl<T> Deref for Node<T> {
38 type Target = ManuallyDrop<T>;
39
40 fn deref(&self) -> &Self::Target {
41 &self.data
42 }
43}
44
45impl<T> Default for TreiberStack<T> {
46 fn default() -> Self {
47 TreiberStack {
48 head: Atomic::null(),
49 }
50 }
51}
52
53impl<T> Stack<T> for TreiberStack<T> {
54 type PushReq = Node<T>;
55
56 fn try_push(
57 &self,
58 req: Owned<Self::PushReq>,
59 guard: &Guard,
60 ) -> Result<(), Owned<Self::PushReq>> {
61 let mut req = req;
62 let head = self.head.load(Ordering::Relaxed, guard);
63 req.next = head.as_raw();
64
65 match self
66 .head
67 .compare_exchange(head, req, Ordering::Release, Ordering::Relaxed, guard)
68 {
69 Ok(_) => Ok(()),
70 Err(e) => Err(e.new),
71 }
72 }
73
74 fn try_pop(&self, guard: &Guard) -> Result<Option<T>, ()> {
75 let head = self.head.load(Ordering::Acquire, guard);
76 let Some(head_ref) = (unsafe { head.as_ref() }) else {
77 return Ok(None);
78 };
79 let next = Shared::from(head_ref.next);
80
81 let _ = self
82 .head
83 .compare_exchange(head, next, Ordering::Relaxed, Ordering::Relaxed, guard)
84 .map_err(|_| ())?;
85
86 let data = ManuallyDrop::into_inner(unsafe { ptr::read(&head_ref.data) });
87 unsafe { guard.defer_destroy(head) };
88 Ok(Some(data))
89 }
90
91 fn is_empty(&self, guard: &Guard) -> bool {
92 self.head.load(Ordering::Acquire, guard).is_null()
93 }
94}
95
96impl<T> Drop for TreiberStack<T> {
97 fn drop(&mut self) {
98 let mut o_curr = mem::take(&mut self.head);
99 while let Some(curr) = unsafe { o_curr.try_into_owned() }.map(Owned::into_box) {
100 drop(ManuallyDrop::into_inner(curr.data));
101 o_curr = curr.next.into();
102 }
103 }
104}
105
106#[cfg(test)]
107mod test {
108 use std::thread::scope;
109
110 use super::*;
111
112 #[test]
113 fn push() {
114 let stack = TreiberStack::default();
115
116 scope(|scope| {
117 let mut handles = Vec::new();
118 for _ in 0..10 {
119 let handle = scope.spawn(|| {
120 for i in 0..10_000 {
121 stack.push(i);
122 assert!(stack.pop().is_some());
123 }
124 });
125 handles.push(handle);
126 }
127 });
128
129 assert!(stack.pop().is_none());
130 }
131}