Commit 87d9f1da authored by Nick Terrell's avatar Nick Terrell Committed by Facebook GitHub Bot

Reset contexts when returning to pool

Summary:
Currently, the first thing you want to do when getting a context from the pool is
reset it, because you don't know the state the last user left it in. This is bug
prone, because you may forget to reset the context. Instead, the pool should
handle the resetting logic, so it cannot be forgotten.

This diff adds a `Resetter` to the pool, and the reset function is called every
time a context is returned to the pool. I chose to reset when returning to the
pool instead of when retrieving a context because that gives the context a
chance to free any resources it doesn't expect to need in the future before it
potentially becomres idle.

Reviewed By: Cyan4973

Differential Revision: D26439710

fbshipit-source-id: 8b9afa3db2fda4d167f0fd5307791e8a8a4f283c
parent 7dd7df55
......@@ -24,14 +24,14 @@
namespace folly {
namespace compression {
template <typename T, typename Creator, typename Deleter>
template <typename T, typename Creator, typename Deleter, typename Resetter>
class CompressionContextPool {
private:
using InternalRef = std::unique_ptr<T, Deleter>;
class ReturnToPoolDeleter {
public:
using Pool = CompressionContextPool<T, Creator, Deleter>;
using Pool = CompressionContextPool<T, Creator, Deleter, Resetter>;
explicit ReturnToPoolDeleter(Pool* pool) : pool_(pool) { DCHECK(pool); }
......@@ -49,8 +49,12 @@ class CompressionContextPool {
using Ref = std::unique_ptr<T, ReturnToPoolDeleter>;
explicit CompressionContextPool(
Creator creator = Creator(), Deleter deleter = Deleter())
: creator_(std::move(creator)), deleter_(std::move(deleter)) {}
Creator creator = Creator(),
Deleter deleter = Deleter(),
Resetter resetter = Resetter())
: creator_(std::move(creator)),
deleter_(std::move(deleter)),
resetter_(std::move(resetter)) {}
Ref get() {
auto stack = stack_.wlock();
......@@ -74,14 +78,18 @@ class CompressionContextPool {
ReturnToPoolDeleter get_deleter() { return ReturnToPoolDeleter(this); }
Resetter& get_resetter() { return resetter_; }
private:
void add(InternalRef ptr) {
DCHECK(ptr);
resetter_(ptr.get());
stack_.wlock()->push_back(std::move(ptr));
}
Creator creator_;
Deleter deleter_;
Resetter resetter_;
folly::Synchronized<std::vector<InternalRef>> stack_;
};
......
......@@ -43,6 +43,18 @@ void ZSTD_DCtx_Deleter::operator()(ZSTD_DCtx* ctx) const noexcept {
ZSTD_freeDCtx(ctx);
}
void ZSTD_CCtx_Resetter::operator()(ZSTD_CCtx* ctx) const noexcept {
size_t const err = ZSTD_CCtx_reset(ctx, ZSTD_reset_session_and_parameters);
assert(!ZSTD_isError(err)); // This function doesn't actually fail
(void)err;
}
void ZSTD_DCtx_Resetter::operator()(ZSTD_DCtx* ctx) const noexcept {
size_t const err = ZSTD_DCtx_reset(ctx, ZSTD_reset_session_and_parameters);
assert(!ZSTD_isError(err)); // This function doesn't actually fail
(void)err;
}
ZSTD_CCtx_Pool::Ref getZSTD_CCtx() {
return zstd_cctx_pool_singleton.get();
}
......
......@@ -52,19 +52,35 @@ struct ZSTD_DCtx_Deleter {
void operator()(ZSTD_DCtx* ctx) const noexcept;
};
struct ZSTD_CCtx_Resetter {
void operator()(ZSTD_CCtx* ctx) const noexcept;
};
struct ZSTD_DCtx_Resetter {
void operator()(ZSTD_DCtx* ctx) const noexcept;
};
using ZSTD_CCtx_Pool = CompressionCoreLocalContextPool<
ZSTD_CCtx,
ZSTD_CCtx_Creator,
ZSTD_CCtx_Deleter,
ZSTD_CCtx_Resetter,
4>;
using ZSTD_DCtx_Pool = CompressionCoreLocalContextPool<
ZSTD_DCtx,
ZSTD_DCtx_Creator,
ZSTD_DCtx_Deleter,
ZSTD_DCtx_Resetter,
4>;
/**
* Returns a clean ZSTD_CCtx.
*/
ZSTD_CCtx_Pool::Ref getZSTD_CCtx();
/**
* Returns a clean ZSTD_DCtx.
*/
ZSTD_DCtx_Pool::Ref getZSTD_DCtx();
ZSTD_CCtx_Pool::Ref getNULL_ZSTD_CCtx();
......
......@@ -36,7 +36,12 @@ namespace compression {
* make for less contention, but mean that a context is less likely to be hot
* in cache.
*/
template <typename T, typename Creator, typename Deleter, size_t NumStripes = 8>
template <
typename T,
typename Creator,
typename Deleter,
typename Resetter,
size_t NumStripes = 8>
class CompressionCoreLocalContextPool {
private:
/**
......@@ -51,8 +56,12 @@ class CompressionCoreLocalContextPool {
class ReturnToPoolDeleter {
public:
using Pool =
CompressionCoreLocalContextPool<T, Creator, Deleter, NumStripes>;
using Pool = CompressionCoreLocalContextPool<
T,
Creator,
Deleter,
Resetter,
NumStripes>;
explicit ReturnToPoolDeleter(Pool* pool) : pool_(pool) { DCHECK(pool_); }
......@@ -62,7 +71,7 @@ class CompressionCoreLocalContextPool {
Pool* pool_;
};
using BackingPool = CompressionContextPool<T, Creator, Deleter>;
using BackingPool = CompressionContextPool<T, Creator, Deleter, Resetter>;
using BackingPoolRef = typename BackingPool::Ref;
public:
......@@ -70,8 +79,11 @@ class CompressionCoreLocalContextPool {
using Ref = std::unique_ptr<T, ReturnToPoolDeleter>;
explicit CompressionCoreLocalContextPool(
Creator creator = Creator(), Deleter deleter = Deleter())
: pool_(std::move(creator), std::move(deleter)), caches_() {}
Creator creator = Creator(),
Deleter deleter = Deleter(),
Resetter resetter = Resetter())
: pool_(std::move(creator), std::move(deleter), std::move(resetter)),
caches_() {}
~CompressionCoreLocalContextPool() {
for (auto& cache : caches_) {
......@@ -98,6 +110,7 @@ class CompressionCoreLocalContextPool {
void store(T* ptr) {
DCHECK(ptr);
pool_.get_resetter()(ptr);
T* expected = nullptr;
const bool stored = local().ptr.compare_exchange_weak(expected, ptr);
if (!stored) {
......
......@@ -57,24 +57,6 @@ namespace {
#define ZSTD_c_compressionLevel ZSTD_p_compressionLevel
#define ZSTD_c_contentSizeFlag ZSTD_p_contentSizeFlag
void resetCCtxSessionAndParameters(ZSTD_CCtx* cctx) {
ZSTD_CCtx_reset(cctx);
}
void resetDCtxSessionAndParameters(ZSTD_DCtx* dctx) {
ZSTD_DCtx_reset(dctx);
}
#else
void resetCCtxSessionAndParameters(ZSTD_CCtx* cctx) {
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
}
void resetDCtxSessionAndParameters(ZSTD_DCtx* dctx) {
ZSTD_DCtx_reset(dctx, ZSTD_reset_session_and_parameters);
}
#endif
size_t zstdThrowIfError(size_t rc) {
......@@ -182,9 +164,8 @@ void ZSTDStreamCodec::doResetStream() {
void ZSTDStreamCodec::resetCCtx() {
DCHECK(cctx_ == nullptr);
cctx_ = getZSTD_CCtx();
cctx_ = getZSTD_CCtx(); // Gives us a clean context
DCHECK(cctx_ != nullptr);
resetCCtxSessionAndParameters(cctx_.get());
zstdThrowIfError(
ZSTD_CCtx_setParametersUsingCCtxParams(cctx_.get(), options_.params()));
zstdThrowIfError(ZSTD_CCtx_setPledgedSrcSize(
......@@ -222,9 +203,8 @@ bool ZSTDStreamCodec::doCompressStream(
void ZSTDStreamCodec::resetDCtx() {
DCHECK(dctx_ == nullptr);
dctx_ = getZSTD_DCtx();
dctx_ = getZSTD_DCtx(); // Gives us a clean context
DCHECK(dctx_ != nullptr);
resetDCtxSessionAndParameters(dctx_.get());
if (options_.maxWindowSize() != 0) {
zstdThrowIfError(
ZSTD_DCtx_setMaxWindowSize(dctx_.get(), options_.maxWindowSize()));
......
......@@ -30,7 +30,18 @@ namespace {
std::atomic<size_t> numFoos{0};
std::atomic<size_t> numDeleted{0};
class Foo {};
class Foo {
public:
void use() {
EXPECT_FALSE(used_);
used_ = true;
}
void reset() { used_ = false; }
private:
bool used_{false};
};
struct FooCreator {
Foo* operator()() {
......@@ -53,8 +64,13 @@ struct FooDeleter {
}
};
using Pool = CompressionContextPool<Foo, FooCreator, FooDeleter>;
using BadPool = CompressionContextPool<Foo, BadFooCreator, FooDeleter>;
struct FooResetter {
void operator()(Foo* f) { f->reset(); }
};
using Pool = CompressionContextPool<Foo, FooCreator, FooDeleter, FooResetter>;
using BadPool =
CompressionContextPool<Foo, BadFooCreator, FooDeleter, FooResetter>;
} // anonymous namespace
......@@ -175,9 +191,28 @@ TEST_F(CompressionContextPoolTest, testBadCreate) {
EXPECT_THROW(pool.get(), std::bad_alloc);
}
TEST_F(CompressionContextPoolTest, testReset) {
Pool::Object* tmp;
{
auto ptr = pool_->get();
ptr->use();
tmp = ptr.get();
}
{
auto ptr = pool_->get();
ptr->use();
EXPECT_EQ(ptr.get(), tmp);
}
}
class CompressionCoreLocalContextPoolTest : public testing::Test {
protected:
using Pool = CompressionCoreLocalContextPool<Foo, FooCreator, FooDeleter, 8>;
using Pool = CompressionCoreLocalContextPool<
Foo,
FooCreator,
FooDeleter,
FooResetter,
8>;
void SetUp() override { pool_ = std::make_unique<Pool>(); }
......@@ -250,6 +285,32 @@ TEST_F(CompressionCoreLocalContextPoolTest, testMultithread) {
EXPECT_LE(numFoos.load(), numThreads);
}
TEST_F(CompressionCoreLocalContextPoolTest, testReset) {
Pool::Object* tmp1;
Pool::Object* tmp2;
{
auto ptr1 = pool_->get();
ptr1->use();
tmp1 = ptr1.get();
}
{
auto ptr1 = pool_->get();
auto ptr2 = pool_->get();
ptr1->use();
ptr2->use();
EXPECT_EQ(ptr1.get(), tmp1);
tmp2 = ptr2.get();
}
{
auto ptr1 = pool_->get();
auto ptr2 = pool_->get();
ptr1->use();
ptr2->use();
EXPECT_EQ(ptr1.get(), tmp2);
EXPECT_EQ(ptr2.get(), tmp1);
}
}
#ifdef FOLLY_COMPRESSION_HAS_ZSTD_CONTEXT_POOL_SINGLETONS
TEST(CompressionContextPoolSingletonsTest, testSingletons) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment