Commit 8aac0e33 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 7

Always override write bio method

Summary:
Always overriding write bio method
allows us to more cleanly implement
features like eor tracking, support
multiple ssl libraries, and also TFO

Reviewed By: anirudhvr

Differential Revision: D3350482

fbshipit-source-id: ddd2333431f9d636d69c8325b2c18d7cc043b848
parent 9fc18f0a
...@@ -223,16 +223,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) { ...@@ -223,16 +223,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
} }
BIO_METHOD eorAwareBioMethod; BIO_METHOD sslWriteBioMethod;
void* initEorBioMethod(void) { void* initsslWriteBioMethod(void) {
memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod)); memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
// override the bwrite method for MSG_EOR support // override the bwrite method for MSG_EOR support
eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite; sslWriteBioMethod.bwrite = AsyncSSLSocket::bioWrite;
// Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
// set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
// then have specific handlings. The eorAwareBioWrite should be compatible // then have specific handlings. The sslWriteBioWrite should be compatible
// with the one in openssl. // with the one in openssl.
// Return something here to enable AsyncSSLSocket to call this method using // Return something here to enable AsyncSSLSocket to call this method using
...@@ -314,8 +314,8 @@ AsyncSSLSocket::~AsyncSSLSocket() { ...@@ -314,8 +314,8 @@ AsyncSSLSocket::~AsyncSSLSocket() {
void AsyncSSLSocket::init() { void AsyncSSLSocket::init() {
// Do this here to ensure we initialize this once before any use of // Do this here to ensure we initialize this once before any use of
// AsyncSSLSocket instances and not as part of library load. // AsyncSSLSocket instances and not as part of library load.
static const auto eorAwareBioMethodInitializer = initEorBioMethod(); static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
(void)eorAwareBioMethodInitializer; (void)sslWriteBioMethodInitializer;
setup_SSL_CTX(ctx_->getSSLCtx()); setup_SSL_CTX(ctx_->getSSLCtx());
} }
...@@ -401,37 +401,15 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept { ...@@ -401,37 +401,15 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
} }
bool AsyncSSLSocket::isEorTrackingEnabled() const { bool AsyncSSLSocket::isEorTrackingEnabled() const {
if (ssl_ == nullptr) { return trackEor_;
return false;
}
const BIO *wb = SSL_get_wbio(ssl_);
return wb && wb->method == &eorAwareBioMethod;
} }
void AsyncSSLSocket::setEorTracking(bool track) { void AsyncSSLSocket::setEorTracking(bool track) {
BIO *wb = SSL_get_wbio(ssl_); if (trackEor_ != track) {
if (!wb) { trackEor_ = track;
throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
"setting EOR tracking without an initialized "
"BIO");
}
if (track) {
if (wb->method != &eorAwareBioMethod) {
// only do this if we didn't
wb->method = &eorAwareBioMethod;
BIO_set_app_data(wb, this);
appEorByteNo_ = 0; appEorByteNo_ = 0;
minEorRawByteNo_ = 0; minEorRawByteNo_ = 0;
} }
} else if (wb->method == &eorAwareBioMethod) {
wb->method = BIO_s_socket();
BIO_set_app_data(wb, nullptr);
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
} else {
CHECK(wb->method == BIO_s_socket());
}
} }
size_t AsyncSSLSocket::getRawBytesWritten() const { size_t AsyncSSLSocket::getRawBytesWritten() const {
...@@ -703,6 +681,19 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { ...@@ -703,6 +681,19 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
} }
} }
bool AsyncSSLSocket::setupSSLBio() {
auto wb = BIO_new(&sslWriteBioMethod);
if (!wb) {
return false;
}
BIO_set_app_data(wb, this);
BIO_set_fd(wb, fd_, BIO_NOCLOSE);
SSL_set_bio(ssl_, wb, wb);
return true;
}
void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) { const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this); DestructorGuard dg(this);
...@@ -741,9 +732,15 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, ...@@ -741,9 +732,15 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
return failHandshake(__func__, ex); return failHandshake(__func__, ex);
} }
if (!setupSSLBio()) {
sslState_ = STATE_ERROR;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
return failHandshake(__func__, ex);
}
applyVerificationOptions(ssl_); applyVerificationOptions(ssl_);
SSL_set_fd(ssl_, fd_);
if (sslSession_ != nullptr) { if (sslSession_ != nullptr) {
SSL_set_session(ssl_, sslSession_); SSL_set_session(ssl_, sslSession_);
SSL_SESSION_free(sslSession_); SSL_SESSION_free(sslSession_);
...@@ -1010,7 +1007,14 @@ AsyncSSLSocket::handleAccept() noexcept { ...@@ -1010,7 +1007,14 @@ AsyncSSLSocket::handleAccept() noexcept {
<< ", fd=" << fd_ << "): " << e.what(); << ", fd=" << fd_ << "): " << e.what();
return failHandshake(__func__, ex); return failHandshake(__func__, ex);
} }
SSL_set_fd(ssl_, fd_);
if (!setupSSLBio()) {
sslState_ = STATE_ERROR;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
return failHandshake(__func__, ex);
}
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this); SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
applyVerificationOptions(ssl_); applyVerificationOptions(ssl_);
...@@ -1448,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( ...@@ -1448,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
bool eor) { bool eor) {
if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) { if (eor && trackEor_) {
if (appEorByteNo_) { if (appEorByteNo_) {
// cannot track for more than one app byte EOR // cannot track for more than one app byte EOR
CHECK(appEorByteNo_ == appBytesWritten_ + n); CHECK(appEorByteNo_ == appBytesWritten_ + n);
...@@ -1493,34 +1497,37 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { ...@@ -1493,34 +1497,37 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
} }
} }
int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) { int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
int ret; int ret;
struct msghdr msg; struct msghdr msg;
struct iovec iov; struct iovec iov;
int flags = 0; int flags = 0;
AsyncSSLSocket *tsslSock; AsyncSSLSocket* tsslSock;
iov.iov_base = const_cast<char *>(in); iov.iov_base = const_cast<char*>(in);
iov.iov_len = inl; iov.iov_len = inl;
memset(&msg, 0, sizeof(msg)); memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov; msg.msg_iov = &iov;
msg.msg_iovlen = 1; msg.msg_iovlen = 1;
tsslSock = auto appData = BIO_get_app_data(b);
reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b)); CHECK(appData);
if (tsslSock &&
tsslSock->minEorRawByteNo_ && tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
CHECK(tsslSock);
if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags = MSG_EOR; flags = MSG_EOR;
} }
ret = sendmsg(b->num, &msg, flags); ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
BIO_clear_retry_flags(b); BIO_clear_retry_flags(b);
if (ret <= 0) { if (ret <= 0) {
if (BIO_sock_should_retry(ret)) if (BIO_sock_should_retry(ret))
BIO_set_retry_write(b); BIO_set_retry_write(b);
} }
return(ret); return ret;
} }
int AsyncSSLSocket::sslVerifyCallback(int preverifyOk, int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
......
...@@ -652,7 +652,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -652,7 +652,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
static int getSSLExDataIndex(); static int getSSLExDataIndex();
static AsyncSSLSocket* getFromSSL(const SSL *ssl); static AsyncSSLSocket* getFromSSL(const SSL *ssl);
static int eorAwareBioWrite(BIO *b, const char *in, int inl); static int bioWrite(BIO* b, const char* in, int inl);
void resetClientHelloParsing(SSL *ssl); void resetClientHelloParsing(SSL *ssl);
static void clientHelloParsingCallback(int write_p, int version, static void clientHelloParsingCallback(int write_p, int version,
int content_type, const void *buf, size_t len, SSL *ssl, void *arg); int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
...@@ -774,6 +774,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -774,6 +774,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/ */
void applyVerificationOptions(SSL * ssl); void applyVerificationOptions(SSL * ssl);
/**
* Sets up SSL with a custom write bio which intercepts all writes.
*
* @return true, if succeeds and false if there is an error creating the bio.
*/
bool setupSSLBio();
/** /**
* A SSL_write wrapper that understand EOR * A SSL_write wrapper that understand EOR
* *
...@@ -815,6 +822,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -815,6 +822,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// whether the SSL session was resumed using session ID or not // whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false}; bool sessionIDResumed_{false};
// Whether to track EOR or not.
bool trackEor_{false};
// The app byte num that we are tracking for the MSG_EOR // The app byte num that we are tracking for the MSG_EOR
// Only one app EOR byte can be tracked. // Only one app EOR byte can be tracked.
size_t appEorByteNo_{0}; size_t appEorByteNo_{0};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment