Commit ffc3227f authored by Andrew Huang's avatar Andrew Huang Committed by Facebook GitHub Bot

Replace old AsyncSSLSocket session API with V2

Summary:
This new API has a few benefits:
1. It allows `getSSLSession` to support TLS 1.3 session resumption by returning a mutable session wrapper as opposed to the immutable `SSL_SESSION*` object.
2. OpenSSL `SSL_SESSION*` objects require the caller to keep accurate reference counts. Failure to do so can result in memory leaks or use-after-free errors.
3. This design abstracts away OpenSSL internals, which are unnecessary for the caller to perform session resumption.

Reviewed By: mingtaoy

Differential Revision: D24239802

fbshipit-source-id: cd3e90217717394f32dc6a2281e7a40c805990b2
parent 0cb5aa0f
......@@ -898,15 +898,7 @@ void AsyncSSLSocket::startSSLConnect() {
handleConnect();
}
SSL_SESSION* AsyncSSLSocket::getSSLSession() {
if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
return SSL_get1_session(ssl_.get());
}
return sslSessionManager_.getRawSession().release();
}
shared_ptr<ssl::SSLSession> AsyncSSLSocket::getSSLSessionV2() {
shared_ptr<ssl::SSLSession> AsyncSSLSocket::getSSLSession() {
return sslSessionManager_.getSession();
}
......@@ -914,17 +906,8 @@ const SSL* AsyncSSLSocket::getSSL() const {
return ssl_.get();
}
void AsyncSSLSocket::setSSLSession(SSL_SESSION* session, bool takeOwnership) {
if (!takeOwnership && session != nullptr) {
// Increment the reference count
// This API exists in BoringSSL and OpenSSL 1.1.0
SSL_SESSION_up_ref(session);
}
sslSessionManager_.setRawSession(SSLSessionUniquePtr(session));
}
void AsyncSSLSocket::setSSLSessionV2(shared_ptr<ssl::SSLSession> session) {
sslSessionManager_.setSession(session);
void AsyncSSLSocket::setSSLSession(shared_ptr<ssl::SSLSession> session) {
sslSessionManager_.setSession(std::move(session));
}
void AsyncSSLSocket::setRawSSLSession(SSLSessionUniquePtr session) {
......
......@@ -498,17 +498,13 @@ class AsyncSSLSocket : public AsyncSocket {
SSLStateEnum getSSLState() const { return sslState_; }
/**
* Get a handle to the negotiated SSL session. This increments the session
* refcount and must be deallocated by the caller.
*/
SSL_SESSION* getSSLSession();
/**
* Currently unsupported. Eventually intended to replace getSSLSession()
* once TLS 1.3 is enabled by default.
* Get an abstracted SSL Session.
* Retrieve the SSL session associated with this established connection.
*
* The SSL Session object is a copyable, opaque token that can be set on other
* unconnected AsyncSSLSockets. If AsyncSSLSocket::connect() is called with a
* previous session set, TLS resumption will be attempted.
*/
std::shared_ptr<ssl::SSLSession> getSSLSessionV2();
std::shared_ptr<ssl::SSLSession> getSSLSession();
/**
* Get a handle to the SSL struct.
......@@ -516,25 +512,13 @@ class AsyncSSLSocket : public AsyncSocket {
const SSL* getSSL() const;
/**
* DEPRECATED. Will eventually be removed. Please use setSSLSessionV2.
*
* Set the SSL session to be used during sslConn. AsyncSSLSocket will
* hold a reference to the session until it is destroyed or released by the
* underlying SSL structure.
*
* @param takeOwnership if true, AsyncSSLSocket will assume the caller's
* reference count to session.
*/
void setSSLSession(SSL_SESSION* session, bool takeOwnership = false);
/**
* Set the SSL session to be used during sslConn.
* Sets the SSL session that will be attempted for TLS resumption.
*/
void setSSLSessionV2(std::shared_ptr<ssl::SSLSession> session);
void setSSLSession(std::shared_ptr<ssl::SSLSession> session);
/**
* Note: This function exists for compatibility reasons. It is strongly
* recommended to use setSSLSessionV2 instead. After setRawSSLSession is
* recommended to use setSSLSession instead. After setRawSSLSession is
* called, subsequent calls to getSSLSession on the socket will return null.
*
* Set the SSL session to be used during sslConn.
......
......@@ -3204,7 +3204,7 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
// create another client, resuming with the prior session, but under a
// different common name.
clientSock = std::move(client).moveSocket();
resumptionSession = clientSock->getSSLSessionV2();
resumptionSession = clientSock->getSSLSession();
}
{
......@@ -3216,7 +3216,7 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
clientSock->setSSLSessionV2(resumptionSession);
clientSock->setSSLSession(resumptionSession);
clientSock->setServerName("Baz");
SSLHandshakeServerParseClientHello server(
std::move(serverSock), true, true);
......
......@@ -1168,7 +1168,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
void connect(bool writeNow = false) {
sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
if (session_ != nullptr) {
sslSocket_->setSSLSessionV2(session_);
sslSocket_->setSSLSession(session_);
}
requests_--;
sslSocket_->connect(this, address_, timeout_);
......@@ -1184,7 +1184,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
hit_++;
} else {
miss_++;
session_ = sslSocket_->getSSLSessionV2();
session_ = sslSocket_->getSSLSession();
}
// write()
......
......@@ -85,7 +85,7 @@ TEST_F(SSLSessionTest, BasicTest) {
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
sslSession = clientPtr->getSSLSessionV2();
sslSession = clientPtr->getSSLSession();
ASSERT_NE(sslSession, nullptr);
{
auto opensslSession =
......@@ -111,57 +111,6 @@ TEST_F(SSLSessionTest, BasicTest) {
}
}
// Session resumption
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
clientPtr->setSSLSessionV2(sslSession);
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_TRUE(clientPtr->getSSLSessionReused());
}
}
/**
* To be removed when getSSLSessionV2() and setSSLSessionV2()
* replace getSSLSession() and setSSLSession(),
* respectively.
*/
TEST_F(SSLSessionTest, BasicRegressionTest) {
SSL_SESSION* sslSession;
// Full handshake
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_FALSE(clientPtr->getSSLSessionReused());
sslSession = clientPtr->getSSLSession();
ASSERT_NE(sslSession, nullptr);
}
// Session resumption
{
NetworkSocket fds[2];
......@@ -181,7 +130,6 @@ TEST_F(SSLSessionTest, BasicRegressionTest) {
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_TRUE(clientPtr->getSSLSessionReused());
SSL_SESSION_free(sslSession);
}
}
......@@ -194,7 +142,7 @@ TEST_F(SSLSessionTest, NullSessionResumptionTest) {
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
clientPtr->setSSLSessionV2(nullptr);
clientPtr->setSSLSession(nullptr);
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment