Commit 015f5dc5 authored by David Goldblatt's avatar David Goldblatt Committed by Facebook Github Bot 9

Correctly deduce RNG type in folly::Random

Summary:Fix a bug in folly where it can't infer any RNG type except the
default-supplied parameter.

Differential Revision: D3163421

fb-gh-sync-id: 23d3963ba19dac93fa3407f3a2dfd1d9aa39ea44
fbshipit-source-id: 23d3963ba19dac93fa3407f3a2dfd1d9aa39ea44
parent caba7148
......@@ -118,15 +118,15 @@ struct SeedData {
} // namespace detail
template <class RNG>
void Random::seed(ValidRNG<RNG>& rng) {
template <class RNG, class /* EnableIf */>
void Random::seed(RNG& rng) {
detail::SeedData<RNG> sd;
std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData));
rng.seed(s);
}
template <class RNG>
auto Random::create() -> ValidRNG<RNG> {
template <class RNG, class /* EnableIf */>
auto Random::create() -> RNG {
detail::SeedData<RNG> sd;
std::seed_seq s(std::begin(sd.seedData), std::end(sd.seedData));
return RNG(s);
......
......@@ -76,10 +76,10 @@ class ThreadLocalPRNG {
class Random {
private:
template<class RNG>
template <class RNG>
using ValidRNG = typename std::enable_if<
std::is_unsigned<typename std::result_of<RNG&()>::type>::value,
RNG>::type;
std::is_unsigned<typename std::result_of<RNG&()>::type>::value,
RNG>::type;
public:
// Default generator type.
......@@ -115,8 +115,8 @@ class Random {
* to create a RNG with a good seed in production, and seed it yourself
* in test.
*/
template <class RNG = DefaultGenerator>
static void seed(ValidRNG<RNG>& rng);
template <class RNG = DefaultGenerator, class /* EnableIf */ = ValidRNG<RNG>>
static void seed(RNG& rng);
/**
* Create a new RNG, seeded with a good seed.
......@@ -126,14 +126,22 @@ class Random {
* to create a RNG with a good seed in production, and seed it yourself
* in test.
*/
template <class RNG = DefaultGenerator>
static ValidRNG<RNG> create();
template <class RNG = DefaultGenerator, class /* EnableIf */ = ValidRNG<RNG>>
static RNG create();
/**
* Returns a random uint32_t
*/
template<class RNG = ThreadLocalPRNG>
static uint32_t rand32(ValidRNG<RNG> rng = RNG()) {
static uint32_t rand32() {
ThreadLocalPRNG prng;
return rand32(prng);
}
/**
* Returns a random uint32_t given a specific RNG
*/
template <class RNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint32_t rand32(RNG rng) {
uint32_t r = rng.operator()();
return r;
}
......@@ -141,8 +149,17 @@ class Random {
/**
* Returns a random uint32_t in [0, max). If max == 0, returns 0.
*/
template<class RNG = ThreadLocalPRNG>
static uint32_t rand32(uint32_t max, ValidRNG<RNG> rng = RNG()) {
static uint32_t rand32(uint32_t max) {
ThreadLocalPRNG prng;
return rand32(max, prng);
}
/**
* Returns a random uint32_t in [0, max) given a specific RNG.
* If max == 0, returns 0.
*/
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint32_t rand32(uint32_t max, RNG rng = RNG()) {
if (max == 0) {
return 0;
}
......@@ -153,10 +170,8 @@ class Random {
/**
* Returns a random uint32_t in [min, max). If min == max, returns 0.
*/
template<class RNG = ThreadLocalPRNG>
static uint32_t rand32(uint32_t min,
uint32_t max,
ValidRNG<RNG> rng = RNG()) {
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint32_t rand32(uint32_t min, uint32_t max, RNG rng = RNG()) {
if (min == max) {
return 0;
}
......@@ -167,16 +182,16 @@ class Random {
/**
* Returns a random uint64_t
*/
template<class RNG = ThreadLocalPRNG>
static uint64_t rand64(ValidRNG<RNG> rng = RNG()) {
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint64_t rand64(RNG rng = RNG()) {
return ((uint64_t) rng() << 32) | rng();
}
/**
* Returns a random uint64_t in [0, max). If max == 0, returns 0.
*/
template<class RNG = ThreadLocalPRNG>
static uint64_t rand64(uint64_t max, ValidRNG<RNG> rng = RNG()) {
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint64_t rand64(uint64_t max, RNG rng = RNG()) {
if (max == 0) {
return 0;
}
......@@ -187,10 +202,8 @@ class Random {
/**
* Returns a random uint64_t in [min, max). If min == max, returns 0.
*/
template<class RNG = ThreadLocalPRNG>
static uint64_t rand64(uint64_t min,
uint64_t max,
ValidRNG<RNG> rng = RNG()) {
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static uint64_t rand64(uint64_t min, uint64_t max, RNG rng = RNG()) {
if (min == max) {
return 0;
}
......@@ -201,7 +214,7 @@ class Random {
/**
* Returns true 1/n of the time. If n == 0, always returns false
*/
template<class RNG = ThreadLocalPRNG>
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static bool oneIn(uint32_t n, ValidRNG<RNG> rng = RNG()) {
if (n == 0) {
return false;
......@@ -213,8 +226,8 @@ class Random {
/**
* Returns a double in [0, 1)
*/
template<class RNG = ThreadLocalPRNG>
static double randDouble01(ValidRNG<RNG> rng = RNG()) {
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static double randDouble01(RNG rng = RNG()) {
return std::generate_canonical<double, std::numeric_limits<double>::digits>
(rng);
}
......@@ -222,8 +235,8 @@ class Random {
/**
* Returns a double in [min, max), if min == max, returns 0.
*/
template<class RNG = ThreadLocalPRNG>
static double randDouble(double min, double max, ValidRNG<RNG> rng = RNG()) {
template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
static double randDouble(double min, double max, RNG rng = RNG()) {
if (std::fabs(max - min) < std::numeric_limits<double>::epsilon()) {
return 0;
}
......
......@@ -47,6 +47,33 @@ TEST(Random, Simple) {
}
}
TEST(Random, FixedSeed) {
// clang-format off
struct ConstantRNG {
typedef uint32_t result_type;
result_type operator()() {
return 4; // chosen by fair dice roll.
// guaranteed to be random.
}
result_type min() {
return std::numeric_limits<result_type>::min();
}
result_type max() {
return std::numeric_limits<result_type>::max();
}
};
// clang-format on
ConstantRNG gen;
// Loop to make sure it really is constant.
for (int i = 0; i < 1024; ++i) {
auto result = Random::rand32(10, gen);
// TODO: This is a little bit brittle; standard library changes could break
// it, if it starts implementing distribution types differently.
EXPECT_EQ(0, result);
}
}
TEST(Random, MultiThreaded) {
const int n = 100;
std::vector<uint32_t> seeds(n);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment