Commit 14a19db2 authored by Ranjeeth Dasineni's avatar Ranjeeth Dasineni Committed by facebook-github-bot-9

add callback to specify a client next protocol filter

Summary: From the client perspective, we set the list in order of
preference once and call into openssl to do the selection. This adds
a little more flexibility in that client optionally can customize the
selection for each negotiation. added tests for the no-op case and the
customized case. Feel free to suggest improvements.

Reviewed By: @afrind

Differential Revision: D2489142
parent 1f46d8c5
......@@ -513,13 +513,21 @@ int SSLContext::selectNextProtocolCallback(
}
unsigned char *client;
int client_len;
if (ctx->advertisedNextProtocols_.empty()) {
client = (unsigned char *) "";
client_len = 0;
} else {
client = ctx->advertisedNextProtocols_[0].protocols;
client_len = ctx->advertisedNextProtocols_[0].length;
unsigned int client_len;
bool filtered = false;
auto cpf = ctx->getClientProtocolFilterCallback();
if (cpf) {
filtered = (*cpf)(&client, &client_len, server, server_len);
}
if (!filtered) {
if (ctx->advertisedNextProtocols_.empty()) {
client = (unsigned char *) "";
client_len = 0;
} else {
client = ctx->advertisedNextProtocols_[0].protocols;
client_len = ctx->advertisedNextProtocols_[0].length;
}
}
int retval = SSL_select_next_proto(out, outlen, server, server_len,
......
......@@ -93,6 +93,10 @@ class SSLContext {
double probability;
};
// Function that selects a client protocol given the server's list
using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
const unsigned char*, unsigned int);
/**
* Convenience function to call getErrors() with the current errno value.
*
......@@ -327,6 +331,13 @@ class SSLContext {
bool setRandomizedAdvertisedNextProtocols(
const std::list<NextProtocolsItem>& items);
void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) {
clientProtoFilter_ = cb;
}
ClientProtocolFilterCallback getClientProtocolFilterCallback() {
return clientProtoFilter_;
}
/**
* Disables NPN on this SSL context.
*/
......@@ -431,6 +442,8 @@ class SSLContext {
std::vector<ClientHelloCallback> clientHelloCbs_;
#endif
ClientProtocolFilterCallback clientProtoFilter_{nullptr};
static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED
......
......@@ -127,6 +127,21 @@ void sslsocketpair(
// (*serverSock)->setSendTimeout(100);
}
// client protocol filters
bool clientProtoFilterPickPony(unsigned char** client,
unsigned int* client_len, const unsigned char*, unsigned int ) {
//the protocol string in length prefixed byte string. the
//length byte is not included in the length
static unsigned char p[7] = {6,'p','o','n','i','e','s'};
*client = p;
*client_len = 7;
return true;
}
bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
const unsigned char*, unsigned int) {
return false;
}
/**
* Test connecting to, writing to, reading from, and closing the
......@@ -387,6 +402,64 @@ TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
EXPECT_EQ(selected.compare("blub"), 0);
}
TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub"});
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
EXPECT_EQ(selected.compare("ponies"), 0);
}
TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub"});
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
EXPECT_EQ(selected.compare("blub"), 0);
}
TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
// Probability that this test will fail is 2^-64, which could be considered
// as negligible.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment