Commit 8e110a76 authored by Yedidya Feldblum's avatar Yedidya Feldblum Committed by Facebook Github Bot

Destroy promise/future callback functions before waking waiters

Summary:
Code may pass a callback which captures an object with a destructor which mutates through a stored reference, triggering heap-use-after-free or stack-use-after-scope.

```lang=c++
void performDataRace() {
  auto number = std::make_unique<int>(0);
  auto guard = folly::makeGuard([&number] { *number = 1; });
  folly::via(getSomeExecutor(), [guard = std::move(guard)]() mutable {}).wait();
  // data race - we may wake and destruct number before guard is destructed on the
  // executor thread, which is both stack-use-after-scope and heap-use-after-free!
}
```

We can avoid this condition by always destructing the provided functor before setting any result on the promise.

Retry at {D4982969}.

Reviewed By: andriigrynenko

Differential Revision: D5058750

fbshipit-source-id: 4d1d878b4889e5e6474941187f03de5fa84d3061
parent 3a4a9694
...@@ -48,7 +48,76 @@ typedef folly::Baton<> FutureBatonType; ...@@ -48,7 +48,76 @@ typedef folly::Baton<> FutureBatonType;
} }
namespace detail { namespace detail {
std::shared_ptr<Timekeeper> getTimekeeperSingleton(); std::shared_ptr<Timekeeper> getTimekeeperSingleton();
// Guarantees that the stored functor is destructed before the stored promise
// may be fulfilled. Assumes the stored functor to be noexcept-destructible.
template <typename T, typename F>
class CoreCallbackState {
public:
template <typename FF>
CoreCallbackState(Promise<T>&& promise, FF&& func) noexcept(
noexcept(F(std::declval<FF>())))
: func_(std::forward<FF>(func)), promise_(std::move(promise)) {
assert(before_barrier());
}
CoreCallbackState(CoreCallbackState&& that) noexcept(
noexcept(F(std::declval<F>()))) {
if (that.before_barrier()) {
new (&func_) F(std::move(that.func_));
promise_ = that.stealPromise();
}
}
CoreCallbackState& operator=(CoreCallbackState&&) = delete;
~CoreCallbackState() {
if (before_barrier()) {
stealPromise();
}
}
template <typename... Args>
auto invoke(Args&&... args) noexcept(
noexcept(std::declval<F&&>()(std::declval<Args&&>()...))) {
assert(before_barrier());
return std::move(func_)(std::forward<Args>(args)...);
}
void setTry(Try<T>&& t) {
stealPromise().setTry(std::move(t));
}
void setException(exception_wrapper&& ew) {
stealPromise().setException(std::move(ew));
}
Promise<T> stealPromise() noexcept {
assert(before_barrier());
func_.~F();
return std::move(promise_);
}
private:
bool before_barrier() const noexcept {
return !promise_.isFulfilled();
}
union {
F func_;
};
Promise<T> promise_{detail::EmptyConstruct{}};
};
template <typename T, typename F>
inline auto makeCoreCallbackState(Promise<T>&& p, F&& f) noexcept(
noexcept(CoreCallbackState<T, _t<std::decay<F>>>(
std::declval<Promise<T>&&>(),
std::declval<F&&>()))) {
return CoreCallbackState<T, _t<std::decay<F>>>(
std::move(p), std::forward<F>(f));
}
} }
template <class T> template <class T>
...@@ -160,13 +229,13 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) { ...@@ -160,13 +229,13 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
in the destruction of the Future used to create it. in the destruction of the Future used to create it.
*/ */
setCallback_( setCallback_(
[ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable { [state = detail::makeCoreCallbackState(
std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
if (!isTry && t.hasException()) { if (!isTry && t.hasException()) {
pm.setException(std::move(t.exception())); state.setException(std::move(t.exception()));
} else { } else {
pm.setWith([&]() { state.setTry(makeTryWith(
return std::move(func)(t.template get<isTry, Args>()...); [&] { return state.invoke(t.template get<isTry, Args>()...); }));
});
} }
}); });
...@@ -191,16 +260,17 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) { ...@@ -191,16 +260,17 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
auto f = p.getFuture(); auto f = p.getFuture();
f.core_->setExecutorNoLock(getExecutor()); f.core_->setExecutorNoLock(getExecutor());
setCallback_([ func = std::forward<F>(func), pm = std::move(p) ]( setCallback_(
Try<T> && t) mutable { [state = detail::makeCoreCallbackState(
std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
auto ew = [&] { auto ew = [&] {
if (!isTry && t.hasException()) { if (!isTry && t.hasException()) {
return std::move(t.exception()); return std::move(t.exception());
} else { } else {
try { try {
auto f2 = std::move(func)(t.template get<isTry, Args>()...); auto f2 = state.invoke(t.template get<isTry, Args>()...);
// that didn't throw, now we can steal p // that didn't throw, now we can steal p
f2.setCallback_([p = std::move(pm)](Try<B> && b) mutable { f2.setCallback_([p = state.stealPromise()](Try<B> && b) mutable {
p.setTry(std::move(b)); p.setTry(std::move(b));
}); });
return exception_wrapper(); return exception_wrapper();
...@@ -212,7 +282,7 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) { ...@@ -212,7 +282,7 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
} }
}(); }();
if (ew) { if (ew) {
pm.setException(std::move(ew)); state.setException(std::move(ew));
} }
}); });
...@@ -266,11 +336,12 @@ Future<T>::onError(F&& func) { ...@@ -266,11 +336,12 @@ Future<T>::onError(F&& func) {
auto f = p.getFuture(); auto f = p.getFuture();
setCallback_( setCallback_(
[ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable { [state = detail::makeCoreCallbackState(
std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
if (!t.template withException<Exn>([&](Exn& e) { if (!t.template withException<Exn>([&](Exn& e) {
pm.setWith([&] { return std::move(func)(e); }); state.setTry(makeTryWith([&] { return state.invoke(e); }));
})) { })) {
pm.setTry(std::move(t)); state.setTry(std::move(t));
} }
}); });
...@@ -293,15 +364,15 @@ Future<T>::onError(F&& func) { ...@@ -293,15 +364,15 @@ Future<T>::onError(F&& func) {
Promise<T> p; Promise<T> p;
auto f = p.getFuture(); auto f = p.getFuture();
setCallback_([ pm = std::move(p), func = std::forward<F>(func) ]( setCallback_(
Try<T> && t) mutable { [state = detail::makeCoreCallbackState(
std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
if (!t.template withException<Exn>([&](Exn& e) { if (!t.template withException<Exn>([&](Exn& e) {
auto ew = [&] { auto ew = [&] {
try { try {
auto f2 = std::move(func)(e); auto f2 = state.invoke(e);
f2.setCallback_([pm = std::move(pm)](Try<T> && t2) mutable { f2.setCallback_([p = state.stealPromise()](
pm.setTry(std::move(t2)); Try<T> && t2) mutable { p.setTry(std::move(t2)); });
});
return exception_wrapper(); return exception_wrapper();
} catch (const std::exception& e2) { } catch (const std::exception& e2) {
return exception_wrapper(std::current_exception(), e2); return exception_wrapper(std::current_exception(), e2);
...@@ -310,10 +381,10 @@ Future<T>::onError(F&& func) { ...@@ -310,10 +381,10 @@ Future<T>::onError(F&& func) {
} }
}(); }();
if (ew) { if (ew) {
pm.setException(std::move(ew)); state.setException(std::move(ew));
} }
})) { })) {
pm.setTry(std::move(t)); state.setTry(std::move(t));
} }
}); });
...@@ -349,13 +420,14 @@ Future<T>::onError(F&& func) { ...@@ -349,13 +420,14 @@ Future<T>::onError(F&& func) {
Promise<T> p; Promise<T> p;
auto f = p.getFuture(); auto f = p.getFuture();
setCallback_( setCallback_(
[ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable { [state = detail::makeCoreCallbackState(
std::move(p), std::forward<F>(func))](Try<T> t) mutable {
if (t.hasException()) { if (t.hasException()) {
auto ew = [&] { auto ew = [&] {
try { try {
auto f2 = std::move(func)(std::move(t.exception())); auto f2 = state.invoke(std::move(t.exception()));
f2.setCallback_([pm = std::move(pm)](Try<T> t2) mutable { f2.setCallback_([p = state.stealPromise()](Try<T> t2) mutable {
pm.setTry(std::move(t2)); p.setTry(std::move(t2));
}); });
return exception_wrapper(); return exception_wrapper();
} catch (const std::exception& e2) { } catch (const std::exception& e2) {
...@@ -365,10 +437,10 @@ Future<T>::onError(F&& func) { ...@@ -365,10 +437,10 @@ Future<T>::onError(F&& func) {
} }
}(); }();
if (ew) { if (ew) {
pm.setException(std::move(ew)); state.setException(std::move(ew));
} }
} else { } else {
pm.setTry(std::move(t)); state.setTry(std::move(t));
} }
}); });
...@@ -390,11 +462,13 @@ Future<T>::onError(F&& func) { ...@@ -390,11 +462,13 @@ Future<T>::onError(F&& func) {
Promise<T> p; Promise<T> p;
auto f = p.getFuture(); auto f = p.getFuture();
setCallback_( setCallback_(
[ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable { [state = detail::makeCoreCallbackState(
std::move(p), std::forward<F>(func))](Try<T> t) mutable {
if (t.hasException()) { if (t.hasException()) {
pm.setWith([&] { return std::move(func)(std::move(t.exception())); }); state.setTry(makeTryWith(
[&] { return state.invoke(std::move(t.exception())); }));
} else { } else {
pm.setTry(std::move(t)); state.setTry(std::move(t));
} }
}); });
......
...@@ -59,6 +59,10 @@ void Promise<T>::throwIfRetrieved() { ...@@ -59,6 +59,10 @@ void Promise<T>::throwIfRetrieved() {
} }
} }
template <class T>
Promise<T>::Promise(detail::EmptyConstruct) noexcept
: retrieved_(false), core_(nullptr) {}
template <class T> template <class T>
Promise<T>::~Promise() { Promise<T>::~Promise() {
detach(); detach();
......
...@@ -25,6 +25,12 @@ namespace folly { ...@@ -25,6 +25,12 @@ namespace folly {
// forward declaration // forward declaration
template <class T> class Future; template <class T> class Future;
namespace detail {
struct EmptyConstruct {};
template <typename T, typename F>
class CoreCallbackState;
}
template <class T> template <class T>
class Promise { class Promise {
public: public:
...@@ -98,6 +104,8 @@ class Promise { ...@@ -98,6 +104,8 @@ class Promise {
private: private:
typedef typename Future<T>::corePtr corePtr; typedef typename Future<T>::corePtr corePtr;
template <class> friend class Future; template <class> friend class Future;
template <class, class>
friend class detail::CoreCallbackState;
// Whether the Future has been retrieved (a one-time operation). // Whether the Future has been retrieved (a one-time operation).
bool retrieved_; bool retrieved_;
...@@ -105,6 +113,8 @@ class Promise { ...@@ -105,6 +113,8 @@ class Promise {
// shared core state object // shared core state object
corePtr core_; corePtr core_;
explicit Promise(detail::EmptyConstruct) noexcept;
void throwIfFulfilled(); void throwIfFulfilled();
void throwIfRetrieved(); void throwIfRetrieved();
void detach(); void detach();
......
/*
* Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/futures/Future.h>
#include <thread>
#include <folly/futures/test/TestExecutor.h>
#include <folly/portability/GTest.h>
using namespace folly;
namespace {
/***
* The basic premise is to check that the callback passed to then or onError
* is destructed before wait returns on the resulting future.
*
* The approach is to use callbacks where the destructor sleeps 500ms and then
* mutates a counter allocated on the caller stack. The caller checks the
* counter immediately after calling wait. Were the callback not destructed
* before wait returns, then we would very likely see an unchanged counter just
* after wait returns. But if, as we expect, the callback were destructed
* before wait returns, then we must be guaranteed to see a mutated counter
* just after wait returns.
*
* Note that the failure condition is not strictly guaranteed under load. :(
*/
class CallbackLifetimeTest : public testing::Test {
public:
using CounterPtr = std::unique_ptr<size_t>;
static bool kRaiseWillThrow() {
return true;
}
static constexpr auto kDelay() {
return std::chrono::milliseconds(500);
}
auto mkC() {
return std::make_unique<size_t>(0);
}
auto mkCGuard(CounterPtr& ptr) {
return makeGuard([&] {
/* sleep override */ std::this_thread::sleep_for(kDelay());
++*ptr;
});
}
static void raise() {
if (kRaiseWillThrow()) { // to avoid marking [[noreturn]]
throw std::runtime_error("raise");
}
}
static Future<Unit> raiseFut() {
raise();
return makeFuture();
}
TestExecutor executor{2}; // need at least 2 threads for internal futures
};
}
TEST_F(CallbackLifetimeTest, thenReturnsValue) {
auto c = mkC();
via(&executor).then([_ = mkCGuard(c)]{}).wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, thenReturnsValueThrows) {
auto c = mkC();
via(&executor).then([_ = mkCGuard(c)] { raise(); }).wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, thenReturnsFuture) {
auto c = mkC();
via(&executor).then([_ = mkCGuard(c)] { return makeFuture(); }).wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, thenReturnsFutureThrows) {
auto c = mkC();
via(&executor).then([_ = mkCGuard(c)] { return raiseFut(); }).wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueMatch) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::exception&){})
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueMatchThrows) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::exception&) { raise(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueWrong) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::logic_error&){})
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueWrongThrows) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::logic_error&) { raise(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureMatch) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::exception&) { return makeFuture(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureMatchThrows) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::exception&) { return raiseFut(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureWrong) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::logic_error&) { return makeFuture(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureWrongThrows) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](std::logic_error&) { return raiseFut(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsValue) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](exception_wrapper &&){})
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsValueThrows) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](exception_wrapper &&) { raise(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsFuture) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](exception_wrapper &&) { return makeFuture(); })
.wait();
EXPECT_EQ(1, *c);
}
TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsFutureThrows) {
auto c = mkC();
via(&executor)
.then(raise)
.onError([_ = mkCGuard(c)](exception_wrapper &&) { return raiseFut(); })
.wait();
EXPECT_EQ(1, *c);
}
...@@ -262,6 +262,7 @@ unit_test_LDADD = libfollytestmain.la ...@@ -262,6 +262,7 @@ unit_test_LDADD = libfollytestmain.la
TESTS += unit_test TESTS += unit_test
futures_test_SOURCES = \ futures_test_SOURCES = \
../futures/test/CallbackLifetimeTest.cpp \
../futures/test/CollectTest.cpp \ ../futures/test/CollectTest.cpp \
../futures/test/ContextTest.cpp \ ../futures/test/ContextTest.cpp \
../futures/test/CoreTest.cpp \ ../futures/test/CoreTest.cpp \
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment