Commit 92c9ccb6 authored by Neel Goyal's avatar Neel Goyal Committed by facebook-github-bot-4

Update SSLContext to use discrete_distribution

Summary: Update the protocol pick logic to use discrete_distribution

Reviewed By: siyengar

Differential Revision: D2741855

fb-gh-sync-id: 244bd087124a7a9584a1108fe8f8150093275878
parent 71c140ed
...@@ -84,6 +84,10 @@ SSLContext::SSLContext(SSLVersion version) { ...@@ -84,6 +84,10 @@ SSLContext::SSLContext(SSLVersion version) {
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this); SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif #endif
#ifdef OPENSSL_NPN_NEGOTIATED
Random::seed(nextProtocolPicker_);
#endif
} }
SSLContext::~SSLContext() { SSLContext::~SSLContext() {
...@@ -374,16 +378,16 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( ...@@ -374,16 +378,16 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
dst += protoLength; dst += protoLength;
} }
total_weight += item.weight; total_weight += item.weight;
advertised_item.probability = item.weight;
advertisedNextProtocols_.push_back(advertised_item); advertisedNextProtocols_.push_back(advertised_item);
advertisedNextProtocolWeights_.push_back(item.weight);
} }
if (total_weight == 0) { if (total_weight == 0) {
deleteNextProtocolsStrings(); deleteNextProtocolsStrings();
return false; return false;
} }
for (auto &advertised_item : advertisedNextProtocols_) { nextProtocolDistribution_ =
advertised_item.probability /= total_weight; std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
} advertisedNextProtocolWeights_.end());
if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
SSL_CTX_set_next_protos_advertised_cb( SSL_CTX_set_next_protos_advertised_cb(
ctx_, advertisedNextProtocolCallback, this); ctx_, advertisedNextProtocolCallback, this);
...@@ -406,6 +410,7 @@ void SSLContext::deleteNextProtocolsStrings() { ...@@ -406,6 +410,7 @@ void SSLContext::deleteNextProtocolsStrings() {
delete[] protocols.protocols; delete[] protocols.protocols;
} }
advertisedNextProtocols_.clear(); advertisedNextProtocols_.clear();
advertisedNextProtocolWeights_.clear();
} }
void SSLContext::unsetNextProtocols() { void SSLContext::unsetNextProtocols() {
...@@ -419,18 +424,8 @@ void SSLContext::unsetNextProtocols() { ...@@ -419,18 +424,8 @@ void SSLContext::unsetNextProtocols() {
} }
size_t SSLContext::pickNextProtocols() { size_t SSLContext::pickNextProtocols() {
unsigned char random_byte; CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
RAND_bytes(&random_byte, 1); return nextProtocolDistribution_(nextProtocolPicker_);
double random_value = random_byte / 255.0;
double sum = 0;
for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
sum += advertisedNextProtocols_[i].probability;
if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
continue;
}
return i;
}
CHECK(false) << "Failed to pickNextProtocols";
} }
int SSLContext::advertisedNextProtocolCallback(SSL* ssl, int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <random>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/tls1.h> #include <openssl/tls1.h>
...@@ -35,6 +36,8 @@ ...@@ -35,6 +36,8 @@
#include <folly/folly-config.h> #include <folly/folly-config.h>
#endif #endif
#include <folly/Random.h>
namespace folly { namespace folly {
/** /**
...@@ -87,12 +90,6 @@ class SSLContext { ...@@ -87,12 +90,6 @@ class SSLContext {
std::list<std::string> protocols; std::list<std::string> protocols;
}; };
struct AdvertisedNextProtocolsItem {
unsigned char *protocols;
unsigned length;
double probability;
};
// Function that selects a client protocol given the server's list // Function that selects a client protocol given the server's list
using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*, using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
const unsigned char*, unsigned int); const unsigned char*, unsigned int);
...@@ -458,10 +455,20 @@ class SSLContext { ...@@ -458,10 +455,20 @@ class SSLContext {
static bool initialized_; static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
struct AdvertisedNextProtocolsItem {
unsigned char* protocols;
unsigned length;
};
/** /**
* Wire-format list of advertised protocols for use in NPN. * Wire-format list of advertised protocols for use in NPN.
*/ */
std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_; std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
std::vector<int> advertisedNextProtocolWeights_;
std::discrete_distribution<int> nextProtocolDistribution_;
Random::DefaultGenerator nextProtocolPicker_;
static int sNextProtocolsExDataIndex_; static int sNextProtocolsExDataIndex_;
static int advertisedNextProtocolCallback(SSL* ssl, static int advertisedNextProtocolCallback(SSL* ssl,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment