Commit 04eafeb6 authored by Pranjal Raihan's avatar Pranjal Raihan Committed by Facebook GitHub Bot

Add API to AsyncServerSocket that allows potentially stale connections to be dropped

Summary:
Added a timestamp to `AsyncServerSocket::QueueMessage` so that `NotificationQueue` can ignore new connection messages which are deemed *expired*. Expired messages represent sockets which have *probably* timed out already.

The TTL is configured per `AsyncServerSocket` instance and is applied to all future messages that are queued by it. By default, messages do not expire. This can be configured with `AsyncServerSocket::setQueueTimeout`.

Reviewed By: andriigrynenko

Differential Revision: D24667870

fbshipit-source-id: 0f9d6c235627393d964e280a0d5956676010c7aa
parent d366ea8f
...@@ -76,8 +76,16 @@ void AsyncServerSocket::RemoteAcceptor::stop( ...@@ -76,8 +76,16 @@ void AsyncServerSocket::RemoteAcceptor::stop(
}); });
} }
void AsyncServerSocket::RemoteAcceptor::Consumer::operator()( AtomicNotificationQueueTaskStatus AsyncServerSocket::RemoteAcceptor::Consumer::
QueueMessage&& msg) noexcept { operator()(QueueMessage&& msg) noexcept {
if (msg.isExpired()) {
closeNoInt(msg.fd);
if (acceptor_.connectionEventCallback_) {
acceptor_.connectionEventCallback_->onConnectionDropped(
msg.fd, msg.address);
}
return AtomicNotificationQueueTaskStatus::DISCARD;
}
switch (msg.type) { switch (msg.type) {
case MessageType::MSG_NEW_CONN: { case MessageType::MSG_NEW_CONN: {
if (acceptor_.connectionEventCallback_) { if (acceptor_.connectionEventCallback_) {
...@@ -100,6 +108,7 @@ void AsyncServerSocket::RemoteAcceptor::Consumer::operator()( ...@@ -100,6 +108,7 @@ void AsyncServerSocket::RemoteAcceptor::Consumer::operator()(
acceptor_.callback_->acceptError(ex); acceptor_.callback_->acceptError(ex);
} }
} }
return AtomicNotificationQueueTaskStatus::CONSUMED;
} }
/* /*
...@@ -1012,6 +1021,9 @@ void AsyncServerSocket::dispatchSocket( ...@@ -1012,6 +1021,9 @@ void AsyncServerSocket::dispatchSocket(
msg.type = MessageType::MSG_NEW_CONN; msg.type = MessageType::MSG_NEW_CONN;
msg.address = std::move(address); msg.address = std::move(address);
msg.fd = socket; msg.fd = socket;
if (queueTimeout_.count() != 0) {
msg.deadline = std::chrono::steady_clock::now() + queueTimeout_;
}
// Loop until we find a free queue to write to // Loop until we find a free queue to write to
while (true) { while (true) {
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <limits.h> #include <limits.h>
#include <stddef.h> #include <stddef.h>
#include <chrono>
#include <exception> #include <exception>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -540,6 +541,26 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -540,6 +541,26 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
*/ */
void setMaxAcceptAtOnce(uint32_t numConns) { maxAcceptAtOnce_ = numConns; } void setMaxAcceptAtOnce(uint32_t numConns) { maxAcceptAtOnce_ = numConns; }
/**
* Get the duration after which new connection messages will be dropped from
* the NotificationQueue if it has not started processing yet.
*/
std::chrono::nanoseconds getQueueTimeout() const { return queueTimeout_; }
/**
* Set the duration after which new connection messages will be dropped from
* the NotificationQueue if it has not started processing yet.
*
* This avoids the NotificationQueue from processing messages where the client
* socket has probably timed out already, or will time out before a response
* can be sent.
*
* The default value (of 0) means that messages will never expire.
*/
void setQueueTimeout(std::chrono::nanoseconds duration) {
queueTimeout_ = duration;
}
/** /**
* Get the maximum number of unprocessed messages which a NotificationQueue * Get the maximum number of unprocessed messages which a NotificationQueue
* can hold. * can hold.
...@@ -723,6 +744,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -723,6 +744,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
int err; int err;
SocketAddress address; SocketAddress address;
std::string msg; std::string msg;
std::chrono::steady_clock::time_point deadline;
bool isExpired() const {
return deadline.time_since_epoch().count() != 0 &&
std::chrono::steady_clock::now() > deadline;
}
}; };
/** /**
...@@ -736,7 +763,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -736,7 +763,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
*/ */
class RemoteAcceptor { class RemoteAcceptor {
struct Consumer { struct Consumer {
void operator()(QueueMessage&& msg) noexcept; AtomicNotificationQueueTaskStatus operator()(QueueMessage&& msg) noexcept;
explicit Consumer(RemoteAcceptor& acceptor) : acceptor_(acceptor) {} explicit Consumer(RemoteAcceptor& acceptor) : acceptor_(acceptor) {}
RemoteAcceptor& acceptor_; RemoteAcceptor& acceptor_;
...@@ -871,6 +898,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -871,6 +898,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
ConnectionEventCallback* connectionEventCallback_{nullptr}; ConnectionEventCallback* connectionEventCallback_{nullptr};
bool tosReflect_{false}; bool tosReflect_{false};
bool zeroCopyVal_{false}; bool zeroCopyVal_{false};
std::chrono::nanoseconds queueTimeout_{0};
}; };
} // namespace folly } // namespace folly
...@@ -4196,3 +4196,55 @@ TEST(AsyncSocketTest, getBufInUse) { ...@@ -4196,3 +4196,55 @@ TEST(AsyncSocketTest, getBufInUse) {
EXPECT_GT(sendBufSize, 0); EXPECT_GT(sendBufSize, 0);
} }
#endif #endif
TEST(AsyncSocketTest, ConnectionExpiry) {
// Create a new AsyncServerSocket
EventBase eventBase;
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
constexpr auto kConnectionExpiryDuration = milliseconds(10);
serverSocket->setQueueTimeout(kConnectionExpiryDuration);
ScopedEventBaseThread acceptThread("ioworker_test");
TestAcceptCallback acceptCb;
acceptCb.setConnectionAcceptedFn(
[&, called = false](auto&&...) mutable {
ASSERT_FALSE(called)
<< "Only the first connection should have been dequeued";
called = true;
// Allow plenty of time for the AsyncSocketServer's event loop to run.
// This should leave no doubt that the acceptor thread has enough time
// to dequeue. If the dequeue succeeds, then our expiry code is broken.
constexpr auto kEventLoopTime = kConnectionExpiryDuration * 5;
eventBase.runInEventBaseThread([&]() {
eventBase.tryRunAfterDelay(
[&]() { serverSocket->removeAcceptCallback(&acceptCb, nullptr); },
milliseconds(kEventLoopTime).count());
});
// After the first message is enqueued, sleep long enough so that the
// second message expires before it has a chance to dequeue.
std::this_thread::sleep_for(kConnectionExpiryDuration);
});
TestConnectionEventCallback connectionEventCb;
serverSocket->setConnectionEventCallback(&connectionEventCb);
serverSocket->addAcceptCallback(&acceptCb, acceptThread.getEventBase());
serverSocket->startAccepting();
std::shared_ptr<AsyncSocket> clientSocket1(
AsyncSocket::newSocket(&eventBase, serverAddress));
std::shared_ptr<AsyncSocket> clientSocket2(
AsyncSocket::newSocket(&eventBase, serverAddress));
// Loop until we are stopped
eventBase.loop();
EXPECT_EQ(connectionEventCb.getConnectionEnqueuedForAcceptCallback(), 2);
// Since the second message is expired, it should NOT be dequeued
EXPECT_EQ(connectionEventCb.getConnectionDequeuedByAcceptCallback(), 1);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment