rand/distr/weighted/
weighted_index.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use super::{Error, Weight};
10use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler};
11use crate::distr::Distribution;
12use crate::Rng;
13
14// Note that this whole module is only imported if feature="alloc" is enabled.
15use alloc::vec::Vec;
16use core::fmt::{self, Debug};
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// A distribution using weighted sampling of discrete items.
22///
23/// Sampling a `WeightedIndex` distribution returns the index of a randomly
24/// selected element from the iterator used when the `WeightedIndex` was
25/// created. The chance of a given element being picked is proportional to the
26/// weight of the element. The weights can use any type `X` for which an
27/// implementation of [`Uniform<X>`] exists. The implementation guarantees that
28/// elements with zero weight are never picked, even when the weights are
29/// floating point numbers.
30///
31/// # Performance
32///
33/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
34/// `N` is the number of weights.
35/// See also [`rand_distr::weighted`] for alternative implementations supporting
36/// potentially-faster sampling or a more easily modifiable tree structure.
37///
38/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
39/// size is the sum of the size of those objects, possibly plus some alignment.
40///
41/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
42/// weights of type `X`, where `N` is the number of weights. However, since
43/// `Vec` doesn't guarantee a particular growth strategy, additional memory
44/// might be allocated but not used. Since the `WeightedIndex` object also
45/// contains an instance of `X::Sampler`, this might cause additional allocations,
46/// though for primitive types, [`Uniform<X>`] doesn't allocate any memory.
47///
48/// Sampling from `WeightedIndex` will result in a single call to
49/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
50/// will request a single value from the underlying [`RngCore`], though the
51/// exact number depends on the implementation of `Uniform<X>::sample`.
52///
53/// # Example
54///
55/// ```
56/// use rand::prelude::*;
57/// use rand::distr::weighted::WeightedIndex;
58///
59/// let choices = ['a', 'b', 'c'];
60/// let weights = [2,   1,   1];
61/// let dist = WeightedIndex::new(&weights).unwrap();
62/// let mut rng = rand::rng();
63/// for _ in 0..100 {
64///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
65///     println!("{}", choices[dist.sample(&mut rng)]);
66/// }
67///
68/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)];
69/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
70/// for _ in 0..100 {
71///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
72///     println!("{}", items[dist2.sample(&mut rng)].0);
73/// }
74/// ```
75///
76/// [`Uniform<X>`]: crate::distr::Uniform
77/// [`RngCore`]: crate::RngCore
78/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html
79#[derive(Debug, Clone, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82    cumulative_weights: Vec<X>,
83    total_weight: X,
84    weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88    /// Creates a new a `WeightedIndex` [`Distribution`] using the values
89    /// in `weights`. The weights can use any type `X` for which an
90    /// implementation of [`Uniform<X>`] exists.
91    ///
92    /// Error cases:
93    /// -   [`Error::InvalidInput`] when the iterator `weights` is empty.
94    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
95    /// -   [`Error::InsufficientNonZero`] when the sum of all weights is zero.
96    /// -   [`Error::Overflow`] when the sum of all weights overflows.
97    ///
98    /// [`Uniform<X>`]: crate::distr::uniform::Uniform
99    pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
100    where
101        I: IntoIterator,
102        I::Item: SampleBorrow<X>,
103        X: Weight,
104    {
105        let mut iter = weights.into_iter();
106        let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone();
107
108        let zero = X::ZERO;
109        if !(total_weight >= zero) {
110            return Err(Error::InvalidWeight);
111        }
112
113        let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
114        for w in iter {
115            // Note that `!(w >= x)` is not equivalent to `w < x` for partially
116            // ordered types due to NaNs which are equal to nothing.
117            if !(w.borrow() >= &zero) {
118                return Err(Error::InvalidWeight);
119            }
120            weights.push(total_weight.clone());
121
122            if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
123                return Err(Error::Overflow);
124            }
125        }
126
127        if total_weight == zero {
128            return Err(Error::InsufficientNonZero);
129        }
130        let distr = X::Sampler::new(zero, total_weight.clone()).unwrap();
131
132        Ok(WeightedIndex {
133            cumulative_weights: weights,
134            total_weight,
135            weight_distribution: distr,
136        })
137    }
138
139    /// Update a subset of weights, without changing the number of weights.
140    ///
141    /// `new_weights` must be sorted by the index.
142    ///
143    /// Using this method instead of `new` might be more efficient if only a small number of
144    /// weights is modified. No allocations are performed, unless the weight type `X` uses
145    /// allocation internally.
146    ///
147    /// In case of error, `self` is not modified. Error cases:
148    /// -   [`Error::InvalidInput`] when `new_weights` are not ordered by
149    ///     index or an index is too large.
150    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
151    /// -   [`Error::InsufficientNonZero`] when the sum of all weights is zero.
152    ///     Note that due to floating-point loss of precision, this case is not
153    ///     always correctly detected; usage of a fixed-point weight type may be
154    ///     preferred.
155    ///
156    /// Updates take `O(N)` time. If you need to frequently update weights, consider
157    /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html)
158    /// as an alternative where an update is `O(log N)`.
159    pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error>
160    where
161        X: for<'a> core::ops::AddAssign<&'a X>
162            + for<'a> core::ops::SubAssign<&'a X>
163            + Clone
164            + Default,
165    {
166        if new_weights.is_empty() {
167            return Ok(());
168        }
169
170        let zero = <X as Default>::default();
171
172        let mut total_weight = self.total_weight.clone();
173
174        // Check for errors first, so we don't modify `self` in case something
175        // goes wrong.
176        let mut prev_i = None;
177        for &(i, w) in new_weights {
178            if let Some(old_i) = prev_i {
179                if old_i >= i {
180                    return Err(Error::InvalidInput);
181                }
182            }
183            if !(*w >= zero) {
184                return Err(Error::InvalidWeight);
185            }
186            if i > self.cumulative_weights.len() {
187                return Err(Error::InvalidInput);
188            }
189
190            let mut old_w = if i < self.cumulative_weights.len() {
191                self.cumulative_weights[i].clone()
192            } else {
193                self.total_weight.clone()
194            };
195            if i > 0 {
196                old_w -= &self.cumulative_weights[i - 1];
197            }
198
199            total_weight -= &old_w;
200            total_weight += w;
201            prev_i = Some(i);
202        }
203        if total_weight <= zero {
204            return Err(Error::InsufficientNonZero);
205        }
206
207        // Update the weights. Because we checked all the preconditions in the
208        // previous loop, this should never panic.
209        let mut iter = new_weights.iter();
210
211        let mut prev_weight = zero.clone();
212        let mut next_new_weight = iter.next();
213        let &(first_new_index, _) = next_new_weight.unwrap();
214        let mut cumulative_weight = if first_new_index > 0 {
215            self.cumulative_weights[first_new_index - 1].clone()
216        } else {
217            zero.clone()
218        };
219        for i in first_new_index..self.cumulative_weights.len() {
220            match next_new_weight {
221                Some(&(j, w)) if i == j => {
222                    cumulative_weight += w;
223                    next_new_weight = iter.next();
224                }
225                _ => {
226                    let mut tmp = self.cumulative_weights[i].clone();
227                    tmp -= &prev_weight; // We know this is positive.
228                    cumulative_weight += &tmp;
229                }
230            }
231            prev_weight = cumulative_weight.clone();
232            core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
233        }
234
235        self.total_weight = total_weight;
236        self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap();
237
238        Ok(())
239    }
240}
241
242/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
243/// This is returned by [`WeightedIndex::weights`].
244pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
245    weighted_index: &'a WeightedIndex<X>,
246    index: usize,
247}
248
249impl<X> Debug for WeightedIndexIter<'_, X>
250where
251    X: SampleUniform + PartialOrd + Debug,
252    X::Sampler: Debug,
253{
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        f.debug_struct("WeightedIndexIter")
256            .field("weighted_index", &self.weighted_index)
257            .field("index", &self.index)
258            .finish()
259    }
260}
261
262impl<X> Clone for WeightedIndexIter<'_, X>
263where
264    X: SampleUniform + PartialOrd,
265{
266    fn clone(&self) -> Self {
267        WeightedIndexIter {
268            weighted_index: self.weighted_index,
269            index: self.index,
270        }
271    }
272}
273
274impl<X> Iterator for WeightedIndexIter<'_, X>
275where
276    X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
277{
278    type Item = X;
279
280    fn next(&mut self) -> Option<Self::Item> {
281        match self.weighted_index.weight(self.index) {
282            None => None,
283            Some(weight) => {
284                self.index += 1;
285                Some(weight)
286            }
287        }
288    }
289}
290
291impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
292    /// Returns the weight at the given index, if it exists.
293    ///
294    /// If the index is out of bounds, this will return `None`.
295    ///
296    /// # Example
297    ///
298    /// ```
299    /// use rand::distr::weighted::WeightedIndex;
300    ///
301    /// let weights = [0, 1, 2];
302    /// let dist = WeightedIndex::new(&weights).unwrap();
303    /// assert_eq!(dist.weight(0), Some(0));
304    /// assert_eq!(dist.weight(1), Some(1));
305    /// assert_eq!(dist.weight(2), Some(2));
306    /// assert_eq!(dist.weight(3), None);
307    /// ```
308    pub fn weight(&self, index: usize) -> Option<X>
309    where
310        X: for<'a> core::ops::SubAssign<&'a X>,
311    {
312        use core::cmp::Ordering::*;
313
314        let mut weight = match index.cmp(&self.cumulative_weights.len()) {
315            Less => self.cumulative_weights[index].clone(),
316            Equal => self.total_weight.clone(),
317            Greater => return None,
318        };
319
320        if index > 0 {
321            weight -= &self.cumulative_weights[index - 1];
322        }
323        Some(weight)
324    }
325
326    /// Returns a lazy-loading iterator containing the current weights of this distribution.
327    ///
328    /// If this distribution has not been updated since its creation, this will return the
329    /// same weights as were passed to `new`.
330    ///
331    /// # Example
332    ///
333    /// ```
334    /// use rand::distr::weighted::WeightedIndex;
335    ///
336    /// let weights = [1, 2, 3];
337    /// let mut dist = WeightedIndex::new(&weights).unwrap();
338    /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
339    /// dist.update_weights(&[(0, &2)]).unwrap();
340    /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
341    /// ```
342    pub fn weights(&self) -> WeightedIndexIter<'_, X>
343    where
344        X: for<'a> core::ops::SubAssign<&'a X>,
345    {
346        WeightedIndexIter {
347            weighted_index: self,
348            index: 0,
349        }
350    }
351
352    /// Returns the sum of all weights in this distribution.
353    pub fn total_weight(&self) -> X {
354        self.total_weight.clone()
355    }
356}
357
358impl<X> Distribution<usize> for WeightedIndex<X>
359where
360    X: SampleUniform + PartialOrd,
361{
362    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
363        let chosen_weight = self.weight_distribution.sample(rng);
364        // Find the first item which has a weight *higher* than the chosen weight.
365        self.cumulative_weights
366            .partition_point(|w| w <= &chosen_weight)
367    }
368}
369
370#[cfg(test)]
371mod test {
372    use super::*;
373
374    #[cfg(feature = "serde")]
375    #[test]
376    fn test_weightedindex_serde() {
377        let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
378
379        let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
380        let de_weighted_index: WeightedIndex<i32> =
381            bincode::deserialize(&ser_weighted_index).unwrap();
382
383        assert_eq!(
384            de_weighted_index.cumulative_weights,
385            weighted_index.cumulative_weights
386        );
387        assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
388    }
389
390    #[test]
391    fn test_accepting_nan() {
392        assert_eq!(
393            WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
394            Error::InvalidWeight,
395        );
396        assert_eq!(
397            WeightedIndex::new([f32::NAN]).unwrap_err(),
398            Error::InvalidWeight,
399        );
400        assert_eq!(
401            WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
402            Error::InvalidWeight,
403        );
404
405        assert_eq!(
406            WeightedIndex::new([0.5, 7.0])
407                .unwrap()
408                .update_weights(&[(0, &f32::NAN)])
409                .unwrap_err(),
410            Error::InvalidWeight,
411        )
412    }
413
414    #[test]
415    #[cfg_attr(miri, ignore)] // Miri is too slow
416    fn test_weightedindex() {
417        let mut r = crate::test::rng(700);
418        const N_REPS: u32 = 5000;
419        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
420        let total_weight = weights.iter().sum::<u32>() as f32;
421
422        let verify = |result: [i32; 14]| {
423            for (i, count) in result.iter().enumerate() {
424                let exp = (weights[i] * N_REPS) as f32 / total_weight;
425                let mut err = (*count as f32 - exp).abs();
426                if err != 0.0 {
427                    err /= exp;
428                }
429                assert!(err <= 0.25);
430            }
431        };
432
433        // WeightedIndex from vec
434        let mut chosen = [0i32; 14];
435        let distr = WeightedIndex::new(weights.to_vec()).unwrap();
436        for _ in 0..N_REPS {
437            chosen[distr.sample(&mut r)] += 1;
438        }
439        verify(chosen);
440
441        // WeightedIndex from slice
442        chosen = [0i32; 14];
443        let distr = WeightedIndex::new(&weights[..]).unwrap();
444        for _ in 0..N_REPS {
445            chosen[distr.sample(&mut r)] += 1;
446        }
447        verify(chosen);
448
449        // WeightedIndex from iterator
450        chosen = [0i32; 14];
451        let distr = WeightedIndex::new(weights.iter()).unwrap();
452        for _ in 0..N_REPS {
453            chosen[distr.sample(&mut r)] += 1;
454        }
455        verify(chosen);
456
457        for _ in 0..5 {
458            assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
459            assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
460            assert_eq!(
461                WeightedIndex::new([0, 0, 0, 0, 10, 0])
462                    .unwrap()
463                    .sample(&mut r),
464                4
465            );
466        }
467
468        assert_eq!(
469            WeightedIndex::new(&[10][0..0]).unwrap_err(),
470            Error::InvalidInput
471        );
472        assert_eq!(
473            WeightedIndex::new([0]).unwrap_err(),
474            Error::InsufficientNonZero
475        );
476        assert_eq!(
477            WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
478            Error::InvalidWeight
479        );
480        assert_eq!(
481            WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
482            Error::InvalidWeight
483        );
484        assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight);
485    }
486
487    #[test]
488    fn test_update_weights() {
489        let data = [
490            (
491                &[10u32, 2, 3, 4][..],
492                &[(1, &100), (2, &4)][..], // positive change
493                &[10, 100, 4, 4][..],
494            ),
495            (
496                &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
497                &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
498                &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
499            ),
500        ];
501
502        for (weights, update, expected_weights) in data.iter() {
503            let total_weight = weights.iter().sum::<u32>();
504            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
505            assert_eq!(distr.total_weight, total_weight);
506
507            distr.update_weights(update).unwrap();
508            let expected_total_weight = expected_weights.iter().sum::<u32>();
509            let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
510            assert_eq!(distr.total_weight, expected_total_weight);
511            assert_eq!(distr.total_weight, expected_distr.total_weight);
512            assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
513        }
514    }
515
516    #[test]
517    fn test_update_weights_errors() {
518        let data = [
519            (
520                &[1i32, 0, 0][..],
521                &[(0, &0)][..],
522                Error::InsufficientNonZero,
523            ),
524            (
525                &[10, 10, 10, 10][..],
526                &[(1, &-11)][..],
527                Error::InvalidWeight, // A weight is negative
528            ),
529            (
530                &[1, 2, 3, 4, 5][..],
531                &[(1, &5), (0, &5)][..], // Wrong order
532                Error::InvalidInput,
533            ),
534            (
535                &[1][..],
536                &[(1, &1)][..], // Index too large
537                Error::InvalidInput,
538            ),
539        ];
540
541        for (weights, update, err) in data.iter() {
542            let total_weight = weights.iter().sum::<i32>();
543            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
544            assert_eq!(distr.total_weight, total_weight);
545            match distr.update_weights(update) {
546                Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
547                Err(e) => assert_eq!(e, *err),
548            }
549        }
550    }
551
552    #[test]
553    fn test_weight_at() {
554        let data = [
555            &[1][..],
556            &[10, 2, 3, 4][..],
557            &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
558            &[u32::MAX][..],
559        ];
560
561        for weights in data.iter() {
562            let distr = WeightedIndex::new(weights.to_vec()).unwrap();
563            for (i, weight) in weights.iter().enumerate() {
564                assert_eq!(distr.weight(i), Some(*weight));
565            }
566            assert_eq!(distr.weight(weights.len()), None);
567        }
568    }
569
570    #[test]
571    fn test_weights() {
572        let data = [
573            &[1][..],
574            &[10, 2, 3, 4][..],
575            &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
576            &[u32::MAX][..],
577        ];
578
579        for weights in data.iter() {
580            let distr = WeightedIndex::new(weights.to_vec()).unwrap();
581            assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
582        }
583    }
584
585    #[test]
586    fn value_stability() {
587        fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
588            weights: I,
589            buf: &mut [usize],
590            expected: &[usize],
591        ) where
592            I: IntoIterator,
593            I::Item: SampleBorrow<X>,
594        {
595            assert_eq!(buf.len(), expected.len());
596            let distr = WeightedIndex::new(weights).unwrap();
597            let mut rng = crate::test::rng(701);
598            for r in buf.iter_mut() {
599                *r = rng.sample(&distr);
600            }
601            assert_eq!(buf, expected);
602        }
603
604        let mut buf = [0; 10];
605        test_samples(
606            [1i32, 1, 1, 1, 1, 1, 1, 1, 1],
607            &mut buf,
608            &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
609        );
610        test_samples(
611            [0.7f32, 0.1, 0.1, 0.1],
612            &mut buf,
613            &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
614        );
615        test_samples(
616            [1.0f64, 0.999, 0.998, 0.997],
617            &mut buf,
618            &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
619        );
620    }
621
622    #[test]
623    fn weighted_index_distributions_can_be_compared() {
624        assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
625    }
626
627    #[test]
628    fn overflow() {
629        assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow));
630    }
631}