rand/distr/weighted/mod.rs
1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Weighted (index) sampling
10//!
11//! Primarily, this module houses the [`WeightedIndex`] distribution.
12//! See also [`rand_distr::weighted`] for alternative implementations supporting
13//! potentially-faster sampling or a more easily modifiable tree structure.
14//!
15//! [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html
16
17use core::fmt;
18mod weighted_index;
19
20pub use weighted_index::WeightedIndex;
21
22/// Bounds on a weight
23///
24/// See usage in [`WeightedIndex`].
25pub trait Weight: Clone {
26 /// Representation of 0
27 const ZERO: Self;
28
29 /// Checked addition
30 ///
31 /// - `Result::Ok`: On success, `v` is added to `self`
32 /// - `Result::Err`: Returns an error when `Self` cannot represent the
33 /// result of `self + v` (i.e. overflow). The value of `self` should be
34 /// discarded.
35 #[allow(clippy::result_unit_err)]
36 fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
37}
38
39macro_rules! impl_weight_int {
40 ($t:ty) => {
41 impl Weight for $t {
42 const ZERO: Self = 0;
43 fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
44 match self.checked_add(*v) {
45 Some(sum) => {
46 *self = sum;
47 Ok(())
48 }
49 None => Err(()),
50 }
51 }
52 }
53 };
54 ($t:ty, $($tt:ty),*) => {
55 impl_weight_int!($t);
56 impl_weight_int!($($tt),*);
57 }
58}
59impl_weight_int!(i8, i16, i32, i64, i128, isize);
60impl_weight_int!(u8, u16, u32, u64, u128, usize);
61
62macro_rules! impl_weight_float {
63 ($t:ty) => {
64 impl Weight for $t {
65 const ZERO: Self = 0.0;
66
67 fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
68 // Floats have an explicit representation for overflow
69 *self += *v;
70 Ok(())
71 }
72 }
73 };
74}
75impl_weight_float!(f32);
76impl_weight_float!(f64);
77
78/// Invalid weight errors
79///
80/// This type represents errors from [`WeightedIndex::new`],
81/// [`WeightedIndex::update_weights`] and other weighted distributions.
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83// Marked non_exhaustive to allow a new error code in the solution to #1476.
84#[non_exhaustive]
85pub enum Error {
86 /// The input weight sequence is empty, too long, or wrongly ordered
87 InvalidInput,
88
89 /// A weight is negative, too large for the distribution, or not a valid number
90 InvalidWeight,
91
92 /// Not enough non-zero weights are available to sample values
93 ///
94 /// When attempting to sample a single value this implies that all weights
95 /// are zero. When attempting to sample `amount` values this implies that
96 /// less than `amount` weights are greater than zero.
97 InsufficientNonZero,
98
99 /// Overflow when calculating the sum of weights
100 Overflow,
101}
102
103#[cfg(feature = "std")]
104impl std::error::Error for Error {}
105
106impl fmt::Display for Error {
107 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108 f.write_str(match *self {
109 Error::InvalidInput => "Weights sequence is empty/too long/unordered",
110 Error::InvalidWeight => "A weight is negative, too large or not a valid number",
111 Error::InsufficientNonZero => "Not enough weights > zero",
112 Error::Overflow => "Overflow when summing weights",
113 })
114 }
115}