Commit 4e0e47bb authored by Kyle Nekritz's avatar Kyle Nekritz Committed by Facebook Github Bot 9

Fix AsyncSSLSocket handshake error reporting.

Summary:https://www.openssl.org/docs/manmaster/ssl/SSL_get_error.html
OpenSSL errors are a pain to deal with and we were handling several cases
incorrectly, resulting in a ton of "DH lib" errors when none were likely
actually DH lib errors.

Reviewed By: siyengar

Differential Revision: D2999084

fb-gh-sync-id: b3182be2c199f79ed341af7dbf7524197a838584
shipit-source-id: b3182be2c199f79ed341af7dbf7524197a838584
parent 4f782bbf
...@@ -246,15 +246,38 @@ void* initEorBioMethod(void) { ...@@ -246,15 +246,38 @@ void* initEorBioMethod(void) {
return nullptr; return nullptr;
} }
std::string decodeOpenSSLError(int sslError,
unsigned long errError,
int sslOperationReturnValue) {
if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
if (sslOperationReturnValue == 0) {
return "SSL_ERROR_SYSCALL: EOF";
} else {
// In this case errno is set, AsyncSocketException will add it.
return "SSL_ERROR_SYSCALL";
}
} else if (sslError == SSL_ERROR_ZERO_RETURN) {
// This signifies a TLS closure alert.
return "SSL_ERROR_ZERO_RETURN";
} else {
char buf[256];
std::string msg(ERR_error_string(errError, buf));
return msg;
}
}
} // anonymous namespace } // anonymous namespace
namespace folly { namespace folly {
SSLException::SSLException(int sslError, int errno_copy): SSLException::SSLException(int sslError,
AsyncSocketException( unsigned long errError,
int sslOperationReturnValue,
int errno_copy)
: AsyncSocketException(
AsyncSocketException::SSL_ERROR, AsyncSocketException::SSL_ERROR,
ERR_error_string(sslError, msg_), decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {} sslError == SSL_ERROR_SYSCALL ? errno_copy : 0) {}
/** /**
* Create a client AsyncSSLSocket * Create a client AsyncSSLSocket
...@@ -889,8 +912,11 @@ int AsyncSSLSocket::getSSLCertSize() const { ...@@ -889,8 +912,11 @@ int AsyncSSLSocket::getSSLCertSize() const {
return certSize; return certSize;
} }
bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { bool AsyncSSLSocket::willBlock(int ret,
int error = *errorOut = SSL_get_error(ssl_, ret); int* sslErrorOut,
unsigned long* errErrorOut) noexcept {
*errErrorOut = 0;
int error = *sslErrorOut = SSL_get_error(ssl_, ret);
if (error == SSL_ERROR_WANT_READ) { if (error == SSL_ERROR_WANT_READ) {
// Register for read event if not already. // Register for read event if not already.
updateEventRegistration(EventHandler::READ, EventHandler::WRITE); updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
...@@ -943,7 +969,7 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { ...@@ -943,7 +969,7 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
} else { } else {
// SSL_ERROR_ZERO_RETURN is processed here so we can get some detail // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
// in the log // in the log
long lastError = ERR_get_error(); unsigned long lastError = *errErrorOut = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", " << "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", " << "sslState=" << sslState_ << ", "
...@@ -955,16 +981,6 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { ...@@ -955,16 +981,6 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
<< "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", " << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
<< "func: " << ERR_func_error_string(lastError) << ", " << "func: " << ERR_func_error_string(lastError) << ", "
<< "reason: " << ERR_reason_error_string(lastError); << "reason: " << ERR_reason_error_string(lastError);
if (error != SSL_ERROR_SYSCALL) {
if (error == SSL_ERROR_SSL) {
*errorOut = lastError;
}
if ((unsigned long)lastError < 0x8000) {
errno = ENOSYS;
} else {
errno = lastError;
}
}
ERR_clear_error(); ERR_clear_error();
return false; return false;
} }
...@@ -1042,12 +1058,14 @@ AsyncSSLSocket::handleAccept() noexcept { ...@@ -1042,12 +1058,14 @@ AsyncSSLSocket::handleAccept() noexcept {
errno = 0; errno = 0;
int ret = SSL_accept(ssl_); int ret = SSL_accept(ssl_);
if (ret <= 0) { if (ret <= 0) {
int error; int sslError;
if (willBlock(ret, &error)) { unsigned long errError;
int errnoCopy = errno;
if (willBlock(ret, &sslError, &errError)) {
return; return;
} else { } else {
sslState_ = STATE_ERROR; sslState_ = STATE_ERROR;
SSLException ex(error, errno); SSLException ex(sslError, errError, ret, errnoCopy);
return failHandshake(__func__, ex); return failHandshake(__func__, ex);
} }
} }
...@@ -1104,12 +1122,14 @@ AsyncSSLSocket::handleConnect() noexcept { ...@@ -1104,12 +1122,14 @@ AsyncSSLSocket::handleConnect() noexcept {
errno = 0; errno = 0;
int ret = SSL_connect(ssl_); int ret = SSL_connect(ssl_);
if (ret <= 0) { if (ret <= 0) {
int error; int sslError;
if (willBlock(ret, &error)) { unsigned long errError;
int errnoCopy = errno;
if (willBlock(ret, &sslError, &errError)) {
return; return;
} else { } else {
sslState_ = STATE_ERROR; sslState_ = STATE_ERROR;
SSLException ex(error, errno); SSLException ex(sslError, errError, ret, errnoCopy);
return failHandshake(__func__, ex); return failHandshake(__func__, ex);
} }
} }
......
...@@ -35,13 +35,10 @@ namespace folly { ...@@ -35,13 +35,10 @@ namespace folly {
class SSLException: public folly::AsyncSocketException { class SSLException: public folly::AsyncSocketException {
public: public:
SSLException(int sslError, int errno_copy); SSLException(int sslError,
unsigned long errError,
int getSSLError() const { return error_; } int sslOperationReturnValue,
int errno_copy);
protected:
int error_;
char msg_[256];
}; };
/** /**
...@@ -782,7 +779,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -782,7 +779,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
void handleConnect() noexcept override; void handleConnect() noexcept override;
void invalidState(HandshakeCB* callback); void invalidState(HandshakeCB* callback);
bool willBlock(int ret, int *errorOut) noexcept; bool willBlock(int ret,
int* sslErrorOut,
unsigned long* errErrorOut) noexcept;
virtual void checkForImmediateRead() noexcept override; virtual void checkForImmediateRead() noexcept override;
// AsyncSocket calls this at the wrong time for SSL // AsyncSocket calls this at the wrong time for SSL
......
...@@ -879,7 +879,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { ...@@ -879,7 +879,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
server.getEventBase().runInEventBaseThread([&handshakeCallback]{ server.getEventBase().runInEventBaseThread([&handshakeCallback]{
handshakeCallback.closeSocket();}); handshakeCallback.closeSocket();});
// give time for the cache lookup to come back and find it closed // give time for the cache lookup to come back and find it closed
usleep(500000); handshakeCallback.waitForHandshake();
EXPECT_EQ(server.getAsyncCallbacks(), 1); EXPECT_EQ(server.getAsyncCallbacks(), 1);
EXPECT_EQ(server.getAsyncLookups(), 1); EXPECT_EQ(server.getAsyncLookups(), 1);
...@@ -1520,6 +1520,71 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) { ...@@ -1520,6 +1520,71 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) {
EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState()); EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
} }
TEST(AsyncSSLSocketTest, ConnResetErrorString) {
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback,
HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
socket->open();
uint8_t buf[3] = {0x16, 0x03, 0x01};
socket->write(buf, sizeof(buf));
socket->closeWithReset();
handshakeCallback.waitForHandshake();
EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
std::string::npos);
EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
}
TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback,
HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
socket->open();
uint8_t buf[3] = {0x16, 0x03, 0x01};
socket->write(buf, sizeof(buf));
socket->close();
handshakeCallback.waitForHandshake();
EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
std::string::npos);
EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
}
TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback,
HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
socket->open();
uint8_t buf[256] = {0x16, 0x03};
memset(buf + 2, 'a', sizeof(buf) - 2);
socket->write(buf, sizeof(buf));
socket->close();
handshakeCallback.waitForHandshake();
EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
std::string::npos);
EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
std::string::npos);
}
} // namespace } // namespace
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
......
...@@ -305,6 +305,8 @@ public: ...@@ -305,6 +305,8 @@ public:
// Functions inherited from AsyncSSLSocketHandshakeCallback // Functions inherited from AsyncSSLSocketHandshakeCallback
void handshakeSuc(AsyncSSLSocket *sock) noexcept override { void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
std::lock_guard<std::mutex> g(mutex_);
cv_.notify_all();
EXPECT_EQ(sock, socket_.get()); EXPECT_EQ(sock, socket_.get());
std::cerr << "HandshakeCallback::connectionAccepted" << std::endl; std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
rcb_->setSocket(socket_); rcb_->setSocket(socket_);
...@@ -313,12 +315,20 @@ public: ...@@ -313,12 +315,20 @@ public:
} }
void handshakeErr(AsyncSSLSocket* /* sock */, void handshakeErr(AsyncSSLSocket* /* sock */,
const AsyncSocketException& ex) noexcept override { const AsyncSocketException& ex) noexcept override {
std::lock_guard<std::mutex> g(mutex_);
cv_.notify_all();
std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl; std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED; state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
if (expect_ == EXPECT_ERROR) { if (expect_ == EXPECT_ERROR) {
// rcb will never be invoked // rcb will never be invoked
rcb_->setState(STATE_SUCCEEDED); rcb_->setState(STATE_SUCCEEDED);
} }
errorString_ = ex.what();
}
void waitForHandshake() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return state != STATE_WAITING; });
} }
~HandshakeCallback() { ~HandshakeCallback() {
...@@ -334,6 +344,9 @@ public: ...@@ -334,6 +344,9 @@ public:
std::shared_ptr<AsyncSSLSocket> socket_; std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_; ReadCallbackBase *rcb_;
ExpectType expect_; ExpectType expect_;
std::mutex mutex_;
std::condition_variable cv_;
std::string errorString_;
}; };
class SSLServerAcceptCallbackBase: class SSLServerAcceptCallbackBase:
......
...@@ -45,6 +45,7 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, ...@@ -45,6 +45,7 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
void close() { void close() {
sock_->close(); sock_->close();
} }
void closeWithReset() { sock_->closeWithReset(); }
int32_t write(uint8_t const* buf, size_t len) { int32_t write(uint8_t const* buf, size_t len) {
sock_->write(this, buf, len); sock_->write(this, buf, len);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment