Commit 564a32e6 authored by Dan Melnic's avatar Dan Melnic Committed by Facebook Github Bot

Add support for setting a zerocopy enable function

Summary: Add support for setting a zerocopy enable function

Reviewed By: kevin-vigor

Differential Revision: D18579527

fbshipit-source-id: dc3ab7bb13b26528bc964c7eb616517e444487a4
parent c7dd6097
...@@ -332,11 +332,18 @@ void AsyncServerSocket::bindSocket( ...@@ -332,11 +332,18 @@ void AsyncServerSocket::bindSocket(
bool AsyncServerSocket::setZeroCopy(bool enable) { bool AsyncServerSocket::setZeroCopy(bool enable) {
if (msgErrQueueSupported) { if (msgErrQueueSupported) {
// save the enable flag here
zeroCopyVal_ = enable;
int val = enable ? 1 : 0; int val = enable ? 1 : 0;
int ret = netops::setsockopt( size_t num = 0;
getNetworkSocket(), SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val)); for (auto& s : sockets_) {
int ret = netops::setsockopt(
s.socket_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
num += (0 == ret) ? 1 : 0;
}
return (0 == ret); return num != 0;
} }
return false; return false;
...@@ -822,6 +829,16 @@ void AsyncServerSocket::setupSocket(NetworkSocket fd, int family) { ...@@ -822,6 +829,16 @@ void AsyncServerSocket::setupSocket(NetworkSocket fd, int family) {
} }
#endif #endif
if (zeroCopyVal_) {
int val = 1;
int ret =
netops::setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
if (ret) {
LOG(WARNING) << "failed to set SO_ZEROCOPY on async server socket: "
<< folly::errnoStr(errno);
}
}
if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) { if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
shutdownSocketSet->add(fd); shutdownSocketSet->add(fd);
} }
......
...@@ -895,6 +895,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -895,6 +895,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_; std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_;
ConnectionEventCallback* connectionEventCallback_{nullptr}; ConnectionEventCallback* connectionEventCallback_{nullptr};
bool tosReflect_{false}; bool tosReflect_{false};
bool zeroCopyVal_{false};
}; };
} // namespace folly } // namespace folly
...@@ -305,8 +305,10 @@ AsyncSocket::AsyncSocket(EventBase* evb) ...@@ -305,8 +305,10 @@ AsyncSocket::AsyncSocket(EventBase* evb)
AsyncSocket::AsyncSocket( AsyncSocket::AsyncSocket(
EventBase* evb, EventBase* evb,
const folly::SocketAddress& address, const folly::SocketAddress& address,
uint32_t connectTimeout) uint32_t connectTimeout,
bool useZeroCopy)
: AsyncSocket(evb) { : AsyncSocket(evb) {
setZeroCopy(useZeroCopy);
connect(nullptr, address, connectTimeout); connect(nullptr, address, connectTimeout);
} }
...@@ -314,8 +316,10 @@ AsyncSocket::AsyncSocket( ...@@ -314,8 +316,10 @@ AsyncSocket::AsyncSocket(
EventBase* evb, EventBase* evb,
const std::string& ip, const std::string& ip,
uint16_t port, uint16_t port,
uint32_t connectTimeout) uint32_t connectTimeout,
bool useZeroCopy)
: AsyncSocket(evb) { : AsyncSocket(evb) {
setZeroCopy(useZeroCopy);
connect(nullptr, ip, port, connectTimeout); connect(nullptr, ip, port, connectTimeout);
} }
...@@ -911,6 +915,10 @@ bool AsyncSocket::setZeroCopy(bool enable) { ...@@ -911,6 +915,10 @@ bool AsyncSocket::setZeroCopy(bool enable) {
return false; return false;
} }
void AsyncSocket::setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func) {
zeroCopyEnableFunc_ = func;
}
void AsyncSocket::setZeroCopyReenableThreshold(size_t threshold) { void AsyncSocket::setZeroCopyReenableThreshold(size_t threshold) {
zeroCopyReenableThreshold_ = threshold; zeroCopyReenableThreshold_ = threshold;
} }
...@@ -1042,6 +1050,12 @@ void AsyncSocket::writeChain( ...@@ -1042,6 +1050,12 @@ void AsyncSocket::writeChain(
WriteFlags flags) { WriteFlags flags) {
adjustZeroCopyFlags(flags); adjustZeroCopyFlags(flags);
// adjustZeroCopyFlags can set zeroCopyEnabled_ to true
if (zeroCopyEnabled_ && !isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY) &&
zeroCopyEnableFunc_ && zeroCopyEnableFunc_(buf)) {
flags |= WriteFlags::WRITE_MSG_ZEROCOPY;
}
constexpr size_t kSmallSizeMax = 64; constexpr size_t kSmallSizeMax = 64;
size_t count = buf->countChainElements(); size_t count = buf->countChainElements();
if (count <= kSmallSizeMax) { if (count <= kSmallSizeMax) {
......
...@@ -81,7 +81,7 @@ namespace folly { ...@@ -81,7 +81,7 @@ namespace folly {
#endif #endif
class AsyncSocket : virtual public AsyncTransportWrapper { class AsyncSocket : virtual public AsyncTransportWrapper {
public: public:
typedef std::unique_ptr<AsyncSocket, Destructor> UniquePtr; using UniquePtr = std::unique_ptr<AsyncSocket, Destructor>;
class ConnectCallback { class ConnectCallback {
public: public:
...@@ -245,11 +245,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -245,11 +245,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param address The address to connect to. * @param address The address to connect to.
* @param connectTimeout Optional timeout in milliseconds for the connection * @param connectTimeout Optional timeout in milliseconds for the connection
* attempt. * attempt.
* @param useZeroCopy Optional zerocopy socket mode
*/ */
AsyncSocket( AsyncSocket(
EventBase* evb, EventBase* evb,
const folly::SocketAddress& address, const folly::SocketAddress& address,
uint32_t connectTimeout = 0); uint32_t connectTimeout = 0,
bool useZeroCopy = false);
/** /**
* Create a new AsyncSocket and begin the connection process. * Create a new AsyncSocket and begin the connection process.
...@@ -259,12 +261,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -259,12 +261,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param port Destination port in host byte order. * @param port Destination port in host byte order.
* @param connectTimeout Optional timeout in milliseconds for the connection * @param connectTimeout Optional timeout in milliseconds for the connection
* attempt. * attempt.
* @param useZeroCopy Optional zerocopy socket mode
*/ */
AsyncSocket( AsyncSocket(
EventBase* evb, EventBase* evb,
const std::string& ip, const std::string& ip,
uint16_t port, uint16_t port,
uint32_t connectTimeout = 0); uint32_t connectTimeout = 0,
bool useZeroCopy = false);
/** /**
* Create a AsyncSocket from an already connected socket file descriptor. * Create a AsyncSocket from an already connected socket file descriptor.
...@@ -305,9 +309,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -305,9 +309,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
static std::shared_ptr<AsyncSocket> newSocket( static std::shared_ptr<AsyncSocket> newSocket(
EventBase* evb, EventBase* evb,
const folly::SocketAddress& address, const folly::SocketAddress& address,
uint32_t connectTimeout = 0) { uint32_t connectTimeout = 0,
bool useZeroCopy = false) {
return std::shared_ptr<AsyncSocket>( return std::shared_ptr<AsyncSocket>(
new AsyncSocket(evb, address, connectTimeout), Destructor()); new AsyncSocket(evb, address, connectTimeout, useZeroCopy),
Destructor());
} }
/** /**
...@@ -317,9 +323,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -317,9 +323,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
EventBase* evb, EventBase* evb,
const std::string& ip, const std::string& ip,
uint16_t port, uint16_t port,
uint32_t connectTimeout = 0) { uint32_t connectTimeout = 0,
bool useZeroCopy = false) {
return std::shared_ptr<AsyncSocket>( return std::shared_ptr<AsyncSocket>(
new AsyncSocket(evb, ip, port, connectTimeout), Destructor()); new AsyncSocket(evb, ip, port, connectTimeout, useZeroCopy),
Destructor());
} }
/** /**
...@@ -392,7 +400,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -392,7 +400,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
}; };
// Maps from a socket option key to its value // Maps from a socket option key to its value
typedef std::map<OptionKey, int> OptionMap; using OptionMap = std::map<OptionKey, int>;
static const OptionMap emptyOptionMap; static const OptionMap emptyOptionMap;
static const folly::SocketAddress& anyAddress(); static const folly::SocketAddress& anyAddress();
...@@ -519,8 +527,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -519,8 +527,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
void setReadCB(ReadCallback* callback) override; void setReadCB(ReadCallback* callback) override;
ReadCallback* getReadCallback() const override; ReadCallback* getReadCallback() const override;
bool setZeroCopy(bool enable); bool setZeroCopy(bool enable) override;
bool getZeroCopy() const { bool getZeroCopy() const override {
return zeroCopyEnabled_; return zeroCopyEnabled_;
} }
...@@ -532,6 +540,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -532,6 +540,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
return zeroCopyReenableThreshold_; return zeroCopyReenableThreshold_;
} }
void setZeroCopyEnableFunc(AsyncWriter::ZeroCopyEnableFunc func) override;
void setZeroCopyReenableThreshold(size_t threshold); void setZeroCopyReenableThreshold(size_t threshold);
void write( void write(
...@@ -1258,6 +1268,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1258,6 +1268,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
bool containsZeroCopyBuf(folly::IOBuf* ptr); bool containsZeroCopyBuf(folly::IOBuf* ptr);
void releaseZeroCopyBuf(uint32_t id); void releaseZeroCopyBuf(uint32_t id);
AsyncWriter::ZeroCopyEnableFunc zeroCopyEnableFunc_;
// a folly::IOBuf can be used in multiple partial requests // a folly::IOBuf can be used in multiple partial requests
// there is a that maps a buffer id to a raw folly::IOBuf ptr // there is a that maps a buffer id to a raw folly::IOBuf ptr
// and another one that adds a ref count for a folly::IOBuf that is either // and another one that adds a ref count for a folly::IOBuf that is either
......
...@@ -719,6 +719,21 @@ class AsyncWriter { ...@@ -719,6 +719,21 @@ class AsyncWriter {
std::unique_ptr<IOBuf>&& buf, std::unique_ptr<IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) = 0; WriteFlags flags = WriteFlags::NONE) = 0;
/** zero copy related
* */
virtual bool setZeroCopy(bool /*enable*/) {
return false;
}
virtual bool getZeroCopy() const {
return false;
}
using ZeroCopyEnableFunc =
std::function<bool(const std::unique_ptr<folly::IOBuf>& buf)>;
virtual void setZeroCopyEnableFunc(ZeroCopyEnableFunc /*func*/) {}
protected: protected:
virtual ~AsyncWriter() = default; virtual ~AsyncWriter() = default;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment