Commit 5eb31cb3 authored by James Sedgwick's avatar James Sedgwick Committed by Facebook Github Bot

add tryReadUntil and make fixes along the way

Summary:
this diff adds tryReadUntil, which is a mirror of tryWriteUntil in both function and implementation.
Two bugs were exposed in the process of implementing and testing tryWriteUntil; they are fixed as well and are as follows:
  1. tryObtainPromisedPopTicket didn't assign to the passed ticket return reference in the failure case
  2. TurnSequencer::tryWaitForTurn() didn't distinguish between past turns and timeouts in the failure case; they need to be
     differentiated because SingleElementQueue::tryWaitFor{De/En}queue() should only fail in the timeout case, not if the turn has passed.

The two added unit tests are admittedly clumsy, but making the obvious simplifications to them keeps them from triggering the premature timeout race caused by bug 2 above, so I kept them as is.

Reviewed By: magedm

Differential Revision: D4050515

fbshipit-source-id: b0a3dd894d502c44be62d362ea347a1837df4c2f
parent 782325fd
This diff is collapsed.
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <assert.h>
#include <limits> #include <limits>
#include <folly/detail/Futex.h> #include <folly/detail/Futex.h>
#include <folly/portability/Asm.h> #include <folly/portability/Asm.h>
#include <folly/portability/Unistd.h> #include <folly/portability/Unistd.h>
#include <glog/logging.h>
namespace folly { namespace folly {
namespace detail { namespace detail {
...@@ -79,14 +80,15 @@ struct TurnSequencer { ...@@ -79,14 +80,15 @@ struct TurnSequencer {
return decodeCurrentSturn(state) == (turn << kTurnShift); return decodeCurrentSturn(state) == (turn << kTurnShift);
} }
enum class TryWaitResult { SUCCESS, PAST, TIMEDOUT };
/// See tryWaitForTurn /// See tryWaitForTurn
/// Requires that `turn` is not a turn in the past. /// Requires that `turn` is not a turn in the past.
void waitForTurn(const uint32_t turn, void waitForTurn(const uint32_t turn,
Atom<uint32_t>& spinCutoff, Atom<uint32_t>& spinCutoff,
const bool updateSpinCutoff) noexcept { const bool updateSpinCutoff) noexcept {
bool success = tryWaitForTurn(turn, spinCutoff, updateSpinCutoff); const auto ret = tryWaitForTurn(turn, spinCutoff, updateSpinCutoff);
(void)success; DCHECK(ret == TryWaitResult::SUCCESS);
assert(success);
} }
// Internally we always work with shifted turn values, which makes the // Internally we always work with shifted turn values, which makes the
...@@ -98,12 +100,14 @@ struct TurnSequencer { ...@@ -98,12 +100,14 @@ struct TurnSequencer {
/// updateSpinCutoff is true then this will spin for up to kMaxSpins tries /// updateSpinCutoff is true then this will spin for up to kMaxSpins tries
/// before blocking and will adjust spinCutoff based on the results, /// before blocking and will adjust spinCutoff based on the results,
/// otherwise it will spin for at most spinCutoff spins. /// otherwise it will spin for at most spinCutoff spins.
/// Returns true if the wait succeeded, false if the turn is in the past /// Returns SUCCESS if the wait succeeded, PAST if the turn is in the past
/// or the absTime time value is not nullptr and is reached before the turn /// or TIMEDOUT if the absTime time value is not nullptr and is reached before
/// arrives /// the turn arrives
template <class Clock = std::chrono::steady_clock, template <
class Clock = std::chrono::steady_clock,
class Duration = typename Clock::duration> class Duration = typename Clock::duration>
bool tryWaitForTurn(const uint32_t turn, TryWaitResult tryWaitForTurn(
const uint32_t turn,
Atom<uint32_t>& spinCutoff, Atom<uint32_t>& spinCutoff,
const bool updateSpinCutoff, const bool updateSpinCutoff,
const std::chrono::time_point<Clock, Duration>* absTime = const std::chrono::time_point<Clock, Duration>* absTime =
...@@ -124,7 +128,7 @@ struct TurnSequencer { ...@@ -124,7 +128,7 @@ struct TurnSequencer {
// wrap-safe version of (current_sturn >= sturn) // wrap-safe version of (current_sturn >= sturn)
if(sturn - current_sturn >= std::numeric_limits<uint32_t>::max() / 2) { if(sturn - current_sturn >= std::numeric_limits<uint32_t>::max() / 2) {
// turn is in the past // turn is in the past
return false; return TryWaitResult::PAST;
} }
// the first effectSpinCutoff tries are spins, after that we will // the first effectSpinCutoff tries are spins, after that we will
...@@ -152,7 +156,7 @@ struct TurnSequencer { ...@@ -152,7 +156,7 @@ struct TurnSequencer {
auto futexResult = auto futexResult =
state_.futexWaitUntil(new_state, *absTime, futexChannel(turn)); state_.futexWaitUntil(new_state, *absTime, futexChannel(turn));
if (futexResult == FutexResult::TIMEDOUT) { if (futexResult == FutexResult::TIMEDOUT) {
return false; return TryWaitResult::TIMEDOUT;
} }
} else { } else {
state_.futexWait(new_state, futexChannel(turn)); state_.futexWait(new_state, futexChannel(turn));
...@@ -184,14 +188,14 @@ struct TurnSequencer { ...@@ -184,14 +188,14 @@ struct TurnSequencer {
} }
} }
return true; return TryWaitResult::SUCCESS;
} }
/// Unblocks a thread running waitForTurn(turn + 1) /// Unblocks a thread running waitForTurn(turn + 1)
void completeTurn(const uint32_t turn) noexcept { void completeTurn(const uint32_t turn) noexcept {
uint32_t state = state_.load(std::memory_order_acquire); uint32_t state = state_.load(std::memory_order_acquire);
while (true) { while (true) {
assert(state == encode(turn << kTurnShift, decodeMaxWaitersDelta(state))); DCHECK(state == encode(turn << kTurnShift, decodeMaxWaitersDelta(state)));
uint32_t max_waiter_delta = decodeMaxWaitersDelta(state); uint32_t max_waiter_delta = decodeMaxWaitersDelta(state);
uint32_t new_state = uint32_t new_state =
encode((turn + 1) << kTurnShift, encode((turn + 1) << kTurnShift,
......
...@@ -204,7 +204,8 @@ public: ...@@ -204,7 +204,8 @@ public:
bool waitAndTryRead(T& dest, uint32_t turn) noexcept { bool waitAndTryRead(T& dest, uint32_t turn) noexcept {
uint32_t desired_turn = (turn + 1) * 2; uint32_t desired_turn = (turn + 1) * 2;
Atom<uint32_t> cutoff(0); Atom<uint32_t> cutoff(0);
if(!sequencer_.tryWaitForTurn(desired_turn, cutoff, false)) { if (sequencer_.tryWaitForTurn(desired_turn, cutoff, false) !=
TurnSequencer<Atom>::TryWaitResult::SUCCESS) {
return false; return false;
} }
memcpy(&dest, &data, sizeof(T)); memcpy(&dest, &data, sizeof(T));
......
...@@ -271,6 +271,13 @@ struct custom_stop_watch { ...@@ -271,6 +271,13 @@ struct custom_stop_watch {
return true; return true;
} }
/**
* Returns the current checkpoint
*/
typename clock_type::time_point getCheckpoint() const {
return checkpoint_;
}
private: private:
typename clock_type::time_point checkpoint_; typename clock_type::time_point checkpoint_;
}; };
......
...@@ -14,18 +14,20 @@ ...@@ -14,18 +14,20 @@
* limitations under the License. * limitations under the License.
*/ */
#include <folly/MPMCQueue.h>
#include <folly/Format.h> #include <folly/Format.h>
#include <folly/MPMCQueue.h>
#include <folly/Memory.h> #include <folly/Memory.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
#include <folly/portability/SysResource.h> #include <folly/portability/SysResource.h>
#include <folly/portability/SysTime.h> #include <folly/portability/SysTime.h>
#include <folly/portability/Unistd.h> #include <folly/portability/Unistd.h>
#include <folly/stop_watch.h>
#include <folly/test/DeterministicSchedule.h> #include <folly/test/DeterministicSchedule.h>
#include <boost/intrusive_ptr.hpp> #include <boost/intrusive_ptr.hpp>
#include <memory> #include <boost/thread/barrier.hpp>
#include <functional> #include <functional>
#include <memory>
#include <thread> #include <thread>
#include <utility> #include <utility>
...@@ -1158,3 +1160,92 @@ TEST(MPMCQueue, explicit_zero_capacity_fail) { ...@@ -1158,3 +1160,92 @@ TEST(MPMCQueue, explicit_zero_capacity_fail) {
using DynamicMPMCQueueInt = MPMCQueue<int, std::atomic, true>; using DynamicMPMCQueueInt = MPMCQueue<int, std::atomic, true>;
ASSERT_THROW(DynamicMPMCQueueInt cq(0), std::invalid_argument); ASSERT_THROW(DynamicMPMCQueueInt cq(0), std::invalid_argument);
} }
template <bool Dynamic>
void testTryReadUntil() {
MPMCQueue<int, std::atomic, Dynamic> q{1};
const auto wait = std::chrono::milliseconds(100);
stop_watch<> watch;
bool rets[2];
int vals[2];
std::vector<std::thread> threads;
boost::barrier b{3};
for (int i = 0; i < 2; i++) {
threads.emplace_back([&, i] {
b.wait();
rets[i] = q.tryReadUntil(watch.getCheckpoint() + wait, vals[i]);
});
}
b.wait();
EXPECT_TRUE(q.write(42));
for (int i = 0; i < 2; i++) {
threads[i].join();
}
for (int i = 0; i < 2; i++) {
int other = (i + 1) % 2;
if (rets[i]) {
EXPECT_EQ(42, vals[i]);
EXPECT_FALSE(rets[other]);
}
}
EXPECT_TRUE(watch.elapsed(wait));
}
template <bool Dynamic>
void testTryWriteUntil() {
MPMCQueue<int, std::atomic, Dynamic> q{1};
EXPECT_TRUE(q.write(42));
const auto wait = std::chrono::milliseconds(100);
stop_watch<> watch;
bool rets[2];
std::vector<std::thread> threads;
boost::barrier b{3};
for (int i = 0; i < 2; i++) {
threads.emplace_back([&, i] {
b.wait();
rets[i] = q.tryWriteUntil(watch.getCheckpoint() + wait, i);
});
}
b.wait();
int x;
EXPECT_TRUE(q.read(x));
EXPECT_EQ(42, x);
for (int i = 0; i < 2; i++) {
threads[i].join();
}
EXPECT_TRUE(q.read(x));
for (int i = 0; i < 2; i++) {
int other = (i + 1) % 2;
if (rets[i]) {
EXPECT_EQ(i, x);
EXPECT_FALSE(rets[other]);
}
}
EXPECT_TRUE(watch.elapsed(wait));
}
TEST(MPMCQueue, try_read_until) {
testTryReadUntil<false>();
}
TEST(MPMCQueue, try_read_until_dynamic) {
testTryReadUntil<true>();
}
TEST(MPMCQueue, try_write_until) {
testTryWriteUntil<false>();
}
TEST(MPMCQueue, try_write_until_dynamic) {
testTryWriteUntil<true>();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment