Commit 727f779b authored by Dave Watson's avatar Dave Watson Committed by Sara Golemon

AsyncSSLSocket StartTLS

Summary:
Adds a StartTLS mode to AsyncSSLSocket.  Previously I could only find anyone doing something like this by using AsyncSocket, calling detachFd, then creating a new AsyncSSLSocket, and calling sslConn/sslAccept.

That had a couple downsides: 1) All pointers to the previous AsyncSocket become invalid and similarly 2) have to be super careful reads/writes happen on the correct socket, are flushed before changing socket types, etc.

This makes it super easy to just use the same AsyncSSLSocket for everything:
a) Create AsyncSSLSocket in StartTLS mode
b) send/recv anything
c) Call sslAccept/sslConn.  Existing writes are still flushed in the correct order, any additional writes are buffered until handshake completes
d) Start receiving encrypted data.

I made it a new mode (vs. the default), since it seems bad to unintentionally send unencrypted data.

Use case is easy secure thrift upgrade (similar to how current kerberos does it)

Test Plan: New unittest

Reviewed By: afrind@fb.com

Subscribers: doug, ssl-diffs@, folly-diffs@, yfeldblum, chalfant, haijunz, andrewcox, alandau, alikhtarov, jsedgwick, simpkins

FB internal diff: D2120114

Signature: t1:2120114:1433798448:caeddc8feb6cc10fb34200ba97ea323bcaf09f7a
parent 2ef7ef74
......@@ -253,18 +253,22 @@ SSLException::SSLException(int sslError, int errno_copy):
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb) :
EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb),
ctx_(ctx),
handshakeTimeout_(this, evb) {
init();
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
}
}
/**
* Create a server/client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd, bool server) :
EventBase* evb, int fd, bool server,
bool deferSecurityNegotiation) :
AsyncSocket(evb, fd),
server_(server),
ctx_(ctx),
......@@ -274,6 +278,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
AsyncSSLSocket::sslInfoCallback);
}
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
}
}
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
......@@ -283,8 +290,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb,
const std::string& serverName) :
AsyncSSLSocket(ctx, evb) {
const std::string& serverName,
bool deferSecurityNegotiation) :
AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
tlsextHostname_ = serverName;
}
......@@ -294,8 +302,9 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd,
const std::string& serverName) :
AsyncSSLSocket(ctx, evb, fd, false) {
const std::string& serverName,
bool deferSecurityNegotiation) :
AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
tlsextHostname_ = serverName;
}
#endif
......@@ -374,7 +383,7 @@ void AsyncSSLSocket::shutdownWriteNow() {
bool AsyncSSLSocket::good() const {
return (AsyncSocket::good() &&
(sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
sslState_ == STATE_ESTABLISHED));
sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
}
// The TAsyncTransport definition of 'good' states that the transport is
......@@ -468,7 +477,9 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
if (!server_ || (sslState_ != STATE_UNINIT &&
sslState_ != STATE_UNENCRYPTED) ||
handshakeCallback_ != nullptr) {
return invalidState(callback);
}
......@@ -674,7 +685,9 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
STATE_UNENCRYPTED) ||
handshakeCallback_ != nullptr) {
return invalidState(callback);
}
......@@ -1078,6 +1091,10 @@ AsyncSSLSocket::handleRead() noexcept {
ssize_t
AsyncSSLSocket::performRead(void* buf, size_t buflen) {
if (sslState_ == STATE_UNENCRYPTED) {
return AsyncSocket::performRead(buf, buflen);
}
errno = 0;
ssize_t bytes = SSL_read(ssl_, buf, buflen);
if (server_ && renegotiateAttempted_) {
......@@ -1169,6 +1186,10 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
if (sslState_ == STATE_UNENCRYPTED) {
return AsyncSocket::performWrite(
vec, count, flags, countWritten, partialWritten);
}
if (sslState_ != STATE_ESTABLISHED) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
......
......@@ -162,7 +162,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb);
EventBase* evb, bool deferSecurityNegotiation = false);
/**
* Create a server/client AsyncSSLSocket from an already connected
......@@ -178,9 +178,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* @param evb EventBase that will manage this socket.
* @param fd File descriptor to take over (should be a connected socket).
* @param server Is socket in server mode?
* @param deferSecurityNegotiation
* unencrypted data can be sent before sslConn/Accept
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, int fd, bool server = true);
EventBase* evb, int fd,
bool server = true, bool deferSecurityNegotiation = false);
/**
......@@ -188,9 +191,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, int fd, bool server=true) {
EventBase* evb, int fd, bool server=true,
bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb, fd, server),
new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation),
Destructor());
}
......@@ -199,9 +203,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb) {
EventBase* evb, bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb),
new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation),
Destructor());
}
......@@ -213,7 +217,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb,
const std::string& serverName);
const std::string& serverName,
bool deferSecurityNegotiation = false);
/**
* Create a client AsyncSSLSocket from an already connected
......@@ -233,14 +238,16 @@ class AsyncSSLSocket : public virtual AsyncSocket {
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
int fd,
const std::string& serverName);
const std::string& serverName,
bool deferSecurityNegotiation = false);
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
const std::string& serverName) {
const std::string& serverName,
bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb, serverName),
new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
Destructor());
}
#endif
......@@ -336,6 +343,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
enum SSLStateEnum {
STATE_UNINIT,
STATE_UNENCRYPTED,
STATE_ACCEPTING,
STATE_CACHE_LOOKUP,
STATE_RSA_ASYNC_PENDING,
......
......@@ -1262,8 +1262,90 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
socket->setMinWriteSize(50000);
EXPECT_EQ(50000, socket->getMinWriteSize());
}
class ReadCallbackTerminator : public ReadCallback {
public:
ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
: ReadCallback(wcb)
, base_(base) {}
// Do not write data back, terminate the loop.
void readDataAvailable(size_t len) noexcept override {
std::cerr << "readDataAvailable, len " << len << std::endl;
currentBuffer.length = len;
buffers.push_back(currentBuffer);
currentBuffer.reset();
state = STATE_SUCCEEDED;
socket_->setReadCB(nullptr);
base_->terminateLoopSoon();
}
private:
EventBase* base_;
};
/**
* Test a full unencrypted codepath
*/
TEST(AsyncSSLSocketTest, UnencryptedTest) {
EventBase base;
auto clientCtx = std::make_shared<folly::SSLContext>();
auto serverCtx = std::make_shared<folly::SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
auto client = AsyncSSLSocket::newSocket(
clientCtx, &base, fds[0], false, true);
auto server = AsyncSSLSocket::newSocket(
serverCtx, &base, fds[1], true, true);
ReadCallbackTerminator readCallback(&base, nullptr);
server->setReadCB(&readCallback);
readCallback.setSocket(server);
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
client->write(nullptr, buf, sizeof(buf));
// Check that bytes are unencrypted
char c;
EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
EXPECT_EQ('a', c);
EventBaseAborter eba(&base, 3000);
base.loop();
EXPECT_EQ(1, readCallback.buffers.size());
EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
server->setReadCB(&readCallback);
// Unencrypted
server->sslAccept(nullptr);
client->sslConn(nullptr);
// Do NOT wait for handshake, writing should be queued and happen after
client->write(nullptr, buf, sizeof(buf));
// Check that bytes are *not* unencrypted
char c2;
EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
EXPECT_NE('a', c2);
base.loop();
EXPECT_EQ(2, readCallback.buffers.size());
EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
}
} // namespace
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment