Commit cea6de83 authored by Xiao Shi's avatar Xiao Shi Committed by Facebook GitHub Bot

heterogeneous mutations for ConcurrentHashMap

Summary:
This diff adds heterogeneous mutation support to ConcurrentHashMap.
The gating of the allowed key types are done in `EnableHeterogeneousInsert`.

`CHM::insert()` does not destructure `pair` arguments and `CHM::emplace()` does
not yet attempt to identify a usable key, leaving those as future follow-ups.

Reviewed By: yfeldblum

Differential Revision: D25255528

fbshipit-source-id: e056db05c96d3bd29c8cbce562ecd2221884cd5f
parent 47a169a1
...@@ -169,6 +169,27 @@ class ConcurrentHashMap { ...@@ -169,6 +169,27 @@ class ConcurrentHashMap {
typedef KeyEqual key_equal; typedef KeyEqual key_equal;
typedef ConstIterator const_iterator; typedef ConstIterator const_iterator;
private:
template <typename K, typename T>
using EnableHeterogeneousInsert = std::enable_if_t<
::folly::detail::
EligibleForHeterogeneousInsert<KeyType, HashFn, KeyEqual, K>::value,
T>;
template <typename K>
using IsIter = std::is_same<ConstIterator, remove_cvref_t<K>>;
template <typename K, typename T>
using EnableHeterogeneousErase = std::enable_if_t<
::folly::detail::EligibleForHeterogeneousFind<
KeyType,
HashFn,
KeyEqual,
std::conditional_t<IsIter<K>::value, KeyType, K>>::value &&
!IsIter<K>::value,
T>;
public:
/* /*
* Construct a ConcurrentHashMap with 1 << ShardBits shards, size * Construct a ConcurrentHashMap with 1 << ShardBits shards, size
* and max_size given. Both size and max_size will be rounded up to * and max_size given. Both size and max_size will be rounded up to
...@@ -260,13 +281,12 @@ class ConcurrentHashMap { ...@@ -260,13 +281,12 @@ class ConcurrentHashMap {
std::pair<ConstIterator, bool> insert( std::pair<ConstIterator, bool> insert(
std::pair<key_type, mapped_type>&& foo) { std::pair<key_type, mapped_type>&& foo) {
auto segment = pickSegment(foo.first); return insertImpl(std::move(foo));
std::pair<ConstIterator, bool> res( }
std::piecewise_construct,
std::forward_as_tuple(this, segment), template <typename Key, EnableHeterogeneousInsert<Key, int> = 0>
std::forward_as_tuple(false)); std::pair<ConstIterator, bool> insert(std::pair<Key, mapped_type>&& foo) {
res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo)); return insertImpl(std::move(foo));
return res;
} }
template <typename Key, typename Value> template <typename Key, typename Value>
...@@ -375,6 +395,12 @@ class ConcurrentHashMap { ...@@ -375,6 +395,12 @@ class ConcurrentHashMap {
return item.first->second; return item.first->second;
} }
template <typename Key, EnableHeterogeneousInsert<Key, int> = 0>
const ValueType operator[](const Key& key) {
auto item = insert(key, ValueType());
return item.first->second;
}
const ValueType at(const KeyType& key) const { return atImpl(key); } const ValueType at(const KeyType& key) const { return atImpl(key); }
template <typename K, EnableHeterogeneousFind<K, int> = 0> template <typename K, EnableHeterogeneousFind<K, int> = 0>
...@@ -384,14 +410,11 @@ class ConcurrentHashMap { ...@@ -384,14 +410,11 @@ class ConcurrentHashMap {
// TODO update assign interface, operator[], at // TODO update assign interface, operator[], at
size_type erase(const key_type& k) { size_type erase(const key_type& k) { return eraseImpl(k); }
auto segment = pickSegment(k);
auto seg = segments_[segment].load(std::memory_order_acquire); template <typename K, EnableHeterogeneousErase<K, int> = 0>
if (!seg) { size_type erase(const K& k) {
return 0; return eraseImpl(k);
} else {
return seg->erase(k);
}
} }
// Calls the hash function, and therefore may throw. // Calls the hash function, and therefore may throw.
...@@ -409,15 +432,24 @@ class ConcurrentHashMap { ...@@ -409,15 +432,24 @@ class ConcurrentHashMap {
k, [&expected](const ValueType& v) { return v == expected; }); k, [&expected](const ValueType& v) { return v == expected; });
} }
template <typename K, EnableHeterogeneousErase<K, int> = 0>
size_type erase_if_equal(const K& k, const ValueType& expected) {
return erase_key_if(
k, [&expected](const ValueType& v) { return v == expected; });
}
// Erase if predicate evaluates to true on the existing value // Erase if predicate evaluates to true on the existing value
template <typename Predicate> template <typename Predicate>
size_type erase_key_if(const key_type& k, Predicate&& predicate) { size_type erase_key_if(const key_type& k, Predicate&& predicate) {
auto segment = pickSegment(k); return eraseKeyIfImpl(k, std::forward<Predicate>(predicate));
auto seg = segments_[segment].load(std::memory_order_acquire);
if (!seg) {
return 0;
} }
return seg->erase_key_if(k, std::forward<Predicate>(predicate));
template <
typename K,
typename Predicate,
EnableHeterogeneousErase<K, int> = 0>
size_type erase_key_if(const K& k, Predicate&& predicate) {
return eraseKeyIfImpl(k, std::forward<Predicate>(predicate));
} }
// NOT noexcept, initializes new shard segments vs. // NOT noexcept, initializes new shard segments vs.
...@@ -561,6 +593,38 @@ class ConcurrentHashMap { ...@@ -561,6 +593,38 @@ class ConcurrentHashMap {
return item->second; return item->second;
} }
template <typename Key>
std::pair<ConstIterator, bool> insertImpl(std::pair<Key, mapped_type>&& foo) {
auto segment = pickSegment(foo.first);
std::pair<ConstIterator, bool> res(
std::piecewise_construct,
std::forward_as_tuple(this, segment),
std::forward_as_tuple(false));
res.second = ensureSegment(segment)->insert(res.first.it_, std::move(foo));
return res;
}
template <typename K>
size_type eraseImpl(const K& k) {
auto segment = pickSegment(k);
auto seg = segments_[segment].load(std::memory_order_acquire);
if (!seg) {
return 0;
} else {
return seg->erase(k);
}
}
template <typename K, typename Predicate>
size_type eraseKeyIfImpl(const K& k, Predicate&& predicate) {
auto segment = pickSegment(k);
auto seg = segments_[segment].load(std::memory_order_acquire);
if (!seg) {
return 0;
}
return seg->erase_key_if(k, std::forward<Predicate>(predicate));
}
template <typename K> template <typename K>
uint64_t pickSegment(const K& k) const { uint64_t pickSegment(const K& k) const {
auto h = HashFn()(k); auto h = HashFn()(k);
......
...@@ -254,10 +254,10 @@ class alignas(64) BucketTable { ...@@ -254,10 +254,10 @@ class alignas(64) BucketTable {
bool empty() { return size() == 0; } bool empty() { return size() == 0; }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool insert( bool insert(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
hazptr_obj_cohort<Atom>* cohort, hazptr_obj_cohort<Atom>* cohort,
...@@ -266,10 +266,10 @@ class alignas(64) BucketTable { ...@@ -266,10 +266,10 @@ class alignas(64) BucketTable {
it, k, type, match, nullptr, cohort, std::forward<Args>(args)...); it, k, type, match, nullptr, cohort, std::forward<Args>(args)...);
} }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool insert( bool insert(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
Node* cur, Node* cur,
...@@ -357,8 +357,8 @@ class alignas(64) BucketTable { ...@@ -357,8 +357,8 @@ class alignas(64) BucketTable {
return false; return false;
} }
template <typename MatchFunc> template <typename K, typename MatchFunc>
std::size_t erase(const KeyType& key, Iterator* iter, MatchFunc match) { std::size_t erase(const K& key, Iterator* iter, MatchFunc match) {
Node* node{nullptr}; Node* node{nullptr};
auto h = HashFn()(key); auto h = HashFn()(key);
{ {
...@@ -605,10 +605,10 @@ class alignas(64) BucketTable { ...@@ -605,10 +605,10 @@ class alignas(64) BucketTable {
DCHECK(buckets); DCHECK(buckets);
} }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool doInsert( bool doInsert(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
Node* cur, Node* cur,
...@@ -1145,10 +1145,10 @@ class alignas(64) SIMDTable { ...@@ -1145,10 +1145,10 @@ class alignas(64) SIMDTable {
bool empty() { return size() == 0; } bool empty() { return size() == 0; }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool insert( bool insert(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
hazptr_obj_cohort<Atom>* cohort, hazptr_obj_cohort<Atom>* cohort,
...@@ -1198,10 +1198,10 @@ class alignas(64) SIMDTable { ...@@ -1198,10 +1198,10 @@ class alignas(64) SIMDTable {
return true; return true;
} }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool insert( bool insert(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
Node* cur, Node* cur,
...@@ -1287,8 +1287,8 @@ class alignas(64) SIMDTable { ...@@ -1287,8 +1287,8 @@ class alignas(64) SIMDTable {
return false; return false;
} }
template <typename MatchFunc> template <typename K, typename MatchFunc>
std::size_t erase(const KeyType& key, Iterator* iter, MatchFunc match) { std::size_t erase(const K& key, Iterator* iter, MatchFunc match) {
auto h = HashFn()(key); auto h = HashFn()(key);
const HashPair hp = splitHash(h); const HashPair hp = splitHash(h);
...@@ -1391,8 +1391,9 @@ class alignas(64) SIMDTable { ...@@ -1391,8 +1391,9 @@ class alignas(64) SIMDTable {
static size_t probeDelta(HashPair hp) { return 2 * hp.second + 1; } static size_t probeDelta(HashPair hp) { return 2 * hp.second + 1; }
// Must hold lock. // Must hold lock.
template <typename K>
Node* find_internal( Node* find_internal(
const KeyType& k, const K& k,
const HashPair& hp, const HashPair& hp,
Chunks* chunks, Chunks* chunks,
size_t ccount, size_t ccount,
...@@ -1421,10 +1422,10 @@ class alignas(64) SIMDTable { ...@@ -1421,10 +1422,10 @@ class alignas(64) SIMDTable {
return nullptr; return nullptr;
} }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool prepare_insert( bool prepare_insert(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
hazptr_obj_cohort<Atom>* cohort, hazptr_obj_cohort<Atom>* cohort,
...@@ -1649,7 +1650,8 @@ class alignas(64) ConcurrentHashMapSegment { ...@@ -1649,7 +1650,8 @@ class alignas(64) ConcurrentHashMapSegment {
bool empty() { return impl_.empty(); } bool empty() { return impl_.empty(); }
bool insert(Iterator& it, std::pair<key_type, mapped_type>&& foo) { template <typename Key>
bool insert(Iterator& it, std::pair<Key, mapped_type>&& foo) {
return insert(it, std::move(foo.first), std::move(foo.second)); return insert(it, std::move(foo.first), std::move(foo.second));
} }
...@@ -1749,10 +1751,10 @@ class alignas(64) ConcurrentHashMapSegment { ...@@ -1749,10 +1751,10 @@ class alignas(64) ConcurrentHashMapSegment {
return res; return res;
} }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool insert_internal( bool insert_internal(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
Args&&... args) { Args&&... args) {
...@@ -1760,10 +1762,10 @@ class alignas(64) ConcurrentHashMapSegment { ...@@ -1760,10 +1762,10 @@ class alignas(64) ConcurrentHashMapSegment {
it, k, type, match, cohort_, std::forward<Args>(args)...); it, k, type, match, cohort_, std::forward<Args>(args)...);
} }
template <typename MatchFunc, typename... Args> template <typename MatchFunc, typename K, typename... Args>
bool insert_internal( bool insert_internal(
Iterator& it, Iterator& it,
const KeyType& k, const K& k,
InsertType type, InsertType type,
MatchFunc match, MatchFunc match,
Node* cur) { Node* cur) {
...@@ -1779,18 +1781,18 @@ class alignas(64) ConcurrentHashMapSegment { ...@@ -1779,18 +1781,18 @@ class alignas(64) ConcurrentHashMapSegment {
} }
// Listed separately because we need a prev pointer. // Listed separately because we need a prev pointer.
size_type erase(const key_type& key) { template <typename K>
size_type erase(const K& key) {
return erase_internal(key, nullptr, [](const ValueType&) { return true; }); return erase_internal(key, nullptr, [](const ValueType&) { return true; });
} }
template <typename Predicate> template <typename K, typename Predicate>
size_type erase_key_if(const key_type& key, Predicate&& predicate) { size_type erase_key_if(const K& key, Predicate&& predicate) {
return erase_internal(key, nullptr, std::forward<Predicate>(predicate)); return erase_internal(key, nullptr, std::forward<Predicate>(predicate));
} }
template <typename MatchFunc> template <typename K, typename MatchFunc>
size_type size_type erase_internal(const K& key, Iterator* iter, MatchFunc match) {
erase_internal(const key_type& key, Iterator* iter, MatchFunc match) {
return impl_.erase(key, iter, match); return impl_.erase(key, iter, match);
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <thread> #include <thread>
#include <folly/Traits.h> #include <folly/Traits.h>
#include <folly/container/test/TrackingTypes.h>
#include <folly/hash/Hash.h> #include <folly/hash/Hash.h>
#include <folly/portability/GFlags.h> #include <folly/portability/GFlags.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
...@@ -834,7 +835,10 @@ TYPED_TEST_P(ConcurrentHashMapTest, IteratorLoop) { ...@@ -834,7 +835,10 @@ TYPED_TEST_P(ConcurrentHashMapTest, IteratorLoop) {
namespace { namespace {
template <typename T, typename Arg> template <typename T, typename Arg>
using detector_find = decltype(std::declval<T>().find(std::declval<Arg>())); using detector_find = decltype(std::declval<T>().find(std::declval<Arg>()));
}
template <typename T, typename Arg>
using detector_erase = decltype(std::declval<T>().erase(std::declval<Arg>()));
} // namespace
TYPED_TEST_P(ConcurrentHashMapTest, HeterogeneousLookup) { TYPED_TEST_P(ConcurrentHashMapTest, HeterogeneousLookup) {
using Hasher = folly::transparent<folly::hasher<folly::StringPiece>>; using Hasher = folly::transparent<folly::hasher<folly::StringPiece>>;
...@@ -868,6 +872,70 @@ TYPED_TEST_P(ConcurrentHashMapTest, HeterogeneousLookup) { ...@@ -868,6 +872,70 @@ TYPED_TEST_P(ConcurrentHashMapTest, HeterogeneousLookup) {
checks(folly::as_const(map)); checks(folly::as_const(map));
} }
TYPED_TEST_P(ConcurrentHashMapTest, HeterogeneousInsert) {
using Hasher = folly::transparent<folly::hasher<folly::StringPiece>>;
using KeyEqual = folly::transparent<std::equal_to<folly::StringPiece>>;
using P = std::pair<StringPiece, std::string>;
using CP = std::pair<const StringPiece, std::string>;
ConcurrentHashMap<std::string, std::string, Hasher, KeyEqual> map;
P p{"foo", "hello"};
StringPiece foo{"foo"};
StringPiece bar{"bar"};
map.insert("foo", "hello");
map.insert(foo, "hello");
// TODO(T31574848): the list-initialization below does not work on libstdc++
// versions (e.g., GCC < 6) with no implementation of N4387 ("perfect
// initialization" for pairs and tuples).
// StringPiece sp{"foo"};
// map.insert({sp, "hello"});
map.insert({"foo", "hello"});
map.insert(P("foo", "hello"));
map.insert(CP("foo", "hello"));
map.insert(std::move(p));
map.insert_or_assign("foo", "hello");
map.insert_or_assign(StringPiece{"foo"}, "hello");
map.erase(StringPiece{"foo"});
map.erase(foo);
map.erase("");
EXPECT_TRUE(map.empty());
map.insert("foo", "hello");
map.insert("bar", "world");
map.erase_if_equal(StringPiece{"foo"}, "hello");
map.erase_key_if(bar, [](const std::string& s) { return s == "world"; });
map.erase("");
EXPECT_TRUE(map.empty());
map.insert("foo", "baz");
EXPECT_TRUE(map.assign(foo, "hello2"));
EXPECT_TRUE(map.assign_if_equal("foo", "hello2", "hello"));
EXPECT_EQ(map[foo], "hello");
auto it = map.find(foo);
map.erase(it);
EXPECT_TRUE(map.empty());
map.try_emplace(foo);
map.try_emplace(foo, "hello");
map.try_emplace(StringPiece{"foo"}, "hello");
map.try_emplace(foo, "hello");
map.try_emplace(foo);
map.try_emplace("foo");
map.try_emplace("foo", "hello");
map.try_emplace("bar", /* count */ 20, 'x');
EXPECT_EQ(map[bar], std::string(20, 'x'));
map.emplace(StringPiece{"foo"}, "hello");
map.emplace("foo", "hello");
// invocability checks
static_assert(
!is_detected_v<detector_erase, decltype(map), int>,
"there shouldn't be an erase() overload for this string map with an int param");
}
REGISTER_TYPED_TEST_CASE_P( REGISTER_TYPED_TEST_CASE_P(
ConcurrentHashMapTest, ConcurrentHashMapTest,
MapTest, MapTest,
...@@ -903,7 +971,8 @@ REGISTER_TYPED_TEST_CASE_P( ...@@ -903,7 +971,8 @@ REGISTER_TYPED_TEST_CASE_P(
insertStressTest, insertStressTest,
IteratorMove, IteratorMove,
IteratorLoop, IteratorLoop,
HeterogeneousLookup); HeterogeneousLookup,
HeterogeneousInsert);
using folly::detail::concurrenthashmap::bucket::BucketTable; using folly::detail::concurrenthashmap::bucket::BucketTable;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment