Commit da591fe6 authored by Yedidya Feldblum's avatar Yedidya Feldblum Committed by Facebook Github Bot

Simplify the StateSize helper in Random

Summary:
[Folly] Simplify the `StateSize` helper in `Random`.

* Using member type aliases rather than class constants means we can remove definitions.
* Partially specializing over all RNG types with `state_size` class constants means we can remove the `mersenne_twister` specializations, which have many template parameters and are a pain.

Reviewed By: Orvid

Differential Revision: D5525144

fbshipit-source-id: bc27f112ed0d9b55befe9dabe08c4d345a402435
parent 2562ef37
......@@ -18,8 +18,6 @@
#error This file may only be included from folly/Random.h
#endif
#include <array>
namespace folly {
namespace detail {
......@@ -29,82 +27,35 @@ namespace detail {
// For some (mersenne_twister_engine), this is exported as a state_size static
// data member; for others, the standard shows formulas.
template <class RNG> struct StateSize {
template <class RNG, typename = void>
struct StateSize {
// A sane default.
static constexpr size_t value = 512;
using type = std::integral_constant<size_t, 512>;
};
template <class RNG>
constexpr size_t StateSize<RNG>::value;
struct StateSize<RNG, void_t<decltype(RNG::state_size)>> {
using type = std::integral_constant<size_t, RNG::state_size>;
};
template <class UIntType, UIntType a, UIntType c, UIntType m>
struct StateSize<std::linear_congruential_engine<UIntType, a, c, m>> {
// From the standard [rand.eng.lcong], this is ceil(log2(m) / 32) + 3,
// which is the same as ceil(ceil(log2(m) / 32) + 3, and
// ceil(log2(m)) <= std::numeric_limits<UIntType>::digits
static constexpr size_t value =
(std::numeric_limits<UIntType>::digits + 31) / 32 + 3;
};
template <class UIntType, UIntType a, UIntType c, UIntType m>
constexpr size_t
StateSize<std::linear_congruential_engine<UIntType, a, c, m>>::value;
template <class UIntType, size_t w, size_t n, size_t m, size_t r,
UIntType a, size_t u, UIntType d, size_t s,
UIntType b, size_t t,
UIntType c, size_t l, UIntType f>
struct StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
a, u, d, s, b, t, c, l, f>> {
static constexpr size_t value =
std::mersenne_twister_engine<UIntType, w, n, m, r,
a, u, d, s, b, t, c, l, f>::state_size;
};
template <class UIntType, size_t w, size_t n, size_t m, size_t r,
UIntType a, size_t u, UIntType d, size_t s,
UIntType b, size_t t,
UIntType c, size_t l, UIntType f>
constexpr size_t
StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
a, u, d, s, b, t, c, l, f>>::value;
#if FOLLY_HAVE_EXTRANDOM_SFMT19937
template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
uint32_t parity4>
struct StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
parity1, parity2, parity3, parity4>> {
static constexpr size_t value =
__gnu_cxx::simd_fast_mersenne_twister_engine<
UIntType, m, pos1, sl1, sl2, sr1, sr2,
msk1, msk2, msk3, msk4,
parity1, parity2, parity3, parity4>::state_size;
using type = std::integral_constant<
size_t,
(std::numeric_limits<UIntType>::digits + 31) / 32 + 3>;
};
template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
uint32_t parity4>
constexpr size_t
StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
parity1, parity2, parity3, parity4>>::value;
#endif
template <class UIntType, size_t w, size_t s, size_t r>
struct StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>> {
// [rand.eng.sub]: r * ceil(w / 32)
static constexpr size_t value = r * ((w + 31) / 32);
using type = std::integral_constant<size_t, r*((w + 31) / 32)>;
};
template <class UIntType, size_t w, size_t s, size_t r>
constexpr size_t
StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>>::value;
template <typename RNG>
using StateSizeT = _t<StateSize<RNG>>;
template <class RNG>
struct SeedData {
......@@ -112,7 +63,7 @@ struct SeedData {
Random::secureRandom(seedData.data(), seedData.size() * sizeof(uint32_t));
}
static constexpr size_t stateSize = StateSize<RNG>::value;
static constexpr size_t stateSize = StateSizeT<RNG>::value;
std::array<uint32_t, stateSize> seedData;
};
......
......@@ -17,11 +17,13 @@
#pragma once
#define FOLLY_RANDOM_H_
#include <array>
#include <cstdint>
#include <random>
#include <type_traits>
#include <folly/Portability.h>
#include <folly/Traits.h>
#if FOLLY_HAVE_EXTRANDOM_SFMT19937
#include <ext/random>
......
......@@ -32,13 +32,13 @@ TEST(Random, StateSize) {
using namespace folly::detail;
// uint_fast32_t is uint64_t on x86_64, w00t
EXPECT_EQ(sizeof(uint_fast32_t) / 4 + 3,
StateSize<std::minstd_rand0>::value);
EXPECT_EQ(624, StateSize<std::mt19937>::value);
EXPECT_EQ(
sizeof(uint_fast32_t) / 4 + 3, StateSizeT<std::minstd_rand0>::value);
EXPECT_EQ(624, StateSizeT<std::mt19937>::value);
#if FOLLY_HAVE_EXTRANDOM_SFMT19937
EXPECT_EQ(624, StateSize<__gnu_cxx::sfmt19937>::value);
EXPECT_EQ(624, StateSizeT<__gnu_cxx::sfmt19937>::value);
#endif
EXPECT_EQ(24, StateSize<std::ranlux24_base>::value);
EXPECT_EQ(24, StateSizeT<std::ranlux24_base>::value);
}
TEST(Random, Simple) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment