Commit 18435bce authored by Subodh Iyengar's avatar Subodh Iyengar Committed by facebook-github-bot-9

Add handshake and connect times

Summary: Add api to get time taken to establish
connections and to complete handshake
for clients using AsyncSocket directly.

Reviewed By: @afrind

Differential Revision: D2435074

fb-gh-sync-id: f44c336e62c426736eb5b3d88dd57a18572382e8
parent ae574fb9
......@@ -356,13 +356,10 @@ void AsyncSSLSocket::closeNow() {
DestructorGuard dg(this);
if (handshakeCallback_) {
AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
"SSL connection closed locally");
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
invokeHandshakeErr(
AsyncSocketException(
AsyncSocketException::END_OF_FILE,
"SSL connection closed locally"));
if (ssl_ != nullptr) {
SSL_free(ssl_);
......@@ -468,6 +465,7 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
"sslAccept() called with socket in invalid state");
handshakeEndTime_ = std::chrono::steady_clock::now();
if (callback) {
callback->handshakeErr(this, ex);
}
......@@ -490,6 +488,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
handshakeCallback_ != nullptr) {
return invalidState(callback);
}
handshakeStartTime_ = std::chrono::steady_clock::now();
// Make end time at least >= start time.
handshakeEndTime_ = handshakeStartTime_;
sslState_ = STATE_ACCEPTING;
handshakeCallback_ = callback;
......@@ -623,20 +624,24 @@ AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
void AsyncSSLSocket::failHandshake(const char* fn,
const AsyncSocketException& ex) {
startFail();
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
invokeHandshakeErr(ex);
finishFail();
}
void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeCallback_ != nullptr) {
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
finishFail();
}
void AsyncSSLSocket::invokeHandshakeCB() {
handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
......@@ -691,6 +696,10 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
return invalidState(callback);
}
handshakeStartTime_ = std::chrono::steady_clock::now();
// Make end time at least >= start time.
handshakeEndTime_ = handshakeStartTime_;
sslState_ = STATE_CONNECTING;
handshakeCallback_ = callback;
......
......@@ -722,6 +722,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
return clientHelloInfo_.get();
}
/**
* Returns the time taken to complete a handshake.
*/
std::chrono::nanoseconds getHandshakeTime() const {
return handshakeEndTime_ - handshakeStartTime_;
}
void setMinWriteSize(size_t minWriteSize) {
minWriteSize_ = minWriteSize;
}
......@@ -813,6 +820,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// Inherit error handling methods from AsyncSocket, plus the following.
void failHandshake(const char* fn, const AsyncSocketException& ex);
void invokeHandshakeErr(const AsyncSocketException& ex);
void invokeHandshakeCB();
static void sslInfoCallback(const SSL *ssl, int type, int val);
......@@ -860,6 +868,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
bool parseClientHello_{false};
std::unique_ptr<ClientHelloInfo> clientHelloInfo_;
// Time taken to complete the ssl handshake.
std::chrono::steady_clock::time_point handshakeStartTime_;
std::chrono::steady_clock::time_point handshakeEndTime_;
};
} // namespace
......@@ -318,6 +318,10 @@ void AsyncSocket::connect(ConnectCallback* callback,
return invalidState(callback);
}
connectStartTime_ = std::chrono::steady_clock::now();
// Make connect end time at least >= connectStartTime.
connectEndTime_ = connectStartTime_;
assert(fd_ == -1);
state_ = StateEnum::CONNECTING;
connectCallback_ = callback;
......@@ -463,10 +467,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
assert(readCallback_ == nullptr);
assert(writeReqHead_ == nullptr);
state_ = StateEnum::ESTABLISHED;
if (callback) {
connectCallback_ = nullptr;
callback->connectSuccess();
}
invokeConnectSuccess();
}
void AsyncSocket::connect(ConnectCallback* callback,
......@@ -838,11 +839,7 @@ void AsyncSocket::closeNow() {
doClose();
}
if (connectCallback_) {
ConnectCallback* callback = connectCallback_;
connectCallback_ = nullptr;
callback->connectErr(socketClosedLocallyEx);
}
invokeConnectErr(socketClosedLocallyEx);
failAllWrites(socketClosedLocallyEx);
......@@ -1617,13 +1614,7 @@ void AsyncSocket::handleConnect() noexcept {
// callbacks (since the callbacks may call detachEventBase()).
EventBase* originalEventBase = eventBase_;
// Call the connect callback.
if (connectCallback_) {
ConnectCallback* callback = connectCallback_;
connectCallback_ = nullptr;
callback->connectSuccess();
}
invokeConnectSuccess();
// Note that the connect callback may have changed our state.
// (set or unset the read callback, called write(), closed the socket, etc.)
// The following code needs to handle these situations correctly.
......@@ -1805,12 +1796,7 @@ void AsyncSocket::finishFail() {
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
withAddr("socket closing after error"));
if (connectCallback_) {
ConnectCallback* callback = connectCallback_;
connectCallback_ = nullptr;
callback->connectErr(ex);
}
invokeConnectErr(ex);
failAllWrites(ex);
if (readCallback_) {
......@@ -1836,12 +1822,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
<< ex.what();
startFail();
if (connectCallback_ != nullptr) {
ConnectCallback* callback = connectCallback_;
connectCallback_ = nullptr;
callback->connectErr(ex);
}
invokeConnectErr(ex);
finishFail();
}
......@@ -1931,6 +1912,7 @@ void AsyncSocket::invalidState(ConnectCallback* callback) {
AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
"connect() called with socket in invalid state");
connectEndTime_ = std::chrono::steady_clock::now();
if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
if (callback) {
callback->connectErr(ex);
......@@ -1947,6 +1929,24 @@ void AsyncSocket::invalidState(ConnectCallback* callback) {
}
}
void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
connectEndTime_ = std::chrono::steady_clock::now();
if (connectCallback_) {
ConnectCallback* callback = connectCallback_;
connectCallback_ = nullptr;
callback->connectErr(ex);
}
}
void AsyncSocket::invokeConnectSuccess() {
connectEndTime_ = std::chrono::steady_clock::now();
if (connectCallback_) {
ConnectCallback* callback = connectCallback_;
connectCallback_ = nullptr;
callback->connectSuccess();
}
}
void AsyncSocket::invalidState(ReadCallback* callback) {
VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
<< "): setReadCallback(" << callback
......
......@@ -28,6 +28,7 @@
#include <folly/io/async/EventHandler.h>
#include <folly/io/async/DelayedDestruction.h>
#include <chrono>
#include <memory>
#include <map>
......@@ -395,6 +396,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
return getAppBytesReceived();
}
std::chrono::nanoseconds getConnectTime() const {
return connectEndTime_ - connectStartTime_;
}
// Methods controlling socket options
/**
......@@ -752,6 +757,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
const AsyncSocketException& ex);
void failWrite(const char* fn, const AsyncSocketException& ex);
void failAllWrites(const AsyncSocketException& ex);
void invokeConnectErr(const AsyncSocketException& ex);
void invokeConnectSuccess();
void invalidState(ConnectCallback* callback);
void invalidState(ReadCallback* callback);
void invalidState(WriteCallback* callback);
......@@ -783,6 +790,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
bool peek_{false}; // Peek bytes.
int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
std::chrono::steady_clock::time_point connectStartTime_;
std::chrono::steady_clock::time_point connectEndTime_;
};
......
......@@ -1054,9 +1054,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
/**
......@@ -1090,9 +1092,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(!client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(!server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
/**
......@@ -1128,9 +1132,11 @@ TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
EXPECT_TRUE(!client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
/**
......@@ -1171,9 +1177,11 @@ TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
/**
......@@ -1205,9 +1213,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
/**
......@@ -1240,9 +1250,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
EXPECT_TRUE(!client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
/**
......@@ -1282,9 +1294,11 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
......@@ -1321,6 +1335,8 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
EXPECT_FALSE(server.handshakeVerify_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_LE(0, server.handshakeTime.count());
}
TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
......
......@@ -1130,6 +1130,7 @@ class SSLHandshakeBase :
bool handshakeVerify_;
bool handshakeSuccess_;
bool handshakeError_;
std::chrono::nanoseconds handshakeTime;
protected:
AsyncSSLSocket::UniquePtr socket_;
......@@ -1149,12 +1150,14 @@ class SSLHandshakeBase :
void handshakeSuc(AsyncSSLSocket*) noexcept override {
handshakeSuccess_ = true;
handshakeTime = socket_->getHandshakeTime();
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
handshakeError_ = true;
handshakeTime = socket_->getHandshakeTime();
}
// WriteCallback
......
......@@ -96,6 +96,7 @@ TEST(AsyncSocketTest, Connect) {
evb.loop();
CHECK_EQ(cb.state, STATE_SUCCEEDED);
EXPECT_LE(0, socket->getConnectTime().count());
}
/**
......@@ -115,6 +116,7 @@ TEST(AsyncSocketTest, ConnectRefused) {
CHECK_EQ(cb.state, STATE_FAILED);
CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN);
EXPECT_LE(0, socket->getConnectTime().count());
}
/**
......@@ -152,6 +154,7 @@ TEST(AsyncSocketTest, ConnectTimeout) {
folly::SocketAddress peer;
socket->getPeerAddress(&peer);
CHECK_EQ(peer, addr);
EXPECT_LE(0, socket->getConnectTime().count());
}
/**
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment