1use super::{Error, Weight};
10use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler};
11use crate::distr::Distribution;
12use crate::Rng;
13
14use alloc::vec::Vec;
16use core::fmt::{self, Debug};
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Clone, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82 cumulative_weights: Vec<X>,
83 total_weight: X,
84 weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
100 where
101 I: IntoIterator,
102 I::Item: SampleBorrow<X>,
103 X: Weight,
104 {
105 let mut iter = weights.into_iter();
106 let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone();
107
108 let zero = X::ZERO;
109 if !(total_weight >= zero) {
110 return Err(Error::InvalidWeight);
111 }
112
113 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
114 for w in iter {
115 if !(w.borrow() >= &zero) {
118 return Err(Error::InvalidWeight);
119 }
120 weights.push(total_weight.clone());
121
122 if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
123 return Err(Error::Overflow);
124 }
125 }
126
127 if total_weight == zero {
128 return Err(Error::InsufficientNonZero);
129 }
130 let distr = X::Sampler::new(zero, total_weight.clone()).unwrap();
131
132 Ok(WeightedIndex {
133 cumulative_weights: weights,
134 total_weight,
135 weight_distribution: distr,
136 })
137 }
138
139 pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error>
160 where
161 X: for<'a> core::ops::AddAssign<&'a X>
162 + for<'a> core::ops::SubAssign<&'a X>
163 + Clone
164 + Default,
165 {
166 if new_weights.is_empty() {
167 return Ok(());
168 }
169
170 let zero = <X as Default>::default();
171
172 let mut total_weight = self.total_weight.clone();
173
174 let mut prev_i = None;
177 for &(i, w) in new_weights {
178 if let Some(old_i) = prev_i {
179 if old_i >= i {
180 return Err(Error::InvalidInput);
181 }
182 }
183 if !(*w >= zero) {
184 return Err(Error::InvalidWeight);
185 }
186 if i > self.cumulative_weights.len() {
187 return Err(Error::InvalidInput);
188 }
189
190 let mut old_w = if i < self.cumulative_weights.len() {
191 self.cumulative_weights[i].clone()
192 } else {
193 self.total_weight.clone()
194 };
195 if i > 0 {
196 old_w -= &self.cumulative_weights[i - 1];
197 }
198
199 total_weight -= &old_w;
200 total_weight += w;
201 prev_i = Some(i);
202 }
203 if total_weight <= zero {
204 return Err(Error::InsufficientNonZero);
205 }
206
207 let mut iter = new_weights.iter();
210
211 let mut prev_weight = zero.clone();
212 let mut next_new_weight = iter.next();
213 let &(first_new_index, _) = next_new_weight.unwrap();
214 let mut cumulative_weight = if first_new_index > 0 {
215 self.cumulative_weights[first_new_index - 1].clone()
216 } else {
217 zero.clone()
218 };
219 for i in first_new_index..self.cumulative_weights.len() {
220 match next_new_weight {
221 Some(&(j, w)) if i == j => {
222 cumulative_weight += w;
223 next_new_weight = iter.next();
224 }
225 _ => {
226 let mut tmp = self.cumulative_weights[i].clone();
227 tmp -= &prev_weight; cumulative_weight += &tmp;
229 }
230 }
231 prev_weight = cumulative_weight.clone();
232 core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
233 }
234
235 self.total_weight = total_weight;
236 self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap();
237
238 Ok(())
239 }
240}
241
242pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
245 weighted_index: &'a WeightedIndex<X>,
246 index: usize,
247}
248
249impl<X> Debug for WeightedIndexIter<'_, X>
250where
251 X: SampleUniform + PartialOrd + Debug,
252 X::Sampler: Debug,
253{
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("WeightedIndexIter")
256 .field("weighted_index", &self.weighted_index)
257 .field("index", &self.index)
258 .finish()
259 }
260}
261
262impl<X> Clone for WeightedIndexIter<'_, X>
263where
264 X: SampleUniform + PartialOrd,
265{
266 fn clone(&self) -> Self {
267 WeightedIndexIter {
268 weighted_index: self.weighted_index,
269 index: self.index,
270 }
271 }
272}
273
274impl<X> Iterator for WeightedIndexIter<'_, X>
275where
276 X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
277{
278 type Item = X;
279
280 fn next(&mut self) -> Option<Self::Item> {
281 match self.weighted_index.weight(self.index) {
282 None => None,
283 Some(weight) => {
284 self.index += 1;
285 Some(weight)
286 }
287 }
288 }
289}
290
291impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
292 pub fn weight(&self, index: usize) -> Option<X>
309 where
310 X: for<'a> core::ops::SubAssign<&'a X>,
311 {
312 use core::cmp::Ordering::*;
313
314 let mut weight = match index.cmp(&self.cumulative_weights.len()) {
315 Less => self.cumulative_weights[index].clone(),
316 Equal => self.total_weight.clone(),
317 Greater => return None,
318 };
319
320 if index > 0 {
321 weight -= &self.cumulative_weights[index - 1];
322 }
323 Some(weight)
324 }
325
326 pub fn weights(&self) -> WeightedIndexIter<'_, X>
343 where
344 X: for<'a> core::ops::SubAssign<&'a X>,
345 {
346 WeightedIndexIter {
347 weighted_index: self,
348 index: 0,
349 }
350 }
351
352 pub fn total_weight(&self) -> X {
354 self.total_weight.clone()
355 }
356}
357
358impl<X> Distribution<usize> for WeightedIndex<X>
359where
360 X: SampleUniform + PartialOrd,
361{
362 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
363 let chosen_weight = self.weight_distribution.sample(rng);
364 self.cumulative_weights
366 .partition_point(|w| w <= &chosen_weight)
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use super::*;
373
374 #[cfg(feature = "serde")]
375 #[test]
376 fn test_weightedindex_serde() {
377 let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
378
379 let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
380 let de_weighted_index: WeightedIndex<i32> =
381 bincode::deserialize(&ser_weighted_index).unwrap();
382
383 assert_eq!(
384 de_weighted_index.cumulative_weights,
385 weighted_index.cumulative_weights
386 );
387 assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
388 }
389
390 #[test]
391 fn test_accepting_nan() {
392 assert_eq!(
393 WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
394 Error::InvalidWeight,
395 );
396 assert_eq!(
397 WeightedIndex::new([f32::NAN]).unwrap_err(),
398 Error::InvalidWeight,
399 );
400 assert_eq!(
401 WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
402 Error::InvalidWeight,
403 );
404
405 assert_eq!(
406 WeightedIndex::new([0.5, 7.0])
407 .unwrap()
408 .update_weights(&[(0, &f32::NAN)])
409 .unwrap_err(),
410 Error::InvalidWeight,
411 )
412 }
413
414 #[test]
415 #[cfg_attr(miri, ignore)] fn test_weightedindex() {
417 let mut r = crate::test::rng(700);
418 const N_REPS: u32 = 5000;
419 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
420 let total_weight = weights.iter().sum::<u32>() as f32;
421
422 let verify = |result: [i32; 14]| {
423 for (i, count) in result.iter().enumerate() {
424 let exp = (weights[i] * N_REPS) as f32 / total_weight;
425 let mut err = (*count as f32 - exp).abs();
426 if err != 0.0 {
427 err /= exp;
428 }
429 assert!(err <= 0.25);
430 }
431 };
432
433 let mut chosen = [0i32; 14];
435 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
436 for _ in 0..N_REPS {
437 chosen[distr.sample(&mut r)] += 1;
438 }
439 verify(chosen);
440
441 chosen = [0i32; 14];
443 let distr = WeightedIndex::new(&weights[..]).unwrap();
444 for _ in 0..N_REPS {
445 chosen[distr.sample(&mut r)] += 1;
446 }
447 verify(chosen);
448
449 chosen = [0i32; 14];
451 let distr = WeightedIndex::new(weights.iter()).unwrap();
452 for _ in 0..N_REPS {
453 chosen[distr.sample(&mut r)] += 1;
454 }
455 verify(chosen);
456
457 for _ in 0..5 {
458 assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
459 assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
460 assert_eq!(
461 WeightedIndex::new([0, 0, 0, 0, 10, 0])
462 .unwrap()
463 .sample(&mut r),
464 4
465 );
466 }
467
468 assert_eq!(
469 WeightedIndex::new(&[10][0..0]).unwrap_err(),
470 Error::InvalidInput
471 );
472 assert_eq!(
473 WeightedIndex::new([0]).unwrap_err(),
474 Error::InsufficientNonZero
475 );
476 assert_eq!(
477 WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
478 Error::InvalidWeight
479 );
480 assert_eq!(
481 WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
482 Error::InvalidWeight
483 );
484 assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight);
485 }
486
487 #[test]
488 fn test_update_weights() {
489 let data = [
490 (
491 &[10u32, 2, 3, 4][..],
492 &[(1, &100), (2, &4)][..], &[10, 100, 4, 4][..],
494 ),
495 (
496 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
497 &[(2, &1), (5, &1), (13, &100)][..], &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
499 ),
500 ];
501
502 for (weights, update, expected_weights) in data.iter() {
503 let total_weight = weights.iter().sum::<u32>();
504 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
505 assert_eq!(distr.total_weight, total_weight);
506
507 distr.update_weights(update).unwrap();
508 let expected_total_weight = expected_weights.iter().sum::<u32>();
509 let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
510 assert_eq!(distr.total_weight, expected_total_weight);
511 assert_eq!(distr.total_weight, expected_distr.total_weight);
512 assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
513 }
514 }
515
516 #[test]
517 fn test_update_weights_errors() {
518 let data = [
519 (
520 &[1i32, 0, 0][..],
521 &[(0, &0)][..],
522 Error::InsufficientNonZero,
523 ),
524 (
525 &[10, 10, 10, 10][..],
526 &[(1, &-11)][..],
527 Error::InvalidWeight, ),
529 (
530 &[1, 2, 3, 4, 5][..],
531 &[(1, &5), (0, &5)][..], Error::InvalidInput,
533 ),
534 (
535 &[1][..],
536 &[(1, &1)][..], Error::InvalidInput,
538 ),
539 ];
540
541 for (weights, update, err) in data.iter() {
542 let total_weight = weights.iter().sum::<i32>();
543 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
544 assert_eq!(distr.total_weight, total_weight);
545 match distr.update_weights(update) {
546 Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
547 Err(e) => assert_eq!(e, *err),
548 }
549 }
550 }
551
552 #[test]
553 fn test_weight_at() {
554 let data = [
555 &[1][..],
556 &[10, 2, 3, 4][..],
557 &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
558 &[u32::MAX][..],
559 ];
560
561 for weights in data.iter() {
562 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
563 for (i, weight) in weights.iter().enumerate() {
564 assert_eq!(distr.weight(i), Some(*weight));
565 }
566 assert_eq!(distr.weight(weights.len()), None);
567 }
568 }
569
570 #[test]
571 fn test_weights() {
572 let data = [
573 &[1][..],
574 &[10, 2, 3, 4][..],
575 &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
576 &[u32::MAX][..],
577 ];
578
579 for weights in data.iter() {
580 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
581 assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
582 }
583 }
584
585 #[test]
586 fn value_stability() {
587 fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
588 weights: I,
589 buf: &mut [usize],
590 expected: &[usize],
591 ) where
592 I: IntoIterator,
593 I::Item: SampleBorrow<X>,
594 {
595 assert_eq!(buf.len(), expected.len());
596 let distr = WeightedIndex::new(weights).unwrap();
597 let mut rng = crate::test::rng(701);
598 for r in buf.iter_mut() {
599 *r = rng.sample(&distr);
600 }
601 assert_eq!(buf, expected);
602 }
603
604 let mut buf = [0; 10];
605 test_samples(
606 [1i32, 1, 1, 1, 1, 1, 1, 1, 1],
607 &mut buf,
608 &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
609 );
610 test_samples(
611 [0.7f32, 0.1, 0.1, 0.1],
612 &mut buf,
613 &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
614 );
615 test_samples(
616 [1.0f64, 0.999, 0.998, 0.997],
617 &mut buf,
618 &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
619 );
620 }
621
622 #[test]
623 fn weighted_index_distributions_can_be_compared() {
624 assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
625 }
626
627 #[test]
628 fn overflow() {
629 assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow));
630 }
631}