Commit a73640d7 authored by Andrew Smith's avatar Andrew Smith Committed by Facebook GitHub Bot

Add optional parameter to consume() and cancel()

Summary:
This diff adds an optional parameter to consume() and cancel(). This parameter will allow callback objects to know the source bridge that is calling them back. With this parameter, we will not need to create a separate callback object for each bridge (which saves memory).

The parameter is optional. If the parameter type is set to void, no parameter will be passed. This preserves backward compatibility with existing users of AtomicQueue that don't need a parameter.

Reviewed By: iahs

Differential Revision: D28550867

fbshipit-source-id: e7d998538c880c2a5c7649d3262cb7f8913e1439
parent 187d8422
......@@ -100,7 +100,8 @@ class AtomicQueue {
AtomicQueue(const AtomicQueue&) = delete;
AtomicQueue& operator=(const AtomicQueue&) = delete;
void push(Message&& value) {
template <typename... ConsumerArgs>
void push(Message&& value, ConsumerArgs&&... consumerArgs) {
std::unique_ptr<typename MessageQueue::Node> node(
new typename MessageQueue::Node(std::move(value)));
assert(!(reinterpret_cast<intptr_t>(node.get()) & kTypeMask));
......@@ -135,7 +136,7 @@ class AtomicQueue {
std::memory_order_relaxed)) {
node.release();
auto consumer = reinterpret_cast<Consumer*>(ptr);
consumer->consume();
consumer->consume(std::forward<ConsumerArgs>(consumerArgs)...);
return;
}
break;
......@@ -145,7 +146,8 @@ class AtomicQueue {
}
}
bool wait(Consumer* consumer) {
template <typename... ConsumerArgs>
bool wait(Consumer* consumer, ConsumerArgs&&... consumerArgs) {
assert(!(reinterpret_cast<intptr_t>(consumer) & kTypeMask));
auto storage = storage_.load(std::memory_order_relaxed);
while (true) {
......@@ -162,7 +164,7 @@ class AtomicQueue {
}
break;
case Type::CLOSED:
consumer->canceled();
consumer->canceled(std::forward<ConsumerArgs>(consumerArgs)...);
return true;
case Type::TAIL:
return false;
......@@ -173,7 +175,8 @@ class AtomicQueue {
}
}
void close() {
template <typename... ConsumerArgs>
void close(ConsumerArgs&&... consumerArgs) {
auto storage = storage_.exchange(
static_cast<intptr_t>(Type::CLOSED), std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask);
......@@ -186,7 +189,8 @@ class AtomicQueue {
reinterpret_cast<typename MessageQueue::Node*>(ptr));
return;
case Type::CONSUMER:
reinterpret_cast<Consumer*>(ptr)->canceled();
reinterpret_cast<Consumer*>(ptr)->canceled(
std::forward<ConsumerArgs>(consumerArgs)...);
return;
case Type::CLOSED:
default:
......@@ -199,7 +203,8 @@ class AtomicQueue {
return type == Type::CLOSED;
}
MessageQueue getMessages() {
template <typename... ConsumerArgs>
MessageQueue getMessages(ConsumerArgs&&... consumerArgs) {
auto storage = storage_.exchange(
static_cast<intptr_t>(Type::EMPTY), std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask);
......@@ -214,7 +219,7 @@ class AtomicQueue {
// We accidentally re-opened the queue, so close it again.
// This is only safe to do because isClosed() can't be called
// concurrently with getMessages().
close();
close(std::forward<ConsumerArgs>(consumerArgs)...);
return MessageQueue();
case Type::CONSUMER:
default:
......
......@@ -24,13 +24,20 @@ namespace folly {
namespace channels {
namespace detail {
static int* getConsumerParam() {
return reinterpret_cast<int*>(1);
}
TEST(AtomicQueueTest, Basic) {
folly::Baton<> producerBaton;
folly::Baton<> consumerBaton;
struct Consumer {
void consume() { baton.post(); }
void canceled() { ADD_FAILURE() << "canceled() shouldn't be called"; }
void consume(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
baton.post();
}
void canceled(int*) { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton;
};
AtomicQueue<Consumer, int> atomicQueue;
......@@ -40,23 +47,23 @@ TEST(AtomicQueueTest, Basic) {
producerBaton.wait();
producerBaton.reset();
atomicQueue.push(1);
atomicQueue.push(1, getConsumerParam());
producerBaton.wait();
producerBaton.reset();
atomicQueue.push(2);
atomicQueue.push(3);
atomicQueue.push(2, getConsumerParam());
atomicQueue.push(3, getConsumerParam());
consumerBaton.post();
});
EXPECT_TRUE(atomicQueue.wait(&consumer));
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
producerBaton.post();
consumer.baton.wait();
consumer.baton.reset();
{
auto q = atomicQueue.getMessages();
auto q = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(q.empty());
EXPECT_EQ(1, q.front());
q.pop();
......@@ -67,9 +74,9 @@ TEST(AtomicQueueTest, Basic) {
consumerBaton.wait();
consumerBaton.reset();
EXPECT_FALSE(atomicQueue.wait(&consumer));
EXPECT_FALSE(atomicQueue.wait(&consumer, getConsumerParam()));
{
auto q = atomicQueue.getMessages();
auto q = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(q.empty());
EXPECT_EQ(2, q.front());
q.pop();
......@@ -79,10 +86,10 @@ TEST(AtomicQueueTest, Basic) {
EXPECT_TRUE(q.empty());
}
EXPECT_TRUE(atomicQueue.wait(&consumer));
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_TRUE(atomicQueue.wait(&consumer));
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_EQ(atomicQueue.cancelCallback(), nullptr);
......@@ -92,41 +99,47 @@ TEST(AtomicQueueTest, Basic) {
TEST(AtomicQueueTest, Canceled) {
struct Consumer {
void consume() { ADD_FAILURE() << "consume() shouldn't be called"; }
void canceled() { canceledCalled = true; }
void consume(int*) { ADD_FAILURE() << "consume() shouldn't be called"; }
void canceled(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
canceledCalled = true;
}
bool canceledCalled{false};
};
AtomicQueue<Consumer, int> atomicQueue;
Consumer consumer;
EXPECT_TRUE(atomicQueue.wait(&consumer));
atomicQueue.close();
EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
atomicQueue.close(getConsumerParam());
EXPECT_TRUE(consumer.canceledCalled);
EXPECT_TRUE(atomicQueue.isClosed());
EXPECT_TRUE(atomicQueue.getMessages().empty());
EXPECT_TRUE(atomicQueue.getMessages(getConsumerParam()).empty());
EXPECT_TRUE(atomicQueue.isClosed());
atomicQueue.push(42);
atomicQueue.push(42, getConsumerParam());
EXPECT_TRUE(atomicQueue.getMessages().empty());
EXPECT_TRUE(atomicQueue.getMessages(getConsumerParam()).empty());
EXPECT_TRUE(atomicQueue.isClosed());
}
TEST(AtomicQueueTest, Stress) {
struct Consumer {
void consume() { baton.post(); }
void canceled() { ADD_FAILURE() << "canceled() shouldn't be called"; }
void consume(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
baton.post();
}
void canceled(int*) { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton;
};
AtomicQueue<Consumer, int> atomicQueue;
auto getNext = [&atomicQueue, queue = Queue<int>()]() mutable {
Consumer consumer;
if (queue.empty()) {
if (atomicQueue.wait(&consumer)) {
if (atomicQueue.wait(&consumer, getConsumerParam())) {
consumer.baton.wait();
}
queue = atomicQueue.getMessages();
queue = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(queue.empty());
}
auto next = queue.front();
......@@ -142,7 +155,7 @@ TEST(AtomicQueueTest, Stress) {
std::thread producerThread([&] {
for (producerIndex = 1; producerIndex <= kNumIters; ++producerIndex) {
atomicQueue.push(producerIndex);
atomicQueue.push(producerIndex, getConsumerParam());
if (producerIndex % kSynchronizeEvery == 0) {
while (producerIndex > consumerIndex.load(std::memory_order_relaxed)) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment