Commit 3caa3408 authored by Christopher Dykes's avatar Christopher Dykes Committed by Facebook Github Bot

Use std::chrono for timeouts to sslAccept and sslConn in AsyncSSLSocket

Summary: Because `std::chrono` makes it clear exactly what unit of time is in use.

Reviewed By: yfeldblum

Differential Revision: D4363560

fbshipit-source-id: 47aeef21f842f39d8e886bec441897ecf1f3761b
parent ecd7292d
...@@ -110,7 +110,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, ...@@ -110,7 +110,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
return; return;
} }
} }
sslSocket_->sslConn(this, timeoutLeft); sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft));
} }
void connectErr(const AsyncSocketException& ex) noexcept override { void connectErr(const AsyncSocketException& ex) noexcept override {
...@@ -417,8 +417,10 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) { ...@@ -417,8 +417,10 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
} }
} }
void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, void AsyncSSLSocket::sslAccept(
const SSLContext::SSLVerifyPeerEnum& verifyPeer) { HandshakeCB* callback,
std::chrono::milliseconds timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this); DestructorGuard dg(this);
assert(eventBase_->isInEventBaseThread()); assert(eventBase_->isInEventBaseThread());
verifyPeer_ = verifyPeer; verifyPeer_ = verifyPeer;
...@@ -443,7 +445,7 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, ...@@ -443,7 +445,7 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
sslState_ = STATE_ACCEPTING; sslState_ = STATE_ACCEPTING;
handshakeCallback_ = callback; handshakeCallback_ = callback;
if (timeout > 0) { if (timeout > std::chrono::milliseconds::zero()) {
handshakeTimeout_.scheduleTimeout(timeout); handshakeTimeout_.scheduleTimeout(timeout);
} }
...@@ -680,8 +682,10 @@ bool AsyncSSLSocket::setupSSLBio() { ...@@ -680,8 +682,10 @@ bool AsyncSSLSocket::setupSSLBio() {
return true; return true;
} }
void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, void AsyncSSLSocket::sslConn(
const SSLContext::SSLVerifyPeerEnum& verifyPeer) { HandshakeCB* callback,
std::chrono::milliseconds timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this); DestructorGuard dg(this);
assert(eventBase_->isInEventBaseThread()); assert(eventBase_->isInEventBaseThread());
...@@ -747,9 +751,8 @@ void AsyncSSLSocket::startSSLConnect() { ...@@ -747,9 +751,8 @@ void AsyncSSLSocket::startSSLConnect() {
handshakeStartTime_ = std::chrono::steady_clock::now(); handshakeStartTime_ = std::chrono::steady_clock::now();
// Make end time at least >= start time. // Make end time at least >= start time.
handshakeEndTime_ = handshakeStartTime_; handshakeEndTime_ = handshakeStartTime_;
if (handshakeConnectTimeout_ > 0) { if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) {
handshakeTimeout_.scheduleTimeout( handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
std::chrono::milliseconds(handshakeConnectTimeout_));
} }
handleConnect(); handleConnect();
} }
......
...@@ -292,9 +292,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -292,9 +292,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* context by default, can be set explcitly to override the * context by default, can be set explcitly to override the
* method in the context * method in the context
*/ */
virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0, virtual void sslAccept(
HandshakeCB* callback,
std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
/** /**
* Invoke SSL accept following an asynchronous session cache lookup * Invoke SSL accept following an asynchronous session cache lookup
...@@ -332,9 +334,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -332,9 +334,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* SSL_VERIFY_PEER and invokes * SSL_VERIFY_PEER and invokes
* HandshakeCB::handshakeVer(). * HandshakeCB::handshakeVer().
*/ */
virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0, virtual void sslConn(
const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = HandshakeCB* callback,
folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
enum SSLStateEnum { enum SSLStateEnum {
STATE_UNINIT, STATE_UNINIT,
...@@ -810,7 +814,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -810,7 +814,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// Time taken to complete the ssl handshake. // Time taken to complete the ssl handshake.
std::chrono::steady_clock::time_point handshakeStartTime_; std::chrono::steady_clock::time_point handshakeStartTime_;
std::chrono::steady_clock::time_point handshakeEndTime_; std::chrono::steady_clock::time_point handshakeEndTime_;
uint64_t handshakeConnectTimeout_{0}; std::chrono::milliseconds handshakeConnectTimeout_{0};
bool sessionResumptionAttempted_{false}; bool sessionResumptionAttempted_{false};
}; };
......
...@@ -451,7 +451,7 @@ public: ...@@ -451,7 +451,7 @@ public:
std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl; std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
hcb_->setSocket(sock); hcb_->setSocket(sock);
sock->sslAccept(hcb_, timeout_); sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
EXPECT_EQ(sock->getSSLState(), EXPECT_EQ(sock->getSSLState(),
AsyncSSLSocket::STATE_ACCEPTING); AsyncSSLSocket::STATE_ACCEPTING);
...@@ -515,7 +515,7 @@ public: ...@@ -515,7 +515,7 @@ public:
std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl; std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
hcb_->setSocket(sock); hcb_->setSocket(sock);
sock->sslAccept(hcb_, timeout_); sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
ASSERT_TRUE((sock->getSSLState() == ASSERT_TRUE((sock->getSSLState() ==
AsyncSSLSocket::STATE_ACCEPTING) || AsyncSSLSocket::STATE_ACCEPTING) ||
(sock->getSSLState() == (sock->getSSLState() ==
...@@ -748,7 +748,7 @@ class BlockingWriteClient : ...@@ -748,7 +748,7 @@ class BlockingWriteClient :
} }
} }
socket_->sslConn(this, 100); socket_->sslConn(this, std::chrono::milliseconds(100));
} }
struct iovec* getIovec() const { struct iovec* getIovec() const {
...@@ -794,7 +794,7 @@ class BlockingWriteServer : ...@@ -794,7 +794,7 @@ class BlockingWriteServer :
bufSize_(2500 * 2000), bufSize_(2500 * 2000),
bytesRead_(0) { bytesRead_(0) {
buf_.reset(new uint8_t[bufSize_]); buf_.reset(new uint8_t[bufSize_]);
socket_->sslAccept(this, 100); socket_->sslAccept(this, std::chrono::milliseconds(100));
} }
void checkBuffer(struct iovec* iov, uint32_t count) const { void checkBuffer(struct iovec* iov, uint32_t count) const {
...@@ -1293,7 +1293,7 @@ class SSLHandshakeClient : public SSLHandshakeBase { ...@@ -1293,7 +1293,7 @@ class SSLHandshakeClient : public SSLHandshakeBase {
bool preverifyResult, bool preverifyResult,
bool verifyResult) : bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslConn(this, 0); socket_->sslConn(this, std::chrono::milliseconds::zero());
} }
}; };
...@@ -1304,8 +1304,10 @@ class SSLHandshakeClientNoVerify : public SSLHandshakeBase { ...@@ -1304,8 +1304,10 @@ class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
bool preverifyResult, bool preverifyResult,
bool verifyResult) : bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslConn(this, 0, socket_->sslConn(
folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY); this,
std::chrono::milliseconds::zero(),
folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
} }
}; };
...@@ -1316,8 +1318,10 @@ class SSLHandshakeClientDoVerify : public SSLHandshakeBase { ...@@ -1316,8 +1318,10 @@ class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
bool preverifyResult, bool preverifyResult,
bool verifyResult) : bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslConn(this, 0, socket_->sslConn(
folly::SSLContext::SSLVerifyPeerEnum::VERIFY); this,
std::chrono::milliseconds::zero(),
folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
} }
}; };
...@@ -1328,7 +1332,7 @@ class SSLHandshakeServer : public SSLHandshakeBase { ...@@ -1328,7 +1332,7 @@ class SSLHandshakeServer : public SSLHandshakeBase {
bool preverifyResult, bool preverifyResult,
bool verifyResult) bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslAccept(this, 0); socket_->sslAccept(this, std::chrono::milliseconds::zero());
} }
}; };
...@@ -1340,7 +1344,7 @@ class SSLHandshakeServerParseClientHello : public SSLHandshakeBase { ...@@ -1340,7 +1344,7 @@ class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
bool verifyResult) bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->enableClientHelloParsing(); socket_->enableClientHelloParsing();
socket_->sslAccept(this, 0); socket_->sslAccept(this, std::chrono::milliseconds::zero());
} }
std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_; std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
...@@ -1363,8 +1367,10 @@ class SSLHandshakeServerNoVerify : public SSLHandshakeBase { ...@@ -1363,8 +1367,10 @@ class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
bool preverifyResult, bool preverifyResult,
bool verifyResult) bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslAccept(this, 0, socket_->sslAccept(
folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY); this,
std::chrono::milliseconds::zero(),
folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
} }
}; };
...@@ -1375,8 +1381,10 @@ class SSLHandshakeServerDoVerify : public SSLHandshakeBase { ...@@ -1375,8 +1381,10 @@ class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
bool preverifyResult, bool preverifyResult,
bool verifyResult) bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) { : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslAccept(this, 0, socket_->sslAccept(
folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); this,
std::chrono::milliseconds::zero(),
folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
} }
}; };
......
...@@ -54,12 +54,11 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { ...@@ -54,12 +54,11 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
MOCK_METHOD1(setReadCB, void(ReadCallback*)); MOCK_METHOD1(setReadCB, void(ReadCallback*));
void sslConn( void sslConn(
AsyncSSLSocket::HandshakeCB* cb, AsyncSSLSocket::HandshakeCB* cb,
uint64_t timeout, std::chrono::milliseconds timeout,
const SSLContext::SSLVerifyPeerEnum& verify) const SSLContext::SSLVerifyPeerEnum& verify) override {
override { if (timeout > std::chrono::milliseconds::zero()) {
if (timeout > 0) { handshakeTimeout_.scheduleTimeout(timeout);
handshakeTimeout_.scheduleTimeout((uint32_t)timeout);
} }
state_ = StateEnum::ESTABLISHED; state_ = StateEnum::ESTABLISHED;
...@@ -70,11 +69,10 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { ...@@ -70,11 +69,10 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
} }
void sslAccept( void sslAccept(
AsyncSSLSocket::HandshakeCB* cb, AsyncSSLSocket::HandshakeCB* cb,
uint32_t timeout, std::chrono::milliseconds timeout,
const SSLContext::SSLVerifyPeerEnum& verify) const SSLContext::SSLVerifyPeerEnum& verify) override {
override { if (timeout > std::chrono::milliseconds::zero()) {
if (timeout > 0) {
handshakeTimeout_.scheduleTimeout(timeout); handshakeTimeout_.scheduleTimeout(timeout);
} }
...@@ -86,14 +84,18 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { ...@@ -86,14 +84,18 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
} }
MOCK_METHOD3( MOCK_METHOD3(
sslConnectMockable, sslConnectMockable,
void(AsyncSSLSocket::HandshakeCB*, uint64_t, void(
const SSLContext::SSLVerifyPeerEnum&)); AsyncSSLSocket::HandshakeCB*,
std::chrono::milliseconds,
const SSLContext::SSLVerifyPeerEnum&));
MOCK_METHOD3( MOCK_METHOD3(
sslAcceptMockable, sslAcceptMockable,
void(AsyncSSLSocket::HandshakeCB*, uint32_t, void(
const SSLContext::SSLVerifyPeerEnum&)); AsyncSSLSocket::HandshakeCB*,
std::chrono::milliseconds,
const SSLContext::SSLVerifyPeerEnum&));
}; };
}} }}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment