crossbeam_utils/sync/
wait_group.rs

1use crate::primitive::sync::{Arc, Condvar, Mutex};
2use std::fmt;
3
4/// Enables threads to synchronize the beginning or end of some computation.
5///
6/// # Wait groups vs barriers
7///
8/// `WaitGroup` is very similar to [`Barrier`], but there are a few differences:
9///
10/// * [`Barrier`] needs to know the number of threads at construction, while `WaitGroup` is cloned to
11///   register more threads.
12///
13/// * A [`Barrier`] can be reused even after all threads have synchronized, while a `WaitGroup`
14///   synchronizes threads only once.
15///
16/// * All threads wait for others to reach the [`Barrier`]. With `WaitGroup`, each thread can choose
17///   to either wait for other threads or to continue without blocking.
18///
19/// # Examples
20///
21/// ```
22/// use crossbeam_utils::sync::WaitGroup;
23/// use std::thread;
24///
25/// // Create a new wait group.
26/// let wg = WaitGroup::new();
27///
28/// for _ in 0..4 {
29///     // Create another reference to the wait group.
30///     let wg = wg.clone();
31///
32///     thread::spawn(move || {
33///         // Do some work.
34///
35///         // Drop the reference to the wait group.
36///         drop(wg);
37///     });
38/// }
39///
40/// // Block until all threads have finished their work.
41/// wg.wait();
42/// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
43/// ```
44///
45/// [`Barrier`]: std::sync::Barrier
46pub struct WaitGroup {
47    inner: Arc<Inner>,
48}
49
50/// Inner state of a `WaitGroup`.
51struct Inner {
52    cvar: Condvar,
53    count: Mutex<usize>,
54}
55
56impl Default for WaitGroup {
57    fn default() -> Self {
58        Self {
59            inner: Arc::new(Inner {
60                cvar: Condvar::new(),
61                count: Mutex::new(1),
62            }),
63        }
64    }
65}
66
67impl WaitGroup {
68    /// Creates a new wait group and returns the single reference to it.
69    ///
70    /// # Examples
71    ///
72    /// ```
73    /// use crossbeam_utils::sync::WaitGroup;
74    ///
75    /// let wg = WaitGroup::new();
76    /// ```
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Drops this reference and waits until all other references are dropped.
82    ///
83    /// # Examples
84    ///
85    /// ```
86    /// use crossbeam_utils::sync::WaitGroup;
87    /// use std::thread;
88    ///
89    /// let wg = WaitGroup::new();
90    ///
91    /// thread::spawn({
92    ///     let wg = wg.clone();
93    ///     move || {
94    ///         // Block until both threads have reached `wait()`.
95    ///         wg.wait();
96    ///     }
97    /// });
98    ///
99    /// // Block until both threads have reached `wait()`.
100    /// wg.wait();
101    /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
102    /// ```
103    pub fn wait(self) {
104        if *self.inner.count.lock().unwrap() == 1 {
105            return;
106        }
107
108        let inner = self.inner.clone();
109        drop(self);
110
111        let mut count = inner.count.lock().unwrap();
112        while *count > 0 {
113            count = inner.cvar.wait(count).unwrap();
114        }
115    }
116}
117
118impl Drop for WaitGroup {
119    fn drop(&mut self) {
120        let mut count = self.inner.count.lock().unwrap();
121        *count -= 1;
122
123        if *count == 0 {
124            self.inner.cvar.notify_all();
125        }
126    }
127}
128
129impl Clone for WaitGroup {
130    fn clone(&self) -> WaitGroup {
131        let mut count = self.inner.count.lock().unwrap();
132        *count += 1;
133
134        WaitGroup {
135            inner: self.inner.clone(),
136        }
137    }
138}
139
140impl fmt::Debug for WaitGroup {
141    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        let count: &usize = &*self.inner.count.lock().unwrap();
143        f.debug_struct("WaitGroup").field("count", count).finish()
144    }
145}