Commit f2987ecd authored by Christopher Small's avatar Christopher Small Committed by Facebook Github Bot

pass RNG by reference so state is updated on each call

Summary: folly::Random was taking the RNG by value (not reference) so it was not updating the RNG's state on each invocation -- so the RNG would not advance to the next value in the sequence.

Reviewed By: yfeldblum, nbronson

Differential Revision: D4362999

fbshipit-source-id: f93fc11911b92e230ac0cc2406151474d15f85af
parent d9817812
......@@ -133,25 +133,22 @@ class Random {
* Returns a random uint32_t
*/
static uint32_t rand32() {
ThreadLocalPRNG prng;
return rand32(prng);
return rand32(ThreadLocalPRNG());
}
/**
* Returns a random uint32_t given a specific RNG
*/
template <class RNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint32_t rand32(RNG rng) {
uint32_t r = rng.operator()();
return r;
static uint32_t rand32(RNG&& rng) {
return rng();
}
/**
* Returns a random uint32_t in [0, max). If max == 0, returns 0.
*/
static uint32_t rand32(uint32_t max) {
ThreadLocalPRNG prng;
return rand32(max, prng);
return rand32(0, max, ThreadLocalPRNG());
}
/**
......@@ -159,84 +156,123 @@ class Random {
* If max == 0, returns 0.
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint32_t rand32(uint32_t max, RNG rng = RNG()) {
if (max == 0) {
return 0;
}
return std::uniform_int_distribution<uint32_t>(0, max - 1)(rng);
static uint32_t rand32(uint32_t max, RNG&& rng) {
return rand32(0, max, rng);
}
/**
* Returns a random uint32_t in [min, max). If min == max, returns 0.
*/
static uint32_t rand32(uint32_t min, uint32_t max) {
return rand32(min, max, ThreadLocalPRNG());
}
/**
* Returns a random uint32_t in [min, max) given a specific RNG.
* If min == max, returns 0.
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint32_t rand32(uint32_t min, uint32_t max, RNG rng = RNG()) {
static uint32_t rand32(uint32_t min, uint32_t max, RNG&& rng) {
if (min == max) {
return 0;
}
return std::uniform_int_distribution<uint32_t>(min, max - 1)(rng);
}
/**
* Returns a random uint64_t
*/
static uint64_t rand64() {
return rand64(ThreadLocalPRNG());
}
/**
* Returns a random uint64_t
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint64_t rand64(RNG rng = RNG()) {
return ((uint64_t) rng() << 32) | rng();
static uint64_t rand64(RNG&& rng) {
return ((uint64_t)rng() << 32) | rng();
}
/**
* Returns a random uint64_t in [0, max). If max == 0, returns 0.
*/
static uint64_t rand64(uint64_t max) {
return rand64(0, max, ThreadLocalPRNG());
}
/**
* Returns a random uint64_t in [0, max). If max == 0, returns 0.
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint64_t rand64(uint64_t max, RNG rng = RNG()) {
if (max == 0) {
return 0;
}
static uint64_t rand64(uint64_t max, RNG&& rng) {
return rand64(0, max, rng);
}
return std::uniform_int_distribution<uint64_t>(0, max - 1)(rng);
/**
* Returns a random uint64_t in [min, max). If min == max, returns 0.
*/
static uint64_t rand64(uint64_t min, uint64_t max) {
return rand64(min, max, ThreadLocalPRNG());
}
/**
* Returns a random uint64_t in [min, max). If min == max, returns 0.
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint64_t rand64(uint64_t min, uint64_t max, RNG rng = RNG()) {
static uint64_t rand64(uint64_t min, uint64_t max, RNG&& rng) {
if (min == max) {
return 0;
}
return std::uniform_int_distribution<uint64_t>(min, max - 1)(rng);
}
/**
* Returns true 1/n of the time. If n == 0, always returns false
*/
static bool oneIn(uint32_t n) {
return oneIn(n, ThreadLocalPRNG());
}
/**
* Returns true 1/n of the time. If n == 0, always returns false
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static bool oneIn(uint32_t n, ValidRNG<RNG> rng = RNG()) {
static bool oneIn(uint32_t n, RNG&& rng) {
if (n == 0) {
return false;
}
return rand32(0, n, rng) == 0;
}
return rand32(n, rng) == 0;
/**
* Returns a double in [0, 1)
*/
static double randDouble01() {
return randDouble01(ThreadLocalPRNG());
}
/**
* Returns a double in [0, 1)
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static double randDouble01(RNG rng = RNG()) {
return std::generate_canonical<double, std::numeric_limits<double>::digits>
(rng);
static double randDouble01(RNG&& rng) {
return std::generate_canonical<double, std::numeric_limits<double>::digits>(
rng);
}
/**
* Returns a double in [min, max), if min == max, returns 0.
*/
static double randDouble(double min, double max) {
return randDouble(min, max, ThreadLocalPRNG());
}
/**
* Returns a double in [min, max), if min == max, returns 0.
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static double randDouble(double min, double max, RNG rng = RNG()) {
static double randDouble(double min, double max, RNG&& rng) {
if (std::fabs(max - min) < std::numeric_limits<double>::epsilon()) {
return 0;
}
......
......@@ -22,6 +22,7 @@
#include <thread>
#include <vector>
#include <random>
#include <unordered_set>
#include <folly/portability/GTest.h>
......@@ -94,3 +95,46 @@ TEST(Random, MultiThreaded) {
EXPECT_LT(seeds[i], seeds[i+1]);
}
}
TEST(Random, sanity) {
// edge cases
EXPECT_EQ(folly::Random::rand32(0), 0);
EXPECT_EQ(folly::Random::rand32(12, 12), 0);
EXPECT_EQ(folly::Random::rand64(0), 0);
EXPECT_EQ(folly::Random::rand64(12, 12), 0);
// 32-bit repeatability, uniqueness
constexpr int kTestSize = 1000;
{
std::vector<uint32_t> vals;
folly::Random::DefaultGenerator rng;
rng.seed(0xdeadbeef);
for (int i = 0; i < kTestSize; ++i) {
vals.push_back(folly::Random::rand32(rng));
}
rng.seed(0xdeadbeef);
for (int i = 0; i < kTestSize; ++i) {
EXPECT_EQ(vals[i], folly::Random::rand32(rng));
}
EXPECT_EQ(
vals.size(),
std::unordered_set<uint32_t>(vals.begin(), vals.end()).size());
}
// 64-bit repeatability, uniqueness
{
std::vector<uint64_t> vals;
folly::Random::DefaultGenerator rng;
rng.seed(0xdeadbeef);
for (int i = 0; i < kTestSize; ++i) {
vals.push_back(folly::Random::rand64(rng));
}
rng.seed(0xdeadbeef);
for (int i = 0; i < kTestSize; ++i) {
EXPECT_EQ(vals[i], folly::Random::rand64(rng));
}
EXPECT_EQ(
vals.size(),
std::unordered_set<uint32_t>(vals.begin(), vals.end()).size());
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment