Commit 87d9f1da authored by Nick Terrell's avatar Nick Terrell Committed by Facebook GitHub Bot

Reset contexts when returning to pool

Summary:
Currently, the first thing you want to do when getting a context from the pool is
reset it, because you don't know the state the last user left it in. This is bug
prone, because you may forget to reset the context. Instead, the pool should
handle the resetting logic, so it cannot be forgotten.

This diff adds a `Resetter` to the pool, and the reset function is called every
time a context is returned to the pool. I chose to reset when returning to the
pool instead of when retrieving a context because that gives the context a
chance to free any resources it doesn't expect to need in the future before it
potentially becomres idle.

Reviewed By: Cyan4973

Differential Revision: D26439710

fbshipit-source-id: 8b9afa3db2fda4d167f0fd5307791e8a8a4f283c
parent 7dd7df55
...@@ -24,14 +24,14 @@ ...@@ -24,14 +24,14 @@
namespace folly { namespace folly {
namespace compression { namespace compression {
template <typename T, typename Creator, typename Deleter> template <typename T, typename Creator, typename Deleter, typename Resetter>
class CompressionContextPool { class CompressionContextPool {
private: private:
using InternalRef = std::unique_ptr<T, Deleter>; using InternalRef = std::unique_ptr<T, Deleter>;
class ReturnToPoolDeleter { class ReturnToPoolDeleter {
public: public:
using Pool = CompressionContextPool<T, Creator, Deleter>; using Pool = CompressionContextPool<T, Creator, Deleter, Resetter>;
explicit ReturnToPoolDeleter(Pool* pool) : pool_(pool) { DCHECK(pool); } explicit ReturnToPoolDeleter(Pool* pool) : pool_(pool) { DCHECK(pool); }
...@@ -49,8 +49,12 @@ class CompressionContextPool { ...@@ -49,8 +49,12 @@ class CompressionContextPool {
using Ref = std::unique_ptr<T, ReturnToPoolDeleter>; using Ref = std::unique_ptr<T, ReturnToPoolDeleter>;
explicit CompressionContextPool( explicit CompressionContextPool(
Creator creator = Creator(), Deleter deleter = Deleter()) Creator creator = Creator(),
: creator_(std::move(creator)), deleter_(std::move(deleter)) {} Deleter deleter = Deleter(),
Resetter resetter = Resetter())
: creator_(std::move(creator)),
deleter_(std::move(deleter)),
resetter_(std::move(resetter)) {}
Ref get() { Ref get() {
auto stack = stack_.wlock(); auto stack = stack_.wlock();
...@@ -74,14 +78,18 @@ class CompressionContextPool { ...@@ -74,14 +78,18 @@ class CompressionContextPool {
ReturnToPoolDeleter get_deleter() { return ReturnToPoolDeleter(this); } ReturnToPoolDeleter get_deleter() { return ReturnToPoolDeleter(this); }
Resetter& get_resetter() { return resetter_; }
private: private:
void add(InternalRef ptr) { void add(InternalRef ptr) {
DCHECK(ptr); DCHECK(ptr);
resetter_(ptr.get());
stack_.wlock()->push_back(std::move(ptr)); stack_.wlock()->push_back(std::move(ptr));
} }
Creator creator_; Creator creator_;
Deleter deleter_; Deleter deleter_;
Resetter resetter_;
folly::Synchronized<std::vector<InternalRef>> stack_; folly::Synchronized<std::vector<InternalRef>> stack_;
}; };
......
...@@ -43,6 +43,18 @@ void ZSTD_DCtx_Deleter::operator()(ZSTD_DCtx* ctx) const noexcept { ...@@ -43,6 +43,18 @@ void ZSTD_DCtx_Deleter::operator()(ZSTD_DCtx* ctx) const noexcept {
ZSTD_freeDCtx(ctx); ZSTD_freeDCtx(ctx);
} }
void ZSTD_CCtx_Resetter::operator()(ZSTD_CCtx* ctx) const noexcept {
size_t const err = ZSTD_CCtx_reset(ctx, ZSTD_reset_session_and_parameters);
assert(!ZSTD_isError(err)); // This function doesn't actually fail
(void)err;
}
void ZSTD_DCtx_Resetter::operator()(ZSTD_DCtx* ctx) const noexcept {
size_t const err = ZSTD_DCtx_reset(ctx, ZSTD_reset_session_and_parameters);
assert(!ZSTD_isError(err)); // This function doesn't actually fail
(void)err;
}
ZSTD_CCtx_Pool::Ref getZSTD_CCtx() { ZSTD_CCtx_Pool::Ref getZSTD_CCtx() {
return zstd_cctx_pool_singleton.get(); return zstd_cctx_pool_singleton.get();
} }
......
...@@ -52,19 +52,35 @@ struct ZSTD_DCtx_Deleter { ...@@ -52,19 +52,35 @@ struct ZSTD_DCtx_Deleter {
void operator()(ZSTD_DCtx* ctx) const noexcept; void operator()(ZSTD_DCtx* ctx) const noexcept;
}; };
struct ZSTD_CCtx_Resetter {
void operator()(ZSTD_CCtx* ctx) const noexcept;
};
struct ZSTD_DCtx_Resetter {
void operator()(ZSTD_DCtx* ctx) const noexcept;
};
using ZSTD_CCtx_Pool = CompressionCoreLocalContextPool< using ZSTD_CCtx_Pool = CompressionCoreLocalContextPool<
ZSTD_CCtx, ZSTD_CCtx,
ZSTD_CCtx_Creator, ZSTD_CCtx_Creator,
ZSTD_CCtx_Deleter, ZSTD_CCtx_Deleter,
ZSTD_CCtx_Resetter,
4>; 4>;
using ZSTD_DCtx_Pool = CompressionCoreLocalContextPool< using ZSTD_DCtx_Pool = CompressionCoreLocalContextPool<
ZSTD_DCtx, ZSTD_DCtx,
ZSTD_DCtx_Creator, ZSTD_DCtx_Creator,
ZSTD_DCtx_Deleter, ZSTD_DCtx_Deleter,
ZSTD_DCtx_Resetter,
4>; 4>;
/**
* Returns a clean ZSTD_CCtx.
*/
ZSTD_CCtx_Pool::Ref getZSTD_CCtx(); ZSTD_CCtx_Pool::Ref getZSTD_CCtx();
/**
* Returns a clean ZSTD_DCtx.
*/
ZSTD_DCtx_Pool::Ref getZSTD_DCtx(); ZSTD_DCtx_Pool::Ref getZSTD_DCtx();
ZSTD_CCtx_Pool::Ref getNULL_ZSTD_CCtx(); ZSTD_CCtx_Pool::Ref getNULL_ZSTD_CCtx();
......
...@@ -36,7 +36,12 @@ namespace compression { ...@@ -36,7 +36,12 @@ namespace compression {
* make for less contention, but mean that a context is less likely to be hot * make for less contention, but mean that a context is less likely to be hot
* in cache. * in cache.
*/ */
template <typename T, typename Creator, typename Deleter, size_t NumStripes = 8> template <
typename T,
typename Creator,
typename Deleter,
typename Resetter,
size_t NumStripes = 8>
class CompressionCoreLocalContextPool { class CompressionCoreLocalContextPool {
private: private:
/** /**
...@@ -51,8 +56,12 @@ class CompressionCoreLocalContextPool { ...@@ -51,8 +56,12 @@ class CompressionCoreLocalContextPool {
class ReturnToPoolDeleter { class ReturnToPoolDeleter {
public: public:
using Pool = using Pool = CompressionCoreLocalContextPool<
CompressionCoreLocalContextPool<T, Creator, Deleter, NumStripes>; T,
Creator,
Deleter,
Resetter,
NumStripes>;
explicit ReturnToPoolDeleter(Pool* pool) : pool_(pool) { DCHECK(pool_); } explicit ReturnToPoolDeleter(Pool* pool) : pool_(pool) { DCHECK(pool_); }
...@@ -62,7 +71,7 @@ class CompressionCoreLocalContextPool { ...@@ -62,7 +71,7 @@ class CompressionCoreLocalContextPool {
Pool* pool_; Pool* pool_;
}; };
using BackingPool = CompressionContextPool<T, Creator, Deleter>; using BackingPool = CompressionContextPool<T, Creator, Deleter, Resetter>;
using BackingPoolRef = typename BackingPool::Ref; using BackingPoolRef = typename BackingPool::Ref;
public: public:
...@@ -70,8 +79,11 @@ class CompressionCoreLocalContextPool { ...@@ -70,8 +79,11 @@ class CompressionCoreLocalContextPool {
using Ref = std::unique_ptr<T, ReturnToPoolDeleter>; using Ref = std::unique_ptr<T, ReturnToPoolDeleter>;
explicit CompressionCoreLocalContextPool( explicit CompressionCoreLocalContextPool(
Creator creator = Creator(), Deleter deleter = Deleter()) Creator creator = Creator(),
: pool_(std::move(creator), std::move(deleter)), caches_() {} Deleter deleter = Deleter(),
Resetter resetter = Resetter())
: pool_(std::move(creator), std::move(deleter), std::move(resetter)),
caches_() {}
~CompressionCoreLocalContextPool() { ~CompressionCoreLocalContextPool() {
for (auto& cache : caches_) { for (auto& cache : caches_) {
...@@ -98,6 +110,7 @@ class CompressionCoreLocalContextPool { ...@@ -98,6 +110,7 @@ class CompressionCoreLocalContextPool {
void store(T* ptr) { void store(T* ptr) {
DCHECK(ptr); DCHECK(ptr);
pool_.get_resetter()(ptr);
T* expected = nullptr; T* expected = nullptr;
const bool stored = local().ptr.compare_exchange_weak(expected, ptr); const bool stored = local().ptr.compare_exchange_weak(expected, ptr);
if (!stored) { if (!stored) {
......
...@@ -57,24 +57,6 @@ namespace { ...@@ -57,24 +57,6 @@ namespace {
#define ZSTD_c_compressionLevel ZSTD_p_compressionLevel #define ZSTD_c_compressionLevel ZSTD_p_compressionLevel
#define ZSTD_c_contentSizeFlag ZSTD_p_contentSizeFlag #define ZSTD_c_contentSizeFlag ZSTD_p_contentSizeFlag
void resetCCtxSessionAndParameters(ZSTD_CCtx* cctx) {
ZSTD_CCtx_reset(cctx);
}
void resetDCtxSessionAndParameters(ZSTD_DCtx* dctx) {
ZSTD_DCtx_reset(dctx);
}
#else
void resetCCtxSessionAndParameters(ZSTD_CCtx* cctx) {
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
}
void resetDCtxSessionAndParameters(ZSTD_DCtx* dctx) {
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_and_parameters);
}
#endif #endif
size_t zstdThrowIfError(size_t rc) { size_t zstdThrowIfError(size_t rc) {
...@@ -182,9 +164,8 @@ void ZSTDStreamCodec::doResetStream() { ...@@ -182,9 +164,8 @@ void ZSTDStreamCodec::doResetStream() {
void ZSTDStreamCodec::resetCCtx() { void ZSTDStreamCodec::resetCCtx() {
DCHECK(cctx_ == nullptr); DCHECK(cctx_ == nullptr);
cctx_ = getZSTD_CCtx(); cctx_ = getZSTD_CCtx(); // Gives us a clean context
DCHECK(cctx_ != nullptr); DCHECK(cctx_ != nullptr);
resetCCtxSessionAndParameters(cctx_.get());
zstdThrowIfError( zstdThrowIfError(
ZSTD_CCtx_setParametersUsingCCtxParams(cctx_.get(), options_.params())); ZSTD_CCtx_setParametersUsingCCtxParams(cctx_.get(), options_.params()));
zstdThrowIfError(ZSTD_CCtx_setPledgedSrcSize( zstdThrowIfError(ZSTD_CCtx_setPledgedSrcSize(
...@@ -222,9 +203,8 @@ bool ZSTDStreamCodec::doCompressStream( ...@@ -222,9 +203,8 @@ bool ZSTDStreamCodec::doCompressStream(
void ZSTDStreamCodec::resetDCtx() { void ZSTDStreamCodec::resetDCtx() {
DCHECK(dctx_ == nullptr); DCHECK(dctx_ == nullptr);
dctx_ = getZSTD_DCtx(); dctx_ = getZSTD_DCtx(); // Gives us a clean context
DCHECK(dctx_ != nullptr); DCHECK(dctx_ != nullptr);
resetDCtxSessionAndParameters(dctx_.get());
if (options_.maxWindowSize() != 0) { if (options_.maxWindowSize() != 0) {
zstdThrowIfError( zstdThrowIfError(
ZSTD_DCtx_setMaxWindowSize(dctx_.get(), options_.maxWindowSize())); ZSTD_DCtx_setMaxWindowSize(dctx_.get(), options_.maxWindowSize()));
......
...@@ -30,7 +30,18 @@ namespace { ...@@ -30,7 +30,18 @@ namespace {
std::atomic<size_t> numFoos{0}; std::atomic<size_t> numFoos{0};
std::atomic<size_t> numDeleted{0}; std::atomic<size_t> numDeleted{0};
class Foo {}; class Foo {
public:
void use() {
EXPECT_FALSE(used_);
used_ = true;
}
void reset() { used_ = false; }
private:
bool used_{false};
};
struct FooCreator { struct FooCreator {
Foo* operator()() { Foo* operator()() {
...@@ -53,8 +64,13 @@ struct FooDeleter { ...@@ -53,8 +64,13 @@ struct FooDeleter {
} }
}; };
using Pool = CompressionContextPool<Foo, FooCreator, FooDeleter>; struct FooResetter {
using BadPool = CompressionContextPool<Foo, BadFooCreator, FooDeleter>; void operator()(Foo* f) { f->reset(); }
};
using Pool = CompressionContextPool<Foo, FooCreator, FooDeleter, FooResetter>;
using BadPool =
CompressionContextPool<Foo, BadFooCreator, FooDeleter, FooResetter>;
} // anonymous namespace } // anonymous namespace
...@@ -175,9 +191,28 @@ TEST_F(CompressionContextPoolTest, testBadCreate) { ...@@ -175,9 +191,28 @@ TEST_F(CompressionContextPoolTest, testBadCreate) {
EXPECT_THROW(pool.get(), std::bad_alloc); EXPECT_THROW(pool.get(), std::bad_alloc);
} }
TEST_F(CompressionContextPoolTest, testReset) {
Pool::Object* tmp;
{
auto ptr = pool_->get();
ptr->use();
tmp = ptr.get();
}
{
auto ptr = pool_->get();
ptr->use();
EXPECT_EQ(ptr.get(), tmp);
}
}
class CompressionCoreLocalContextPoolTest : public testing::Test { class CompressionCoreLocalContextPoolTest : public testing::Test {
protected: protected:
using Pool = CompressionCoreLocalContextPool<Foo, FooCreator, FooDeleter, 8>; using Pool = CompressionCoreLocalContextPool<
Foo,
FooCreator,
FooDeleter,
FooResetter,
8>;
void SetUp() override { pool_ = std::make_unique<Pool>(); } void SetUp() override { pool_ = std::make_unique<Pool>(); }
...@@ -250,6 +285,32 @@ TEST_F(CompressionCoreLocalContextPoolTest, testMultithread) { ...@@ -250,6 +285,32 @@ TEST_F(CompressionCoreLocalContextPoolTest, testMultithread) {
EXPECT_LE(numFoos.load(), numThreads); EXPECT_LE(numFoos.load(), numThreads);
} }
TEST_F(CompressionCoreLocalContextPoolTest, testReset) {
Pool::Object* tmp1;
Pool::Object* tmp2;
{
auto ptr1 = pool_->get();
ptr1->use();
tmp1 = ptr1.get();
}
{
auto ptr1 = pool_->get();
auto ptr2 = pool_->get();
ptr1->use();
ptr2->use();
EXPECT_EQ(ptr1.get(), tmp1);
tmp2 = ptr2.get();
}
{
auto ptr1 = pool_->get();
auto ptr2 = pool_->get();
ptr1->use();
ptr2->use();
EXPECT_EQ(ptr1.get(), tmp2);
EXPECT_EQ(ptr2.get(), tmp1);
}
}
#ifdef FOLLY_COMPRESSION_HAS_ZSTD_CONTEXT_POOL_SINGLETONS #ifdef FOLLY_COMPRESSION_HAS_ZSTD_CONTEXT_POOL_SINGLETONS
TEST(CompressionContextPoolSingletonsTest, testSingletons) { TEST(CompressionContextPoolSingletonsTest, testSingletons) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment