rand/seq/coin_flipper.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
// Copyright 2018-2023 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use crate::RngCore;
pub(crate) struct CoinFlipper<R: RngCore> {
pub rng: R,
chunk: u32, // TODO(opt): this should depend on RNG word size
chunk_remaining: u32,
}
impl<R: RngCore> CoinFlipper<R> {
pub fn new(rng: R) -> Self {
Self {
rng,
chunk: 0,
chunk_remaining: 0,
}
}
#[inline]
/// Returns true with a probability of 1 / d
/// Uses an expected two bits of randomness
/// Panics if d == 0
pub fn random_ratio_one_over(&mut self, d: usize) -> bool {
debug_assert_ne!(d, 0);
// This uses the same logic as `random_ratio` but is optimized for the case that
// the starting numerator is one (which it always is for `Sequence::Choose()`)
// In this case (but not `random_ratio`), this way of calculating c is always accurate
let c = (usize::BITS - 1 - d.leading_zeros()).min(32);
if self.flip_c_heads(c) {
let numerator = 1 << c;
self.random_ratio(numerator, d)
} else {
false
}
}
#[inline]
/// Returns true with a probability of n / d
/// Uses an expected two bits of randomness
fn random_ratio(&mut self, mut n: usize, d: usize) -> bool {
// Explanation:
// We are trying to return true with a probability of n / d
// If n >= d, we can just return true
// Otherwise there are two possibilities 2n < d and 2n >= d
// In either case we flip a coin.
// If 2n < d
// If it comes up tails, return false
// If it comes up heads, double n and start again
// This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d
// (if 2n was greater than d we would effectively round it down to 1
// by returning true)
// If 2n >= d
// If it comes up tails, set n to 2n - d and start again
// If it comes up heads, return true
// This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
// Note that if 2n = d and the coin comes up tails, n will be set to 0
// before restarting which is equivalent to returning false.
// As a performance optimization we can flip multiple coins at once
// This is efficient because we can use the `lzcnt` intrinsic
// We can check up to 32 flips at once but we only receive one bit of information
// - all heads or at least one tail.
// Let c be the number of coins to flip. 1 <= c <= 32
// If 2n < d, n * 2^c < d
// If the result is all heads, then set n to n * 2^c
// If there was at least one tail, return false
// If 2n >= d, the order of results matters so we flip one coin at a time so c = 1
// Ideally, c will be as high as possible within these constraints
while n < d {
// Find a good value for c by counting leading zeros
// This will either give the highest possible c, or 1 less than that
let c = n
.leading_zeros()
.saturating_sub(d.leading_zeros() + 1)
.clamp(1, 32);
if self.flip_c_heads(c) {
// All heads
// Set n to n * 2^c
// If 2n >= d, the while loop will exit and we will return `true`
// If n * 2^c > `usize::MAX` we always return `true` anyway
n = n.saturating_mul(2_usize.pow(c));
} else {
// At least one tail
if c == 1 {
// Calculate 2n - d.
// We need to use wrapping as 2n might be greater than `usize::MAX`
let next_n = n.wrapping_add(n).wrapping_sub(d);
if next_n == 0 || next_n > n {
// This will happen if 2n < d
return false;
}
n = next_n;
} else {
// c > 1 so 2n < d so we can return false
return false;
}
}
}
true
}
/// If the next `c` bits of randomness all represent heads, consume them, return true
/// Otherwise return false and consume the number of heads plus one.
/// Generates new bits of randomness when necessary (in 32 bit chunks)
/// Has a 1 in 2 to the `c` chance of returning true
/// `c` must be less than or equal to 32
fn flip_c_heads(&mut self, mut c: u32) -> bool {
debug_assert!(c <= 32);
// Note that zeros on the left of the chunk represent heads.
// It needs to be this way round because zeros are filled in when left shifting
loop {
let zeros = self.chunk.leading_zeros();
if zeros < c {
// The happy path - we found a 1 and can return false
// Note that because a 1 bit was detected,
// We cannot have run out of random bits so we don't need to check
// First consume all of the bits read
// Using shl seems to give worse performance for size-hinted iterators
self.chunk = self.chunk.wrapping_shl(zeros + 1);
self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1);
return false;
} else {
// The number of zeros is larger than `c`
// There are two possibilities
if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) {
// Those zeroes were all part of our random chunk,
// throw away `c` bits of randomness and return true
self.chunk_remaining = new_remaining;
self.chunk <<= c;
return true;
} else {
// Some of those zeroes were part of the random chunk
// and some were part of the space behind it
// We need to take into account only the zeroes that were random
c -= self.chunk_remaining;
// Generate a new chunk
self.chunk = self.rng.next_u32();
self.chunk_remaining = 32;
// Go back to start of loop
}
}
}
}
}