Commit 4708133f authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 0

Stop abusing errno

Summary:
We abuse errno to propagate exceptions from AsyncSSLSocket.
Stop doing this and propagate exceptions correctly.

This also formats the exception messages better.

Reviewed By: anirudhvr

Differential Revision: D3226808

fb-gh-sync-id: 15a5e67b0332136857e5fb85b1765757e548e040
fbshipit-source-id: 15a5e67b0332136857e5fb85b1765757e548e040
parent 38c0b1ab
......@@ -234,6 +234,7 @@ nobase_follyinclude_HEADERS = \
io/async/HHWheelTimer.h \
io/async/ssl/OpenSSLPtrTypes.h \
io/async/ssl/OpenSSLUtils.h \
io/async/ssl/SSLErrors.h \
io/async/ssl/TLSDefinitions.h \
io/async/Request.h \
io/async/SSLContext.h \
......@@ -417,6 +418,7 @@ libfolly_la_SOURCES = \
io/async/test/SocketPair.cpp \
io/async/test/TimeUtil.cpp \
io/async/ssl/OpenSSLUtils.cpp \
io/async/ssl/SSLErrors.cpp \
json.cpp \
detail/MemoryIdler.cpp \
MacAddress.cpp \
......
This diff is collapsed.
......@@ -27,6 +27,7 @@
#include <folly/io/async/TimeoutManager.h>
#include <folly/io/async/ssl/OpenSSLPtrTypes.h>
#include <folly/io/async/ssl/OpenSSLUtils.h>
#include <folly/io/async/ssl/SSLErrors.h>
#include <folly/io/async/ssl/TLSDefinitions.h>
#include <folly/Bits.h>
......@@ -35,14 +36,6 @@
namespace folly {
class SSLException: public folly::AsyncSocketException {
public:
SSLException(int sslError,
unsigned long errError,
int sslOperationReturnValue,
int errno_copy);
};
/**
* A class for performing asynchronous I/O on an SSL connection.
*
......@@ -143,18 +136,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
AsyncSSLSocket* sslSocket_;
};
/**
* These are passed to the application via errno, packed in an SSL err which
* are outside the valid errno range. The values are chosen to be unique
* against values in ssl.h
*/
enum SSLError {
SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900,
SSL_INVALID_RENEGOTIATION = 901,
SSL_EARLY_WRITE = 902
};
/**
* Create a client AsyncSSLSocket
*/
......@@ -365,6 +346,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
SSL_SESSION *getSSLSession();
/**
* Get a handle to the SSL struct.
*/
const SSL* getSSL() const;
/**
* Set the SSL session to be used during sslConn. AsyncSSLSocket will
* hold a reference to the session until it is destroyed or released by the
......@@ -760,11 +746,14 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// AsyncSocket calls this at the wrong time for SSL
void handleInitialReadWrite() noexcept override {}
int interpretSSLError(int rc, int error);
ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override;
ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags,
uint32_t* countWritten, uint32_t* partialWritten)
override;
WriteResult interpretSSLError(int rc, int error);
ReadResult performRead(void** buf, size_t* buflen, size_t* offset) override;
WriteResult performWrite(
const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) override;
ssize_t performWriteIovec(const iovec* vec, uint32_t count,
WriteFlags flags, uint32_t* countWritten,
......
......@@ -91,14 +91,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
free(this);
}
bool performWrite() override {
WriteResult performWrite() override {
WriteFlags writeFlags = flags_;
if (getNext() != nullptr) {
writeFlags = writeFlags | WriteFlags::CORK;
}
bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags,
&opsWritten_, &partialBytes_);
return bytesWritten_ >= 0;
return socket_->performWrite(
getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
}
bool isComplete() override {
......@@ -694,10 +693,14 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
assert(writeReqTail_ == nullptr);
assert((eventFlags_ & EventHandler::WRITE) == 0);
bytesWritten = performWrite(vec, count, flags,
&countWritten, &partialWritten);
auto writeResult =
performWrite(vec, count, flags, &countWritten, &partialWritten);
bytesWritten = writeResult.writeReturn;
if (bytesWritten < 0) {
auto errnoCopy = errno;
if (writeResult.exception) {
return failWrite(__func__, callback, 0, *writeResult.exception);
}
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
withAddr("writev failed"),
......@@ -1259,11 +1262,10 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
}
}
ssize_t AsyncSocket::performRead(void** buf,
size_t* buflen,
size_t* /* offset */) {
VLOG(5) << "AsyncSocket::performRead() this=" << this
<< ", buf=" << *buf << ", buflen=" << *buflen;
AsyncSocket::ReadResult
AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
<< ", buflen=" << *buflen;
int recvFlags = 0;
if (peek_) {
......@@ -1274,13 +1276,13 @@ ssize_t AsyncSocket::performRead(void** buf,
if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now.
return READ_BLOCKING;
return ReadResult(READ_BLOCKING);
} else {
return READ_ERROR;
return ReadResult(READ_ERROR);
}
} else {
appBytesReceived_ += bytes;
return bytes;
return ReadResult(bytes);
}
}
......@@ -1347,7 +1349,8 @@ void AsyncSocket::handleRead() noexcept {
}
// Perform the read
ssize_t bytesRead = performRead(&buf, &buflen, &offset);
auto readResult = performRead(&buf, &buflen, &offset);
auto bytesRead = readResult.readReturn;
VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
<< bytesRead << " bytes";
if (bytesRead > 0) {
......@@ -1376,6 +1379,9 @@ void AsyncSocket::handleRead() noexcept {
return;
} else if (bytesRead == READ_ERROR) {
readErr_ = READ_ERROR;
if (readResult.exception) {
return failRead(__func__, *readResult.exception);
}
auto errnoCopy = errno;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
......@@ -1439,7 +1445,11 @@ void AsyncSocket::handleWrite() noexcept {
// (See the comment in handleRead() explaining how this can happen.)
EventBase* originalEventBase = eventBase_;
while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
if (!writeReqHead_->performWrite()) {
auto writeResult = writeReqHead_->performWrite();
if (writeResult.writeReturn < 0) {
if (writeResult.exception) {
return failWrite(__func__, *writeResult.exception);
}
auto errnoCopy = errno;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
......@@ -1697,11 +1707,12 @@ void AsyncSocket::timeoutExpired() noexcept {
}
}
ssize_t AsyncSocket::performWrite(const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
AsyncSocket::WriteResult AsyncSocket::performWrite(
const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
// We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
// We correctly handle EPIPE errors, so we never want to receive SIGPIPE
// (since it may terminate the program if the main program doesn't explicitly
......@@ -1736,12 +1747,12 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
// TCP buffer is full; we can't write any more data right now.
*countWritten = 0;
*partialWritten = 0;
return 0;
return WriteResult(0);
}
// error
*countWritten = 0;
*partialWritten = 0;
return -1;
return WriteResult(WRITE_ERROR);
}
appBytesWritten_ += totalWritten;
......@@ -1754,7 +1765,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
// Partial write finished in the middle of this iovec
*countWritten = n;
*partialWritten = bytesWritten;
return totalWritten;
return WriteResult(totalWritten);
}
bytesWritten -= v->iov_len;
......@@ -1763,7 +1774,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
assert(bytesWritten == 0);
*countWritten = n;
*partialWritten = 0;
return totalWritten;
return WriteResult(totalWritten);
}
/**
......
......@@ -16,16 +16,17 @@
#pragma once
#include <sys/types.h>
#include <sys/socket.h>
#include <folly/Optional.h>
#include <folly/SocketAddress.h>
#include <folly/io/ShutdownSocketSet.h>
#include <folly/io/IOBuf.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/ShutdownSocketSet.h>
#include <folly/io/async/AsyncSocketException.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventHandler.h>
#include <folly/io/async/DelayedDestruction.h>
#include <folly/io/async/EventHandler.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <chrono>
#include <memory>
......@@ -517,6 +518,41 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
void setBufferCallback(BufferCallback* cb);
/**
* writeReturn is the total number of bytes written, or WRITE_ERROR on error.
* If no data has been written, 0 is returned.
* exception is a more specific exception that cause a write error.
* Not all writes have exceptions associated with them thus writeReturn
* should be checked to determine whether the operation resulted in an error.
*/
struct WriteResult {
explicit WriteResult(ssize_t ret) : writeReturn(ret) {}
WriteResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
: writeReturn(ret), exception(std::move(e)) {}
ssize_t writeReturn;
std::unique_ptr<const AsyncSocketException> exception;
};
/**
* readReturn is the number of bytes read, or READ_EOF on EOF, or
* READ_ERROR on error, or READ_BLOCKING if the operation will
* block.
* exception is a more specific exception that may have caused a read error.
* Not all read errors have exceptions associated with them thus readReturn
* should be checked to determine whether the operation resulted in an error.
*/
struct ReadResult {
explicit ReadResult(ssize_t ret) : readReturn(ret) {}
ReadResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
: readReturn(ret), exception(std::move(e)) {}
ssize_t readReturn;
std::unique_ptr<const AsyncSocketException> exception;
};
/**
* A WriteRequest object tracks information about a pending write operation.
*/
......@@ -529,7 +565,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
virtual void destroy() = 0;
virtual bool performWrite() = 0;
virtual WriteResult performWrite() = 0;
virtual void consume() = 0;
......@@ -579,6 +615,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
READ_NO_ERROR = -3,
};
enum WriteResultEnum {
WRITE_ERROR = -1,
};
/**
* Protected destructor.
*
......@@ -683,11 +723,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param buf The buffer to read data into.
* @param buflen The length of the buffer.
*
* @return Returns the number of bytes read, or READ_EOF on EOF, or
* READ_ERROR on error, or READ_BLOCKING if the operation will
* block.
* @return Returns a read result. See read result for details.
*/
virtual ssize_t performRead(void** buf, size_t* buflen, size_t* offset);
virtual ReadResult performRead(void** buf, size_t* buflen, size_t* offset);
/**
* Populate an iovec array from an IOBuf and attempt to write it.
......@@ -736,12 +774,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* will contain the number of bytes written in the
* partially written iovec entry.
*
* @return Returns the total number of bytes written, or -1 on error. If no
* data can be written immediately, 0 is returned.
* @return Returns a WriteResult. See WriteResult for more details.
*/
virtual ssize_t performWrite(const iovec* vec, uint32_t count,
WriteFlags flags, uint32_t* countWritten,
uint32_t* partialWritten);
virtual WriteResult performWrite(
const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten);
bool updateEventRegistration();
......
/*
* Copyright 2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/ssl/SSLErrors.h>
#include <folly/Range.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
using namespace folly;
namespace {
std::string decodeOpenSSLError(
int sslError,
unsigned long errError,
int sslOperationReturnValue) {
if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
if (sslOperationReturnValue == 0) {
return "SSL_ERROR_SYSCALL: EOF";
} else {
// In this case errno is set, AsyncSocketException will add it.
return "SSL_ERROR_SYSCALL";
}
} else if (sslError == SSL_ERROR_ZERO_RETURN) {
// This signifies a TLS closure alert.
return "SSL_ERROR_ZERO_RETURN";
} else {
std::array<char, 256> buf;
std::string msg(ERR_error_string(errError, buf.data()));
return msg;
}
}
const StringPiece getSSLErrorString(SSLError error) {
StringPiece ret;
switch (error) {
case SSLError::CLIENT_RENEGOTIATION:
ret = "Client tried to renegotiate with server";
break;
case SSLError::INVALID_RENEGOTIATION:
ret = "Attempt to start renegotiation, but unsupported";
break;
case SSLError::EARLY_WRITE:
ret = "Attempt to write before SSL connection established";
break;
case SSLError::OPENSSL_ERR:
// decodeOpenSSLError should be used for this type.
ret = "OPENSSL error";
break;
}
return ret;
}
}
namespace folly {
SSLException::SSLException(
int sslError,
unsigned long errError,
int sslOperationReturnValue,
int errno_copy)
: AsyncSocketException(
AsyncSocketException::SSL_ERROR,
decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
sslError == SSL_ERROR_SYSCALL ? errno_copy : 0),
sslError(SSLError::OPENSSL_ERR),
opensslSSLError(sslError),
opensslErr(errError) {}
SSLException::SSLException(SSLError error)
: AsyncSocketException(
AsyncSocketException::SSL_ERROR,
getSSLErrorString(error).str(),
0),
sslError(error) {}
}
/*
* Copyright 2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Optional.h>
#include <folly/io/async/AsyncSocketException.h>
namespace folly {
enum class SSLError {
CLIENT_RENEGOTIATION, // A client tried to renegotiate with this server
INVALID_RENEGOTIATION, // We attempted to start a renegotiation.
EARLY_WRITE, // Wrote before SSL connection established.
// An openssl error type. The openssl specific methods should be used
// to find the real error type.
// This exists for compatibility until all error types can be move to proper
// errors.
OPENSSL_ERR,
};
class SSLException : public folly::AsyncSocketException {
public:
SSLException(
int sslError,
unsigned long errError,
int sslOperationReturnValue,
int errno_copy);
explicit SSLException(SSLError error);
SSLError getType() const {
return sslError;
}
// These methods exist for compatibility until there are proper exceptions
// for all ssl error types.
int getOpensslSSLError() const {
return opensslSSLError;
}
unsigned long getOpensslErr() const {
return opensslErr;
}
private:
SSLError sslError;
int opensslSSLError;
unsigned long opensslErr;
};
}
......@@ -201,13 +201,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
cerr << "ConnectWriteReadClose test completed" << endl;
}
/**
* Test reading after server close.
*/
TEST(AsyncSSLSocketTest, ReadAfterClose) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadEOFCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
auto socket =
std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
socket->open();
// This should trigger an EOF on the client.
auto evb = handshakeCallback.getSocket()->getEventBase();
evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
std::array<uint8_t, 128> readbuf;
auto bytesRead = socket->read(readbuf.data(), readbuf.size());
EXPECT_EQ(0, bytesRead);
}
/**
* Test bad renegotiation
*/
TEST(AsyncSSLSocketTest, Renegotiate) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
std::array<int, 2> fds;
getfds(fds.data());
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
RenegotiatingServer server(std::move(serverSock));
while (!client.handshakeSuccess_ && !client.handshakeError_) {
eventBase.loopOnce();
}
ASSERT_TRUE(client.handshakeSuccess_);
auto sslSock = std::move(client).moveSocket();
sslSock->detachEventBase();
// This is nasty, however we don't want to add support for
// renegotiation in AsyncSSLSocket.
SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
std::thread t([&]() { eventBase.loopForever(); });
// Trigger the renegotiation.
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
try {
socket->write(buf.data(), buf.size());
} catch (AsyncSocketException& e) {
LOG(INFO) << "client got error " << e.what();
}
eventBase.terminateLoopSoon();
t.join();
eventBase.loop();
ASSERT_TRUE(server.renegotiationError_);
}
/**
* Negative test for handshakeError().
*/
TEST(AsyncSSLSocketTest, HandshakeError) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
HandshakeErrorCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
......
......@@ -18,13 +18,15 @@
#include <signal.h>
#include <pthread.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/ExceptionWrapper.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/ssl/SSLErrors.h>
#include <gtest/gtest.h>
#include <iostream>
......@@ -58,7 +60,7 @@ public:
, exception(AsyncSocketException::UNKNOWN, "none") {}
~WriteCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED);
EXPECT_EQ(STATE_SUCCEEDED, state);
}
void setSocket(
......@@ -92,10 +94,9 @@ public:
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
public:
explicit ReadCallbackBase(WriteCallbackBase *wcb)
: wcb_(wcb)
, state(STATE_WAITING) {}
public:
explicit ReadCallbackBase(WriteCallbackBase* wcb)
: wcb_(wcb), state(STATE_WAITING) {}
~ReadCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED);
......@@ -222,6 +223,27 @@ public:
}
};
class ReadEOFCallback : public ReadCallbackBase {
public:
explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
// Return nullptr buffer to trigger readError()
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = nullptr;
*lenReturn = 0;
}
void readDataAvailable(size_t /* len */) noexcept override {
// This should never to called.
FAIL();
}
void readEOF() noexcept override {
ReadCallbackBase::readEOF();
setState(STATE_SUCCEEDED);
}
};
class WriteErrorCallback : public ReadCallback {
public:
explicit WriteErrorCallback(WriteCallbackBase *wcb)
......@@ -340,6 +362,10 @@ public:
state = STATE_SUCCEEDED;
}
std::shared_ptr<AsyncSSLSocket> getSocket() {
return socket_;
}
StateEnum state;
std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_;
......@@ -879,6 +905,48 @@ class NpnServer :
AsyncSSLSocket::UniquePtr socket_;
};
class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
public AsyncTransportWrapper::ReadCallback {
public:
explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
: socket_(std::move(socket)) {
socket_->sslAccept(this);
}
~RenegotiatingServer() {
socket_->setReadCB(nullptr);
}
void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
LOG(INFO) << "Renegotiating server handshake success";
socket_->setReadCB(this);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*lenReturn = sizeof(buf);
*bufReturn = buf;
}
void readDataAvailable(size_t /* len */) noexcept override {}
void readEOF() noexcept override {}
void readErr(const AsyncSocketException& ex) noexcept override {
LOG(INFO) << "server got read error " << ex.what();
auto exPtr = dynamic_cast<const SSLException*>(&ex);
ASSERT_NE(nullptr, exPtr);
std::string exStr(ex.what());
SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
renegotiationError_ = true;
}
AsyncSSLSocket::UniquePtr socket_;
unsigned char buf[128];
bool renegotiationError_{false};
};
#ifndef OPENSSL_NO_TLSEXT
class SNIClient :
private AsyncSSLSocket::HandshakeCB,
......@@ -1139,6 +1207,10 @@ class SSLHandshakeBase :
verifyResult_(verifyResult) {
}
AsyncSSLSocket::UniquePtr moveSocket() && {
return std::move(socket_);
}
bool handshakeVerify_;
bool handshakeSuccess_;
bool handshakeError_;
......@@ -1160,12 +1232,15 @@ class SSLHandshakeBase :
}
void handshakeSuc(AsyncSSLSocket*) noexcept override {
LOG(INFO) << "Handshake success";
handshakeSuccess_ = true;
handshakeTime = socket_->getHandshakeTime();
}
void handshakeErr(AsyncSSLSocket*,
const AsyncSocketException& /* ex */) noexcept override {
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true;
handshakeTime = socket_->getHandshakeTime();
}
......
......@@ -58,8 +58,12 @@ class MockAsyncSSLSocket : public AsyncSSLSocket{
MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
// public wrapper for protected interface
ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
uint32_t* countWritten, uint32_t* partialWritten) {
WriteResult testPerformWrite(
const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
return performWrite(vec, count, flags, countWritten, partialWritten);
}
......
......@@ -35,6 +35,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
new folly::AsyncSocket(&eventBase_)),
address_(address) {}
explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
: sock_(std::move(socket)) {
sock_->attachEventBase(&eventBase_);
}
void open() {
sock_->connect(this, address_);
eventBase_.loop();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment