Commit da591fe6 authored by Yedidya Feldblum's avatar Yedidya Feldblum Committed by Facebook Github Bot

Simplify the StateSize helper in Random

Summary:
[Folly] Simplify the `StateSize` helper in `Random`.

* Using member type aliases rather than class constants means we can remove definitions.
* Partially specializing over all RNG types with `state_size` class constants means we can remove the `mersenne_twister` specializations, which have many template parameters and are a pain.

Reviewed By: Orvid

Differential Revision: D5525144

fbshipit-source-id: bc27f112ed0d9b55befe9dabe08c4d345a402435
parent 2562ef37
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
#error This file may only be included from folly/Random.h #error This file may only be included from folly/Random.h
#endif #endif
#include <array>
namespace folly { namespace folly {
namespace detail { namespace detail {
...@@ -29,82 +27,35 @@ namespace detail { ...@@ -29,82 +27,35 @@ namespace detail {
// For some (mersenne_twister_engine), this is exported as a state_size static // For some (mersenne_twister_engine), this is exported as a state_size static
// data member; for others, the standard shows formulas. // data member; for others, the standard shows formulas.
template <class RNG> struct StateSize { template <class RNG, typename = void>
struct StateSize {
// A sane default. // A sane default.
static constexpr size_t value = 512; using type = std::integral_constant<size_t, 512>;
}; };
template <class RNG> template <class RNG>
constexpr size_t StateSize<RNG>::value; struct StateSize<RNG, void_t<decltype(RNG::state_size)>> {
using type = std::integral_constant<size_t, RNG::state_size>;
};
template <class UIntType, UIntType a, UIntType c, UIntType m> template <class UIntType, UIntType a, UIntType c, UIntType m>
struct StateSize<std::linear_congruential_engine<UIntType, a, c, m>> { struct StateSize<std::linear_congruential_engine<UIntType, a, c, m>> {
// From the standard [rand.eng.lcong], this is ceil(log2(m) / 32) + 3, // From the standard [rand.eng.lcong], this is ceil(log2(m) / 32) + 3,
// which is the same as ceil(ceil(log2(m) / 32) + 3, and // which is the same as ceil(ceil(log2(m) / 32) + 3, and
// ceil(log2(m)) <= std::numeric_limits<UIntType>::digits // ceil(log2(m)) <= std::numeric_limits<UIntType>::digits
static constexpr size_t value = using type = std::integral_constant<
(std::numeric_limits<UIntType>::digits + 31) / 32 + 3; size_t,
}; (std::numeric_limits<UIntType>::digits + 31) / 32 + 3>;
template <class UIntType, UIntType a, UIntType c, UIntType m>
constexpr size_t
StateSize<std::linear_congruential_engine<UIntType, a, c, m>>::value;
template <class UIntType, size_t w, size_t n, size_t m, size_t r,
UIntType a, size_t u, UIntType d, size_t s,
UIntType b, size_t t,
UIntType c, size_t l, UIntType f>
struct StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
a, u, d, s, b, t, c, l, f>> {
static constexpr size_t value =
std::mersenne_twister_engine<UIntType, w, n, m, r,
a, u, d, s, b, t, c, l, f>::state_size;
};
template <class UIntType, size_t w, size_t n, size_t m, size_t r,
UIntType a, size_t u, UIntType d, size_t s,
UIntType b, size_t t,
UIntType c, size_t l, UIntType f>
constexpr size_t
StateSize<std::mersenne_twister_engine<UIntType, w, n, m, r,
a, u, d, s, b, t, c, l, f>>::value;
#if FOLLY_HAVE_EXTRANDOM_SFMT19937
template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
uint32_t parity4>
struct StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
parity1, parity2, parity3, parity4>> {
static constexpr size_t value =
__gnu_cxx::simd_fast_mersenne_twister_engine<
UIntType, m, pos1, sl1, sl2, sr1, sr2,
msk1, msk2, msk3, msk4,
parity1, parity2, parity3, parity4>::state_size;
}; };
template <class UIntType, size_t m, size_t pos1, size_t sl1, size_t sl2,
size_t sr1, size_t sr2, uint32_t msk1, uint32_t msk2, uint32_t msk3,
uint32_t msk4, uint32_t parity1, uint32_t parity2, uint32_t parity3,
uint32_t parity4>
constexpr size_t
StateSize<__gnu_cxx::simd_fast_mersenne_twister_engine<
UIntType, m, pos1, sl1, sl2, sr1, sr2, msk1, msk2, msk3, msk4,
parity1, parity2, parity3, parity4>>::value;
#endif
template <class UIntType, size_t w, size_t s, size_t r> template <class UIntType, size_t w, size_t s, size_t r>
struct StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>> { struct StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>> {
// [rand.eng.sub]: r * ceil(w / 32) // [rand.eng.sub]: r * ceil(w / 32)
static constexpr size_t value = r * ((w + 31) / 32); using type = std::integral_constant<size_t, r*((w + 31) / 32)>;
}; };
template <class UIntType, size_t w, size_t s, size_t r> template <typename RNG>
constexpr size_t using StateSizeT = _t<StateSize<RNG>>;
StateSize<std::subtract_with_carry_engine<UIntType, w, s, r>>::value;
template <class RNG> template <class RNG>
struct SeedData { struct SeedData {
...@@ -112,7 +63,7 @@ struct SeedData { ...@@ -112,7 +63,7 @@ struct SeedData {
Random::secureRandom(seedData.data(), seedData.size() * sizeof(uint32_t)); Random::secureRandom(seedData.data(), seedData.size() * sizeof(uint32_t));
} }
static constexpr size_t stateSize = StateSize<RNG>::value; static constexpr size_t stateSize = StateSizeT<RNG>::value;
std::array<uint32_t, stateSize> seedData; std::array<uint32_t, stateSize> seedData;
}; };
......
...@@ -17,11 +17,13 @@ ...@@ -17,11 +17,13 @@
#pragma once #pragma once
#define FOLLY_RANDOM_H_ #define FOLLY_RANDOM_H_
#include <array>
#include <cstdint> #include <cstdint>
#include <random> #include <random>
#include <type_traits> #include <type_traits>
#include <folly/Portability.h> #include <folly/Portability.h>
#include <folly/Traits.h>
#if FOLLY_HAVE_EXTRANDOM_SFMT19937 #if FOLLY_HAVE_EXTRANDOM_SFMT19937
#include <ext/random> #include <ext/random>
......
...@@ -32,13 +32,13 @@ TEST(Random, StateSize) { ...@@ -32,13 +32,13 @@ TEST(Random, StateSize) {
using namespace folly::detail; using namespace folly::detail;
// uint_fast32_t is uint64_t on x86_64, w00t // uint_fast32_t is uint64_t on x86_64, w00t
EXPECT_EQ(sizeof(uint_fast32_t) / 4 + 3, EXPECT_EQ(
StateSize<std::minstd_rand0>::value); sizeof(uint_fast32_t) / 4 + 3, StateSizeT<std::minstd_rand0>::value);
EXPECT_EQ(624, StateSize<std::mt19937>::value); EXPECT_EQ(624, StateSizeT<std::mt19937>::value);
#if FOLLY_HAVE_EXTRANDOM_SFMT19937 #if FOLLY_HAVE_EXTRANDOM_SFMT19937
EXPECT_EQ(624, StateSize<__gnu_cxx::sfmt19937>::value); EXPECT_EQ(624, StateSizeT<__gnu_cxx::sfmt19937>::value);
#endif #endif
EXPECT_EQ(24, StateSize<std::ranlux24_base>::value); EXPECT_EQ(24, StateSizeT<std::ranlux24_base>::value);
} }
TEST(Random, Simple) { TEST(Random, Simple) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment