Commit 3bd6e353 authored by Andrew Huang's avatar Andrew Huang Committed by Facebook GitHub Bot

Handle session resumption with abstracted SSLSession

Summary: Add session resumption functionality in AsyncSSLSocket for the new abstracted SSLSession. Add logic to resume sessions correctly no matter which API is used, allowing for a piecewise transition to using the abstracted SSLSession.

Reviewed By: mingtaoy

Differential Revision: D20600380

fbshipit-source-id: bb81d5be95ccaa6486a7431391676f3e5c0b5b8e
parent 1972a4b5
......@@ -19,6 +19,7 @@
#include <folly/io/async/EventBase.h>
#include <folly/portability/Sockets.h>
#include <boost/variant.hpp>
#include <fcntl.h>
#include <sys/types.h>
#include <cerrno>
......@@ -43,6 +44,8 @@ using std::shared_ptr;
using folly::SpinLock;
using folly::io::Cursor;
using folly::ssl::SSLSessionUniquePtr;
using folly::ssl::detail::OpenSSLSession;
namespace {
using folly::AsyncSSLSocket;
......@@ -120,6 +123,48 @@ void* initsslBioMethod() {
namespace folly {
/**
* Variant visitors. Will be removed once sslSession_ is converted
* to a non-variant type.
*/
class RawSessionRetrievalVisitor : boost::static_visitor<SSLSessionUniquePtr> {
public:
SSLSessionUniquePtr operator()(const SSLSessionUniquePtr& sessionPtr) const {
if (!sessionPtr) {
return SSLSessionUniquePtr();
}
SSL_SESSION* session = sessionPtr.get();
if (session) {
SSL_SESSION_up_ref(session);
}
return SSLSessionUniquePtr(session);
}
SSLSessionUniquePtr operator()(
const shared_ptr<OpenSSLSession>& session) const {
if (!session) {
return ssl::SSLSessionUniquePtr();
}
return session->getActiveSession();
}
};
class SSLSessionRetrievalVisitor
: boost::static_visitor<shared_ptr<OpenSSLSession>> {
public:
shared_ptr<OpenSSLSession> operator()(const SSLSessionUniquePtr&) const {
return nullptr;
}
shared_ptr<OpenSSLSession> operator()(
const shared_ptr<OpenSSLSession>& session) const {
return session;
}
};
class AsyncSSLSocketConnector : public AsyncSocket::ConnectCallback,
public AsyncSSLSocket::HandshakeCB {
private:
......@@ -314,7 +359,7 @@ void AsyncSSLSocket::init() {
(void)sslBioMethodInitializer;
setup_SSL_CTX(ctx_->getSSLCtx());
sslSessionV2_ = std::make_shared<ssl::detail::OpenSSLSession>();
sslSession_ = std::make_shared<OpenSSLSession>();
}
void AsyncSSLSocket::closeNow() {
......@@ -329,11 +374,6 @@ void AsyncSSLSocket::closeNow() {
}
}
if (sslSession_ != nullptr) {
SSL_SESSION_free(sslSession_);
sslSession_ = nullptr;
}
sslState_ = STATE_CLOSED;
if (handshakeTimeout_.isScheduled()) {
......@@ -826,11 +866,10 @@ void AsyncSSLSocket::sslConn(
return failHandshake(__func__, *ex);
}
if (sslSession_ != nullptr) {
SSLSessionUniquePtr sessionPtr = getRawSSLSession();
if (sessionPtr) {
sessionResumptionAttempted_ = true;
SSL_set_session(ssl_.get(), sslSession_);
SSL_SESSION_free(sslSession_);
sslSession_ = nullptr;
SSL_set_session(ssl_.get(), sessionPtr.get());
}
#if FOLLY_OPENSSL_HAS_SNI
if (!tlsextHostname_.empty()) {
......@@ -861,11 +900,11 @@ SSL_SESSION* AsyncSSLSocket::getSSLSession() {
return SSL_get1_session(ssl_.get());
}
return sslSession_;
return getRawSSLSession().release();
}
std::shared_ptr<ssl::SSLSession> AsyncSSLSocket::getSSLSessionV2() {
return sslSessionV2_;
shared_ptr<ssl::SSLSession> AsyncSSLSocket::getSSLSessionV2() {
return getAbstractSSLSession();
}
const SSL* AsyncSSLSocket::getSSL() const {
......@@ -873,15 +912,19 @@ const SSL* AsyncSSLSocket::getSSL() const {
}
void AsyncSSLSocket::setSSLSession(SSL_SESSION* session, bool takeOwnership) {
if (sslSession_) {
SSL_SESSION_free(sslSession_);
}
sslSession_ = session;
if (!takeOwnership && session != nullptr) {
// Increment the reference count
// This API exists in BoringSSL and OpenSSL 1.1.0
SSL_SESSION_up_ref(session);
}
sslSession_ = SSLSessionUniquePtr(session);
}
void AsyncSSLSocket::setSSLSessionV2(shared_ptr<ssl::SSLSession> session) {
auto openSSLSession = std::dynamic_pointer_cast<OpenSSLSession>(session);
if (openSSLSession) {
sslSession_ = openSSLSession;
}
}
void AsyncSSLSocket::getSelectedNextProtocol(
......@@ -2082,4 +2125,14 @@ void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const {
}
}
SSLSessionUniquePtr AsyncSSLSocket::getRawSSLSession() const {
static auto visitor = RawSessionRetrievalVisitor();
return boost::apply_visitor(visitor, sslSession_);
}
shared_ptr<OpenSSLSession> AsyncSSLSocket::getAbstractSSLSession() const {
static auto visitor = SSLSessionRetrievalVisitor();
return boost::apply_visitor(visitor, sslSession_);
}
} // namespace folly
......@@ -16,6 +16,7 @@
#pragma once
#include <boost/variant.hpp>
#include <iomanip>
#include <folly/Optional.h>
......@@ -41,6 +42,12 @@ namespace folly {
class AsyncSSLSocketConnector;
namespace ssl {
namespace detail {
class OpenSSLSession;
} // namespace detail
} // namespace ssl
/**
* A class for performing asynchronous I/O on an SSL connection.
*
......@@ -485,6 +492,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
void setSSLSession(SSL_SESSION* session, bool takeOwnership = false);
/**
* Currently unsupported. Eventually intended to replace setSSLSession()
* once TLS 1.3 is enabled by default.
* Set the abstracted SSL session to be used during sslConn.
*/
void setSSLSessionV2(std::shared_ptr<ssl::SSLSession> session);
/**
* Get the name of the protocol selected by the client during
* Application Layer Protocol Negotiation (ALPN)
......@@ -900,6 +914,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
static void sslInfoCallback(const SSL* ssl, int where, int ret);
folly::ssl::SSLSessionUniquePtr getRawSSLSession() const;
std::shared_ptr<folly::ssl::detail::OpenSSLSession> getAbstractSSLSession()
const;
// Whether the current write to the socket should use MSG_MORE.
bool corkCurrentWrite_{false};
// SSL related members.
......@@ -915,7 +933,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// Callback for SSL_accept() or SSL_connect()
HandshakeCB* handshakeCallback_{nullptr};
ssl::SSLUniquePtr ssl_;
SSL_SESSION* sslSession_{nullptr};
// The SSL session. Which type the variant contains depends on the
// session API that is used. Is only intended to temporarily be a variant.
// Will be converted to a non-variant once SSL session APIs are merged.
boost::variant<
folly::ssl::SSLSessionUniquePtr,
std::shared_ptr<folly::ssl::detail::OpenSSLSession>>
sslSession_;
Timeout handshakeTimeout_;
Timeout connectionTimeout_;
......@@ -980,9 +1004,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
std::unique_ptr<ReadCallback> asyncOperationFinishCallback_;
// Whether this socket is currently waiting on SSL_accept
bool waitingOnAccept_{false};
// Unsupported. Currently used for getSSLSessionV2().
std::shared_ptr<ssl::SSLSession> sslSessionV2_{nullptr};
};
} // namespace folly
......@@ -38,7 +38,9 @@ class SimpleCallbackManager {
auto sslSession =
std::dynamic_pointer_cast<folly::ssl::detail::OpenSSLSession>(
socket->getSSLSessionV2());
sslSession->setActiveSession(std::move(sessionPtr));
if (sslSession) {
sslSession->setActiveSession(std::move(sessionPtr));
}
return 1;
}
};
......@@ -96,29 +98,130 @@ class SSLSessionTest : public testing::Test {
TEST_F(SSLSessionTest, BasicTest) {
std::shared_ptr<SSLSession> sslSession;
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
sslSession = clientPtr->getSSLSessionV2();
ASSERT_NE(sslSession, nullptr);
// The underlying SSL_SESSION is set in the session callback
// that is attached to the SSL_CTX. The session is guaranteed to
// be resumable here in TLS 1.2, but not in TLS 1.3
auto opensslSession = std::dynamic_pointer_cast<OpenSSLSession>(sslSession);
auto sessionPtr = opensslSession->getActiveSession();
ASSERT_NE(sessionPtr.get(), nullptr);
// Full handshake
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_FALSE(clientPtr->getSSLSessionReused());
sslSession = clientPtr->getSSLSessionV2();
ASSERT_NE(sslSession, nullptr);
// The underlying SSL_SESSION is set in the session callback
// that is attached to the SSL_CTX. The session is guaranteed to
// be resumable here in TLS 1.2, but not in TLS 1.3
auto opensslSession = std::dynamic_pointer_cast<OpenSSLSession>(sslSession);
auto sessionPtr = opensslSession->getActiveSession();
ASSERT_NE(sessionPtr.get(), nullptr);
}
// Session resumption
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
clientPtr->setSSLSessionV2(sslSession);
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_TRUE(clientPtr->getSSLSessionReused());
}
}
/**
* To be removed when getSSLSessionV2() and setSSLSessionV2()
* replace getSSLSession() and setSSLSession(),
* respectively.
*/
TEST_F(SSLSessionTest, BasicRegressionTest) {
SSL_SESSION* sslSession;
// Full handshake
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_FALSE(clientPtr->getSSLSessionReused());
sslSession = clientPtr->getSSLSession();
ASSERT_NE(sslSession, nullptr);
SSL_SESSION_free(sslSession);
}
// Session resumption
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
clientPtr->setSSLSession(sslSession);
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_TRUE(clientPtr->getSSLSessionReused());
}
}
TEST_F(SSLSessionTest, NullSessionResumptionTest) {
// Set null session, should result in full handshake
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
clientPtr->setSSLSessionV2(nullptr);
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_FALSE(clientPtr->getSSLSessionReused());
}
}
} // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment