Commit 12ace861 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 2

Invoking correct callback during TFO fallback

Summary:
If we fallback from SSL to TFO and the connection times
out, invokeConnectSuccess tries to deliver the connectError,
however we've already delivered the connect callback to the user.

This is bad because we have no way of reporting an error back.
This changes it so that when using SSL and we're scheduling a timeout
when we're falling back, we will schedule a timeout of our own which
will invoke AsyncSSLSocket's timeoutExpired. This will return a handshakeError
instead to the client.

Reviewed By: yfeldblum

Differential Revision: D3708699

fbshipit-source-id: 41fe668f00972c0875bb0318c6a6de863d3ab8f9
parent 457fa717
...@@ -253,7 +253,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx, ...@@ -253,7 +253,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb, bool deferSecurityNegotiation) : EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb), AsyncSocket(evb),
ctx_(ctx), ctx_(ctx),
handshakeTimeout_(this, evb) { handshakeTimeout_(this, evb),
connectionTimeout_(this, evb) {
init(); init();
if (deferSecurityNegotiation) { if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED; sslState_ = STATE_UNENCRYPTED;
...@@ -269,7 +270,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, ...@@ -269,7 +270,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
AsyncSocket(evb, fd), AsyncSocket(evb, fd),
server_(server), server_(server),
ctx_(ctx), ctx_(ctx),
handshakeTimeout_(this, evb) { handshakeTimeout_(this, evb),
connectionTimeout_(this, evb) {
init(); init();
if (server) { if (server) {
SSL_CTX_set_info_callback(ctx_->getSSLCtx(), SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
...@@ -587,6 +589,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept { ...@@ -587,6 +589,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept {
// We are expecting a callback in restartSSLAccept. The cache lookup // We are expecting a callback in restartSSLAccept. The cache lookup
// and rsa-call necessarily have pointers to this ssl socket, so delay // and rsa-call necessarily have pointers to this ssl socket, so delay
// the cleanup until he calls us back. // the cleanup until he calls us back.
} else if (state_ == StateEnum::CONNECTING) {
assert(sslState_ == STATE_CONNECTING);
DestructorGuard dg(this);
AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
"Fallback connect timed out during TFO");
failHandshake(__func__, ex);
} else { } else {
assert(state_ == StateEnum::ESTABLISHED && assert(state_ == StateEnum::ESTABLISHED &&
(sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING)); (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
...@@ -1157,15 +1165,45 @@ AsyncSSLSocket::handleConnect() noexcept { ...@@ -1157,15 +1165,45 @@ AsyncSSLSocket::handleConnect() noexcept {
AsyncSocket::handleInitialReadWrite(); AsyncSocket::handleInitialReadWrite();
} }
void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
connectionTimeout_.cancelTimeout();
AsyncSocket::invokeConnectErr(ex);
}
void AsyncSSLSocket::invokeConnectSuccess() { void AsyncSSLSocket::invokeConnectSuccess() {
connectionTimeout_.cancelTimeout();
if (sslState_ == SSLStateEnum::STATE_CONNECTING) { if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
// If we failed TFO, we'd fall back to trying to connect the socket, // If we failed TFO, we'd fall back to trying to connect the socket,
// to setup things like timeouts. // to setup things like timeouts.
startSSLConnect(); startSSLConnect();
} }
// still invoke the base class since it re-sets the connect time.
AsyncSocket::invokeConnectSuccess(); AsyncSocket::invokeConnectSuccess();
} }
void AsyncSSLSocket::scheduleConnectTimeout() {
if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
// We fell back from TFO, and need to set the timeouts.
// We will not have a connect callback in this case, thus if the timer
// expires we would have no-one to notify.
// Thus we should reset even the connect timers to point to the handshake
// timeouts.
assert(connectCallback_ == nullptr);
// We use a different connect timeout here than the handshake timeout, so
// that we can disambiguate the 2 timers.
int timeout = connectTimeout_.count();
if (timeout > 0) {
if (!connectionTimeout_.scheduleTimeout(timeout)) {
throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
withAddr("failed to schedule AsyncSSLSocket connect timeout"));
}
}
return;
}
AsyncSocket::scheduleConnectTimeout();
}
void AsyncSSLSocket::setReadCB(ReadCallback *callback) { void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
// turn on the buffer movable in openssl // turn on the buffer movable in openssl
......
...@@ -136,6 +136,20 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -136,6 +136,20 @@ class AsyncSSLSocket : public virtual AsyncSocket {
AsyncSSLSocket* sslSocket_; AsyncSSLSocket* sslSocket_;
}; };
// Timer for if we fallback from SSL connects to TCP connects
class ConnectionTimeout : public AsyncTimeout {
public:
ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
: AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
virtual void timeoutExpired() noexcept override {
sslSocket_->timeoutExpired();
}
private:
AsyncSSLSocket* sslSocket_;
};
/** /**
* Create a client AsyncSSLSocket * Create a client AsyncSSLSocket
*/ */
...@@ -811,7 +825,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -811,7 +825,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
void invokeHandshakeErr(const AsyncSocketException& ex); void invokeHandshakeErr(const AsyncSocketException& ex);
void invokeHandshakeCB(); void invokeHandshakeCB();
void invokeConnectErr(const AsyncSocketException& ex) override;
void invokeConnectSuccess() override; void invokeConnectSuccess() override;
void scheduleConnectTimeout() override;
void cacheLocalPeerAddr(); void cacheLocalPeerAddr();
...@@ -836,6 +852,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -836,6 +852,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
SSL* ssl_{nullptr}; SSL* ssl_{nullptr};
SSL_SESSION *sslSession_{nullptr}; SSL_SESSION *sslSession_{nullptr};
HandshakeTimeout handshakeTimeout_; HandshakeTimeout handshakeTimeout_;
ConnectionTimeout connectionTimeout_;
// whether the SSL session was resumed using session ID or not // whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false}; bool sessionIDResumed_{false};
......
...@@ -472,7 +472,8 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { ...@@ -472,7 +472,8 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
if (rv < 0) { if (rv < 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
if (errnoCopy == EINPROGRESS) { if (errnoCopy == EINPROGRESS) {
scheduleConnectTimeoutAndRegisterForEvents(); scheduleConnectTimeout();
registerForConnectEvents();
} else { } else {
throw AsyncSocketException( throw AsyncSocketException(
AsyncSocketException::NOT_OPEN, AsyncSocketException::NOT_OPEN,
...@@ -483,7 +484,7 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { ...@@ -483,7 +484,7 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
return rv; return rv;
} }
void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() { void AsyncSocket::scheduleConnectTimeout() {
// Connection in progress. // Connection in progress.
int timeout = connectTimeout_.count(); int timeout = connectTimeout_.count();
if (timeout > 0) { if (timeout > 0) {
...@@ -494,7 +495,9 @@ void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() { ...@@ -494,7 +495,9 @@ void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
withAddr("failed to schedule AsyncSocket connect timeout")); withAddr("failed to schedule AsyncSocket connect timeout"));
} }
} }
}
void AsyncSocket::registerForConnectEvents() {
// Register for write events, so we'll // Register for write events, so we'll
// be notified when the connection finishes/fails. // be notified when the connection finishes/fails.
// Note that we don't register for a persistent event here. // Note that we don't register for a persistent event here.
...@@ -1781,7 +1784,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { ...@@ -1781,7 +1784,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
// cookie. // cookie.
state_ = StateEnum::CONNECTING; state_ = StateEnum::CONNECTING;
try { try {
scheduleConnectTimeoutAndRegisterForEvents(); scheduleConnectTimeout();
registerForConnectEvents();
} catch (const AsyncSocketException& ex) { } catch (const AsyncSocketException& ex) {
return WriteResult( return WriteResult(
WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex)); WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
......
...@@ -838,7 +838,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -838,7 +838,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
int socketConnect(const struct sockaddr* addr, socklen_t len); int socketConnect(const struct sockaddr* addr, socklen_t len);
void scheduleConnectTimeoutAndRegisterForEvents(); virtual void scheduleConnectTimeout();
void registerForConnectEvents();
bool updateEventRegistration(); bool updateEventRegistration();
...@@ -869,7 +870,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -869,7 +870,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
const AsyncSocketException& ex); const AsyncSocketException& ex);
void failWrite(const char* fn, const AsyncSocketException& ex); void failWrite(const char* fn, const AsyncSocketException& ex);
void failAllWrites(const AsyncSocketException& ex); void failAllWrites(const AsyncSocketException& ex);
void invokeConnectErr(const AsyncSocketException& ex); virtual void invokeConnectErr(const AsyncSocketException& ex);
virtual void invokeConnectSuccess(); virtual void invokeConnectSuccess();
void invalidState(ConnectCallback* callback); void invalidState(ConnectCallback* callback);
void invalidState(ReadCallback* callback); void invalidState(ReadCallback* callback);
......
...@@ -1788,13 +1788,15 @@ class ConnCallback : public AsyncSocket::ConnectCallback { ...@@ -1788,13 +1788,15 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
state = State::SUCCESS; state = State::SUCCESS;
} }
virtual void connectErr(const AsyncSocketException&) noexcept override { virtual void connectErr(const AsyncSocketException& ex) noexcept override {
state = State::ERROR; state = State::ERROR;
error = ex.what();
} }
enum class State { WAITING, SUCCESS, ERROR }; enum class State { WAITING, SUCCESS, ERROR };
State state{State::WAITING}; State state{State::WAITING};
std::string error;
}; };
template <class Cardinality> template <class Cardinality>
...@@ -1869,7 +1871,7 @@ TEST(AsyncSSLSocketTest, ConnectTFOTimeout) { ...@@ -1869,7 +1871,7 @@ TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
std::make_shared<BlockingSocket>(server.getAddress(), sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO(); socket->enableTFO();
EXPECT_THROW( EXPECT_THROW(
socket->open(std::chrono::milliseconds(1)), AsyncSocketException); socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
} }
TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
...@@ -1888,6 +1890,25 @@ TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { ...@@ -1888,6 +1890,25 @@ TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
} }
TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
// Start listening on a local port
EmptyReadCallback readCallback;
HandshakeCallback handshakeCallback(
&readCallback, HandshakeCallback::EXPECT_ERROR);
HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, true);
EventBase evb;
auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 100);
evb.loop();
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
}
#endif #endif
} // namespace } // namespace
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment