Commit dd280141 authored by Dave Watson's avatar Dave Watson

Move Acceptor to wangle

Summary:
Initial pass at moving acceptor to wangle.  Involves moving most of the config stuff from proxygen/lib/services, and *all* of the ssl stuff from proxygen/lib/ssl.

Only minor changes:
* Acceptor can be overriden to use thrift socket types, so I don't have to change TTransportException everywhere just yet
* proxygen::Exception to std::runtime_exception in a few spots - looks like it is entirely bad config exceptions, so it should be okay
* Just used std::chrono directly instead of stuff in Time.h (which is just typedefs and simple helpers)

Test Plan:
used in D1539327

fbconfig -r proxygen/httpserver; fbmake runtests

Probably other projects are broken, will iterate to fix

None of the failling tests look related

Reviewed By: dcsommer@fb.com

Subscribers: oleksandr, netego-diffs@, hphp-diffs@, ps, trunkagent, doug, fugalh, alandau, bmatheny, njormrod, mshneer, folly-diffs@

FB internal diff: D1638358

Tasks: 5002353

Signature: t1:1638358:1414526683:87a405e3c24711078707c00b62a50b0e960bf126
parent 81fa7fd7
......@@ -88,6 +88,26 @@ nobase_follyinclude_HEADERS = \
experimental/wangle/rx/types.h \
experimental/wangle/ConnectionManager.h \
experimental/wangle/ManagedConnection.h \
experimental/wangle/acceptor/Acceptor.h \
experimental/wangle/acceptor/ConnectionCounter.h \
experimental/wangle/acceptor/SocketOptions.h \
experimental/wangle/acceptor/DomainNameMisc.h \
experimental/wangle/acceptor/LoadShedConfiguration.h \
experimental/wangle/acceptor/NetworkAddress.h \
experimental/wangle/acceptor/ServerSocketConfig.h \
experimental/wangle/acceptor/TransportInfo.h \
experimental/wangle/ssl/ClientHelloExtStats.h \
experimental/wangle/ssl/DHParam.h \
experimental/wangle/ssl/PasswordInFile.h \
experimental/wangle/ssl/SSLCacheOptions.h \
experimental/wangle/ssl/SSLCacheProvider.h \
experimental/wangle/ssl/SSLContextConfig.h \
experimental/wangle/ssl/SSLContextManager.h \
experimental/wangle/ssl/SSLSessionCacheManager.h \
experimental/wangle/ssl/SSLStats.h \
experimental/wangle/ssl/SSLUtil.h \
experimental/wangle/ssl/TLSTicketKeyManager.h \
experimental/wangle/ssl/TLSTicketKeySeeds.h \
FBString.h \
FBVector.h \
File.h \
......@@ -301,7 +321,16 @@ libfolly_la_SOURCES = \
experimental/wangle/concurrent/IOThreadPoolExecutor.cpp \
experimental/wangle/concurrent/ThreadPoolExecutor.cpp \
experimental/wangle/ConnectionManager.cpp \
experimental/wangle/ManagedConnection.cpp
experimental/wangle/ManagedConnection.cpp \
experimental/wangle/acceptor/Acceptor.cpp \
experimental/wangle/acceptor/SocketOptions.cpp \
experimental/wangle/acceptor/LoadShedConfiguration.cpp \
experimental/wangle/acceptor/TransportInfo.cpp \
experimental/wangle/ssl/PasswordInFile.cpp \
experimental/wangle/ssl/SSLContextManager.cpp \
experimental/wangle/ssl/SSLSessionCacheManager.cpp \
experimental/wangle/ssl/SSLUtil.cpp \
experimental/wangle/ssl/TLSTicketKeyManager.cpp
if HAVE_LINUX
nobase_follyinclude_HEADERS += \
......
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/acceptor/Acceptor.h>
#include <folly/experimental/wangle/ManagedConnection.h>
#include <folly/experimental/wangle/ssl/SSLContextManager.h>
#include <boost/cast.hpp>
#include <fcntl.h>
#include <folly/ScopeGuard.h>
#include <folly/experimental/wangle/ManagedConnection.h>
#include <folly/io/async/EventBase.h>
#include <fstream>
#include <sys/socket.h>
#include <sys/types.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/EventBase.h>
#include <unistd.h>
using folly::wangle::ConnectionManager;
using folly::wangle::ManagedConnection;
using std::chrono::microseconds;
using std::chrono::milliseconds;
using std::filebuf;
using std::ifstream;
using std::ios;
using std::shared_ptr;
using std::string;
namespace folly {
#ifndef NO_LIB_GFLAGS
DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
"closing idle conns");
#else
const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
#endif
static const std::string empty_string;
std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
/**
* Lightweight wrapper class to keep track of a newly
* accepted connection during SSL handshaking.
*/
class AcceptorHandshakeHelper :
public AsyncSSLSocket::HandshakeCB,
public ManagedConnection {
public:
AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
Acceptor* acceptor,
const SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime)
: socket_(std::move(socket)), acceptor_(acceptor),
acceptTime_(acceptTime), clientAddr_(clientAddr) {
acceptor_->downstreamConnectionManager_->addConnection(this, true);
if(acceptor_->parseClientHello_) {
socket_->enableClientHelloParsing();
}
socket_->sslAccept(this);
}
virtual void timeoutExpired() noexcept {
VLOG(4) << "SSL handshake timeout expired";
sslError_ = SSLErrorEnum::TIMEOUT;
dropConnection();
}
virtual void describe(std::ostream& os) const {
os << "pending handshake on " << clientAddr_;
}
virtual bool isBusy() const {
return true;
}
virtual void notifyPendingShutdown() {}
virtual void closeWhenIdle() {}
virtual void dropConnection() {
VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
socket_->closeNow();
}
virtual void dumpConnectionState(uint8_t loglevel) {
}
private:
// AsyncSSLSocket::HandshakeCallback API
virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
const unsigned char* nextProto = nullptr;
unsigned nextProtoLength = 0;
sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
if (VLOG_IS_ON(3)) {
if (nextProto) {
VLOG(3) << "Client selected next protocol " <<
string((const char*)nextProto, nextProtoLength);
} else {
VLOG(3) << "Client did not select a next protocol";
}
}
// fill in SSL-related fields from TransportInfo
// the other fields like RTT are filled in the Acceptor
TransportInfo tinfo;
tinfo.ssl = true;
tinfo.acceptTime = acceptTime_;
tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
tinfo.sslServerName = sock->getSSLServerName();
tinfo.sslCipher = sock->getNegotiatedCipherName();
tinfo.sslVersion = sock->getSSLVersion();
tinfo.sslCertSize = sock->getSSLCertSize();
tinfo.sslResume = SSLUtil::getResumeState(sock);
sock->getSSLClientCiphers(tinfo.sslClientCiphers);
sock->getSSLServerCiphers(tinfo.sslServerCiphers);
tinfo.sslClientComprMethods = sock->getSSLClientComprMethods();
tinfo.sslClientExts = sock->getSSLClientExts();
tinfo.sslNextProtocol.assign(
reinterpret_cast<const char*>(nextProto),
nextProtoLength);
acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
acceptor_->downstreamConnectionManager_->removeConnection(this);
acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
nextProto ? string((const char*)nextProto, nextProtoLength) :
empty_string, tinfo);
delete this;
}
virtual void handshakeErr(AsyncSSLSocket* sock,
const AsyncSocketException& ex) noexcept {
auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
" ms; " << sock->getRawBytesReceived() << " bytes received & " <<
sock->getRawBytesWritten() << " bytes sent: " <<
ex.what();
acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
acceptor_->sslConnectionError();
delete this;
}
AsyncSSLSocket::UniquePtr socket_;
Acceptor* acceptor_;
std::chrono::steady_clock::time_point acceptTime_;
SocketAddress clientAddr_;
SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
};
Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
accConfig_(accConfig),
socketOptions_(accConfig.getSocketOptions()) {
}
void
Acceptor::init(AsyncServerSocket* serverSocket,
EventBase* eventBase) {
CHECK(nullptr == this->base_);
if (accConfig_.isSSL()) {
if (!sslCtxManager_) {
sslCtxManager_ = folly::make_unique<SSLContextManager>(
eventBase,
"vip_" + getName(),
accConfig_.strictSSL, nullptr);
}
for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
sslCtxManager_->addSSLContextConfig(
sslCtxConfig,
accConfig_.sslCacheOptions,
&accConfig_.initialTicketSeeds,
accConfig_.bindAddress,
cacheProvider_);
parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
}
CHECK(sslCtxManager_->getDefaultSSLCtx());
}
base_ = eventBase;
state_ = State::kRunning;
downstreamConnectionManager_ = ConnectionManager::makeUnique(
eventBase, accConfig_.connectionIdleTimeout, this);
serverSocket->addAcceptCallback(this, eventBase);
// SO_KEEPALIVE is the only setting that is inherited by accepted
// connections so only apply this setting
for (const auto& option: socketOptions_) {
if (option.first.level == SOL_SOCKET &&
option.first.optname == SO_KEEPALIVE && option.second == 1) {
serverSocket->setKeepAliveEnabled(true);
break;
}
}
}
Acceptor::~Acceptor(void) {
}
void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
sslCtxManager_->addSSLContextConfig(sslCtxConfig,
accConfig_.sslCacheOptions,
&accConfig_.initialTicketSeeds,
accConfig_.bindAddress,
cacheProvider_);
}
void
Acceptor::drainAllConnections() {
if (downstreamConnectionManager_) {
downstreamConnectionManager_->initiateGracefulShutdown(
std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
}
}
void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
IConnectionCounter* counter) {
loadShedConfig_ = from;
connectionCounter_ = counter;
}
bool Acceptor::canAccept(const SocketAddress& address) {
if (!connectionCounter_) {
return true;
}
uint64_t maxConnections = connectionCounter_->getMaxConnections();
if (maxConnections == 0) {
return true;
}
uint64_t currentConnections = connectionCounter_->getNumConnections();
if (currentConnections < maxConnections) {
return true;
}
if (loadShedConfig_.isWhitelisted(address)) {
return true;
}
// Take care of comparing connection count against max connections across
// all acceptors. Expensive since a lock must be taken to get the counter.
auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
return true;
}
VLOG(4) << address.describe() << " not whitelisted";
return false;
}
void
Acceptor::connectionAccepted(
int fd, const SocketAddress& clientAddr) noexcept {
if (!canAccept(clientAddr)) {
close(fd);
return;
}
auto acceptTime = std::chrono::steady_clock::now();
for (const auto& opt: socketOptions_) {
opt.first.apply(fd, opt.second);
}
onDoneAcceptingConnection(fd, clientAddr, acceptTime);
}
void Acceptor::onDoneAcceptingConnection(
int fd,
const SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime) noexcept {
processEstablishedConnection(fd, clientAddr, acceptTime);
}
void
Acceptor::processEstablishedConnection(
int fd,
const SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime) noexcept {
if (accConfig_.isSSL()) {
CHECK(sslCtxManager_);
AsyncSSLSocket::UniquePtr sslSock(
makeNewAsyncSSLSocket(
sslCtxManager_->getDefaultSSLCtx(), base_, fd));
++numPendingSSLConns_;
++totalNumPendingSSLConns_;
if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
" too many handshakes in progress";
updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
SSLErrorEnum::DROPPED);
sslConnectionError();
return;
}
new AcceptorHandshakeHelper(
std::move(sslSock), this, clientAddr, acceptTime);
} else {
TransportInfo tinfo;
tinfo.ssl = false;
tinfo.acceptTime = acceptTime;
AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
}
}
void
Acceptor::connectionReady(
AsyncSocket::UniquePtr sock,
const SocketAddress& clientAddr,
const string& nextProtocolName,
TransportInfo& tinfo) {
// Limit the number of reads from the socket per poll loop iteration,
// both to keep memory usage under control and to prevent one fast-
// writing client from starving other connections.
sock->setMaxReadsPerEvent(16);
tinfo.initWithSocket(sock.get());
onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
}
void
Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
const SocketAddress& clientAddr,
const string& nextProtocol,
TransportInfo& tinfo) {
CHECK(numPendingSSLConns_ > 0);
connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
--numPendingSSLConns_;
--totalNumPendingSSLConns_;
if (state_ == State::kDraining) {
checkDrained();
}
}
void
Acceptor::sslConnectionError() {
CHECK(numPendingSSLConns_ > 0);
--numPendingSSLConns_;
--totalNumPendingSSLConns_;
if (state_ == State::kDraining) {
checkDrained();
}
}
void
Acceptor::acceptError(const std::exception& ex) noexcept {
// An error occurred.
// The most likely error is out of FDs. AsyncServerSocket will back off
// briefly if we are out of FDs, then continue accepting later.
// Just log a message here.
LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
}
void
Acceptor::acceptStopped() noexcept {
VLOG(3) << "Acceptor " << this << " acceptStopped()";
// Drain the open client connections
drainAllConnections();
// If we haven't yet finished draining, begin doing so by marking ourselves
// as in the draining state. We must be sure to hit checkDrained() here, as
// if we're completely idle, we can should consider ourself drained
// immediately (as there is no outstanding work to complete to cause us to
// re-evaluate this).
if (state_ != State::kDone) {
state_ = State::kDraining;
checkDrained();
}
}
void
Acceptor::onEmpty(const ConnectionManager& cm) {
VLOG(3) << "Acceptor=" << this << " onEmpty()";
if (state_ == State::kDraining) {
checkDrained();
}
}
void
Acceptor::checkDrained() {
CHECK(state_ == State::kDraining);
if (forceShutdownInProgress_ ||
(downstreamConnectionManager_->getNumConnections() != 0) ||
(numPendingSSLConns_ != 0)) {
return;
}
VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
<< base_;
downstreamConnectionManager_.reset();
state_ = State::kDone;
onConnectionsDrained();
}
milliseconds
Acceptor::getConnTimeout() const {
return accConfig_.connectionIdleTimeout;
}
void Acceptor::addConnection(ManagedConnection* conn) {
// Add the socket to the timeout manager so that it can be cleaned
// up after being left idle for a long time.
downstreamConnectionManager_->addConnection(conn, true);
}
void
Acceptor::forceStop() {
base_->runInEventBaseThread([&] { dropAllConnections(); });
}
void
Acceptor::dropAllConnections() {
if (downstreamConnectionManager_) {
LOG(INFO) << "Dropping all connections from Acceptor=" << this <<
" in thread " << base_;
assert(base_->isInEventBaseThread());
forceShutdownInProgress_ = true;
downstreamConnectionManager_->dropAllConnections();
CHECK(downstreamConnectionManager_->getNumConnections() == 0);
downstreamConnectionManager_.reset();
}
CHECK(numPendingSSLConns_ == 0);
state_ = State::kDone;
onConnectionsDrained();
}
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include "folly/experimental/wangle/acceptor/ServerSocketConfig.h"
#include "folly/experimental/wangle/acceptor/ConnectionCounter.h"
#include <folly/experimental/wangle/ConnectionManager.h>
#include "folly/experimental/wangle/acceptor/LoadShedConfiguration.h"
#include "folly/experimental/wangle/ssl/SSLCacheProvider.h"
#include "folly/experimental/wangle/acceptor/TransportInfo.h"
#include <chrono>
#include <event.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h>
namespace folly { namespace wangle {
class ManagedConnection;
}}
namespace folly {
class SocketAddress;
class SSLContext;
class AsyncTransport;
class SSLContextManager;
/**
* An abstract acceptor for TCP-based network services.
*
* There is one acceptor object per thread for each listening socket. When a
* new connection arrives on the listening socket, it is accepted by one of the
* acceptor objects. From that point on the connection will be processed by
* that acceptor's thread.
*
* The acceptor will call the abstract onNewConnection() method to create
* a new ManagedConnection object for each accepted socket. The acceptor
* also tracks all outstanding connections that it has accepted.
*/
class Acceptor :
public folly::AsyncServerSocket::AcceptCallback,
public folly::wangle::ConnectionManager::Callback {
public:
enum class State : uint32_t {
kInit, // not yet started
kRunning, // processing requests normally
kDraining, // processing outstanding conns, but not accepting new ones
kDone, // no longer accepting, and all connections finished
};
explicit Acceptor(const ServerSocketConfig& accConfig);
virtual ~Acceptor();
/**
* Supply an SSL cache provider
* @note Call this before init()
*/
virtual void setSSLCacheProvider(
const std::shared_ptr<SSLCacheProvider>& cacheProvider) {
cacheProvider_ = cacheProvider;
}
/**
* Initialize the Acceptor to run in the specified EventBase
* thread, receiving connections from the specified AsyncServerSocket.
*
* This method will be called from the AsyncServerSocket's primary thread,
* not the specified EventBase thread.
*/
virtual void init(AsyncServerSocket* serverSocket,
EventBase* eventBase);
/**
* Dynamically add a new SSLContextConfig
*/
void addSSLContextConfig(const SSLContextConfig& sslCtxConfig);
SSLContextManager* getSSLContextManager() const {
return sslCtxManager_.get();
}
/**
* Return the number of outstanding connections in this service instance.
*/
uint32_t getNumConnections() const {
return downstreamConnectionManager_ ?
downstreamConnectionManager_->getNumConnections() : 0;
}
/**
* Access the Acceptor's event base.
*/
EventBase* getEventBase() { return base_; }
/**
* Access the Acceptor's downstream (client-side) ConnectionManager
*/
virtual folly::wangle::ConnectionManager* getConnectionManager() {
return downstreamConnectionManager_.get();
}
/**
* Invoked when a new ManagedConnection is created.
*
* This allows the Acceptor to track the outstanding connections,
* for tracking timeouts and for ensuring that all connections have been
* drained on shutdown.
*/
void addConnection(folly::wangle::ManagedConnection* connection);
/**
* Get this acceptor's current state.
*/
State getState() const {
return state_;
}
/**
* Get the current connection timeout.
*/
std::chrono::milliseconds getConnTimeout() const;
/**
* Returns the name of this VIP.
*
* Will return an empty string if no name has been configured.
*/
const std::string& getName() const {
return accConfig_.name;
}
/**
* Force the acceptor to drop all connections and stop processing.
*
* This function may be called from any thread. The acceptor will not
* necessarily stop before this function returns: the stop will be scheduled
* to run in the acceptor's thread.
*/
virtual void forceStop();
bool isSSL() const { return accConfig_.isSSL(); }
const ServerSocketConfig& getConfig() const { return accConfig_; }
static uint64_t getTotalNumPendingSSLConns() {
return totalNumPendingSSLConns_.load();
}
/**
* Called right when the TCP connection has been accepted, before processing
* the first HTTP bytes (HTTP) or the SSL handshake (HTTPS)
*/
virtual void onDoneAcceptingConnection(
int fd,
const SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime
) noexcept;
/**
* Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS)
*/
void processEstablishedConnection(
int fd,
const SocketAddress& clientAddr,
std::chrono::steady_clock::time_point acceptTime
) noexcept;
protected:
friend class AcceptorHandshakeHelper;
/**
* Our event loop.
*
* Probably needs to be used to pass to a ManagedConnection
* implementation. Also visible in case a subclass wishes to do additional
* things w/ the event loop (e.g. in attach()).
*/
EventBase* base_{nullptr};
virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; }
/**
* Hook for subclasses to drop newly accepted connections prior
* to handshaking.
*/
virtual bool canAccept(const folly::SocketAddress&);
/**
* Invoked when a new connection is created. This is where application starts
* processing a new downstream connection.
*
* NOTE: Application should add the new connection to
* downstreamConnectionManager so that it can be garbage collected after
* certain period of idleness.
*
* @param sock the socket connected to the client
* @param address the address of the client
* @param nextProtocolName the name of the L6 or L7 protocol to be
* spoken on the connection, if known (e.g.,
* from TLS NPN during secure connection setup),
* or an empty string if unknown
*/
virtual void onNewConnection(
AsyncSocket::UniquePtr sock,
const folly::SocketAddress* address,
const std::string& nextProtocolName,
const TransportInfo& tinfo) = 0;
virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) {
return AsyncSocket::UniquePtr(new AsyncSocket(base, fd));
}
virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket(
const std::shared_ptr<SSLContext>& ctx, EventBase* base, int fd) {
return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd));
}
/**
* Hook for subclasses to record stats about SSL connection establishment.
*/
virtual void updateSSLStats(
const AsyncSSLSocket* sock,
std::chrono::milliseconds acceptLatency,
SSLErrorEnum error) noexcept {}
/**
* Drop all connections.
*
* forceStop() schedules dropAllConnections() to be called in the acceptor's
* thread.
*/
void dropAllConnections();
/**
* Drains all open connections of their outstanding transactions. When
* a connection's transaction count reaches zero, the connection closes.
*/
void drainAllConnections();
/**
* onConnectionsDrained() will be called once all connections have been
* drained while the acceptor is stopping.
*
* Subclasses can override this method to perform any subclass-specific
* cleanup.
*/
virtual void onConnectionsDrained() {}
// AsyncServerSocket::AcceptCallback methods
void connectionAccepted(int fd,
const folly::SocketAddress& clientAddr)
noexcept;
void acceptError(const std::exception& ex) noexcept;
void acceptStopped() noexcept;
// ConnectionManager::Callback methods
void onEmpty(const folly::wangle::ConnectionManager& cm);
void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {}
void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {}
/**
* Process a connection that is to ready to receive L7 traffic.
* This method is called immediately upon accept for plaintext
* connections and upon completion of SSL handshaking or resumption
* for SSL connections.
*/
void connectionReady(
AsyncSocket::UniquePtr sock,
const folly::SocketAddress& clientAddr,
const std::string& nextProtocolName,
TransportInfo& tinfo);
const LoadShedConfiguration& getLoadShedConfiguration() const {
return loadShedConfig_;
}
protected:
const ServerSocketConfig accConfig_;
void setLoadShedConfig(const LoadShedConfiguration& from,
IConnectionCounter* counter);
/**
* Socket options to apply to the client socket
*/
AsyncSocket::OptionMap socketOptions_;
std::unique_ptr<SSLContextManager> sslCtxManager_;
/**
* Whether we want to enable client hello parsing in the handshake helper
* to get list of supported client ciphers.
*/
bool parseClientHello_{false};
folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_;
private:
// Forbidden copy constructor and assignment opererator
Acceptor(Acceptor const &) = delete;
Acceptor& operator=(Acceptor const &) = delete;
/**
* Wrapper for connectionReady() that decrements the count of
* pending SSL connections.
*/
void sslConnectionReady(AsyncSocket::UniquePtr sock,
const folly::SocketAddress& clientAddr,
const std::string& nextProtocol,
TransportInfo& tinfo);
/**
* Notification callback for SSL handshake failures.
*/
void sslConnectionError();
void checkDrained();
State state_{State::kInit};
uint64_t numPendingSSLConns_{0};
static std::atomic<uint64_t> totalNumPendingSSLConns_;
bool forceShutdownInProgress_{false};
LoadShedConfiguration loadShedConfig_;
IConnectionCounter* connectionCounter_{nullptr};
std::shared_ptr<SSLCacheProvider> cacheProvider_;
};
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
namespace folly {
class IConnectionCounter {
public:
virtual uint64_t getNumConnections() const = 0;
/**
* Get the maximum number of non-whitelisted client-side connections
* across all Acceptors managed by this. A value
* of zero means "unlimited."
*/
virtual uint64_t getMaxConnections() const = 0;
/**
* Increment the count of client-side connections.
*/
virtual void onConnectionAdded() = 0;
/**
* Decrement the count of client-side connections.
*/
virtual void onConnectionRemoved() = 0;
virtual ~IConnectionCounter() {}
};
class SimpleConnectionCounter: public IConnectionCounter {
public:
uint64_t getNumConnections() const override { return numConnections_; }
uint64_t getMaxConnections() const override { return maxConnections_; }
void setMaxConnections(uint64_t maxConnections) {
maxConnections_ = maxConnections;
}
void onConnectionAdded() override { numConnections_++; }
void onConnectionRemoved() override { numConnections_--; }
virtual ~SimpleConnectionCounter() {}
protected:
uint64_t maxConnections_{0};
uint64_t numConnections_{0};
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <string>
namespace folly {
struct dn_char_traits : public std::char_traits<char> {
static bool eq(char c1, char c2) {
return ::tolower(c1) == ::tolower(c2);
}
static bool ne(char c1, char c2) {
return ::tolower(c1) != ::tolower(c2);
}
static bool lt(char c1, char c2) {
return ::tolower(c1) < ::tolower(c2);
}
static int compare(const char* s1, const char* s2, size_t n) {
while (n--) {
if(::tolower(*s1) < ::tolower(*s2) ) {
return -1;
}
if(::tolower(*s1) > ::tolower(*s2) ) {
return 1;
}
++s1;
++s2;
}
return 0;
}
static const char* find(const char* s, size_t n, char a) {
char la = ::tolower(a);
while (n--) {
if(::tolower(*s) == la) {
return s;
} else {
++s;
}
}
return nullptr;
}
};
// Case insensitive string
typedef std::basic_string<char, dn_char_traits> DNString;
struct DNStringHash : public std::hash<std::string> {
size_t operator()(const DNString& s) const noexcept {
size_t h = static_cast<size_t>(0xc70f6907UL);
const char* d = s.data();
for (size_t i = 0; i < s.length(); ++i) {
char a = ::tolower(*d++);
h = std::_Hash_impl::hash(&a, sizeof(a), h);
}
return h;
}
};
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/acceptor/LoadShedConfiguration.h>
#include <folly/Conv.h>
#include <openssl/ssl.h>
using std::string;
namespace folly {
void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) {
auto addr = input.str();
size_t separator = addr.find_first_of('/');
if (separator == string::npos) {
whitelistAddrs_.insert(SocketAddress(addr, 0));
} else {
unsigned prefixLen = folly::to<unsigned>(addr.substr(separator + 1));
addr.erase(separator);
whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen));
}
}
bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const {
if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) {
return true;
}
for (auto& network : whitelistNetworks_) {
if (network.contains(address)) {
return true;
}
}
return false;
}
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <chrono>
#include <folly/Range.h>
#include <folly/SocketAddress.h>
#include <glog/logging.h>
#include <list>
#include <set>
#include <string>
#include <folly/experimental/wangle/acceptor/NetworkAddress.h>
namespace folly {
/**
* Class that holds an LoadShed configuration for a service
*/
class LoadShedConfiguration {
public:
// Comparison function for SocketAddress that disregards the port
struct AddressOnlyCompare {
bool operator()(
const SocketAddress& addr1,
const SocketAddress& addr2) const {
return addr1.getIPAddress() < addr2.getIPAddress();
}
};
typedef std::set<SocketAddress, AddressOnlyCompare> AddressSet;
typedef std::set<NetworkAddress> NetworkSet;
LoadShedConfiguration() {}
~LoadShedConfiguration() {}
void addWhitelistAddr(folly::StringPiece);
/**
* Set/get the set of IPs that should be whitelisted through even when we're
* trying to shed load.
*/
void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; }
const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; }
/**
* Set/get the set of networks that should be whitelisted through even
* when we're trying to shed load.
*/
void setWhitelistNetworks(const NetworkSet& networks) {
whitelistNetworks_ = networks;
}
const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; }
/**
* Set/get the maximum number of downstream connections across all VIPs.
*/
void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; }
uint64_t getMaxConnections() const { return maxConnections_; }
/**
* Set/get the maximum cpu usage.
*/
void setMaxMemUsage(double max) {
CHECK(max >= 0);
CHECK(max <= 1);
maxMemUsage_ = max;
}
double getMaxMemUsage() const { return maxMemUsage_; }
/**
* Set/get the maximum memory usage.
*/
void setMaxCpuUsage(double max) {
CHECK(max >= 0);
CHECK(max <= 1);
maxCpuUsage_ = max;
}
double getMaxCpuUsage() const { return maxCpuUsage_; }
void setLoadUpdatePeriod(std::chrono::milliseconds period) {
period_ = period;
}
std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; }
bool isWhitelisted(const SocketAddress& addr) const;
private:
AddressSet whitelistAddrs_;
NetworkSet whitelistNetworks_;
uint64_t maxConnections_{0};
double maxMemUsage_;
double maxCpuUsage_;
std::chrono::milliseconds period_;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/SocketAddress.h>
namespace folly {
/**
* A simple wrapper around SocketAddress that represents
* a network in CIDR notation
*/
class NetworkAddress {
public:
/**
* Create a NetworkAddress for an addr/prefixLen
* @param addr IPv4 or IPv6 address of the network
* @param prefixLen Prefix length, in bits
*/
NetworkAddress(const folly::SocketAddress& addr,
unsigned prefixLen):
addr_(addr), prefixLen_(prefixLen) {}
/** Get the network address */
const folly::SocketAddress& getAddress() const {
return addr_;
}
/** Get the prefix length in bits */
unsigned getPrefixLength() const { return prefixLen_; }
/** Check whether a given address lies within the network */
bool contains(const folly::SocketAddress& addr) const {
return addr_.prefixMatch(addr, prefixLen_);
}
/** Comparison operator to enable use in ordered collections */
bool operator<(const NetworkAddress& other) const {
if (addr_ < other.addr_) {
return true;
} else if (other.addr_ < addr_) {
return false;
} else {
return (prefixLen_ < other.prefixLen_);
}
}
private:
folly::SocketAddress addr_;
unsigned prefixLen_;
};
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/experimental/wangle/ssl/SSLCacheOptions.h>
#include <folly/experimental/wangle/ssl/SSLContextConfig.h>
#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
#include <folly/experimental/wangle/ssl/SSLUtil.h>
#include <folly/experimental/wangle/acceptor/SocketOptions.h>
#include <boost/optional.hpp>
#include <chrono>
#include <fcntl.h>
#include <folly/Random.h>
#include <folly/SocketAddress.h>
#include <folly/String.h>
#include <folly/io/async/SSLContext.h>
#include <list>
#include <string>
#include <sys/stat.h>
#include <sys/types.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/SSLContext.h>
#include <folly/SocketAddress.h>
namespace folly {
/**
* Configuration for a single Acceptor.
*
* This configures not only accept behavior, but also some types of SSL
* behavior that may make sense to configure on a per-VIP basis (e.g. which
* cert(s) we use, etc).
*/
struct ServerSocketConfig {
ServerSocketConfig() {
// generate a single random current seed
uint8_t seed[32];
folly::Random::secureRandom(seed, sizeof(seed));
initialTicketSeeds.currentSeeds.push_back(
SSLUtil::hexlify(std::string((char *)seed, sizeof(seed))));
}
bool isSSL() const { return !(sslContextConfigs.empty()); }
/**
* Set/get the socket options to apply on all downstream connections.
*/
void setSocketOptions(
const AsyncSocket::OptionMap& opts) {
socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily());
}
AsyncSocket::OptionMap&
getSocketOptions() {
return socketOptions_;
}
const AsyncSocket::OptionMap&
getSocketOptions() const {
return socketOptions_;
}
bool hasExternalPrivateKey() const {
for (const auto& cfg : sslContextConfigs) {
if (!cfg.isLocalPrivateKey) {
return true;
}
}
return false;
}
/**
* The name of this acceptor; used for stats/reporting purposes.
*/
std::string name;
/**
* The depth of the accept queue backlog.
*/
uint32_t acceptBacklog{1024};
/**
* The number of milliseconds a connection can be idle before we close it.
*/
std::chrono::milliseconds connectionIdleTimeout{600000};
/**
* The address to bind to.
*/
SocketAddress bindAddress;
/**
* Options for controlling the SSL cache.
*/
SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200};
/**
* The initial TLS ticket seeds.
*/
TLSTicketKeySeeds initialTicketSeeds;
/**
* The configs for all the SSL_CTX for use by this Acceptor.
*/
std::vector<SSLContextConfig> sslContextConfigs;
/**
* Determines if the Acceptor does strict checking when loading the SSL
* contexts.
*/
bool strictSSL{true};
/**
* Maximum number of concurrent pending SSL handshakes
*/
uint32_t maxConcurrentSSLHandshakes{30720};
private:
AsyncSocket::OptionMap socketOptions_;
};
} // folly
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/acceptor/SocketOptions.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
namespace folly {
AsyncSocket::OptionMap filterIPSocketOptions(
const AsyncSocket::OptionMap& allOptions,
const int addrFamily) {
AsyncSocket::OptionMap opts;
int exclude;
if (addrFamily == AF_INET) {
exclude = IPPROTO_IPV6;
} else if (addrFamily == AF_INET6) {
exclude = IPPROTO_IP;
} else {
LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6";
return opts;
}
for (const auto& opt: allOptions) {
if (opt.first.level != exclude) {
opts[opt.first] = opt.second;
}
}
return opts;
}
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/io/async/AsyncSocket.h>
namespace folly {
/**
* Returns a copy of the socket options excluding options with the given
* level.
*/
AsyncSocket::OptionMap filterIPSocketOptions(
const AsyncSocket::OptionMap& allOptions,
const int addrFamily);
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/acceptor/TransportInfo.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <folly/io/async/AsyncSocket.h>
using std::chrono::microseconds;
using std::map;
using std::string;
namespace folly {
bool TransportInfo::initWithSocket(const AsyncSocket* sock) {
#if defined(__linux__) || defined(__FreeBSD__)
if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
tcpinfoErrno = errno;
return false;
}
rtt = microseconds(tcpinfo.tcpi_rtt);
validTcpinfo = true;
#else
tcpinfoErrno = EINVAL;
rtt = microseconds(-1);
#endif
return true;
}
int64_t TransportInfo::readRTT(const AsyncSocket* sock) {
#if defined(__linux__) || defined(__FreeBSD__)
struct tcp_info tcpinfo;
if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) {
return -1;
}
return tcpinfo.tcpi_rtt;
#else
return -1;
#endif
}
#if defined(__linux__) || defined(__FreeBSD__)
bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo,
const AsyncSocket* sock) {
socklen_t len = sizeof(struct tcp_info);
if (!sock) {
return false;
}
if (getsockopt(sock->getFd(), IPPROTO_TCP,
TCP_INFO, (void*) tcpinfo, &len) < 0) {
VLOG(4) << "Error calling getsockopt(): " << strerror(errno);
return false;
}
return true;
}
#endif
} // folly
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/experimental/wangle/ssl/SSLUtil.h>
#include <chrono>
#include <netinet/tcp.h>
#include <string>
namespace folly {
class AsyncSocket;
/**
* A structure that encapsulates byte counters related to the HTTP headers.
*/
struct HTTPHeaderSize {
/**
* The number of bytes used to represent the header after compression or
* before decompression. If header compression is not supported, the value
* is set to 0.
*/
uint32_t compressed{0};
/**
* The number of bytes used to represent the serialized header before
* compression or after decompression, in plain-text format.
*/
uint32_t uncompressed{0};
};
struct TransportInfo {
/*
* timestamp of when the connection handshake was completed
*/
std::chrono::steady_clock::time_point acceptTime{};
/*
* connection RTT (Round-Trip Time)
*/
std::chrono::microseconds rtt{0};
#if defined(__linux__) || defined(__FreeBSD__)
/*
* TCP information as fetched from getsockopt(2)
*/
tcp_info tcpinfo {
#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32
#else
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29
#endif // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17
};
#endif // defined(__linux__) || defined(__FreeBSD__)
/*
* time for setting the connection, from the moment in was accepted until it
* is established.
*/
std::chrono::milliseconds setupTime{0};
/*
* time for setting up the SSL connection or SSL handshake
*/
std::chrono::milliseconds sslSetupTime{0};
/*
* The name of the SSL ciphersuite used by the transaction's
* transport. Returns null if the transport is not SSL.
*/
const char* sslCipher{nullptr};
/*
* The SSL server name used by the transaction's
* transport. Returns null if the transport is not SSL.
*/
const char* sslServerName{nullptr};
/*
* list of ciphers sent by the client
*/
std::string sslClientCiphers{};
/*
* list of compression methods sent by the client
*/
std::string sslClientComprMethods{};
/*
* list of TLS extensions sent by the client
*/
std::string sslClientExts{};
/*
* hash of all the SSL parameters sent by the client
*/
std::string sslSignature{};
/*
* list of ciphers supported by the server
*/
std::string sslServerCiphers{};
/*
* guessed "(os) (browser)" based on SSL Signature
*/
std::string guessedUserAgent{};
/**
* The result of SSL NPN negotiation.
*/
std::string sslNextProtocol{};
/*
* total number of bytes sent over the connection
*/
int64_t totalBytes{0};
/**
* header bytes read
*/
HTTPHeaderSize ingressHeader;
/*
* header bytes written
*/
HTTPHeaderSize egressHeader;
/*
* Here is how the timeToXXXByte variables are planned out:
* 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_
* 2. You can get the timing between two ByteEvents by calculating their
* differences. For example:
* timeToLastBodyByteAck - timeToFirstByte
* => Total time to deliver the body
* 3. The calculation in point (2) is typically done outside acceptor
*
* Future plan:
* We should log the timestamps (TimePoints) and allow
* the consumer to calculate the latency whatever it
* wants instead of calculating them in wangle, for the sake of flexibility.
* For example:
* 1. TimePoint reqStartTimestamp;
* 2. TimePoint firstHeaderByteSentTimestamp;
* 3. TimePoint firstBodyByteTimestamp;
* 3. TimePoint lastBodyByteTimestamp;
* 4. TimePoint lastBodyByteAckTimestamp;
*/
/*
* time to first header byte written to the kernel send buffer
* NOTE: It is not 100% accurate since TAsyncSocket does not do
* do callback on partial write.
*/
int32_t timeToFirstHeaderByte{-1};
/*
* time to first body byte written to the kernel send buffer
*/
int32_t timeToFirstByte{-1};
/*
* time to last body byte written to the kernel send buffer
*/
int32_t timeToLastByte{-1};
/*
* time to TCP Ack received for the last written body byte
*/
int32_t timeToLastBodyByteAck{-1};
/*
* time it took the client to ACK the last byte, from the moment when the
* kernel sent the last byte to the client and until it received the ACK
* for that byte
*/
int32_t lastByteAckLatency{-1};
/*
* time spent inside wangle
*/
int32_t proxyLatency{-1};
/*
* time between connection accepted and client message headers completed
*/
int32_t clientLatency{-1};
/*
* latency for communication with the server
*/
int32_t serverLatency{-1};
/*
* time used to get a usable connection.
*/
int32_t connectLatency{-1};
/*
* body bytes written
*/
uint32_t egressBodySize{0};
/*
* value of errno in case of getsockopt() error
*/
int tcpinfoErrno{0};
/*
* bytes read & written during SSL Setup
*/
uint32_t sslSetupBytesWritten{0};
uint32_t sslSetupBytesRead{0};
/**
* SSL error detail
*/
uint32_t sslError{0};
/**
* body bytes read
*/
uint32_t ingressBodySize{0};
/*
* The SSL version used by the transaction's transport, in
* OpenSSL's format: 4 bits for the major version, followed by 4 bits
* for the minor version. Returns zero for non-SSL.
*/
uint16_t sslVersion{0};
/*
* The SSL certificate size.
*/
uint16_t sslCertSize{0};
/**
* response status code
*/
uint16_t statusCode{0};
/*
* The SSL mode for the transaction's transport: new session,
* resumed session, or neither (non-SSL).
*/
SSLResumeEnum sslResume{SSLResumeEnum::NA};
/*
* true if the tcpinfo was successfully read from the kernel
*/
bool validTcpinfo{false};
/*
* true if the connection is SSL, false otherwise
*/
bool ssl{false};
/*
* get the RTT value in milliseconds
*/
std::chrono::milliseconds getRttMs() const {
return std::chrono::duration_cast<std::chrono::milliseconds>(rtt);
}
/*
* initialize the fields related with tcp_info
*/
bool initWithSocket(const AsyncSocket* sock);
/*
* Get the kernel's estimate of round-trip time (RTT) to the transport's peer
* in microseconds. Returns -1 on error.
*/
static int64_t readRTT(const AsyncSocket* sock);
#if defined(__linux__) || defined(__FreeBSD__)
/*
* perform the getsockopt(2) syscall to fetch TCP info for a given socket
*/
static bool readTcpInfo(struct tcp_info* tcpinfo,
const AsyncSocket* sock);
#endif
};
} // folly
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
namespace folly {
class ClientHelloExtStats {
public:
virtual ~ClientHelloExtStats() noexcept {}
// client hello
virtual void recordAbsentHostname() noexcept = 0;
virtual void recordMatch() noexcept = 0;
virtual void recordNotMatch() noexcept = 0;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <openssl/dh.h>
// The following was auto-generated by
// openssl dhparam -C 2048
DH *get_dh2048()
{
static unsigned char dh2048_p[]={
0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95,
0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79,
0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02,
0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC,
0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8,
0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6,
0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8,
0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF,
0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E,
0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09,
0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D,
0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B,
0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A,
0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8,
0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E,
0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4,
0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9,
0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E,
0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6,
0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD,
0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB,
0x8C,0xA1,0x76,0xDB,
};
static unsigned char dh2048_g[]={
0x02,
};
DH *dh;
if ((dh=DH_new()) == nullptr) return(nullptr);
dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr);
dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr);
if ((dh->p == nullptr) || (dh->g == nullptr))
{ DH_free(dh); return(nullptr); }
return(dh);
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/ssl/PasswordInFile.h>
#include <folly/FileUtil.h>
using namespace std;
namespace folly {
PasswordInFile::PasswordInFile(const string& file)
: fileName_(file) {
folly::readFile(file.c_str(), password_);
auto p = password_.find('\0');
if (p != std::string::npos) {
password_.erase(p);
}
}
PasswordInFile::~PasswordInFile() {
OPENSSL_cleanse((char *)password_.data(), password_.length());
}
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/io/async/SSLContext.h> // PasswordCollector
namespace folly {
class PasswordInFile: public folly::PasswordCollector {
public:
explicit PasswordInFile(const std::string& file);
~PasswordInFile();
void getPassword(std::string& password, int size) override {
password = password_;
}
const char* getPasswordStr() const {
return password_.c_str();
}
std::string describe() const override {
return fileName_;
}
protected:
std::string fileName_;
std::string password_;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <chrono>
#include <cstdint>
namespace folly {
struct SSLCacheOptions {
std::chrono::seconds sslCacheTimeout;
uint64_t maxSSLCacheSize;
uint64_t sslCacheFlushSize;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/io/async/AsyncSSLSocket.h>
namespace folly {
class SSLSessionCacheManager;
/**
* Interface to be implemented by providers of external session caches
*/
class SSLCacheProvider {
public:
/**
* Context saved during an external cache request that is used to
* resume the waiting client.
*/
typedef struct {
std::string sessionId;
SSL_SESSION* session;
SSLSessionCacheManager* manager;
AsyncSSLSocket* sslSocket;
std::unique_ptr<
folly::DelayedDestruction::DestructorGuard> guard;
} CacheContext;
virtual ~SSLCacheProvider() {}
/**
* Store a session in the external cache.
* @param sessionId Identifier that can be used later to fetch the
* session with getAsync()
* @param value Serialized session to store
* @param expiration Relative expiration time: seconds from now
* @return true if the storing of the session is initiated successfully
* (though not necessarily completed; the completion may
* happen either before or after this method returns), or
* false if the storing cannot be initiated due to an error.
*/
virtual bool setAsync(const std::string& sessionId,
const std::string& value,
std::chrono::seconds expiration) = 0;
/**
* Retrieve a session from the external cache. When done, call
* the cache manager's onGetSuccess() or onGetFailure() callback.
* @param sessionId Session ID to fetch
* @param context Data to pass back to the SSLSessionCacheManager
* in the completion callback
* @return true if the lookup of the session is initiated successfully
* (though not necessarily completed; the completion may
* happen either before or after this method returns), or
* false if the lookup cannot be initiated due to an error.
*/
virtual bool getAsync(const std::string& sessionId,
CacheContext* context) = 0;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <string>
#include <folly/io/async/SSLContext.h>
#include <vector>
/**
* SSLContextConfig helps to describe the configs/options for
* a SSL_CTX. For example:
*
* 1. Filename of X509, private key and its password.
* 2. ciphers list
* 3. NPN list
* 4. Is session cache enabled?
* 5. Is it the default X509 in SNI operation?
* 6. .... and a few more
*/
namespace folly {
struct SSLContextConfig {
SSLContextConfig() {}
~SSLContextConfig() {}
struct CertificateInfo {
std::string certPath;
std::string keyPath;
std::string passwordPath;
};
/**
* Helpers to set/add a certificate
*/
void setCertificate(const std::string& certPath,
const std::string& keyPath,
const std::string& passwordPath) {
certificates.clear();
addCertificate(certPath, keyPath, passwordPath);
}
void addCertificate(const std::string& certPath,
const std::string& keyPath,
const std::string& passwordPath) {
certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath});
}
/**
* Set the optional list of protocols to advertise via TLS
* Next Protocol Negotiation. An empty list means NPN is not enabled.
*/
void setNextProtocols(const std::list<std::string>& inNextProtocols) {
nextProtocols.clear();
nextProtocols.push_back({1, inNextProtocols});
}
typedef std::function<bool(char const* server_name)> SNINoMatchFn;
std::vector<CertificateInfo> certificates;
folly::SSLContext::SSLVersion sslVersion{
folly::SSLContext::TLSv1};
bool sessionCacheEnabled{true};
bool sessionTicketEnabled{true};
bool clientHelloParsingEnabled{false};
std::string sslCiphers{
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:"
"ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:"
"ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:"
"AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:"
"ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:"
"ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"};
std::string eccCurveName;
// Ciphers to negotiate if TLS version >= 1.1
std::string tls11Ciphers{""};
// Weighted lists of NPN strings to advertise
std::list<folly::SSLContext::NextProtocolsItem>
nextProtocols;
bool isLocalPrivateKey{true};
// Should this SSLContextConfig be the default for SNI purposes
bool isDefault{false};
// Callback function to invoke when there are no matching certificates
// (will only be invoked once)
SNINoMatchFn sniNoMatchFn;
// File containing trusted CA's to validate client certificates
std::string clientCAFile;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/ssl/SSLContextManager.h>
#include <folly/experimental/wangle/ssl/ClientHelloExtStats.h>
#include <folly/experimental/wangle/ssl/DHParam.h>
#include <folly/experimental/wangle/ssl/PasswordInFile.h>
#include <folly/experimental/wangle/ssl/SSLCacheOptions.h>
#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
#include <folly/experimental/wangle/ssl/SSLUtil.h>
#include <folly/experimental/wangle/ssl/TLSTicketKeyManager.h>
#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
#include <folly/Conv.h>
#include <folly/ScopeGuard.h>
#include <folly/String.h>
#include <functional>
#include <openssl/asn1.h>
#include <openssl/ssl.h>
#include <string>
#include <folly/io/async/EventBase.h>
#define OPENSSL_MISSING_FEATURE(name) \
do { \
throw std::runtime_error("missing " #name " support in openssl"); \
} while(0)
using std::string;
using std::shared_ptr;
/**
* SSLContextManager helps to create and manage all SSL_CTX,
* SSLSessionCacheManager and TLSTicketManager for a listening
* VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)).
*
* Other responsibilities:
* 1. It also handles the SSL_CTX selection after getting the tlsext_hostname
* in the client hello message.
*
* Usage:
* 1. Each listening VIP:PORT serving SSL should have one SSLContextManager.
* It maps to Acceptor in the wangle vocabulary.
*
* 2. Create a SSLContextConfig object (e.g. by parsing the JSON config).
*
* 3. Call SSLContextManager::addSSLContextConfig() which will
* then create and configure the SSL_CTX
*
* Note: Each Acceptor, with SSL support, should have one SSLContextManager to
* manage all SSL_CTX for the VIP:PORT.
*/
namespace folly {
namespace {
X509* getX509(SSL_CTX* ctx) {
SSL* ssl = SSL_new(ctx);
SSL_set_connect_state(ssl);
X509* x509 = SSL_get_certificate(ssl);
CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
SSL_free(ssl);
return x509;
}
void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) {
#if OPENSSL_VERSION_NUMBER >= 0x0090800fL
#ifndef OPENSSL_NO_ECDH
EC_KEY* ecdh = nullptr;
int nid;
/*
* Elliptic-Curve Diffie-Hellman parameters are either "named curves"
* from RFC 4492 section 5.1.1, or explicitly described curves over
* binary fields. OpenSSL only supports the "named curves", which provide
* maximum interoperability.
*/
nid = OBJ_sn2nid(curveName.c_str());
if (nid == 0) {
LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
return;
}
ecdh = EC_KEY_new_by_curve_name(nid);
if (ecdh == nullptr) {
LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
return;
}
SSL_CTX_set_tmp_ecdh(ctx, ecdh);
EC_KEY_free(ecdh);
#endif
#endif
}
// Helper to create TLSTicketKeyManger and aware of the needed openssl
// version/feature.
std::unique_ptr<TLSTicketKeyManager> createTicketManagerHelper(
std::shared_ptr<folly::SSLContext> ctx,
const TLSTicketKeySeeds* ticketSeeds,
const SSLContextConfig& ctxConfig,
SSLStats* stats) {
std::unique_ptr<TLSTicketKeyManager> ticketManager;
#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
ticketManager = folly::make_unique<TLSTicketKeyManager>(ctx.get(), stats);
ticketManager->setTLSTicketKeySeeds(
ticketSeeds->oldSeeds,
ticketSeeds->currentSeeds,
ticketSeeds->newSeeds);
} else {
ctx->setOptions(SSL_OP_NO_TICKET);
}
#else
if (ticketSeeds && ctxConfig.sessionTicketEnabled) {
OPENSSL_MISSING_FEATURE(TLSTicket);
}
#endif
return ticketManager;
}
std::string flattenList(const std::list<std::string>& list) {
std::string s;
bool first = true;
for (auto& item : list) {
if (first) {
first = false;
} else {
s.append(", ");
}
s.append(item);
}
return s;
}
}
SSLContextManager::~SSLContextManager() {}
SSLContextManager::SSLContextManager(
EventBase* eventBase,
const std::string& vipName,
bool strict,
SSLStats* stats) :
stats_(stats),
eventBase_(eventBase),
strict_(strict) {
}
void SSLContextManager::addSSLContextConfig(
const SSLContextConfig& ctxConfig,
const SSLCacheOptions& cacheOptions,
const TLSTicketKeySeeds* ticketSeeds,
const folly::SocketAddress& vipAddress,
const std::shared_ptr<SSLCacheProvider>& externalCache) {
unsigned numCerts = 0;
std::string commonName;
std::string lastCertPath;
std::unique_ptr<std::list<std::string>> subjectAltName;
auto sslCtx = std::make_shared<SSLContext>(ctxConfig.sslVersion);
for (const auto& cert : ctxConfig.certificates) {
try {
sslCtx->loadCertificate(cert.certPath.c_str());
} catch (const std::exception& ex) {
// The exception isn't very useful without the certificate path name,
// so throw a new exception that includes the path to the certificate.
string msg = folly::to<string>("error loading SSL certificate ",
cert.certPath, ": ",
folly::exceptionStr(ex));
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
// Verify that the Common Name and (if present) Subject Alternative Names
// are the same for all the certs specified for the SSL context.
numCerts++;
X509* x509 = getX509(sslCtx->getSSLCtx());
auto guard = folly::makeGuard([x509] { X509_free(x509); });
auto cn = SSLUtil::getCommonName(x509);
if (!cn) {
throw std::runtime_error(folly::to<string>("Cannot get CN for X509 ",
cert.certPath));
}
auto altName = SSLUtil::getSubjectAltName(x509);
VLOG(2) << "cert " << cert.certPath << " CN: " << *cn;
if (altName) {
altName->sort();
VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName);
} else {
VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}";
}
if (numCerts == 1) {
commonName = *cn;
subjectAltName = std::move(altName);
} else {
if (commonName != *cn) {
throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
" does not have same CN as ",
lastCertPath));
}
if (altName == nullptr) {
if (subjectAltName != nullptr) {
throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
" does not have same SAN as ",
lastCertPath));
}
} else {
if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) {
throw std::runtime_error(folly::to<string>("X509 ", cert.certPath,
" does not have same SAN as ",
lastCertPath));
}
}
}
lastCertPath = cert.certPath;
// TODO t4438250 - Add ECDSA support to the crypto_ssl offload server
// so we can avoid storing the ECDSA private key in the
// address space of the Internet-facing process. For
// now, if cert name includes "-EC" to denote elliptic
// curve, we load its private key even if the server as
// a whole has been configured for async crypto.
if (ctxConfig.isLocalPrivateKey ||
(cert.certPath.find("-EC") != std::string::npos)) {
// The private key lives in the same process
// This needs to be called before loadPrivateKey().
if (!cert.passwordPath.empty()) {
auto sslPassword = std::make_shared<PasswordInFile>(cert.passwordPath);
sslCtx->passwordCollector(sslPassword);
}
try {
sslCtx->loadPrivateKey(cert.keyPath.c_str());
} catch (const std::exception& ex) {
// Throw an error that includes the key path, so the user can tell
// which key had a problem.
string msg = folly::to<string>("error loading private SSL key ",
cert.keyPath, ": ",
folly::exceptionStr(ex));
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
}
if (!ctxConfig.isLocalPrivateKey) {
enableAsyncCrypto(sslCtx);
}
// Let the server pick the highest performing cipher from among the client's
// choices.
//
// Let's use a unique private key for all DH key exchanges.
//
// Because some old implementations choke on empty fragments, most SSL
// applications disable them (it's part of SSL_OP_ALL). This
// will improve performance and decrease write buffer fragmentation.
sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE |
SSL_OP_SINGLE_DH_USE |
SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS);
// Configure SSL ciphers list
if (!ctxConfig.tls11Ciphers.empty()) {
// FIXME: create a dummy SSL_CTX for cipher testing purpose? It can
// remove the ordering dependency
// Test to see if the specified TLS1.1 ciphers are valid. Note that
// these will be overwritten by the ciphers() call below.
sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers);
}
// Important that we do this *after* checking the TLS1.1 ciphers above,
// since we test their validity by actually setting them.
sslCtx->ciphers(ctxConfig.sslCiphers);
// Use a fix DH param
DH* dh = get_dh2048();
SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh);
DH_free(dh);
const string& curve = ctxConfig.eccCurveName;
if (!curve.empty()) {
set_key_from_curve(sslCtx->getSSLCtx(), curve);
}
if (!ctxConfig.clientCAFile.empty()) {
try {
sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT);
sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str());
sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str());
} catch (const std::exception& ex) {
string msg = folly::to<string>("error loading client CA",
ctxConfig.clientCAFile, ": ",
folly::exceptionStr(ex));
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
// - start - SSL session cache config
// the internal cache never does what we want (per-thread-per-vip).
// Disable it. SSLSessionCacheManager will set it appropriately.
SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF);
SSL_CTX_set_timeout(sslCtx->getSSLCtx(),
cacheOptions.sslCacheTimeout.count());
std::unique_ptr<SSLSessionCacheManager> sessionCacheManager;
if (ctxConfig.sessionCacheEnabled &&
cacheOptions.maxSSLCacheSize > 0 &&
cacheOptions.sslCacheFlushSize > 0) {
sessionCacheManager =
folly::make_unique<SSLSessionCacheManager>(
cacheOptions.maxSSLCacheSize,
cacheOptions.sslCacheFlushSize,
sslCtx.get(),
vipAddress,
commonName,
eventBase_,
stats_,
externalCache);
}
// - end - SSL session cache config
std::unique_ptr<TLSTicketKeyManager> ticketManager =
createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_);
// finalize sslCtx setup by the individual features supported by openssl
ctxSetupByOpensslFeature(sslCtx, ctxConfig);
try {
insert(sslCtx,
std::move(sessionCacheManager),
std::move(ticketManager),
ctxConfig.isDefault);
} catch (const std::exception& ex) {
string msg = folly::to<string>("Error adding certificate : ",
folly::exceptionStr(ex));
LOG(ERROR) << msg;
throw std::runtime_error(msg);
}
}
#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
SSLContext::ServerNameCallbackResult
SSLContextManager::serverNameCallback(SSL* ssl) {
shared_ptr<SSLContext> ctx;
const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (!sn) {
VLOG(6) << "Server Name (tlsext_hostname) is missing";
if (clientHelloTLSExtStats_) {
clientHelloTLSExtStats_->recordAbsentHostname();
}
return SSLContext::SERVER_NAME_NOT_FOUND;
}
size_t snLen = strlen(sn);
VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' ";
// FIXME: This code breaks the abstraction. Suggestion?
AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
CHECK(sslSocket);
DNString dnstr(sn, snLen);
uint32_t count = 0;
do {
// Try exact match first
ctx = getSSLCtx(dnstr);
if (ctx) {
sslSocket->switchServerSSLContext(ctx);
if (clientHelloTLSExtStats_) {
clientHelloTLSExtStats_->recordMatch();
}
return SSLContext::SERVER_NAME_FOUND;
}
ctx = getSSLCtxBySuffix(dnstr);
if (ctx) {
sslSocket->switchServerSSLContext(ctx);
if (clientHelloTLSExtStats_) {
clientHelloTLSExtStats_->recordMatch();
}
return SSLContext::SERVER_NAME_FOUND;
}
// Give the noMatchFn one chance to add the correct cert
}
while (count++ == 0 && noMatchFn_ && noMatchFn_(sn));
VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn);
if (clientHelloTLSExtStats_) {
clientHelloTLSExtStats_->recordNotMatch();
}
return SSLContext::SERVER_NAME_NOT_FOUND;
}
#endif
// Consolidate all SSL_CTX setup which depends on openssl version/feature
void
SSLContextManager::ctxSetupByOpensslFeature(
shared_ptr<folly::SSLContext> sslCtx,
const SSLContextConfig& ctxConfig) {
// Disable compression - profiling shows this to be very expensive in
// terms of CPU and memory consumption.
//
#ifdef SSL_OP_NO_COMPRESSION
sslCtx->setOptions(SSL_OP_NO_COMPRESSION);
#endif
// Enable early release of SSL buffers to reduce the memory footprint
#ifdef SSL_MODE_RELEASE_BUFFERS
sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS;
#endif
#ifdef SSL_MODE_EARLY_RELEASE_BBIO
sslCtx->getSSLCtx()->mode |= SSL_MODE_EARLY_RELEASE_BBIO;
#endif
// This number should (probably) correspond to HTTPSession::kMaxReadSize
// For now, this number must also be large enough to accommodate our
// largest certificate, because some older clients (IE6/7) require the
// cert to be in a single fragment.
#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT
SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000);
#endif
// Specify cipher(s) to be used for TLS1.1 client
if (!ctxConfig.tls11Ciphers.empty()) {
#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
// Specified TLS1.1 ciphers are valid
sslCtx->addClientHelloCallback(
std::bind(
&SSLContext::switchCiphersIfTLS11,
sslCtx.get(),
std::placeholders::_1,
ctxConfig.tls11Ciphers
)
);
#else
OPENSSL_MISSING_FEATURE(SNI);
#endif
}
// NPN (Next Protocol Negotiation)
if (!ctxConfig.nextProtocols.empty()) {
#ifdef OPENSSL_NPN_NEGOTIATED
sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols);
#else
OPENSSL_MISSING_FEATURE(NPN);
#endif
}
// SNI
#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK
noMatchFn_ = ctxConfig.sniNoMatchFn;
if (ctxConfig.isDefault) {
if (defaultCtx_) {
throw std::runtime_error(">1 X509 is set as default");
}
defaultCtx_ = sslCtx;
defaultCtx_->setServerNameCallback(
std::bind(&SSLContextManager::serverNameCallback, this,
std::placeholders::_1));
}
#else
if (ctxs_.size() > 1) {
OPENSSL_MISSING_FEATURE(SNI);
}
#endif
}
void
SSLContextManager::insert(shared_ptr<SSLContext> sslCtx,
std::unique_ptr<SSLSessionCacheManager> smanager,
std::unique_ptr<TLSTicketKeyManager> tmanager,
bool defaultFallback) {
X509* x509 = getX509(sslCtx->getSSLCtx());
auto guard = folly::makeGuard([x509] { X509_free(x509); });
auto cn = SSLUtil::getCommonName(x509);
if (!cn) {
throw std::runtime_error("Cannot get CN");
}
/**
* Some notes from RFC 2818. Only for future quick references in case of bugs
*
* RFC 2818 section 3.1:
* "......
* If a subjectAltName extension of type dNSName is present, that MUST
* be used as the identity. Otherwise, the (most specific) Common Name
* field in the Subject field of the certificate MUST be used. Although
* the use of the Common Name is existing practice, it is deprecated and
* Certification Authorities are encouraged to use the dNSName instead.
* ......
* In some cases, the URI is specified as an IP address rather than a
* hostname. In this case, the iPAddress subjectAltName must be present
* in the certificate and must exactly match the IP in the URI.
* ......"
*/
// Not sure if we ever get this kind of X509...
// If we do, assume '*' is always in the CN and ignore all subject alternative
// names.
if (cn->length() == 1 && (*cn)[0] == '*') {
if (!defaultFallback) {
throw std::runtime_error("STAR X509 is not the default");
}
ctxs_.emplace_back(sslCtx);
sessionCacheManagers_.emplace_back(std::move(smanager));
ticketManagers_.emplace_back(std::move(tmanager));
return;
}
// Insert by CN
insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx);
// Insert by subject alternative name(s)
auto altNames = SSLUtil::getSubjectAltName(x509);
if (altNames) {
for (auto& name : *altNames) {
insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx);
}
}
ctxs_.emplace_back(sslCtx);
sessionCacheManagers_.emplace_back(std::move(smanager));
ticketManagers_.emplace_back(std::move(tmanager));
}
void
SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len,
shared_ptr<SSLContext> sslCtx) {
try {
insertSSLCtxByDomainNameImpl(dn, len, sslCtx);
} catch (const std::runtime_error& ex) {
if (strict_) {
throw ex;
} else {
LOG(ERROR) << ex.what() << " DN=" << dn;
}
}
}
void
SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len,
shared_ptr<SSLContext> sslCtx)
{
VLOG(4) <<
folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for "
"SNI search", dn);
// Only support wildcard domains which are prefixed exactly by "*." .
// "*" appearing at other locations is not accepted.
if (len > 2 && dn[0] == '*') {
if (dn[1] == '.') {
// skip the first '*'
dn++;
len--;
} else {
throw std::runtime_error(
"Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" "
"(only allow character \".\" after \"*\"");
}
}
if (len == 1 && *dn == '.') {
throw std::runtime_error("X509 has only '.' in the CN or subject alternative name "
"(after removing any preceding '*')");
}
if (strchr(dn, '*')) {
throw std::runtime_error("X509 has '*' in the the CN or subject alternative name "
"(after removing any preceding '*')");
}
DNString dnstr(dn, len);
const auto v = dnMap_.find(dnstr);
if (v == dnMap_.end()) {
dnMap_.emplace(dnstr, sslCtx);
} else if (v->second == sslCtx) {
VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509."
" Ignore the later name.";
} else {
throw std::runtime_error("Duplicate CN or subject alternative name found: \"" +
std::string(dnstr.c_str()) + "\"");
}
}
shared_ptr<SSLContext>
SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const
{
size_t dot;
if ((dot = dnstr.find_first_of(".")) != DNString::npos) {
DNString suffixDNStr(dnstr, dot);
const auto v = dnMap_.find(suffixDNStr);
if (v != dnMap_.end()) {
VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"",
dnstr.c_str(), suffixDNStr.c_str());
return v->second;
}
}
VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match",
dnstr.c_str());
return shared_ptr<SSLContext>();
}
shared_ptr<SSLContext>
SSLContextManager::getSSLCtx(const DNString& dnstr) const
{
const auto v = dnMap_.find(dnstr);
if (v == dnMap_.end()) {
VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match",
dnstr.c_str());
return shared_ptr<SSLContext>();
} else {
VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str());
return v->second;
}
}
shared_ptr<SSLContext>
SSLContextManager::getDefaultSSLCtx() const {
return defaultCtx_;
}
void
SSLContextManager::reloadTLSTicketKeys(
const std::vector<std::string>& oldSeeds,
const std::vector<std::string>& currentSeeds,
const std::vector<std::string>& newSeeds) {
#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
for (auto& tmgr: ticketManagers_) {
tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds);
}
#endif
}
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/io/async/EventBase.h>
#include <folly/io/async/SSLContext.h>
#include <glog/logging.h>
#include <list>
#include <memory>
#include <folly/experimental/wangle/ssl/SSLContextConfig.h>
#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
#include <folly/experimental/wangle/ssl/TLSTicketKeySeeds.h>
#include <folly/experimental/wangle/acceptor/DomainNameMisc.h>
#include <vector>
namespace folly {
class SocketAddress;
class SSLContext;
class ClientHelloExtStats;
class SSLCacheOptions;
class SSLStats;
class TLSTicketKeyManager;
class TLSTicketKeySeeds;
class SSLContextManager {
public:
explicit SSLContextManager(EventBase* eventBase,
const std::string& vipName, bool strict,
SSLStats* stats);
virtual ~SSLContextManager();
/**
* Add a new X509 to SSLContextManager. The details of a X509
* is passed as a SSLContextConfig object.
*
* @param ctxConfig Details of a X509, its private key, password, etc.
* @param cacheOptions Options for how to do session caching.
* @param ticketSeeds If non-null, the initial ticket key seeds to use.
* @param vipAddress Which VIP are the X509(s) used for? It is only for
* for user friendly log message
* @param externalCache Optional external provider for the session cache;
* may be null
*/
void addSSLContextConfig(
const SSLContextConfig& ctxConfig,
const SSLCacheOptions& cacheOptions,
const TLSTicketKeySeeds* ticketSeeds,
const folly::SocketAddress& vipAddress,
const std::shared_ptr<SSLCacheProvider> &externalCache);
/**
* Get the default SSL_CTX for a VIP
*/
std::shared_ptr<SSLContext>
getDefaultSSLCtx() const;
/**
* Search by the _one_ level up subdomain
*/
std::shared_ptr<SSLContext>
getSSLCtxBySuffix(const DNString& dnstr) const;
/**
* Search by the full-string domain name
*/
std::shared_ptr<SSLContext>
getSSLCtx(const DNString& dnstr) const;
/**
* Insert a SSLContext by domain name.
*/
void insertSSLCtxByDomainName(
const char* dn,
size_t len,
std::shared_ptr<SSLContext> sslCtx);
void insertSSLCtxByDomainNameImpl(
const char* dn,
size_t len,
std::shared_ptr<SSLContext> sslCtx);
void reloadTLSTicketKeys(const std::vector<std::string>& oldSeeds,
const std::vector<std::string>& currentSeeds,
const std::vector<std::string>& newSeeds);
/**
* SSLContextManager only collects SNI stats now
*/
void setClientHelloExtStats(ClientHelloExtStats* stats) {
clientHelloTLSExtStats_ = stats;
}
protected:
virtual void enableAsyncCrypto(
const std::shared_ptr<SSLContext>& sslCtx) {
LOG(FATAL) << "Unsupported in base SSLContextManager";
}
SSLStats* stats_{nullptr};
private:
SSLContextManager(const SSLContextManager&) = delete;
void ctxSetupByOpensslFeature(
std::shared_ptr<SSLContext> sslCtx,
const SSLContextConfig& ctxConfig);
/**
* Callback function from openssl to find the right X509 to
* use during SSL handshake
*/
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \
!defined(OPENSSL_NO_TLSEXT) && \
defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB)
# define PROXYGEN_HAVE_SERVERNAMECALLBACK
SSLContext::ServerNameCallbackResult
serverNameCallback(SSL* ssl);
#endif
/**
* The following functions help to maintain the data structure for
* domain name matching in SNI. Some notes:
*
* 1. It is a best match.
*
* 2. It allows wildcard CN and wildcard subject alternative name in a X509.
* The wildcard name must be _prefixed_ by '*.'. It errors out whenever
* it sees '*' in any other locations.
*
* 3. It uses one std::unordered_map<DomainName, SSL_CTX> object to
* do this. For wildcard name like "*.facebook.com", ".facebook.com"
* is used as the key.
*
* 4. After getting tlsext_hostname from the client hello message, it
* will do a full string search first and then try one level up to
* match any wildcard name (if any) in the X509.
* [Note, browser also only looks one level up when matching the requesting
* domain name with the wildcard name in the server X509].
*/
void insert(
std::shared_ptr<SSLContext> sslCtx,
std::unique_ptr<SSLSessionCacheManager> cmanager,
std::unique_ptr<TLSTicketKeyManager> tManager,
bool defaultFallback);
/**
* Container to own the SSLContext, SSLSessionCacheManager and
* TLSTicketKeyManager.
*/
std::vector<std::shared_ptr<SSLContext>> ctxs_;
std::vector<std::unique_ptr<SSLSessionCacheManager>>
sessionCacheManagers_;
std::vector<std::unique_ptr<TLSTicketKeyManager>> ticketManagers_;
std::shared_ptr<SSLContext> defaultCtx_;
/**
* Container to store the (DomainName -> SSL_CTX) mapping
*/
std::unordered_map<
DNString,
std::shared_ptr<SSLContext>,
DNStringHash> dnMap_;
EventBase* eventBase_;
ClientHelloExtStats* clientHelloTLSExtStats_{nullptr};
SSLContextConfig::SNINoMatchFn noMatchFn_;
bool strict_{true};
};
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/ssl/SSLSessionCacheManager.h>
#include <folly/experimental/wangle/ssl/SSLCacheProvider.h>
#include <folly/experimental/wangle/ssl/SSLStats.h>
#include <folly/experimental/wangle/ssl/SSLUtil.h>
#include <folly/io/async/EventBase.h>
using std::string;
using std::shared_ptr;
namespace {
const uint32_t NUM_CACHE_BUCKETS = 16;
// We use the default ID generator which fills the maximum ID length
// for the protocol. 16 bytes for SSLv2 or 32 for SSLv3+
const int MIN_SESSION_ID_LENGTH = 16;
}
#ifndef NO_LIB_GFLAGS
DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache");
#else
const bool FLAGS_dcache_unit_test = false;
#endif
namespace folly {
int SSLSessionCacheManager::sExDataIndex_ = -1;
shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::sCache_;
std::mutex SSLSessionCacheManager::sCacheLock_;
LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize,
uint32_t cacheCullSize)
: sessionCache(maxCacheSize, cacheCullSize) {
sessionCache.setPruneHook(std::bind(
&LocalSSLSessionCache::pruneSessionCallback,
this, std::placeholders::_1,
std::placeholders::_2));
}
void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId,
SSL_SESSION* session) {
VLOG(4) << "Free SSL session from local cache; id="
<< SSLUtil::hexlify(sessionId);
SSL_SESSION_free(session);
++removedSessions_;
}
// SSLSessionCacheManager implementation
SSLSessionCacheManager::SSLSessionCacheManager(
uint32_t maxCacheSize,
uint32_t cacheCullSize,
SSLContext* ctx,
const folly::SocketAddress& sockaddr,
const string& context,
EventBase* eventBase,
SSLStats* stats,
const std::shared_ptr<SSLCacheProvider>& externalCache):
ctx_(ctx),
stats_(stats),
externalCache_(externalCache) {
SSL_CTX* sslCtx = ctx->getSSLCtx();
SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this);
SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback);
SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback);
SSL_CTX_sess_set_remove_cb(sslCtx,
SSLSessionCacheManager::removeSessionCallback);
if (!FLAGS_dcache_unit_test && !context.empty()) {
// Use the passed in context
SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(),
std::min((int)context.length(),
SSL_MAX_SSL_SESSION_ID_LENGTH));
}
SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL
| SSL_SESS_CACHE_SERVER);
localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize,
cacheCullSize);
VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context;
}
SSLSessionCacheManager::~SSLSessionCacheManager() {
}
void SSLSessionCacheManager::shutdown() {
std::lock_guard<std::mutex> g(sCacheLock_);
sCache_.reset();
}
shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::getLocalCache(
uint32_t maxCacheSize,
uint32_t cacheCullSize) {
std::lock_guard<std::mutex> g(sCacheLock_);
if (!sCache_) {
sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS,
maxCacheSize,
cacheCullSize));
}
return sCache_;
}
int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
SSLSessionCacheManager* manager = nullptr;
SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
if (manager == nullptr) {
LOG(FATAL) << "Null SSLSessionCacheManager in callback";
return -1;
}
return manager->newSession(ssl, session);
}
int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) {
string sessionId((char*)session->session_id, session->session_id_length);
VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId);
if (stats_) {
stats_->recordSSLSession(true /* new session */, false, false);
}
localCache_->storeSession(sessionId, session, stats_);
if (externalCache_) {
VLOG(4) << "New SSL session: send session to external cache; id=" <<
SSLUtil::hexlify(sessionId);
storeCacheRecord(sessionId, session);
}
return 1;
}
void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx,
SSL_SESSION* session) {
SSLSessionCacheManager* manager = nullptr;
manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
if (manager == nullptr) {
LOG(FATAL) << "Null SSLSessionCacheManager in callback";
return;
}
return manager->removeSession(ctx, session);
}
void SSLSessionCacheManager::removeSession(SSL_CTX* ctx,
SSL_SESSION* session) {
string sessionId((char*)session->session_id, session->session_id_length);
// This hook is only called from SSL when the internal session cache needs to
// flush sessions. Since we run with the internal cache disabled, this should
// never be called
VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId);
localCache_->removeSession(sessionId);
if (stats_) {
stats_->recordSSLSessionRemove();
}
}
SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl,
unsigned char* sess_id,
int id_len,
int* copyflag) {
SSLSessionCacheManager* manager = nullptr;
SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
if (manager == nullptr) {
LOG(FATAL) << "Null SSLSessionCacheManager in callback";
return nullptr;
}
return manager->getSession(ssl, sess_id, id_len, copyflag);
}
SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl,
unsigned char* session_id,
int id_len,
int* copyflag) {
VLOG(7) << "SSL get session callback";
SSL_SESSION* session = nullptr;
bool foreign = false;
char const* missReason = nullptr;
if (id_len < MIN_SESSION_ID_LENGTH) {
// We didn't generate this session so it's going to be a miss.
// This doesn't get logged or counted in the stats.
return nullptr;
}
string sessionId((char*)session_id, id_len);
AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
assert(sslSocket != nullptr);
// look it up in the local cache first
session = localCache_->lookupSession(sessionId);
#ifdef SSL_SESSION_CB_WOULD_BLOCK
if (session == nullptr && externalCache_) {
// external cache might have the session
foreign = true;
if (!SSL_want_sess_cache_lookup(ssl)) {
missReason = "reason: No async cache support;";
} else {
PendingLookupMap::iterator pit = pendingLookups_.find(sessionId);
if (pit == pendingLookups_.end()) {
auto result = pendingLookups_.emplace(sessionId, PendingLookup());
// initiate fetch
VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" <<
sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
if (lookupCacheRecord(sessionId, sslSocket)) {
// response is pending
*copyflag = SSL_SESSION_CB_WOULD_BLOCK;
return nullptr;
} else {
missReason = "reason: failed to send lookup request;";
pendingLookups_.erase(result.first);
}
} else {
// A lookup was already initiated from this thread
if (pit->second.request_in_progress) {
// Someone else initiated the request, attach
VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; "
"fd=" << sslSocket->getFd() << " id=" <<
SSLUtil::hexlify(sessionId);
std::unique_ptr<DelayedDestruction::DestructorGuard> dg(
new DelayedDestruction::DestructorGuard(sslSocket));
pit->second.waiters.push_back(
std::make_pair(sslSocket, std::move(dg)));
*copyflag = SSL_SESSION_CB_WOULD_BLOCK;
return nullptr;
}
// request is complete
session = pit->second.session; // nullptr if our friend didn't have it
if (session != nullptr) {
CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
}
}
}
}
#endif
bool hit = (session != nullptr);
if (stats_) {
stats_->recordSSLSession(false, hit, foreign);
}
if (hit) {
sslSocket->setSessionIDResumed(true);
}
VLOG(4) << "Get SSL session [" <<
((hit) ? "Hit" : "Miss") << "]: " <<
((foreign) ? "external" : "local") << " cache; " <<
((missReason != nullptr) ? missReason : "") << "fd=" <<
sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
// We already bumped the refcount
*copyflag = 0;
return session;
}
bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId,
SSL_SESSION* session) {
std::string sessionString;
uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr);
sessionString.resize(sessionLen);
uint8_t* cp = (uint8_t *)sessionString.data();
i2d_SSL_SESSION(session, &cp);
size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx());
return externalCache_->setAsync(sessionId, sessionString,
std::chrono::seconds(expiration));
}
bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId,
AsyncSSLSocket* sslSocket) {
auto cacheCtx = new SSLCacheProvider::CacheContext();
cacheCtx->sessionId = sessionId;
cacheCtx->session = nullptr;
cacheCtx->sslSocket = sslSocket;
cacheCtx->guard.reset(
new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket));
cacheCtx->manager = this;
bool res = externalCache_->getAsync(sessionId, cacheCtx);
if (!res) {
delete cacheCtx;
}
return res;
}
void SSLSessionCacheManager::restartSSLAccept(
const SSLCacheProvider::CacheContext* cacheCtx) {
PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId);
CHECK(pit != pendingLookups_.end());
pit->second.request_in_progress = false;
pit->second.session = cacheCtx->session;
VLOG(7) << "Restart SSL accept";
cacheCtx->sslSocket->restartSSLAccept();
for (const auto& attachedLookup: pit->second.waiters) {
// Wake up anyone else who was waiting for this session
VLOG(4) << "Restart SSL accept (waiters) for fd=" <<
attachedLookup.first->getFd();
attachedLookup.first->restartSSLAccept();
}
pendingLookups_.erase(pit);
}
void SSLSessionCacheManager::onGetSuccess(
SSLCacheProvider::CacheContext* cacheCtx,
const std::string& value) {
const uint8_t* cp = (uint8_t*)value.data();
cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length());
restartSSLAccept(cacheCtx);
/* Insert in the LRU after restarting all clients. The stats logic
* in getSession would treat this as a local hit otherwise.
*/
localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_);
delete cacheCtx;
}
void SSLSessionCacheManager::onGetFailure(
SSLCacheProvider::CacheContext* cacheCtx) {
restartSSLAccept(cacheCtx);
delete cacheCtx;
}
} // namespace
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/experimental/wangle/ssl/SSLCacheProvider.h>
#include <folly/experimental/wangle/ssl/SSLStats.h>
#include <folly/EvictingCacheMap.h>
#include <mutex>
#include <folly/io/async/AsyncSSLSocket.h>
namespace folly {
class SSLStats;
/**
* Basic SSL session cache map: Maps session id -> session
*/
typedef folly::EvictingCacheMap<std::string, SSL_SESSION*> SSLSessionCacheMap;
/**
* Holds an SSLSessionCacheMap and associated lock
*/
class LocalSSLSessionCache: private boost::noncopyable {
public:
LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize);
~LocalSSLSessionCache() {
std::lock_guard<std::mutex> g(lock);
// EvictingCacheMap dtor doesn't free values
sessionCache.clear();
}
SSLSessionCacheMap sessionCache;
std::mutex lock;
uint32_t removedSessions_{0};
private:
void pruneSessionCallback(const std::string& sessionId,
SSL_SESSION* session);
};
/**
* A sharded LRU for SSL sessions. The sharding is inteneded to reduce
* contention for the LRU locks. Assuming uniform distribution, two workers
* will contend for the same lock with probability 1 / n_buckets^2.
*/
class ShardedLocalSSLSessionCache : private boost::noncopyable {
public:
ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize,
uint32_t cacheCullSize) {
CHECK(n_buckets > 0);
maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets);
cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets);
if (maxCacheSize == 0) {
maxCacheSize = 1;
}
if (cacheCullSize == 0) {
cacheCullSize = 1;
}
for (uint32_t i = 0; i < n_buckets; i++) {
caches_.push_back(
std::unique_ptr<LocalSSLSessionCache>(
new LocalSSLSessionCache(maxCacheSize, cacheCullSize)));
}
}
SSL_SESSION* lookupSession(const std::string& sessionId) {
size_t bucket = hash(sessionId);
SSL_SESSION* session = nullptr;
std::lock_guard<std::mutex> g(caches_[bucket]->lock);
auto itr = caches_[bucket]->sessionCache.find(sessionId);
if (itr != caches_[bucket]->sessionCache.end()) {
session = itr->second;
}
if (session) {
CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
}
return session;
}
void storeSession(const std::string& sessionId, SSL_SESSION* session,
SSLStats* stats) {
size_t bucket = hash(sessionId);
SSL_SESSION* oldSession = nullptr;
std::lock_guard<std::mutex> g(caches_[bucket]->lock);
auto itr = caches_[bucket]->sessionCache.find(sessionId);
if (itr != caches_[bucket]->sessionCache.end()) {
oldSession = itr->second;
}
if (oldSession) {
// LRUCacheMap doesn't free on overwrite, so 2x the work for us
// This can happen in race conditions
SSL_SESSION_free(oldSession);
}
caches_[bucket]->removedSessions_ = 0;
caches_[bucket]->sessionCache.set(sessionId, session, true);
if (stats) {
stats->recordSSLSessionFree(caches_[bucket]->removedSessions_);
}
}
void removeSession(const std::string& sessionId) {
size_t bucket = hash(sessionId);
std::lock_guard<std::mutex> g(caches_[bucket]->lock);
caches_[bucket]->sessionCache.erase(sessionId);
}
private:
/* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */
size_t hash(const std::string& key) {
CHECK(key.length() >= 2);
return (key[0] << 8 | key[1]) % caches_.size();
}
std::vector< std::unique_ptr<LocalSSLSessionCache> > caches_;
};
/* A socket/DestructorGuard pair */
typedef std::pair<AsyncSSLSocket *,
std::unique_ptr<DelayedDestruction::DestructorGuard>>
AttachedLookup;
/**
* PendingLookup structure
*
* Keeps track of clients waiting for an SSL session to be retrieved from
* the external cache provider.
*/
struct PendingLookup {
bool request_in_progress;
SSL_SESSION* session;
std::list<AttachedLookup> waiters;
PendingLookup() {
request_in_progress = true;
session = nullptr;
}
};
/* Maps SSL session id to a PendingLookup structure */
typedef std::map<std::string, PendingLookup> PendingLookupMap;
/**
* SSLSessionCacheManager handles all stateful session caching. There is an
* instance of this object per SSL VIP per thread, with a 1:1 correlation with
* SSL_CTX. The cache can work locally or in concert with an external cache
* to share sessions across instances.
*
* There is a single in memory session cache shared by all VIPs. The cache is
* split into N buckets (currently 16) with a separate lock per bucket. The
* VIP ID is hashed and stored as part of the session to handle the
* (very unlikely) case of session ID collision.
*
* When a new SSL session is created, it is added to the LRU cache and
* sent to the external cache to be stored. The external cache
* expiration is equal to the SSL session's expiration.
*
* When a resume request is received, SSLSessionCacheManager first looks in the
* local LRU cache for the VIP. If there is a miss there, an asynchronous
* request for this session is dispatched to the external cache. When the
* external cache query returns, the LRU cache is updated if the session was
* found, and the SSL_accept call is resumed.
*
* If additional resume requests for the same session ID arrive in the same
* thread while the request is pending, the 2nd - Nth callers attach to the
* original external cache requests and are resumed when it comes back. No
* attempt is made to coalesce external cache requests for the same session
* ID in different worker threads. Previous work did this, but the
* complexity was deemed to outweigh the potential savings.
*
*/
class SSLSessionCacheManager : private boost::noncopyable {
public:
/**
* Constructor. SSL session related callbacks will be set on the underlying
* SSL_CTX. vipId is assumed to a unique string identifying the VIP and must
* be the same on all servers that wish to share sessions via the same
* external cache.
*/
SSLSessionCacheManager(
uint32_t maxCacheSize,
uint32_t cacheCullSize,
SSLContext* ctx,
const folly::SocketAddress& sockaddr,
const std::string& context,
EventBase* eventBase,
SSLStats* stats,
const std::shared_ptr<SSLCacheProvider>& externalCache);
virtual ~SSLSessionCacheManager();
/**
* Call this on shutdown to release the global instance of the
* ShardedLocalSSLSessionCache.
*/
static void shutdown();
/**
* Callback for ExternalCache to call when an async get succeeds
* @param context The context that was passed to the async get request
* @param value Serialized session
*/
void onGetSuccess(SSLCacheProvider::CacheContext* context,
const std::string& value);
/**
* Callback for ExternalCache to call when an async get fails, either
* because the requested session is not in the external cache or because
* of an error.
* @param context The context that was passed to the async get request
*/
void onGetFailure(SSLCacheProvider::CacheContext* context);
private:
SSLContext* ctx_;
std::shared_ptr<ShardedLocalSSLSessionCache> localCache_;
PendingLookupMap pendingLookups_;
SSLStats* stats_{nullptr};
std::shared_ptr<SSLCacheProvider> externalCache_;
/**
* Invoked by openssl when a new SSL session is created
*/
int newSession(SSL* ssl, SSL_SESSION* session);
/**
* Invoked by openssl when an SSL session is ejected from its internal cache.
* This can't be invoked in the current implementation because SSL's internal
* caching is disabled.
*/
void removeSession(SSL_CTX* ctx, SSL_SESSION* session);
/**
* Invoked by openssl when a client requests a stateful session resumption.
* Triggers a lookup in our local cache and potentially an asynchronous
* request to an external cache.
*/
SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id,
int id_len, int* copyflag);
/**
* Store a new session record in the external cache
*/
bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session);
/**
* Lookup a session in the external cache for the specified SSL socket.
*/
bool lookupCacheRecord(const std::string& sessionId,
AsyncSSLSocket* sslSock);
/**
* Restart all clients waiting for the answer to an external cache query
*/
void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx);
/**
* Get or create the LRU cache for the given VIP ID
*/
static std::shared_ptr<ShardedLocalSSLSessionCache> getLocalCache(
uint32_t maxCacheSize, uint32_t cacheCullSize);
/**
* static functions registered as callbacks to openssl via
* SSL_CTX_sess_set_new/get/remove_cb
*/
static int newSessionCallback(SSL* ssl, SSL_SESSION* session);
static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session);
static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id,
int id_len, int* copyflag);
static int32_t sExDataIndex_;
static std::shared_ptr<ShardedLocalSSLSessionCache> sCache_;
static std::mutex sCacheLock_;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
namespace folly {
class SSLStats {
public:
virtual ~SSLStats() noexcept {}
// downstream
virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0;
virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0;
virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign)
noexcept = 0;
virtual void recordSSLSessionRemove() noexcept = 0;
virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0;
virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0;
virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0;
virtual void recordClientRenegotiation() noexcept = 0;
// upstream
virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0;
virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0;
virtual void recordCryptoSSLExternalAttempt() noexcept = 0;
virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0;
virtual void recordCryptoSSLExternalApplicationException() noexcept = 0;
virtual void recordCryptoSSLExternalSuccess() noexcept = 0;
virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0;
virtual void recordCryptoSSLLocalAttempt() noexcept = 0;
virtual void recordCryptoSSLLocalSuccess() noexcept = 0;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/ssl/SSLUtil.h>
#include <folly/Memory.h>
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL
#define OPENSSL_GE_101 1
#include <openssl/asn1.h>
#include <openssl/x509v3.h>
#else
#undef OPENSSL_GE_101
#endif
namespace folly {
std::mutex SSLUtil::sIndexLock_;
std::unique_ptr<std::string> SSLUtil::getCommonName(const X509* cert) {
X509_NAME* subject = X509_get_subject_name((X509*)cert);
if (!subject) {
return nullptr;
}
char cn[ub_common_name + 1];
int res = X509_NAME_get_text_by_NID(subject, NID_commonName,
cn, ub_common_name);
if (res <= 0) {
return nullptr;
} else {
cn[ub_common_name] = '\0';
return folly::make_unique<std::string>(cn);
}
}
std::unique_ptr<std::list<std::string>> SSLUtil::getSubjectAltName(
const X509* cert) {
#ifdef OPENSSL_GE_101
auto nameList = folly::make_unique<std::list<std::string>>();
GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i(
(X509*)cert, NID_subject_alt_name, nullptr, nullptr);
if (names) {
auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); });
size_t count = sk_GENERAL_NAME_num(names);
CHECK(count < std::numeric_limits<int>::max());
for (int i = 0; i < (int)count; ++i) {
GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i);
if (generalName->type == GEN_DNS) {
ASN1_STRING* s = generalName->d.dNSName;
const char* name = (const char*)ASN1_STRING_data(s);
// I can't find any docs on what a negative return value here
// would mean, so I'm going to ignore it.
auto len = ASN1_STRING_length(s);
DCHECK(len >= 0);
if (size_t(len) != strlen(name)) {
// Null byte(s) in the name; return an error rather than depending on
// the caller to safely handle this case.
return nullptr;
}
nameList->emplace_back(name);
}
}
}
return nameList;
#else
return nullptr;
#endif
}
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/String.h>
#include <mutex>
#include <folly/io/async/AsyncSSLSocket.h>
namespace folly {
/**
* SSL session establish/resume status
*
* changing these values will break logging pipelines
*/
enum class SSLResumeEnum : uint8_t {
HANDSHAKE = 0,
RESUME_SESSION_ID = 1,
RESUME_TICKET = 3,
NA = 2
};
enum class SSLErrorEnum {
NO_ERROR,
TIMEOUT,
DROPPED
};
class SSLUtil {
private:
static std::mutex sIndexLock_;
public:
/**
* Ensures only one caller will allocate an ex_data index for a given static
* or global.
*/
static void getSSLCtxExIndex(int* pindex) {
std::lock_guard<std::mutex> g(sIndexLock_);
if (*pindex < 0) {
*pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
}
}
static void getRSAExIndex(int* pindex) {
std::lock_guard<std::mutex> g(sIndexLock_);
if (*pindex < 0) {
*pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
}
}
static inline std::string hexlify(const std::string& binary) {
std::string hex;
folly::hexlify<std::string, std::string>(binary, hex);
return hex;
}
static inline const std::string& hexlify(const std::string& binary,
std::string& hex) {
folly::hexlify<std::string, std::string>(binary, hex);
return hex;
}
/**
* Return the SSL resume type for the given socket.
*/
static inline SSLResumeEnum getResumeState(
AsyncSSLSocket* sslSocket) {
return sslSocket->getSSLSessionReused() ?
(sslSocket->sessionIDResumed() ?
SSLResumeEnum::RESUME_SESSION_ID :
SSLResumeEnum::RESUME_TICKET) :
SSLResumeEnum::HANDSHAKE;
}
/**
* Get the Common Name from an X.509 certificate
* @param cert certificate to inspect
* @return common name, or null if an error occurs
*/
static std::unique_ptr<std::string> getCommonName(const X509* cert);
/**
* Get the Subject Alternative Name value(s) from an X.509 certificate
* @param cert certificate to inspect
* @return set of zero or more alternative names, or null if
* an error occurs
*/
static std::unique_ptr<std::list<std::string>> getSubjectAltName(
const X509* cert);
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/experimental/wangle/ssl/TLSTicketKeyManager.h>
#include <folly/experimental/wangle/ssl/SSLStats.h>
#include <folly/experimental/wangle/ssl/SSLUtil.h>
#include <folly/String.h>
#include <openssl/aes.h>
#include <openssl/rand.h>
#include <openssl/ssl.h>
#include <folly/io/async/AsyncTimeout.h>
#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
using std::string;
namespace {
const int kTLSTicketKeyNameLen = 4;
const int kTLSTicketKeySaltLen = 12;
}
namespace folly {
// TLSTicketKeyManager Implementation
int32_t TLSTicketKeyManager::sExDataIndex_ = -1;
TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats)
: ctx_(ctx),
randState_(0),
stats_(stats) {
SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this);
}
TLSTicketKeyManager::~TLSTicketKeyManager() {
}
int
TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName,
unsigned char* iv,
EVP_CIPHER_CTX* cipherCtx,
HMAC_CTX* hmacCtx, int encrypt) {
TLSTicketKeyManager* manager = nullptr;
SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
if (manager == nullptr) {
LOG(FATAL) << "Null TLSTicketKeyManager in callback" ;
return -1;
}
return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt);
}
int
TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName,
unsigned char* iv,
EVP_CIPHER_CTX* cipherCtx,
HMAC_CTX* hmacCtx, int encrypt) {
uint8_t salt[kTLSTicketKeySaltLen];
uint8_t* saltptr = nullptr;
uint8_t output[SHA256_DIGEST_LENGTH];
uint8_t* hmacKey = nullptr;
uint8_t* aesKey = nullptr;
TLSTicketKeySource* key = nullptr;
int result = 0;
if (encrypt) {
key = findEncryptionKey();
if (key == nullptr) {
// no keys available to encrypt
VLOG(2) << "No TLS ticket key found";
return -1;
}
VLOG(4) << "Encrypting new ticket with key name=" <<
SSLUtil::hexlify(key->keyName_);
// Get a random salt and write out key name
RAND_pseudo_bytes(salt, (int)sizeof(salt));
memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen);
memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen);
// Create the unique keys by hashing with the salt
makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output);
// This relies on the fact that SHA256 has 32 bytes of output
// and that AES-128 keys are 16 bytes
hmacKey = output;
aesKey = output + SHA256_DIGEST_LENGTH / 2;
// Initialize iv and cipher/mac CTX
RAND_pseudo_bytes(iv, AES_BLOCK_SIZE);
HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
EVP_sha256(), nullptr);
EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
result = 1;
} else {
key = findDecryptionKey(keyName);
if (key == nullptr) {
// no ticket found for decryption - will issue a new ticket
if (VLOG_IS_ON(4)) {
string skeyName((char *)keyName, kTLSTicketKeyNameLen);
VLOG(4) << "Can't find ticket key with name=" <<
SSLUtil::hexlify(skeyName)<< ", will generate new ticket";
}
result = 0;
} else {
VLOG(4) << "Decrypting ticket with key name=" <<
SSLUtil::hexlify(key->keyName_);
// Reconstruct the unique key via the salt
saltptr = keyName + kTLSTicketKeyNameLen;
makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output);
hmacKey = output;
aesKey = output + SHA256_DIGEST_LENGTH / 2;
// Initialize cipher/mac CTX
HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2,
EVP_sha256(), nullptr);
EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv);
result = 1;
}
}
// result records whether a ticket key was found to decrypt this ticket,
// not wether the session was re-used.
if (stats_) {
stats_->recordTLSTicket(encrypt, result);
}
return result;
}
bool
TLSTicketKeyManager::setTLSTicketKeySeeds(
const std::vector<std::string>& oldSeeds,
const std::vector<std::string>& currentSeeds,
const std::vector<std::string>& newSeeds) {
bool result = true;
activeKeys_.clear();
ticketKeys_.clear();
ticketSeeds_.clear();
const std::vector<string> *seedList = &oldSeeds;
for (uint32_t i = 0; i < 3; i++) {
TLSTicketSeedType type = (TLSTicketSeedType)i;
if (type == SEED_CURRENT) {
seedList = &currentSeeds;
} else if (type == SEED_NEW) {
seedList = &newSeeds;
}
for (const auto& seedInput: *seedList) {
TLSTicketSeed* seed = insertSeed(seedInput, type);
if (seed == nullptr) {
result = false;
continue;
}
insertNewKey(seed, 1, nullptr);
}
}
if (!result) {
VLOG(2) << "One or more seeds failed to decode";
}
if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) {
LOG(WARNING) << "No keys configured, falling back to default";
SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr);
return false;
}
SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(),
TLSTicketKeyManager::callback);
return true;
}
string
TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n,
unsigned char* nameBuf) {
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_));
SHA256_Update(&ctx, &n, sizeof(n));
SHA256_Final(nameBuf, &ctx);
return string((char *)nameBuf, kTLSTicketKeyNameLen);
}
TLSTicketKeyManager::TLSTicketKeySource*
TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
TLSTicketKeySource* prevKey) {
unsigned char nameBuf[SHA256_DIGEST_LENGTH];
std::unique_ptr<TLSTicketKeySource> newKey(new TLSTicketKeySource);
// This function supports hash chaining but it is not currently used.
if (prevKey != nullptr) {
hashNth(prevKey->keySource_, sizeof(prevKey->keySource_),
newKey->keySource_, 1);
} else {
// can't go backwards or the current is missing, start from the beginning
hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(),
newKey->keySource_, hashCount);
}
newKey->hashCount_ = hashCount;
newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf);
newKey->type_ = seed->type_;
auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_,
std::move(newKey)));
auto key = it.first->second.get();
if (key->type_ == SEED_CURRENT) {
activeKeys_.push_back(key);
}
VLOG(4) << "Adding key for " << hashCount << " type=" <<
(uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_);
return key;
}
void
TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len,
unsigned char* output, uint32_t n) {
assert(n > 0);
for (uint32_t i = 0; i < n; i++) {
SHA256(input, input_len, output);
input = output;
input_len = SHA256_DIGEST_LENGTH;
}
}
TLSTicketKeyManager::TLSTicketSeed *
TLSTicketKeyManager::insertSeed(const string& seedInput,
TLSTicketSeedType type) {
TLSTicketSeed* seed = nullptr;
string seedOutput;
if (!folly::unhexlify<string, string>(seedInput, seedOutput)) {
LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type <<
" seed=" << seedInput;
return seed;
}
seed = new TLSTicketSeed();
seed->seed_ = seedOutput;
seed->type_ = type;
SHA256((unsigned char *)seedOutput.data(), seedOutput.length(),
seed->seedName_);
ticketSeeds_.push_back(std::unique_ptr<TLSTicketSeed>(seed));
return seed;
}
TLSTicketKeyManager::TLSTicketKeySource *
TLSTicketKeyManager::findEncryptionKey() {
TLSTicketKeySource* result = nullptr;
// call to rand here is a bit hokey since it's not cryptographically
// random, and is predictably seeded with 0. However, activeKeys_
// is probably not going to have very many keys in it, and most
// likely only 1.
size_t numKeys = activeKeys_.size();
if (numKeys > 0) {
result = activeKeys_[rand_r(&randState_) % numKeys];
}
return result;
}
TLSTicketKeyManager::TLSTicketKeySource *
TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) {
string name((char *)keyName, kTLSTicketKeyNameLen);
TLSTicketKeySource* key = nullptr;
TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name);
if (mapit != ticketKeys_.end()) {
key = mapit->second.get();
}
return key;
}
void
TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey,
size_t keyLen,
unsigned char* salt,
unsigned char* output) {
SHA256_CTX hash_ctx;
SHA256_Init(&hash_ctx);
SHA256_Update(&hash_ctx, parentKey, keyLen);
SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen);
SHA256_Final(output, &hash_ctx);
}
} // namespace
#endif
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/EventBase.h>
namespace folly {
#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB
class TLSTicketKeyManager {};
#else
class SSLStats;
/**
* The TLSTicketKeyManager handles TLS ticket key encryption and decryption in
* a way that facilitates sharing the ticket keys across a range of servers.
* Hash chaining is employed to achieve frequent key rotation with minimal
* configuration change. The scheme is as follows:
*
* The manager is supplied with three lists of seeds (old, current and new).
* The config should be updated with new seeds periodically (e.g., daily).
* 3 config changes are recommended to achieve the smoothest seed rotation
* eg:
* 1. Introduce new seed in the push prior to rotation
* 2. Rotation push
* 3. Remove old seeds in the push following rotation
*
* Multiple seeds are supported but only a single seed is required.
*
* Generating encryption keys from the seed works as follows. For a given
* seed, hash forward N times where N is currently the constant 1.
* This is the base key. The name of the base key is the first 4
* bytes of hash(hash(seed), N). This is copied into the first 4 bytes of the
* TLS ticket key name field.
*
* For each new ticket encryption, the manager generates a random 12 byte salt.
* Hash the salt and the base key together to form the encryption key for
* that ticket. The salt is included in the ticket's 'key name' field so it
* can be used to derive the decryption key. The salt is copied into the second
* 8 bytes of the TLS ticket key name field.
*
* A key is valid for decryption for the lifetime of the instance.
* Sessions will be valid for less time than that, which results in an extra
* symmetric decryption to discover the session is expired.
*
* A TLSTicketKeyManager should be used in only one thread, and should have
* a 1:1 relationship with the SSLContext provided.
*
*/
class TLSTicketKeyManager : private boost::noncopyable {
public:
explicit TLSTicketKeyManager(folly::SSLContext* ctx,
SSLStats* stats);
virtual ~TLSTicketKeyManager();
/**
* SSL callback to set up encryption/decryption context for a TLS Ticket Key.
*
* This will be supplied to the SSL library via
* SSL_CTX_set_tlsext_ticket_key_cb.
*/
static int callback(SSL* ssl, unsigned char* keyName,
unsigned char* iv,
EVP_CIPHER_CTX* cipherCtx,
HMAC_CTX* hmacCtx, int encrypt);
/**
* Initialize the manager with three sets of seeds. There must be at least
* one current seed, or the manager will revert to the default SSL behavior.
*
* @param oldSeeds Seeds previously used which can still decrypt.
* @param currentSeeds Seeds to use for new ticket encryptions.
* @param newSeeds Seeds which will be used soon, can be used to decrypt
* in case some servers in the cluster have already rotated.
*/
bool setTLSTicketKeySeeds(const std::vector<std::string>& oldSeeds,
const std::vector<std::string>& currentSeeds,
const std::vector<std::string>& newSeeds);
private:
enum TLSTicketSeedType {
SEED_OLD = 0,
SEED_CURRENT,
SEED_NEW
};
/* The seeds supplied by the configuration */
struct TLSTicketSeed {
std::string seed_;
TLSTicketSeedType type_;
unsigned char seedName_[SHA256_DIGEST_LENGTH];
};
struct TLSTicketKeySource {
int32_t hashCount_;
std::string keyName_;
TLSTicketSeedType type_;
unsigned char keySource_[SHA256_DIGEST_LENGTH];
};
/**
* Method to setup encryption/decryption context for a TLS Ticket Key
*
* OpenSSL documentation is thin on the return value semantics.
*
* For encrypt=1, return < 0 on error, >= 0 for successfully initialized
* For encrypt=0, return < 0 on error, 0 on key not found
* 1 on key found, 2 renew_ticket
*
* renew_ticket means a new ticket will be issued. We could return this value
* when receiving a ticket encrypted with a key derived from an OLD seed.
* However, session_timeout seconds after deploying with a seed
* rotated from CURRENT -> OLD, there will be no valid tickets outstanding
* encrypted with the old key. This grace period means no unnecessary
* handshakes will be performed. If the seed is believed compromised, it
* should NOT be configured as an OLD seed.
*/
int processTicket(SSL* ssl, unsigned char* keyName,
unsigned char* iv,
EVP_CIPHER_CTX* cipherCtx,
HMAC_CTX* hmacCtx, int encrypt);
// Creates the name for the nth key generated from seed
std::string makeKeyName(TLSTicketSeed* seed, uint32_t n,
unsigned char* nameBuf);
/**
* Creates the key hashCount hashes from the given seed and inserts it in
* ticketKeys. A naked pointer to the key is returned for additional
* processing if needed.
*/
TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount,
TLSTicketKeySource* prevKeySource);
/**
* hashes input N times placing result in output, which must be at least
* SHA256_DIGEST_LENGTH long.
*/
void hashNth(const unsigned char* input, size_t input_len,
unsigned char* output, uint32_t n);
/**
* Adds the given seed to the manager
*/
TLSTicketSeed* insertSeed(const std::string& seedInput,
TLSTicketSeedType type);
/**
* Locate a key for encrypting a new ticket
*/
TLSTicketKeySource* findEncryptionKey();
/**
* Locate a key for decrypting a ticket with the given keyName
*/
TLSTicketKeySource* findDecryptionKey(unsigned char* keyName);
/**
* Derive a unique key from the parent key and the salt via hashing
*/
void makeUniqueKeys(unsigned char* parentKey, size_t keyLen,
unsigned char* salt, unsigned char* output);
/**
* For standalone decryption utility
*/
friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager,
const std::string& testTicket,
SSL_SESSION **psess);
typedef std::vector<std::unique_ptr<TLSTicketSeed>> TLSTicketSeedList;
typedef std::map<std::string, std::unique_ptr<TLSTicketKeySource> >
TLSTicketKeyMap;
typedef std::vector<TLSTicketKeySource *> TLSActiveKeyList;
TLSTicketSeedList ticketSeeds_;
// All key sources that can be used for decryption
TLSTicketKeyMap ticketKeys_;
// Key sources that can be used for encryption
TLSActiveKeyList activeKeys_;
folly::SSLContext* ctx_;
uint32_t randState_;
SSLStats* stats_{nullptr};
static int32_t sExDataIndex_;
};
#endif
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#pragma once
namespace folly {
struct TLSTicketKeySeeds {
std::vector<std::string> oldSeeds;
std::vector<std::string> currentSeeds;
std::vector<std::string> newSeeds;
};
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/Portability.h>
#include <folly/io/async/EventBase.h>
#include <gflags/gflags.h>
#include <iostream>
#include <thread>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <vector>
using namespace std;
using namespace folly;
DEFINE_int32(clients, 1, "Number of simulated SSL clients");
DEFINE_int32(threads, 1, "Number of threads to spread clients across");
DEFINE_int32(requests, 2, "Total number of requests per client");
DEFINE_int32(port, 9423, "Server port");
DEFINE_bool(sticky, false, "A given client sends all reqs to one "
"(random) server");
DEFINE_bool(global, false, "All clients in a thread use the same SSL session");
DEFINE_bool(handshakes, false, "Force 100% handshakes");
string f_servers[10];
int f_num_servers = 0;
int tnum = 0;
class ClientRunner {
public:
ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {}
void run();
int reqs;
int hits;
int miss;
int num;
};
class SSLCacheClient : public AsyncSocket::ConnectCallback,
public AsyncSSLSocket::HandshakeCB
{
private:
EventBase* eventBase_;
int currReq_;
int serverIdx_;
AsyncSocket* socket_;
AsyncSSLSocket* sslSocket_;
SSL_SESSION* session_;
SSL_SESSION **pSess_;
std::shared_ptr<SSLContext> ctx_;
ClientRunner* cr_;
public:
SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr);
~SSLCacheClient() {
if (session_ && !FLAGS_global)
SSL_SESSION_free(session_);
if (socket_ != nullptr) {
if (sslSocket_ != nullptr) {
sslSocket_->destroy();
sslSocket_ = nullptr;
}
socket_->destroy();
socket_ = nullptr;
}
};
void start();
virtual void connectSuccess() noexcept;
virtual void connectErr(const AsyncSocketException& ex)
noexcept ;
virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept;
virtual void handshakeErr(
AsyncSSLSocket* sock,
const AsyncSocketException& ex) noexcept;
};
int
main(int argc, char* argv[])
{
gflags::SetUsageMessage(std::string("\n\n"
"usage: sslcachetest [options] -c <clients> -t <threads> servers\n"
));
gflags::ParseCommandLineFlags(&argc, &argv, true);
int reqs = 0;
int hits = 0;
int miss = 0;
struct timeval start;
struct timeval end;
struct timeval result;
srand((unsigned int)time(nullptr));
for (int i = 1; i < argc; i++) {
f_servers[f_num_servers++] = argv[i];
}
if (f_num_servers == 0) {
cout << "require at least one server\n";
return 1;
}
gettimeofday(&start, nullptr);
if (FLAGS_threads == 1) {
ClientRunner r;
r.run();
gettimeofday(&end, nullptr);
reqs = r.reqs;
hits = r.hits;
miss = r.miss;
}
else {
std::vector<ClientRunner> clients;
std::vector<std::thread> threads;
for (int t = 0; t < FLAGS_threads; t++) {
threads.emplace_back([&] {
clients[t].run();
});
}
for (auto& thr: threads) {
thr.join();
}
gettimeofday(&end, nullptr);
for (const auto& client: clients) {
reqs += client.reqs;
hits += client.hits;
miss += client.miss;
}
}
timersub(&end, &start, &result);
cout << "Requests: " << reqs << endl;
cout << "Handshakes: " << miss << endl;
cout << "Resumes: " << hits << endl;
cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 <<
endl;
cout << "ops/sec: " << (reqs * 1.0) /
((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl;
return 0;
}
void
ClientRunner::run()
{
EventBase eb;
std::list<SSLCacheClient *> clients;
SSL_SESSION* session = nullptr;
for (int i = 0; i < FLAGS_clients; i++) {
SSLCacheClient* c = new SSLCacheClient(&eb, &session, this);
c->start();
clients.push_back(c);
}
eb.loop();
for (auto it = clients.begin(); it != clients.end(); it++) {
delete* it;
}
reqs += hits + miss;
}
SSLCacheClient::SSLCacheClient(EventBase* eb,
SSL_SESSION **pSess,
ClientRunner* cr)
: eventBase_(eb),
currReq_(0),
serverIdx_(0),
socket_(nullptr),
sslSocket_(nullptr),
session_(nullptr),
pSess_(pSess),
cr_(cr)
{
ctx_.reset(new SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET);
}
void
SSLCacheClient::start()
{
if (currReq_ >= FLAGS_requests) {
cout << "+";
return;
}
if (currReq_ == 0 || !FLAGS_sticky) {
serverIdx_ = rand() % f_num_servers;
}
if (socket_ != nullptr) {
if (sslSocket_ != nullptr) {
sslSocket_->destroy();
sslSocket_ = nullptr;
}
socket_->destroy();
socket_ = nullptr;
}
socket_ = new AsyncSocket(eventBase_);
socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port);
}
void
SSLCacheClient::connectSuccess() noexcept
{
sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(),
false);
if (!FLAGS_handshakes) {
if (session_ != nullptr)
sslSocket_->setSSLSession(session_);
else if (FLAGS_global && pSess_ != nullptr)
sslSocket_->setSSLSession(*pSess_);
}
sslSocket_->sslConn(this);
}
void
SSLCacheClient::connectErr(const AsyncSocketException& ex)
noexcept
{
cout << "connectError: " << ex.what() << endl;
}
void
SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept
{
if (sslSocket_->getSSLSessionReused()) {
cr_->hits++;
} else {
cr_->miss++;
if (session_ != nullptr) {
SSL_SESSION_free(session_);
}
session_ = sslSocket_->getSSLSession();
if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) {
*pSess_ = session_;
}
}
if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) {
cout << ".";
cout.flush();
}
sslSocket_->closeNow();
currReq_++;
this->start();
}
void
SSLCacheClient::handshakeErr(
AsyncSSLSocket* sock,
const AsyncSocketException& ex)
noexcept
{
cout << "handshakeError: " << ex.what() << endl;
}
/*
* Copyright (c) 2014, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <folly/io/async/EventBase.h>
#include <folly/io/async/SSLContext.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <folly/experimental/wangle/ssl/SSLContextManager.h>
#include <folly/experimental/wangle/acceptor/DomainNameMisc.h>
using std::shared_ptr;
namespace folly {
TEST(SSLContextManagerTest, Test1)
{
EventBase eventBase;
SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_",
true, nullptr);
auto www_facebook_com_ctx = std::make_shared<SSLContext>();
auto start_facebook_com_ctx = std::make_shared<SSLContext>();
auto start_abc_facebook_com_ctx = std::make_shared<SSLContext>();
sslCtxMgr.insertSSLCtxByDomainName(
"www.facebook.com",
strlen("www.facebook.com"),
www_facebook_com_ctx);
sslCtxMgr.insertSSLCtxByDomainName(
"www.facebook.com",
strlen("www.facebook.com"),
www_facebook_com_ctx);
try {
sslCtxMgr.insertSSLCtxByDomainName(
"www.facebook.com",
strlen("www.facebook.com"),
std::make_shared<SSLContext>());
} catch (const std::exception& ex) {
}
sslCtxMgr.insertSSLCtxByDomainName(
"*.facebook.com",
strlen("*.facebook.com"),
start_facebook_com_ctx);
sslCtxMgr.insertSSLCtxByDomainName(
"*.abc.facebook.com",
strlen("*.abc.facebook.com"),
start_abc_facebook_com_ctx);
try {
sslCtxMgr.insertSSLCtxByDomainName(
"*.abc.facebook.com",
strlen("*.abc.facebook.com"),
std::make_shared<SSLContext>());
FAIL();
} catch (const std::exception& ex) {
}
shared_ptr<SSLContext> retCtx;
retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com"));
EXPECT_EQ(retCtx, www_facebook_com_ctx);
retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com"));
EXPECT_EQ(retCtx, www_facebook_com_ctx);
EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com")));
retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com"));
EXPECT_EQ(retCtx, start_facebook_com_ctx);
retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com"));
EXPECT_EQ(retCtx, start_facebook_com_ctx);
retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com"));
EXPECT_EQ(retCtx, start_abc_facebook_com_ctx);
// ensure "facebook.com" does not match "*.facebook.com"
EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com")));
// ensure "Xfacebook.com" does not match "*.facebook.com"
EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com")));
// ensure wildcard name only matches one domain up
EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com")));
eventBase.loop(); // Clean up events before SSLContextManager is destructed
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment