Commit 727f779b authored by Dave Watson's avatar Dave Watson Committed by Sara Golemon

AsyncSSLSocket StartTLS

Summary:
Adds a StartTLS mode to AsyncSSLSocket.  Previously I could only find anyone doing something like this by using AsyncSocket, calling detachFd, then creating a new AsyncSSLSocket, and calling sslConn/sslAccept.

That had a couple downsides: 1) All pointers to the previous AsyncSocket become invalid and similarly 2) have to be super careful reads/writes happen on the correct socket, are flushed before changing socket types, etc.

This makes it super easy to just use the same AsyncSSLSocket for everything:
a) Create AsyncSSLSocket in StartTLS mode
b) send/recv anything
c) Call sslAccept/sslConn.  Existing writes are still flushed in the correct order, any additional writes are buffered until handshake completes
d) Start receiving encrypted data.

I made it a new mode (vs. the default), since it seems bad to unintentionally send unencrypted data.

Use case is easy secure thrift upgrade (similar to how current kerberos does it)

Test Plan: New unittest

Reviewed By: afrind@fb.com

Subscribers: doug, ssl-diffs@, folly-diffs@, yfeldblum, chalfant, haijunz, andrewcox, alandau, alikhtarov, jsedgwick, simpkins

FB internal diff: D2120114

Signature: t1:2120114:1433798448:caeddc8feb6cc10fb34200ba97ea323bcaf09f7a
parent 2ef7ef74
...@@ -253,18 +253,22 @@ SSLException::SSLException(int sslError, int errno_copy): ...@@ -253,18 +253,22 @@ SSLException::SSLException(int sslError, int errno_copy):
* Create a client AsyncSSLSocket * Create a client AsyncSSLSocket
*/ */
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx, AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb) : EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb), AsyncSocket(evb),
ctx_(ctx), ctx_(ctx),
handshakeTimeout_(this, evb) { handshakeTimeout_(this, evb) {
init(); init();
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
}
} }
/** /**
* Create a server/client AsyncSSLSocket * Create a server/client AsyncSSLSocket
*/ */
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd, bool server) : EventBase* evb, int fd, bool server,
bool deferSecurityNegotiation) :
AsyncSocket(evb, fd), AsyncSocket(evb, fd),
server_(server), server_(server),
ctx_(ctx), ctx_(ctx),
...@@ -274,6 +278,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, ...@@ -274,6 +278,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
SSL_CTX_set_info_callback(ctx_->getSSLCtx(), SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
AsyncSSLSocket::sslInfoCallback); AsyncSSLSocket::sslInfoCallback);
} }
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
}
} }
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
...@@ -283,8 +290,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, ...@@ -283,8 +290,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
*/ */
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx, AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb, EventBase* evb,
const std::string& serverName) : const std::string& serverName,
AsyncSSLSocket(ctx, evb) { bool deferSecurityNegotiation) :
AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
tlsextHostname_ = serverName; tlsextHostname_ = serverName;
} }
...@@ -294,8 +302,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx, ...@@ -294,8 +302,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
*/ */
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd, EventBase* evb, int fd,
const std::string& serverName) : const std::string& serverName,
AsyncSSLSocket(ctx, evb, fd, false) { bool deferSecurityNegotiation) :
AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
tlsextHostname_ = serverName; tlsextHostname_ = serverName;
} }
#endif #endif
...@@ -374,7 +383,7 @@ void AsyncSSLSocket::shutdownWriteNow() { ...@@ -374,7 +383,7 @@ void AsyncSSLSocket::shutdownWriteNow() {
bool AsyncSSLSocket::good() const { bool AsyncSSLSocket::good() const {
return (AsyncSocket::good() && return (AsyncSocket::good() &&
(sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING || (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
sslState_ == STATE_ESTABLISHED)); sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
} }
// The TAsyncTransport definition of 'good' states that the transport is // The TAsyncTransport definition of 'good' states that the transport is
...@@ -468,7 +477,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, ...@@ -468,7 +477,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
verifyPeer_ = verifyPeer; verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state // Make sure we're in the uninitialized state
if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) { if (!server_ || (sslState_ != STATE_UNINIT &&
sslState_ != STATE_UNENCRYPTED) ||
handshakeCallback_ != nullptr) {
return invalidState(callback); return invalidState(callback);
} }
...@@ -674,7 +685,9 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, ...@@ -674,7 +685,9 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
verifyPeer_ = verifyPeer; verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state // Make sure we're in the uninitialized state
if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) { if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
STATE_UNENCRYPTED) ||
handshakeCallback_ != nullptr) {
return invalidState(callback); return invalidState(callback);
} }
...@@ -1078,6 +1091,10 @@ AsyncSSLSocket::handleRead() noexcept { ...@@ -1078,6 +1091,10 @@ AsyncSSLSocket::handleRead() noexcept {
ssize_t ssize_t
AsyncSSLSocket::performRead(void* buf, size_t buflen) { AsyncSSLSocket::performRead(void* buf, size_t buflen) {
if (sslState_ == STATE_UNENCRYPTED) {
return AsyncSocket::performRead(buf, buflen);
}
errno = 0; errno = 0;
ssize_t bytes = SSL_read(ssl_, buf, buflen); ssize_t bytes = SSL_read(ssl_, buf, buflen);
if (server_ && renegotiateAttempted_) { if (server_ && renegotiateAttempted_) {
...@@ -1169,6 +1186,10 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, ...@@ -1169,6 +1186,10 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
WriteFlags flags, WriteFlags flags,
uint32_t* countWritten, uint32_t* countWritten,
uint32_t* partialWritten) { uint32_t* partialWritten) {
if (sslState_ == STATE_UNENCRYPTED) {
return AsyncSocket::performWrite(
vec, count, flags, countWritten, partialWritten);
}
if (sslState_ != STATE_ESTABLISHED) { if (sslState_ != STATE_ESTABLISHED) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
......
...@@ -162,7 +162,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -162,7 +162,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* Create a client AsyncSSLSocket * Create a client AsyncSSLSocket
*/ */
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx, AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb); EventBase* evb, bool deferSecurityNegotiation = false);
/** /**
* Create a server/client AsyncSSLSocket from an already connected * Create a server/client AsyncSSLSocket from an already connected
...@@ -178,9 +178,12 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -178,9 +178,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* @param evb EventBase that will manage this socket. * @param evb EventBase that will manage this socket.
* @param fd File descriptor to take over (should be a connected socket). * @param fd File descriptor to take over (should be a connected socket).
* @param server Is socket in server mode? * @param server Is socket in server mode?
* @param deferSecurityNegotiation
* unencrypted data can be sent before sslConn/Accept
*/ */
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx, AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, int fd, bool server = true); EventBase* evb, int fd,
bool server = true, bool deferSecurityNegotiation = false);
/** /**
...@@ -188,9 +191,10 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -188,9 +191,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/ */
static std::shared_ptr<AsyncSSLSocket> newSocket( static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx, const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, int fd, bool server=true) { EventBase* evb, int fd, bool server=true,
bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>( return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb, fd, server), new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation),
Destructor()); Destructor());
} }
...@@ -199,9 +203,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -199,9 +203,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/ */
static std::shared_ptr<AsyncSSLSocket> newSocket( static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx, const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb) { EventBase* evb, bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>( return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb), new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation),
Destructor()); Destructor());
} }
...@@ -213,7 +217,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -213,7 +217,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/ */
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx, AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb, EventBase* evb,
const std::string& serverName); const std::string& serverName,
bool deferSecurityNegotiation = false);
/** /**
* Create a client AsyncSSLSocket from an already connected * Create a client AsyncSSLSocket from an already connected
...@@ -233,14 +238,16 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -233,14 +238,16 @@ class AsyncSSLSocket : public virtual AsyncSocket {
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx, AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, EventBase* evb,
int fd, int fd,
const std::string& serverName); const std::string& serverName,
bool deferSecurityNegotiation = false);
static std::shared_ptr<AsyncSSLSocket> newSocket( static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx, const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, EventBase* evb,
const std::string& serverName) { const std::string& serverName,
bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>( return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb, serverName), new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
Destructor()); Destructor());
} }
#endif #endif
...@@ -336,6 +343,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -336,6 +343,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
enum SSLStateEnum { enum SSLStateEnum {
STATE_UNINIT, STATE_UNINIT,
STATE_UNENCRYPTED,
STATE_ACCEPTING, STATE_ACCEPTING,
STATE_CACHE_LOOKUP, STATE_CACHE_LOOKUP,
STATE_RSA_ASYNC_PENDING, STATE_RSA_ASYNC_PENDING,
......
...@@ -1262,8 +1262,90 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) { ...@@ -1262,8 +1262,90 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
socket->setMinWriteSize(50000); socket->setMinWriteSize(50000);
EXPECT_EQ(50000, socket->getMinWriteSize()); EXPECT_EQ(50000, socket->getMinWriteSize());
} }
class ReadCallbackTerminator : public ReadCallback {
public:
ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
: ReadCallback(wcb)
, base_(base) {}
// Do not write data back, terminate the loop.
void readDataAvailable(size_t len) noexcept override {
std::cerr << "readDataAvailable, len " << len << std::endl;
currentBuffer.length = len;
buffers.push_back(currentBuffer);
currentBuffer.reset();
state = STATE_SUCCEEDED;
socket_->setReadCB(nullptr);
base_->terminateLoopSoon();
}
private:
EventBase* base_;
};
/**
* Test a full unencrypted codepath
*/
TEST(AsyncSSLSocketTest, UnencryptedTest) {
EventBase base;
auto clientCtx = std::make_shared<folly::SSLContext>();
auto serverCtx = std::make_shared<folly::SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
auto client = AsyncSSLSocket::newSocket(
clientCtx, &base, fds[0], false, true);
auto server = AsyncSSLSocket::newSocket(
serverCtx, &base, fds[1], true, true);
ReadCallbackTerminator readCallback(&base, nullptr);
server->setReadCB(&readCallback);
readCallback.setSocket(server);
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
client->write(nullptr, buf, sizeof(buf));
// Check that bytes are unencrypted
char c;
EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
EXPECT_EQ('a', c);
EventBaseAborter eba(&base, 3000);
base.loop();
EXPECT_EQ(1, readCallback.buffers.size());
EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
server->setReadCB(&readCallback);
// Unencrypted
server->sslAccept(nullptr);
client->sslConn(nullptr);
// Do NOT wait for handshake, writing should be queued and happen after
client->write(nullptr, buf, sizeof(buf));
// Check that bytes are *not* unencrypted
char c2;
EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
EXPECT_NE('a', c2);
base.loop();
EXPECT_EQ(2, readCallback.buffers.size());
EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
} }
} // namespace
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite // init_unit_test_suite
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment