1use alloc::vec::{self, Vec};
11use core::slice;
12use core::{hash::Hash, ops::AddAssign};
13#[cfg(feature = "std")]
15use super::WeightError;
16use crate::distr::uniform::SampleUniform;
17use crate::distr::{Distribution, Uniform};
18use crate::Rng;
19#[cfg(not(feature = "std"))]
20use alloc::collections::BTreeSet;
21#[cfg(feature = "serde")]
22use serde::{Deserialize, Serialize};
23#[cfg(feature = "std")]
24use std::collections::HashSet;
25
26#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
27compile_error!("unsupported pointer width");
28
29#[derive(Clone, Debug)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34pub enum IndexVec {
35 #[doc(hidden)]
36 U32(Vec<u32>),
37 #[cfg(target_pointer_width = "64")]
38 #[doc(hidden)]
39 U64(Vec<u64>),
40}
41
42impl IndexVec {
43 #[inline]
45 pub fn len(&self) -> usize {
46 match self {
47 IndexVec::U32(v) => v.len(),
48 #[cfg(target_pointer_width = "64")]
49 IndexVec::U64(v) => v.len(),
50 }
51 }
52
53 #[inline]
55 pub fn is_empty(&self) -> bool {
56 match self {
57 IndexVec::U32(v) => v.is_empty(),
58 #[cfg(target_pointer_width = "64")]
59 IndexVec::U64(v) => v.is_empty(),
60 }
61 }
62
63 #[inline]
68 pub fn index(&self, index: usize) -> usize {
69 match self {
70 IndexVec::U32(v) => v[index] as usize,
71 #[cfg(target_pointer_width = "64")]
72 IndexVec::U64(v) => v[index] as usize,
73 }
74 }
75
76 #[inline]
78 pub fn into_vec(self) -> Vec<usize> {
79 match self {
80 IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(),
81 #[cfg(target_pointer_width = "64")]
82 IndexVec::U64(v) => v.into_iter().map(|i| i as usize).collect(),
83 }
84 }
85
86 #[inline]
88 pub fn iter(&self) -> IndexVecIter<'_> {
89 match self {
90 IndexVec::U32(v) => IndexVecIter::U32(v.iter()),
91 #[cfg(target_pointer_width = "64")]
92 IndexVec::U64(v) => IndexVecIter::U64(v.iter()),
93 }
94 }
95}
96
97impl IntoIterator for IndexVec {
98 type IntoIter = IndexVecIntoIter;
99 type Item = usize;
100
101 #[inline]
103 fn into_iter(self) -> IndexVecIntoIter {
104 match self {
105 IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()),
106 #[cfg(target_pointer_width = "64")]
107 IndexVec::U64(v) => IndexVecIntoIter::U64(v.into_iter()),
108 }
109 }
110}
111
112impl PartialEq for IndexVec {
113 fn eq(&self, other: &IndexVec) -> bool {
114 use self::IndexVec::*;
115 match (self, other) {
116 (U32(v1), U32(v2)) => v1 == v2,
117 #[cfg(target_pointer_width = "64")]
118 (U64(v1), U64(v2)) => v1 == v2,
119 #[cfg(target_pointer_width = "64")]
120 (U32(v1), U64(v2)) => {
121 (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as u64 == *y))
122 }
123 #[cfg(target_pointer_width = "64")]
124 (U64(v1), U32(v2)) => {
125 (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as u64))
126 }
127 }
128 }
129}
130
131impl From<Vec<u32>> for IndexVec {
132 #[inline]
133 fn from(v: Vec<u32>) -> Self {
134 IndexVec::U32(v)
135 }
136}
137
138#[cfg(target_pointer_width = "64")]
139impl From<Vec<u64>> for IndexVec {
140 #[inline]
141 fn from(v: Vec<u64>) -> Self {
142 IndexVec::U64(v)
143 }
144}
145
146#[derive(Debug)]
148pub enum IndexVecIter<'a> {
149 #[doc(hidden)]
150 U32(slice::Iter<'a, u32>),
151 #[cfg(target_pointer_width = "64")]
152 #[doc(hidden)]
153 U64(slice::Iter<'a, u64>),
154}
155
156impl Iterator for IndexVecIter<'_> {
157 type Item = usize;
158
159 #[inline]
160 fn next(&mut self) -> Option<usize> {
161 use self::IndexVecIter::*;
162 match self {
163 U32(iter) => iter.next().map(|i| *i as usize),
164 #[cfg(target_pointer_width = "64")]
165 U64(iter) => iter.next().map(|i| *i as usize),
166 }
167 }
168
169 #[inline]
170 fn size_hint(&self) -> (usize, Option<usize>) {
171 match self {
172 IndexVecIter::U32(v) => v.size_hint(),
173 #[cfg(target_pointer_width = "64")]
174 IndexVecIter::U64(v) => v.size_hint(),
175 }
176 }
177}
178
179impl ExactSizeIterator for IndexVecIter<'_> {}
180
181#[derive(Clone, Debug)]
183pub enum IndexVecIntoIter {
184 #[doc(hidden)]
185 U32(vec::IntoIter<u32>),
186 #[cfg(target_pointer_width = "64")]
187 #[doc(hidden)]
188 U64(vec::IntoIter<u64>),
189}
190
191impl Iterator for IndexVecIntoIter {
192 type Item = usize;
193
194 #[inline]
195 fn next(&mut self) -> Option<Self::Item> {
196 use self::IndexVecIntoIter::*;
197 match self {
198 U32(v) => v.next().map(|i| i as usize),
199 #[cfg(target_pointer_width = "64")]
200 U64(v) => v.next().map(|i| i as usize),
201 }
202 }
203
204 #[inline]
205 fn size_hint(&self) -> (usize, Option<usize>) {
206 use self::IndexVecIntoIter::*;
207 match self {
208 U32(v) => v.size_hint(),
209 #[cfg(target_pointer_width = "64")]
210 U64(v) => v.size_hint(),
211 }
212 }
213}
214
215impl ExactSizeIterator for IndexVecIntoIter {}
216
217#[track_caller]
240pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
241where
242 R: Rng + ?Sized,
243{
244 if amount > length {
245 panic!("`amount` of samples must be less than or equal to `length`");
246 }
247 if length > (u32::MAX as usize) {
248 #[cfg(target_pointer_width = "32")]
249 unreachable!();
250
251 #[cfg(target_pointer_width = "64")]
254 return sample_rejection(rng, length as u64, amount as u64);
255 }
256 let amount = amount as u32;
257 let length = length as u32;
258
259 if amount < 163 {
264 const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]];
265 let j = usize::from(length >= 500_000);
266 let amount_fp = amount as f32;
267 let m4 = C[0][j] * amount_fp;
268 if amount > 11 && (length as f32) < (C[1][j] + m4) * amount_fp {
270 sample_inplace(rng, length, amount)
271 } else {
272 sample_floyd(rng, length, amount)
273 }
274 } else {
275 const C: [f32; 2] = [270.0, 330.0 / 9.0];
276 let j = usize::from(length >= 500_000);
277 if (length as f32) < C[j] * (amount as f32) {
278 sample_inplace(rng, length, amount)
279 } else {
280 sample_rejection(rng, length, amount)
281 }
282 }
283}
284
285#[cfg(feature = "std")]
302pub fn sample_weighted<R, F, X>(
303 rng: &mut R,
304 length: usize,
305 weight: F,
306 amount: usize,
307) -> Result<IndexVec, WeightError>
308where
309 R: Rng + ?Sized,
310 F: Fn(usize) -> X,
311 X: Into<f64>,
312{
313 if length > (u32::MAX as usize) {
314 #[cfg(target_pointer_width = "32")]
315 unreachable!();
316
317 #[cfg(target_pointer_width = "64")]
318 {
319 let amount = amount as u64;
320 let length = length as u64;
321 sample_efraimidis_spirakis(rng, length, weight, amount)
322 }
323 } else {
324 assert!(amount <= u32::MAX as usize);
325 let amount = amount as u32;
326 let length = length as u32;
327 sample_efraimidis_spirakis(rng, length, weight, amount)
328 }
329}
330
331#[cfg(feature = "std")]
344fn sample_efraimidis_spirakis<R, F, X, N>(
345 rng: &mut R,
346 length: N,
347 weight: F,
348 amount: N,
349) -> Result<IndexVec, WeightError>
350where
351 R: Rng + ?Sized,
352 F: Fn(usize) -> X,
353 X: Into<f64>,
354 N: UInt,
355 IndexVec: From<Vec<N>>,
356{
357 use std::{cmp::Ordering, collections::BinaryHeap};
358
359 if amount == N::zero() {
360 return Ok(IndexVec::U32(Vec::new()));
361 }
362
363 struct Element<N> {
364 index: N,
365 key: f64,
366 }
367
368 impl<N> PartialOrd for Element<N> {
369 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
370 Some(self.cmp(other))
371 }
372 }
373
374 impl<N> Ord for Element<N> {
375 fn cmp(&self, other: &Self) -> Ordering {
376 self.key.partial_cmp(&other.key).unwrap().reverse()
379 }
380 }
381
382 impl<N> PartialEq for Element<N> {
383 fn eq(&self, other: &Self) -> bool {
384 self.key == other.key
385 }
386 }
387
388 impl<N> Eq for Element<N> {}
389
390 let mut candidates = BinaryHeap::with_capacity(amount.as_usize());
391 let mut index = N::zero();
392 while index < length && candidates.len() < amount.as_usize() {
393 let weight = weight(index.as_usize()).into();
394 if weight > 0.0 {
395 let key = rng.random::<f64>().ln() / weight;
398 candidates.push(Element { index, key });
399 } else if !(weight >= 0.0) {
400 return Err(WeightError::InvalidWeight);
401 }
402
403 index += N::one();
404 }
405
406 if candidates.len() < amount.as_usize() {
407 return Err(WeightError::InsufficientNonZero);
408 }
409
410 let mut x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
411 while index < length {
412 let weight = weight(index.as_usize()).into();
413 if weight > 0.0 {
414 x -= weight;
415 if x <= 0.0 {
416 let min_candidate = candidates.pop().unwrap();
417 let t = (min_candidate.key * weight).exp();
418 let key = rng.random_range(t..1.0).ln() / weight;
419 candidates.push(Element { index, key });
420
421 x = rng.random::<f64>().ln() / candidates.peek().unwrap().key;
422 }
423 } else if !(weight >= 0.0) {
424 return Err(WeightError::InvalidWeight);
425 }
426
427 index += N::one();
428 }
429
430 Ok(IndexVec::from(
431 candidates.iter().map(|elt| elt.index).collect(),
432 ))
433}
434
435fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
442where
443 R: Rng + ?Sized,
444{
445 debug_assert!(amount <= length);
449 let mut indices = Vec::with_capacity(amount as usize);
450 for j in length - amount..length {
451 let t = rng.random_range(..=j);
452 if let Some(pos) = indices.iter().position(|&x| x == t) {
453 indices[pos] = j;
454 }
455 indices.push(t);
456 }
457 IndexVec::from(indices)
458}
459
460fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
473where
474 R: Rng + ?Sized,
475{
476 debug_assert!(amount <= length);
477 let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
478 indices.extend(0..length);
479 for i in 0..amount {
480 let j: u32 = rng.random_range(i..length);
481 indices.swap(i as usize, j as usize);
482 }
483 indices.truncate(amount as usize);
484 debug_assert_eq!(indices.len(), amount as usize);
485 IndexVec::from(indices)
486}
487
488trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + Hash + AddAssign {
489 fn zero() -> Self;
490 #[cfg_attr(feature = "alloc", allow(dead_code))]
491 fn one() -> Self;
492 fn as_usize(self) -> usize;
493}
494
495impl UInt for u32 {
496 #[inline]
497 fn zero() -> Self {
498 0
499 }
500
501 #[inline]
502 fn one() -> Self {
503 1
504 }
505
506 #[inline]
507 fn as_usize(self) -> usize {
508 self as usize
509 }
510}
511
512#[cfg(target_pointer_width = "64")]
513impl UInt for u64 {
514 #[inline]
515 fn zero() -> Self {
516 0
517 }
518
519 #[inline]
520 fn one() -> Self {
521 1
522 }
523
524 #[inline]
525 fn as_usize(self) -> usize {
526 self as usize
527 }
528}
529
530fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec
540where
541 R: Rng + ?Sized,
542 IndexVec: From<Vec<X>>,
543{
544 debug_assert!(amount < length);
545 #[cfg(feature = "std")]
546 let mut cache = HashSet::with_capacity(amount.as_usize());
547 #[cfg(not(feature = "std"))]
548 let mut cache = BTreeSet::new();
549 let distr = Uniform::new(X::zero(), length).unwrap();
550 let mut indices = Vec::with_capacity(amount.as_usize());
551 for _ in 0..amount.as_usize() {
552 let mut pos = distr.sample(rng);
553 while !cache.insert(pos) {
554 pos = distr.sample(rng);
555 }
556 indices.push(pos);
557 }
558
559 debug_assert_eq!(indices.len(), amount.as_usize());
560 IndexVec::from(indices)
561}
562
563#[cfg(test)]
564mod test {
565 use super::*;
566 use alloc::vec;
567
568 #[test]
569 #[cfg(feature = "serde")]
570 fn test_serialization_index_vec() {
571 let some_index_vec = IndexVec::from(vec![254_u32, 234, 2, 1]);
572 let de_some_index_vec: IndexVec =
573 bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
574 assert_eq!(some_index_vec, de_some_index_vec);
575 }
576
577 #[test]
578 fn test_sample_boundaries() {
579 let mut r = crate::test::rng(404);
580
581 assert_eq!(sample_inplace(&mut r, 0, 0).len(), 0);
582 assert_eq!(sample_inplace(&mut r, 1, 0).len(), 0);
583 assert_eq!(sample_inplace(&mut r, 1, 1).into_vec(), vec![0]);
584
585 assert_eq!(sample_rejection(&mut r, 1u32, 0).len(), 0);
586
587 assert_eq!(sample_floyd(&mut r, 0, 0).len(), 0);
588 assert_eq!(sample_floyd(&mut r, 1, 0).len(), 0);
589 assert_eq!(sample_floyd(&mut r, 1, 1).into_vec(), vec![0]);
590
591 let sum: usize = sample_rejection(&mut r, 1 << 25, 10u32).into_iter().sum();
593 assert!(1 << 25 < sum && sum < (1 << 25) * 25);
594
595 let sum: usize = sample_floyd(&mut r, 1 << 25, 10).into_iter().sum();
596 assert!(1 << 25 < sum && sum < (1 << 25) * 25);
597 }
598
599 #[test]
600 #[cfg_attr(miri, ignore)] fn test_sample_alg() {
602 let seed_rng = crate::test::rng;
603
604 let (length, amount): (usize, usize) = (100, 50);
610 let v1 = sample(&mut seed_rng(420), length, amount);
611 let v2 = sample_inplace(&mut seed_rng(420), length as u32, amount as u32);
612 assert!(v1.iter().all(|e| e < length));
613 assert_eq!(v1, v2);
614
615 let v3 = sample_floyd(&mut seed_rng(420), length as u32, amount as u32);
617 assert!(v1 != v3);
618
619 let (length, amount): (usize, usize) = (1 << 20, 50);
621 let v1 = sample(&mut seed_rng(421), length, amount);
622 let v2 = sample_floyd(&mut seed_rng(421), length as u32, amount as u32);
623 assert!(v1.iter().all(|e| e < length));
624 assert_eq!(v1, v2);
625
626 let (length, amount): (usize, usize) = (1 << 20, 600);
628 let v1 = sample(&mut seed_rng(422), length, amount);
629 let v2 = sample_rejection(&mut seed_rng(422), length as u32, amount as u32);
630 assert!(v1.iter().all(|e| e < length));
631 assert_eq!(v1, v2);
632 }
633
634 #[cfg(feature = "std")]
635 #[test]
636 fn test_sample_weighted() {
637 let seed_rng = crate::test::rng;
638 for &(amount, len) in &[(0, 10), (5, 10), (9, 10)] {
639 let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap();
640 match v {
641 IndexVec::U32(mut indices) => {
642 assert_eq!(indices.len(), amount);
643 indices.sort_unstable();
644 indices.dedup();
645 assert_eq!(indices.len(), amount);
646 for &i in &indices {
647 assert!((i as usize) < len);
648 }
649 }
650 #[cfg(target_pointer_width = "64")]
651 _ => panic!("expected `IndexVec::U32`"),
652 }
653 }
654
655 let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10);
656 assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero);
657 }
658
659 #[test]
660 fn value_stability_sample() {
661 let do_test = |length, amount, values: &[u32]| {
662 let mut buf = [0u32; 8];
663 let mut rng = crate::test::rng(410);
664
665 let res = sample(&mut rng, length, amount);
666 let len = res.len().min(buf.len());
667 for (x, y) in res.into_iter().zip(buf.iter_mut()) {
668 *y = x as u32;
669 }
670 assert_eq!(
671 &buf[0..len],
672 values,
673 "failed sampling {}, {}",
674 length,
675 amount
676 );
677 };
678
679 do_test(10, 6, &[0, 9, 5, 4, 6, 8]); do_test(25, 10, &[24, 20, 19, 9, 22, 16, 0, 14]); do_test(300, 8, &[30, 283, 243, 150, 218, 240, 1, 189]); do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); do_test(
686 1_000_000,
687 8,
688 &[103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573],
689 ); do_test(
691 1_000_000,
692 180,
693 &[103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573],
694 ); }
696}