Commit 6d3bf646 authored by Mohammad Husain's avatar Mohammad Husain Committed by facebook-github-bot-4

Add connection event callback to AsyncServerSocket

Summary: Adding a callback to AsyncServerSocket to get notified of client connection events. This can be used for example to record stats about these events.

Reviewed By: @afrind

Differential Revision: D2544776

fb-gh-sync-id: 20d22cfc939c5b937abec2b600c10b7228923ff3
parent dd631eb7
......@@ -91,6 +91,10 @@ void AsyncServerSocket::RemoteAcceptor::messageAvailable(
switch (msg.type) {
case MessageType::MSG_NEW_CONN:
{
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDequeuedByAcceptCallback(
msg.fd, msg.address);
}
callback_->connectionAccepted(msg.fd, msg.address);
break;
}
......@@ -515,7 +519,7 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
// callback more efficiently without having to use a notification queue.
RemoteAcceptor* acceptor = nullptr;
try {
acceptor = new RemoteAcceptor(callback);
acceptor = new RemoteAcceptor(callback, connectionEventCallback_);
acceptor->start(eventBase, maxAtOnce, maxNumMsgsInQueue_);
} catch (...) {
callbacks_.pop_back();
......@@ -722,6 +726,10 @@ void AsyncServerSocket::handlerReady(
address.setFromSockaddr(saddr, addrLen);
if (clientSocket >= 0 && connectionEventCallback_) {
connectionEventCallback_->onConnectionAccepted(clientSocket, address);
}
std::chrono::time_point<std::chrono::steady_clock> nowMs =
std::chrono::steady_clock::now();
auto timeSinceLastAccept = std::max<int64_t>(
......@@ -737,6 +745,10 @@ void AsyncServerSocket::handlerReady(
++numDroppedConnections_;
if (clientSocket >= 0) {
closeNoInt(clientSocket);
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped(clientSocket,
address);
}
}
continue;
}
......@@ -760,6 +772,9 @@ void AsyncServerSocket::handlerReady(
} else {
dispatchError("accept() failed", errno);
}
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionAcceptError(errno);
}
return;
}
......@@ -769,6 +784,9 @@ void AsyncServerSocket::handlerReady(
closeNoInt(clientSocket);
dispatchError("failed to set accepted socket to non-blocking mode",
errno);
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped(clientSocket, address);
}
return;
}
#endif
......@@ -795,6 +813,7 @@ void AsyncServerSocket::dispatchSocket(int socket,
return;
}
const SocketAddress addr(address);
// Create a message to send over the notification queue
QueueMessage msg;
msg.type = MessageType::MSG_NEW_CONN;
......@@ -804,9 +823,13 @@ void AsyncServerSocket::dispatchSocket(int socket,
// Loop until we find a free queue to write to
while (true) {
if (info->consumer->getQueue()->tryPutMessageNoThrow(std::move(msg))) {
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionEnqueuedForAcceptCallback(socket,
addr);
}
// Success! return.
return;
}
}
// We couldn't add to queue. Fall through to below
......@@ -831,6 +854,9 @@ void AsyncServerSocket::dispatchSocket(int socket,
LOG(ERROR) << "failed to dispatch newly accepted socket:"
<< " all accept callback queues are full";
closeNoInt(socket);
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped(socket, addr);
}
return;
}
......@@ -886,6 +912,9 @@ void AsyncServerSocket::enterBackoff() {
// since we won't be able to re-enable ourselves later.
LOG(ERROR) << "failed to allocate AsyncServerSocket backoff"
<< " timer; unable to temporarly pause accepting";
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffError();
}
return;
}
}
......@@ -903,6 +932,9 @@ void AsyncServerSocket::enterBackoff() {
if (!backoffTimeout_->scheduleTimeout(timeoutMS)) {
LOG(ERROR) << "failed to schedule AsyncServerSocket backoff timer;"
<< "unable to temporarly pause accepting";
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffError();
}
return;
}
......@@ -912,6 +944,9 @@ void AsyncServerSocket::enterBackoff() {
for (auto& handler : sockets_) {
handler.unregisterHandler();
}
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffStarted();
}
}
void AsyncServerSocket::backoffTimeoutExpired() {
......@@ -924,6 +959,9 @@ void AsyncServerSocket::backoffTimeoutExpired() {
// If all of the callbacks were removed, we shouldn't re-enable accepts
if (callbacks_.empty()) {
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffEnded();
}
return;
}
......@@ -942,6 +980,9 @@ void AsyncServerSocket::backoffTimeoutExpired() {
abort();
}
}
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffEnded();
}
}
......
......@@ -64,6 +64,71 @@ class AsyncServerSocket : public DelayedDestruction
// Disallow copy, move, and default construction.
AsyncServerSocket(AsyncServerSocket&&) = delete;
/**
* A callback interface to get notified of client socket events.
*
* The ConnectionEventCallback implementations need to be thread-safe as the
* callbacks may be called from different threads.
*/
class ConnectionEventCallback {
public:
virtual ~ConnectionEventCallback() = default;
/**
* onConnectionAccepted() is called right after a client connection
* is accepted using the system accept()/accept4() APIs.
*/
virtual void onConnectionAccepted(const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onConnectionAcceptError() is called when an error occurred accepting
* a connection.
*/
virtual void onConnectionAcceptError(const int err) noexcept = 0;
/**
* onConnectionDropped() is called when a connection is dropped,
* probably because of some error encountered.
*/
virtual void onConnectionDropped(const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onConnectionEnqueuedForAcceptCallback() is called when the
* connection is successfully enqueued for an AcceptCallback to pick up.
*/
virtual void onConnectionEnqueuedForAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onConnectionDequeuedByAcceptCallback() is called when the
* connection is successfully dequeued by an AcceptCallback.
*/
virtual void onConnectionDequeuedByAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onBackoffStarted is called when the socket has successfully started
* backing off accepting new client sockets.
*/
virtual void onBackoffStarted() noexcept = 0;
/**
* onBackoffEnded is called when the backoff period has ended and the socket
* has successfully resumed accepting new connections if there is any
* AcceptCallback registered.
*/
virtual void onBackoffEnded() noexcept = 0;
/**
* onBackoffError is called when there is an error entering backoff
*/
virtual void onBackoffError() noexcept = 0;
};
class AcceptCallback {
public:
virtual ~AcceptCallback() = default;
......@@ -320,8 +385,8 @@ class AsyncServerSocket : public DelayedDestruction
*
* When a new socket is accepted, one of the AcceptCallbacks will be invoked
* with the new socket. The AcceptCallbacks are invoked in a round-robin
* fashion. This allows the accepted sockets to distributed among a pool of
* threads, each running its own EventBase object. This is a common model,
* fashion. This allows the accepted sockets to be distributed among a pool
* of threads, each running its own EventBase object. This is a common model,
* since most asynchronous-style servers typically run one EventBase thread
* per CPU.
*
......@@ -584,6 +649,21 @@ class AsyncServerSocket : public DelayedDestruction
return accepting_;
}
/**
* Set the ConnectionEventCallback
*/
void setConnectionEventCallback(
ConnectionEventCallback* const connectionEventCallback) {
connectionEventCallback_ = connectionEventCallback;
}
/**
* Get the ConnectionEventCallback
*/
ConnectionEventCallback* getConnectionEventCallback() const {
return connectionEventCallback_;
}
protected:
/**
* Protected destructor.
......@@ -618,8 +698,10 @@ class AsyncServerSocket : public DelayedDestruction
class RemoteAcceptor
: private NotificationQueue<QueueMessage>::Consumer {
public:
explicit RemoteAcceptor(AcceptCallback *callback)
: callback_(callback) {}
explicit RemoteAcceptor(AcceptCallback *callback,
ConnectionEventCallback *connectionEventCallback)
: callback_(callback),
connectionEventCallback_(connectionEventCallback) {}
~RemoteAcceptor() = default;
......@@ -634,6 +716,7 @@ class AsyncServerSocket : public DelayedDestruction
private:
AcceptCallback *callback_;
ConnectionEventCallback* connectionEventCallback_;
NotificationQueue<QueueMessage> queue_;
};
......@@ -738,6 +821,7 @@ class AsyncServerSocket : public DelayedDestruction
bool reusePortEnabled_{false};
bool closeOnExec_;
ShutdownSocketSet* shutdownSocketSet_;
ConnectionEventCallback* connectionEventCallback_{nullptr};
};
} // folly
......@@ -17,6 +17,7 @@
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h>
#include <folly/RWSpinLock.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
......@@ -1452,6 +1453,113 @@ TEST(AsyncSocket, ConnectReadUninstallRead) {
///////////////////////////////////////////////////////////////////////////
// AsyncServerSocket tests
///////////////////////////////////////////////////////////////////////////
namespace {
/**
* Helper ConnectionEventCallback class for the test code.
* It maintains counters protected by a spin lock.
*/
class TestConnectionEventCallback :
public AsyncServerSocket::ConnectionEventCallback {
public:
virtual void onConnectionAccepted(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionAccepted_++;
}
virtual void onConnectionAcceptError(const int err) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionAcceptedError_++;
}
virtual void onConnectionDropped(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionDropped_++;
}
virtual void onConnectionEnqueuedForAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionEnqueuedForAcceptCallback_++;
}
virtual void onConnectionDequeuedByAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionDequeuedByAcceptCallback_++;
}
virtual void onBackoffStarted() noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
backoffStarted_++;
}
virtual void onBackoffEnded() noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
backoffEnded_++;
}
virtual void onBackoffError() noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
backoffError_++;
}
unsigned int getConnectionAccepted() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionAccepted_;
}
unsigned int getConnectionAcceptedError() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionAcceptedError_;
}
unsigned int getConnectionDropped() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionDropped_;
}
unsigned int getConnectionEnqueuedForAcceptCallback() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionEnqueuedForAcceptCallback_;
}
unsigned int getConnectionDequeuedByAcceptCallback() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionDequeuedByAcceptCallback_;
}
unsigned int getBackoffStarted() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return backoffStarted_;
}
unsigned int getBackoffEnded() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return backoffEnded_;
}
unsigned int getBackoffError() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return backoffError_;
}
private:
mutable folly::RWSpinLock spinLock_;
unsigned int connectionAccepted_{0};
unsigned int connectionAcceptedError_{0};
unsigned int connectionDropped_{0};
unsigned int connectionEnqueuedForAcceptCallback_{0};
unsigned int connectionDequeuedByAcceptCallback_{0};
unsigned int backoffStarted_{0};
unsigned int backoffEnded_{0};
unsigned int backoffError_{0};
};
/**
* Helper AcceptCallback class for the test code
......@@ -1552,6 +1660,7 @@ class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
std::deque<EventInfo> events_;
};
}
/**
* Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
......@@ -2043,3 +2152,46 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
int flags = fcntl(fd, F_GETFL, 0);
CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
}
TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
EventBase eventBase;
TestConnectionEventCallback connectionEventCallback;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->setConnectionEventCallback(&connectionEventCallback);
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
// Validate the connection event counters
ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
ASSERT_EQ(
connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment