1use crate::job::{JobFifo, JobRef, StackJob};
2use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch};
3use crate::sleep::Sleep;
4use crate::sync::Mutex;
5use crate::unwind;
6use crate::{
7 ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
8 Yield,
9};
10use crossbeam_deque::{Injector, Steal, Stealer, Worker};
11use std::cell::Cell;
12use std::collections::hash_map::DefaultHasher;
13use std::fmt;
14use std::hash::Hasher;
15use std::io;
16use std::mem;
17use std::ptr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Once};
20use std::thread;
21use std::usize;
22
23pub struct ThreadBuilder {
26 name: Option<String>,
27 stack_size: Option<usize>,
28 worker: Worker<JobRef>,
29 stealer: Stealer<JobRef>,
30 registry: Arc<Registry>,
31 index: usize,
32}
33
34impl ThreadBuilder {
35 pub fn index(&self) -> usize {
37 self.index
38 }
39
40 pub fn name(&self) -> Option<&str> {
42 self.name.as_deref()
43 }
44
45 pub fn stack_size(&self) -> Option<usize> {
47 self.stack_size
48 }
49
50 pub fn run(self) {
53 unsafe { main_loop(self) }
54 }
55}
56
57impl fmt::Debug for ThreadBuilder {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("ThreadBuilder")
60 .field("pool", &self.registry.id())
61 .field("index", &self.index)
62 .field("name", &self.name)
63 .field("stack_size", &self.stack_size)
64 .finish()
65 }
66}
67
68pub trait ThreadSpawn {
73 private_decl! {}
74
75 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
78}
79
80#[derive(Debug, Default)]
85pub struct DefaultSpawn;
86
87impl ThreadSpawn for DefaultSpawn {
88 private_impl! {}
89
90 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
91 let mut b = thread::Builder::new();
92 if let Some(name) = thread.name() {
93 b = b.name(name.to_owned());
94 }
95 if let Some(stack_size) = thread.stack_size() {
96 b = b.stack_size(stack_size);
97 }
98 b.spawn(|| thread.run())?;
99 Ok(())
100 }
101}
102
103#[derive(Debug)]
108pub struct CustomSpawn<F>(F);
109
110impl<F> CustomSpawn<F>
111where
112 F: FnMut(ThreadBuilder) -> io::Result<()>,
113{
114 pub(super) fn new(spawn: F) -> Self {
115 CustomSpawn(spawn)
116 }
117}
118
119impl<F> ThreadSpawn for CustomSpawn<F>
120where
121 F: FnMut(ThreadBuilder) -> io::Result<()>,
122{
123 private_impl! {}
124
125 #[inline]
126 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
127 (self.0)(thread)
128 }
129}
130
131pub(super) struct Registry {
132 thread_infos: Vec<ThreadInfo>,
133 sleep: Sleep,
134 injected_jobs: Injector<JobRef>,
135 broadcasts: Mutex<Vec<Worker<JobRef>>>,
136 panic_handler: Option<Box<PanicHandler>>,
137 start_handler: Option<Box<StartHandler>>,
138 exit_handler: Option<Box<ExitHandler>>,
139
140 terminate_count: AtomicUsize,
154}
155
156static mut THE_REGISTRY: Option<Arc<Registry>> = None;
160static THE_REGISTRY_SET: Once = Once::new();
161
162pub(super) fn global_registry() -> &'static Arc<Registry> {
166 set_global_registry(default_global_registry)
167 .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
168 .expect("The global thread pool has not been initialized.")
169}
170
171pub(super) fn init_global_registry<S>(
174 builder: ThreadPoolBuilder<S>,
175) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
176where
177 S: ThreadSpawn,
178{
179 set_global_registry(|| Registry::new(builder))
180}
181
182fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
185where
186 F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
187{
188 let mut result = Err(ThreadPoolBuildError::new(
189 ErrorKind::GlobalPoolAlreadyInitialized,
190 ));
191
192 THE_REGISTRY_SET.call_once(|| {
193 result = registry()
194 .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) })
195 });
196
197 result
198}
199
200fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
201 let result = Registry::new(ThreadPoolBuilder::new());
202
203 let unsupported = matches!(&result, Err(e) if e.is_unsupported());
210 if unsupported && WorkerThread::current().is_null() {
211 let builder = ThreadPoolBuilder::new().num_threads(1).use_current_thread();
212 let fallback_result = Registry::new(builder);
213 if fallback_result.is_ok() {
214 return fallback_result;
215 }
216 }
217
218 result
219}
220
221struct Terminator<'a>(&'a Arc<Registry>);
222
223impl<'a> Drop for Terminator<'a> {
224 fn drop(&mut self) {
225 self.0.terminate()
226 }
227}
228
229impl Registry {
230 pub(super) fn new<S>(
231 mut builder: ThreadPoolBuilder<S>,
232 ) -> Result<Arc<Self>, ThreadPoolBuildError>
233 where
234 S: ThreadSpawn,
235 {
236 let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
238
239 let breadth_first = builder.get_breadth_first();
240
241 let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
242 .map(|_| {
243 let worker = if breadth_first {
244 Worker::new_fifo()
245 } else {
246 Worker::new_lifo()
247 };
248
249 let stealer = worker.stealer();
250 (worker, stealer)
251 })
252 .unzip();
253
254 let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
255 .map(|_| {
256 let worker = Worker::new_fifo();
257 let stealer = worker.stealer();
258 (worker, stealer)
259 })
260 .unzip();
261
262 let registry = Arc::new(Registry {
263 thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
264 sleep: Sleep::new(n_threads),
265 injected_jobs: Injector::new(),
266 broadcasts: Mutex::new(broadcasts),
267 terminate_count: AtomicUsize::new(1),
268 panic_handler: builder.take_panic_handler(),
269 start_handler: builder.take_start_handler(),
270 exit_handler: builder.take_exit_handler(),
271 });
272
273 let t1000 = Terminator(®istry);
275
276 for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
277 let thread = ThreadBuilder {
278 name: builder.get_thread_name(index),
279 stack_size: builder.get_stack_size(),
280 registry: Arc::clone(®istry),
281 worker,
282 stealer,
283 index,
284 };
285
286 if index == 0 && builder.use_current_thread {
287 if !WorkerThread::current().is_null() {
288 return Err(ThreadPoolBuildError::new(
289 ErrorKind::CurrentThreadAlreadyInPool,
290 ));
291 }
292 let worker_thread = Box::into_raw(Box::new(WorkerThread::from(thread)));
296
297 unsafe {
298 WorkerThread::set_current(worker_thread);
299 Latch::set(®istry.thread_infos[index].primed);
300 }
301 continue;
302 }
303
304 if let Err(e) = builder.get_spawn_handler().spawn(thread) {
305 return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
306 }
307 }
308
309 mem::forget(t1000);
311
312 Ok(registry)
313 }
314
315 pub(super) fn current() -> Arc<Registry> {
316 unsafe {
317 let worker_thread = WorkerThread::current();
318 let registry = if worker_thread.is_null() {
319 global_registry()
320 } else {
321 &(*worker_thread).registry
322 };
323 Arc::clone(registry)
324 }
325 }
326
327 pub(super) fn current_num_threads() -> usize {
331 unsafe {
332 let worker_thread = WorkerThread::current();
333 if worker_thread.is_null() {
334 global_registry().num_threads()
335 } else {
336 (*worker_thread).registry.num_threads()
337 }
338 }
339 }
340
341 pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
343 unsafe {
344 let worker = WorkerThread::current().as_ref()?;
345 if worker.registry().id() == self.id() {
346 Some(worker)
347 } else {
348 None
349 }
350 }
351 }
352
353 pub(super) fn id(&self) -> RegistryId {
355 RegistryId {
358 addr: self as *const Self as usize,
359 }
360 }
361
362 pub(super) fn num_threads(&self) -> usize {
363 self.thread_infos.len()
364 }
365
366 pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
367 if let Err(err) = unwind::halt_unwinding(f) {
368 let abort_guard = unwind::AbortIfPanic;
370 if let Some(ref handler) = self.panic_handler {
371 handler(err);
372 mem::forget(abort_guard);
373 }
374 }
375 }
376
377 pub(super) fn wait_until_primed(&self) {
382 for info in &self.thread_infos {
383 info.primed.wait();
384 }
385 }
386
387 #[cfg(test)]
390 pub(super) fn wait_until_stopped(&self) {
391 for info in &self.thread_infos {
392 info.stopped.wait();
393 }
394 }
395
396 pub(super) fn inject_or_push(&self, job_ref: JobRef) {
406 let worker_thread = WorkerThread::current();
407 unsafe {
408 if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
409 (*worker_thread).push(job_ref);
410 } else {
411 self.inject(job_ref);
412 }
413 }
414 }
415
416 pub(super) fn inject(&self, injected_job: JobRef) {
420 debug_assert_ne!(
426 self.terminate_count.load(Ordering::Acquire),
427 0,
428 "inject() sees state.terminate as true"
429 );
430
431 let queue_was_empty = self.injected_jobs.is_empty();
432
433 self.injected_jobs.push(injected_job);
434 self.sleep.new_injected_jobs(1, queue_was_empty);
435 }
436
437 fn has_injected_job(&self) -> bool {
438 !self.injected_jobs.is_empty()
439 }
440
441 fn pop_injected_job(&self) -> Option<JobRef> {
442 loop {
443 match self.injected_jobs.steal() {
444 Steal::Success(job) => return Some(job),
445 Steal::Empty => return None,
446 Steal::Retry => {}
447 }
448 }
449 }
450
451 pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
457 assert_eq!(self.num_threads(), injected_jobs.len());
458 {
459 let broadcasts = self.broadcasts.lock().unwrap();
460
461 debug_assert_ne!(
467 self.terminate_count.load(Ordering::Acquire),
468 0,
469 "inject_broadcast() sees state.terminate as true"
470 );
471
472 assert_eq!(broadcasts.len(), injected_jobs.len());
473 for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
474 worker.push(job_ref);
475 }
476 }
477 for i in 0..self.num_threads() {
478 self.sleep.notify_worker_latch_is_set(i);
479 }
480 }
481
482 pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
488 where
489 OP: FnOnce(&WorkerThread, bool) -> R + Send,
490 R: Send,
491 {
492 unsafe {
493 let worker_thread = WorkerThread::current();
494 if worker_thread.is_null() {
495 self.in_worker_cold(op)
496 } else if (*worker_thread).registry().id() != self.id() {
497 self.in_worker_cross(&*worker_thread, op)
498 } else {
499 op(&*worker_thread, false)
503 }
504 }
505 }
506
507 #[cold]
508 unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
509 where
510 OP: FnOnce(&WorkerThread, bool) -> R + Send,
511 R: Send,
512 {
513 thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
514
515 LOCK_LATCH.with(|l| {
516 debug_assert!(WorkerThread::current().is_null());
518 let job = StackJob::new(
519 |injected| {
520 let worker_thread = WorkerThread::current();
521 assert!(injected && !worker_thread.is_null());
522 op(&*worker_thread, true)
523 },
524 LatchRef::new(l),
525 );
526 self.inject(job.as_job_ref());
527 job.latch.wait_and_reset(); job.into_result()
530 })
531 }
532
533 #[cold]
534 unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
535 where
536 OP: FnOnce(&WorkerThread, bool) -> R + Send,
537 R: Send,
538 {
539 debug_assert!(current_thread.registry().id() != self.id());
542 let latch = SpinLatch::cross(current_thread);
543 let job = StackJob::new(
544 |injected| {
545 let worker_thread = WorkerThread::current();
546 assert!(injected && !worker_thread.is_null());
547 op(&*worker_thread, true)
548 },
549 latch,
550 );
551 self.inject(job.as_job_ref());
552 current_thread.wait_until(&job.latch);
553 job.into_result()
554 }
555
556 pub(super) fn increment_terminate_count(&self) {
577 let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
578 debug_assert!(previous != 0, "registry ref count incremented from zero");
579 assert!(
580 previous != std::usize::MAX,
581 "overflow in registry ref count"
582 );
583 }
584
585 pub(super) fn terminate(&self) {
589 if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
590 for (i, thread_info) in self.thread_infos.iter().enumerate() {
591 unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
592 }
593 }
594 }
595
596 pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
598 self.sleep.notify_worker_latch_is_set(target_worker_index);
599 }
600}
601
602#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
603pub(super) struct RegistryId {
604 addr: usize,
605}
606
607struct ThreadInfo {
608 primed: LockLatch,
612
613 stopped: LockLatch,
616
617 terminate: OnceLatch,
622
623 stealer: Stealer<JobRef>,
625}
626
627impl ThreadInfo {
628 fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
629 ThreadInfo {
630 primed: LockLatch::new(),
631 stopped: LockLatch::new(),
632 terminate: OnceLatch::new(),
633 stealer,
634 }
635 }
636}
637
638pub(super) struct WorkerThread {
642 worker: Worker<JobRef>,
644
645 stealer: Stealer<JobRef>,
647
648 fifo: JobFifo,
650
651 index: usize,
652
653 rng: XorShift64Star,
655
656 registry: Arc<Registry>,
657}
658
659thread_local! {
665 static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) };
666}
667
668impl From<ThreadBuilder> for WorkerThread {
669 fn from(thread: ThreadBuilder) -> Self {
670 Self {
671 worker: thread.worker,
672 stealer: thread.stealer,
673 fifo: JobFifo::new(),
674 index: thread.index,
675 rng: XorShift64Star::new(),
676 registry: thread.registry,
677 }
678 }
679}
680
681impl Drop for WorkerThread {
682 fn drop(&mut self) {
683 WORKER_THREAD_STATE.with(|t| {
685 assert!(t.get().eq(&(self as *const _)));
686 t.set(ptr::null());
687 });
688 }
689}
690
691impl WorkerThread {
692 #[inline]
696 pub(super) fn current() -> *const WorkerThread {
697 WORKER_THREAD_STATE.with(Cell::get)
698 }
699
700 unsafe fn set_current(thread: *const WorkerThread) {
703 WORKER_THREAD_STATE.with(|t| {
704 assert!(t.get().is_null());
705 t.set(thread);
706 });
707 }
708
709 #[inline]
711 pub(super) fn registry(&self) -> &Arc<Registry> {
712 &self.registry
713 }
714
715 #[inline]
717 pub(super) fn index(&self) -> usize {
718 self.index
719 }
720
721 #[inline]
722 pub(super) unsafe fn push(&self, job: JobRef) {
723 let queue_was_empty = self.worker.is_empty();
724 self.worker.push(job);
725 self.registry.sleep.new_internal_jobs(1, queue_was_empty);
726 }
727
728 #[inline]
729 pub(super) unsafe fn push_fifo(&self, job: JobRef) {
730 self.push(self.fifo.push(job));
731 }
732
733 #[inline]
734 pub(super) fn local_deque_is_empty(&self) -> bool {
735 self.worker.is_empty()
736 }
737
738 #[inline]
743 pub(super) fn take_local_job(&self) -> Option<JobRef> {
744 let popped_job = self.worker.pop();
745
746 if popped_job.is_some() {
747 return popped_job;
748 }
749
750 loop {
751 match self.stealer.steal() {
752 Steal::Success(job) => return Some(job),
753 Steal::Empty => return None,
754 Steal::Retry => {}
755 }
756 }
757 }
758
759 fn has_injected_job(&self) -> bool {
760 !self.stealer.is_empty() || self.registry.has_injected_job()
761 }
762
763 #[inline]
766 pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
767 let latch = latch.as_core_latch();
768 if !latch.probe() {
769 self.wait_until_cold(latch);
770 }
771 }
772
773 #[cold]
774 unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
775 let abort_guard = unwind::AbortIfPanic;
781
782 'outer: while !latch.probe() {
783 if let Some(job) = self.take_local_job() {
786 self.execute(job);
787 continue;
788 }
789
790 let mut idle_state = self.registry.sleep.start_looking(self.index);
791 while !latch.probe() {
792 if let Some(job) = self.find_work() {
793 self.registry.sleep.work_found();
794 self.execute(job);
795 continue 'outer;
797 } else {
798 self.registry
799 .sleep
800 .no_work_found(&mut idle_state, latch, || self.has_injected_job())
801 }
802 }
803
804 self.registry.sleep.work_found();
807 break;
808 }
809
810 mem::forget(abort_guard); }
812
813 unsafe fn wait_until_out_of_work(&self) {
814 debug_assert_eq!(self as *const _, WorkerThread::current());
815 let registry = &*self.registry;
816 let index = self.index;
817
818 self.wait_until(®istry.thread_infos[index].terminate);
819
820 debug_assert!(self.take_local_job().is_none());
822
823 Latch::set(®istry.thread_infos[index].stopped);
825 }
826
827 fn find_work(&self) -> Option<JobRef> {
828 self.take_local_job()
834 .or_else(|| self.steal())
835 .or_else(|| self.registry.pop_injected_job())
836 }
837
838 pub(super) fn yield_now(&self) -> Yield {
839 match self.find_work() {
840 Some(job) => unsafe {
841 self.execute(job);
842 Yield::Executed
843 },
844 None => Yield::Idle,
845 }
846 }
847
848 pub(super) fn yield_local(&self) -> Yield {
849 match self.take_local_job() {
850 Some(job) => unsafe {
851 self.execute(job);
852 Yield::Executed
853 },
854 None => Yield::Idle,
855 }
856 }
857
858 #[inline]
859 pub(super) unsafe fn execute(&self, job: JobRef) {
860 job.execute();
861 }
862
863 fn steal(&self) -> Option<JobRef> {
868 debug_assert!(self.local_deque_is_empty());
870
871 let thread_infos = &self.registry.thread_infos.as_slice();
873 let num_threads = thread_infos.len();
874 if num_threads <= 1 {
875 return None;
876 }
877
878 loop {
879 let mut retry = false;
880 let start = self.rng.next_usize(num_threads);
881 let job = (start..num_threads)
882 .chain(0..start)
883 .filter(move |&i| i != self.index)
884 .find_map(|victim_index| {
885 let victim = &thread_infos[victim_index];
886 match victim.stealer.steal() {
887 Steal::Success(job) => Some(job),
888 Steal::Empty => None,
889 Steal::Retry => {
890 retry = true;
891 None
892 }
893 }
894 });
895 if job.is_some() || !retry {
896 return job;
897 }
898 }
899 }
900}
901
902unsafe fn main_loop(thread: ThreadBuilder) {
905 let worker_thread = &WorkerThread::from(thread);
906 WorkerThread::set_current(worker_thread);
907 let registry = &*worker_thread.registry;
908 let index = worker_thread.index;
909
910 Latch::set(®istry.thread_infos[index].primed);
912
913 let abort_guard = unwind::AbortIfPanic;
917
918 if let Some(ref handler) = registry.start_handler {
920 registry.catch_unwind(|| handler(index));
921 }
922
923 worker_thread.wait_until_out_of_work();
924
925 mem::forget(abort_guard);
927
928 if let Some(ref handler) = registry.exit_handler {
930 registry.catch_unwind(|| handler(index));
931 }
933}
934
935pub(super) fn in_worker<OP, R>(op: OP) -> R
941where
942 OP: FnOnce(&WorkerThread, bool) -> R + Send,
943 R: Send,
944{
945 unsafe {
946 let owner_thread = WorkerThread::current();
947 if !owner_thread.is_null() {
948 op(&*owner_thread, false)
952 } else {
953 global_registry().in_worker(op)
954 }
955 }
956}
957
958struct XorShift64Star {
963 state: Cell<u64>,
964}
965
966impl XorShift64Star {
967 fn new() -> Self {
968 let mut seed = 0;
970 while seed == 0 {
971 let mut hasher = DefaultHasher::new();
972 static COUNTER: AtomicUsize = AtomicUsize::new(0);
973 hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
974 seed = hasher.finish();
975 }
976
977 XorShift64Star {
978 state: Cell::new(seed),
979 }
980 }
981
982 fn next(&self) -> u64 {
983 let mut x = self.state.get();
984 debug_assert_ne!(x, 0);
985 x ^= x >> 12;
986 x ^= x << 25;
987 x ^= x >> 27;
988 self.state.set(x);
989 x.wrapping_mul(0x2545_f491_4f6c_dd1d)
990 }
991
992 fn next_usize(&self, n: usize) -> usize {
994 (self.next() % n as u64) as usize
995 }
996}