Commit b6a27e14 authored by Sarang Masti's avatar Sarang Masti Committed by Facebook Github Bot

Support custom comparator in EvictingCacheMap

Summary: Allow passing in a custom comparator to compare keys

Reviewed By: yfeldblum, aary

Differential Revision: D7175777

fbshipit-source-id: e1e8d836a908b595a74b172b04ca847a5c5eb435
parent 417035de
...@@ -90,13 +90,23 @@ namespace folly { ...@@ -90,13 +90,23 @@ namespace folly {
* unless evictions of LRU items are triggered by calling prune() by clients * unless evictions of LRU items are triggered by calling prune() by clients
* (using their own eviction criteria). * (using their own eviction criteria).
*/ */
template <class TKey, class TValue, class THash = std::hash<TKey>> template <
class TKey,
class TValue,
class THash = std::hash<TKey>,
class TKeyEqual = std::equal_to<TKey>>
class EvictingCacheMap { class EvictingCacheMap {
private: private:
// typedefs for brevity // typedefs for brevity
struct Node; struct Node;
struct KeyHasher;
struct KeyValueEqual;
typedef boost::intrusive::link_mode<boost::intrusive::safe_link> link_mode; typedef boost::intrusive::link_mode<boost::intrusive::safe_link> link_mode;
typedef boost::intrusive::unordered_set<Node> NodeMap; typedef boost::intrusive::unordered_set<
Node,
boost::intrusive::hash<KeyHasher>,
boost::intrusive::equal<KeyValueEqual>>
NodeMap;
typedef boost::intrusive::list<Node> NodeList; typedef boost::intrusive::list<Node> NodeList;
typedef std::pair<const TKey, TValue> TPair; typedef std::pair<const TKey, TValue> TPair;
...@@ -144,13 +154,19 @@ class EvictingCacheMap { ...@@ -144,13 +154,19 @@ class EvictingCacheMap {
* @param clearSize the number of elements to clear at a time when the * @param clearSize the number of elements to clear at a time when the
* eviction size is reached. * eviction size is reached.
*/ */
explicit EvictingCacheMap(std::size_t maxSize, std::size_t clearSize = 1) explicit EvictingCacheMap(
std::size_t maxSize,
std::size_t clearSize = 1,
const THash& keyHash = THash(),
const TKeyEqual& keyEqual = TKeyEqual())
: nIndexBuckets_(std::max(maxSize / 2, std::size_t(kMinNumIndexBuckets))), : nIndexBuckets_(std::max(maxSize / 2, std::size_t(kMinNumIndexBuckets))),
indexBuckets_(new typename NodeMap::bucket_type[nIndexBuckets_]), indexBuckets_(new typename NodeMap::bucket_type[nIndexBuckets_]),
indexTraits_(indexBuckets_.get(), nIndexBuckets_), indexTraits_(indexBuckets_.get(), nIndexBuckets_),
index_(indexTraits_), keyHash_(keyHash),
keyEqual_(keyEqual),
index_(indexTraits_, keyHash_, keyEqual_),
maxSize_(maxSize), maxSize_(maxSize),
clearSize_(clearSize) { } clearSize_(clearSize) {}
EvictingCacheMap(const EvictingCacheMap&) = delete; EvictingCacheMap(const EvictingCacheMap&) = delete;
EvictingCacheMap& operator=(const EvictingCacheMap&) = delete; EvictingCacheMap& operator=(const EvictingCacheMap&) = delete;
...@@ -412,37 +428,36 @@ class EvictingCacheMap { ...@@ -412,37 +428,36 @@ class EvictingCacheMap {
} }
private: private:
struct Node struct Node : public boost::intrusive::unordered_set_base_hook<link_mode>,
: public boost::intrusive::unordered_set_base_hook<link_mode>,
public boost::intrusive::list_base_hook<link_mode> { public boost::intrusive::list_base_hook<link_mode> {
Node(const TKey& key, TValue&& value) Node(const TKey& key, TValue&& value)
: pr(std::make_pair(key, std::move(value))) { : pr(std::make_pair(key, std::move(value))) {}
}
TPair pr; TPair pr;
friend bool operator==(const Node& lhs, const Node& rhs) {
return lhs.pr.first == rhs.pr.first;
}
friend std::size_t hash_value(const Node& node) {
return THash()(node.pr.first);
}
}; };
struct KeyHasher { struct KeyHasher {
std::size_t operator()(const Node& node) { KeyHasher(const THash& keyHash) : hash(keyHash) {}
return THash()(node.pr.first); std::size_t operator()(const Node& node) const {
return hash(node.pr.first);
} }
std::size_t operator()(const TKey& key) { std::size_t operator()(const TKey& key) const {
return THash()(key); return hash(key);
} }
THash hash;
}; };
struct KeyValueEqual { struct KeyValueEqual {
bool operator()(const TKey& lhs, const Node& rhs) { KeyValueEqual(const TKeyEqual& keyEqual) : equal(keyEqual) {}
return lhs == rhs.pr.first; bool operator()(const TKey& lhs, const Node& rhs) const {
return equal(lhs, rhs.pr.first);
}
bool operator()(const Node& lhs, const TKey& rhs) const {
return equal(lhs.pr.first, rhs);
} }
bool operator()(const Node& lhs, const TKey& rhs) { bool operator()(const Node& lhs, const Node& rhs) const {
return lhs.pr.first == rhs; return equal(lhs.pr.first, rhs.pr.first);
} }
TKeyEqual equal;
}; };
/** /**
...@@ -453,11 +468,11 @@ class EvictingCacheMap { ...@@ -453,11 +468,11 @@ class EvictingCacheMap {
* (a std::pair of const TKey, TValue) or index_.end() if it does not exist * (a std::pair of const TKey, TValue) or index_.end() if it does not exist
*/ */
typename NodeMap::iterator findInIndex(const TKey& key) { typename NodeMap::iterator findInIndex(const TKey& key) {
return index_.find(key, KeyHasher(), KeyValueEqual()); return index_.find(key, KeyHasher(keyHash_), KeyValueEqual(keyEqual_));
} }
typename NodeMap::const_iterator findInIndex(const TKey& key) const { typename NodeMap::const_iterator findInIndex(const TKey& key) const {
return index_.find(key, KeyHasher(), KeyValueEqual()); return index_.find(key, KeyHasher(keyHash_), KeyValueEqual(keyEqual_));
} }
/** /**
...@@ -493,6 +508,8 @@ class EvictingCacheMap { ...@@ -493,6 +508,8 @@ class EvictingCacheMap {
std::size_t nIndexBuckets_; std::size_t nIndexBuckets_;
std::unique_ptr<typename NodeMap::bucket_type[]> indexBuckets_; std::unique_ptr<typename NodeMap::bucket_type[]> indexBuckets_;
typename NodeMap::bucket_traits indexTraits_; typename NodeMap::bucket_traits indexTraits_;
THash keyHash_;
TKeyEqual keyEqual_;
NodeMap index_; NodeMap index_;
NodeList lru_; NodeList lru_;
std::size_t maxSize_; std::size_t maxSize_;
......
...@@ -633,3 +633,28 @@ TEST(EvictingCacheMap, MoveTest) { ...@@ -633,3 +633,28 @@ TEST(EvictingCacheMap, MoveTest) {
EXPECT_EQ(i, map2.get(i)); EXPECT_EQ(i, map2.get(i));
} }
} }
TEST(EvictingCacheMap, CustomKeyEqual) {
const int nItems = 100;
struct Eq {
bool operator()(const int& a, const int& b) const {
return (a % mod) == (b % mod);
}
int mod;
};
struct Hash {
size_t operator()(const int& a) const {
return std::hash<int>()(a % mod);
}
int mod;
};
EvictingCacheMap<int, int, Hash, Eq> map(
nItems, 1 /* clearSize */, Hash{nItems}, Eq{nItems});
for (int i = 0; i < nItems; i++) {
map.set(i, i);
EXPECT_TRUE(map.exists(i));
EXPECT_EQ(i, map.get(i));
EXPECT_TRUE(map.exists(i + nItems));
EXPECT_EQ(i, map.get(i + nItems));
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment