Commit 220215e9 authored by Brandon Schlinker's avatar Brandon Schlinker Committed by Facebook GitHub Bot

netops::Dispatcher

Summary:
Wrapper around `folly::netops` methods that makes it easier to mock these methods in unit tests.

When we want to mock out calls to `folly::netops` we currently:
- Use methods like `getSockOptVirtual` and `setSockOptVirtual`
- Mock part of the socket, like in the tests in `AsyncSSLSocketWriteTest`

I've the latter makes the tests particularly error prone, since we're mocking the object that's also under test.

This change introduces `netops::Dispatcher`, which is a class containing all of the functions in `folly::netops`:
- By default `AsyncSocket` uses a default, static instance of `Dispatcher` that forwards calls to the original `netops::` calls (e.g., calling `netops::Dispatcher::sendmsg` results in a call to `netops::sendmsg`.
- When a test wants to mock a a `folly::netops` call, it can call `setOverrideNetOpsDispatcher` to insert a mock `netops::Dispatcher`. I use it in this manner in D24094832

Differential Revision: D24661160

fbshipit-source-id: e9cb4ed28ffe409c74998a1c9501c0706fc853e0
parent 8ea0a28b
......@@ -476,7 +476,7 @@ void AsyncSocket::setShutdownSocketSet(
}
void AsyncSocket::setCloseOnExec() {
int rv = netops::set_socket_close_on_exec(fd_);
int rv = netops_->set_socket_close_on_exec(fd_);
if (rv != 0) {
auto errnoCopy = errno;
throw AsyncSocketException(
......@@ -520,7 +520,7 @@ void AsyncSocket::connect(
// constant (PF_xxx) rather than an address family (AF_xxx), but the
// distinction is mainly just historical. In pretty much all
// implementations the PF_foo and AF_foo constants are identical.
fd_ = netops::socket(address.getFamily(), SOCK_STREAM, 0);
fd_ = netops_->socket(address.getFamily(), SOCK_STREAM, 0);
if (fd_ == NetworkSocket()) {
auto errnoCopy = errno;
throw AsyncSocketException(
......@@ -537,7 +537,7 @@ void AsyncSocket::connect(
setCloseOnExec();
// Put the socket in non-blocking mode
int rv = netops::set_socket_non_blocking(fd_);
int rv = netops_->set_socket_non_blocking(fd_);
if (rv == -1) {
auto errnoCopy = errno;
throw AsyncSocketException(
......@@ -577,7 +577,7 @@ void AsyncSocket::connect(
// bind the socket
if (bindAddr != anyAddress()) {
int one = 1;
if (netops::setsockopt(
if (netops_->setsockopt(
fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
auto errnoCopy = errno;
doClose();
......@@ -589,7 +589,7 @@ void AsyncSocket::connect(
bindAddr.getAddress(&addrStorage);
if (netops::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
if (netops_->bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
auto errnoCopy = errno;
doClose();
throw AsyncSocketException(
......@@ -649,7 +649,7 @@ void AsyncSocket::connect(
}
int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
int rv = netops::connect(fd_, saddr, len);
int rv = netops_->connect(fd_, saddr, len);
if (rv < 0) {
auto errnoCopy = errno;
if (errnoCopy == EINPROGRESS) {
......@@ -931,7 +931,7 @@ bool AsyncSocket::setZeroCopy(bool enable) {
int val = enable ? 1 : 0;
int ret =
netops::setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
netops_->setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
// if enable == false, set zeroCopyEnabled_ = false regardless
// if SO_ZEROCOPY is set or not
......@@ -946,7 +946,7 @@ bool AsyncSocket::setZeroCopy(bool enable) {
if (ret) {
val = 0;
socklen_t optlen = sizeof(val);
ret = netops::getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
ret = netops_->getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
if (!ret) {
enable = val != 0;
......@@ -1492,7 +1492,7 @@ void AsyncSocket::shutdownWriteNow() {
}
// Shutdown writes on the file descriptor
netops::shutdown(fd_, SHUT_WR);
netops_->shutdown(fd_, SHUT_WR);
// Immediately fail all write requests
failAllWrites(getSocketShutdownForWritesEx());
......@@ -1546,7 +1546,7 @@ bool AsyncSocket::readable() const {
fds[0].fd = fd_;
fds[0].events = POLLIN;
fds[0].revents = 0;
int rc = netops::poll(fds, 1, 0);
int rc = netops_->poll(fds, 1, 0);
return rc == 1;
}
......@@ -1558,7 +1558,7 @@ bool AsyncSocket::writable() const {
fds[0].fd = fd_;
fds[0].events = POLLOUT;
fds[0].revents = 0;
int rc = netops::poll(fds, 1, 0);
int rc = netops_->poll(fds, 1, 0);
return rc == 1;
}
......@@ -1577,7 +1577,7 @@ bool AsyncSocket::hangup() const {
fds[0].fd = fd_;
fds[0].events = POLLRDHUP | POLLHUP;
fds[0].revents = 0;
netops::poll(fds, 1, 0);
netops_->poll(fds, 1, 0);
return (fds[0].revents & (POLLRDHUP | POLLHUP)) != 0;
#else
return false;
......@@ -1713,7 +1713,7 @@ int AsyncSocket::setNoDelay(bool noDelay) {
}
int value = noDelay ? 1 : 0;
if (netops::setsockopt(
if (netops_->setsockopt(
fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket " << this
......@@ -1736,7 +1736,7 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) {
return EINVAL;
}
if (netops::setsockopt(
if (netops_->setsockopt(
fd_,
IPPROTO_TCP,
TCP_CONGESTION,
......@@ -1762,7 +1762,7 @@ int AsyncSocket::setQuickAck(bool quickack) {
#ifdef TCP_QUICKACK // Linux-only
int value = quickack ? 1 : 0;
if (netops::setsockopt(
if (netops_->setsockopt(
fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket" << this
......@@ -1784,7 +1784,7 @@ int AsyncSocket::setSendBufSize(size_t bufsize) {
return EINVAL;
}
if (netops::setsockopt(
if (netops_->setsockopt(
fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket" << this
......@@ -1803,7 +1803,7 @@ int AsyncSocket::setRecvBufSize(size_t bufsize) {
return EINVAL;
}
if (netops::setsockopt(
if (netops_->setsockopt(
fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket" << this
......@@ -1870,7 +1870,7 @@ int AsyncSocket::setTCPProfile(int profd) {
return EINVAL;
}
if (netops::setsockopt(
if (netops_->setsockopt(
fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to set socket namespace option on AsyncSocket" << this
......@@ -1957,7 +1957,7 @@ AsyncSocket::ReadResult AsyncSocket::performRead(
// No callback to read ancillary data was set
if (readAncillaryDataCallback_ == nullptr) {
bytes = netops::recv(fd_, *buf, *buflen, MSG_DONTWAIT);
bytes = netops_->recv(fd_, *buf, *buflen, MSG_DONTWAIT);
} else {
struct msghdr msg;
struct iovec iov;
......@@ -2037,7 +2037,7 @@ size_t AsyncSocket::handleErrMessages() noexcept {
size_t num = 0;
// the socket may be closed by errMessage callback, so check on each iteration
while (fd_ != NetworkSocket()) {
ret = netops::recvmsg(fd_, &msg, MSG_ERRQUEUE);
ret = netops_->recvmsg(fd_, &msg, MSG_ERRQUEUE);
VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
if (ret < 0) {
......@@ -2325,7 +2325,7 @@ void AsyncSocket::handleWrite() noexcept {
}
} else {
// Reads are still enabled, so we are only doing a half-shutdown
netops::shutdown(fd_, SHUT_WR);
netops_->shutdown(fd_, SHUT_WR);
}
}
}
......@@ -2450,7 +2450,7 @@ void AsyncSocket::handleConnect() noexcept {
// Call getsockopt() to check if the connect succeeded
int error;
socklen_t len = sizeof(error);
int rv = netops::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
int rv = netops_->getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
if (rv != 0) {
auto errnoCopy = errno;
AsyncSocketException ex(
......@@ -2480,7 +2480,7 @@ void AsyncSocket::handleConnect() noexcept {
// are still connecting we just abort the connect rather than waiting for
// it to complete.
assert((shutdownFlags_ & SHUT_READ) == 0);
netops::shutdown(fd_, SHUT_WR);
netops_->shutdown(fd_, SHUT_WR);
shutdownFlags_ |= SHUT_WRITE;
}
......@@ -2646,7 +2646,7 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
AsyncSocketException::UNKNOWN, "No more free local ports"));
}
} else {
totalWritten = netops::sendmsg(fd, msg, msg_flags);
totalWritten = netops_->sendmsg(fd, msg, msg_flags);
}
return WriteResult(totalWritten);
}
......@@ -3050,7 +3050,7 @@ void AsyncSocket::doClose() {
if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
shutdownSocketSet->close(fd_);
} else {
netops::close(fd_);
netops_->close(fd_);
}
fd_ = NetworkSocket();
......
......@@ -33,6 +33,7 @@
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/DelayedDestruction.h>
#include <folly/io/async/EventHandler.h>
#include <folly/net/NetOpsDispatcher.h>
#include <folly/portability/Sockets.h>
#include <folly/small_vector.h>
......@@ -551,6 +552,28 @@ class AsyncSocket : public AsyncTransport {
*/
virtual SendMsgParamsCallback* getSendMsgParamsCB() const;
/**
* Override netops::Dispatcher to be used for netops:: calls.
*
* Pass empty shared_ptr to reset to default.
* Override can be used by unit tests to intercept and mock netops:: calls.
*/
virtual void setOverrideNetOpsDispatcher(
std::shared_ptr<netops::Dispatcher> dispatcher) {
netops_.setOverride(std::move(dispatcher));
}
/**
* Returns override netops::Dispatcher being used for netops:: calls.
*
* Returns empty shared_ptr if no override set.
* Override can be used by unit tests to intercept and mock netops:: calls.
*/
virtual std::shared_ptr<netops::Dispatcher> getOverrideNetOpsDispatcher()
const {
return netops_.getOverride();
}
// Read and write methods
void setReadCB(ReadCallback* callback) override;
ReadCallback* getReadCallback() const override;
......@@ -777,7 +800,7 @@ class AsyncSocket : public AsyncTransport {
*/
template <typename T>
int getSockOpt(int level, int optname, T* optval, socklen_t* optlen) {
return netops::getsockopt(fd_, level, optname, (void*)optval, optlen);
return netops_->getsockopt(fd_, level, optname, (void*)optval, optlen);
}
/**
......@@ -790,7 +813,7 @@ class AsyncSocket : public AsyncTransport {
*/
template <typename T>
int setSockOpt(int level, int optname, const T* optval) {
return netops::setsockopt(fd_, level, optname, optval, sizeof(T));
return netops_->setsockopt(fd_, level, optname, optval, sizeof(T));
}
/**
......@@ -806,7 +829,7 @@ class AsyncSocket : public AsyncTransport {
*/
virtual int getSockOptVirtual(
int level, int optname, void* optval, socklen_t* optlen) {
return netops::getsockopt(fd_, level, optname, optval, optlen);
return netops_->getsockopt(fd_, level, optname, optval, optlen);
}
/**
......@@ -822,7 +845,7 @@ class AsyncSocket : public AsyncTransport {
*/
virtual int setSockOptVirtual(
int level, int optname, void const* optval, socklen_t optlen) {
return netops::setsockopt(fd_, level, optname, optval, optlen);
return netops_->setsockopt(fd_, level, optname, optval, optlen);
}
/**
......@@ -1467,6 +1490,8 @@ class AsyncSocket : public AsyncTransport {
nullptr};
bool closeOnFailedWrite_{true};
netops::DispatcherContainer netops_;
};
} // namespace folly
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/net/NetOps.h>
#include <folly/net/NetOpsDispatcher.h>
namespace folly {
namespace netops {
Dispatcher* Dispatcher::getDefaultInstance() {
static Dispatcher wrapper = {};
return &wrapper;
}
NetworkSocket Dispatcher::accept(
NetworkSocket s, sockaddr* addr, socklen_t* addrlen) {
return folly::netops::accept(s, addr, addrlen);
}
int Dispatcher::bind(NetworkSocket s, const sockaddr* name, socklen_t namelen) {
return folly::netops::bind(s, name, namelen);
}
int Dispatcher::close(NetworkSocket s) {
return folly::netops::close(s);
}
int Dispatcher::connect(
NetworkSocket s, const sockaddr* name, socklen_t namelen) {
return folly::netops::connect(s, name, namelen);
}
int Dispatcher::getpeername(
NetworkSocket s, sockaddr* name, socklen_t* namelen) {
return folly::netops::getpeername(s, name, namelen);
}
int Dispatcher::getsockname(
NetworkSocket s, sockaddr* name, socklen_t* namelen) {
return folly::netops::getsockname(s, name, namelen);
}
int Dispatcher::getsockopt(
NetworkSocket s, int level, int optname, void* optval, socklen_t* optlen) {
return folly::netops::getsockopt(s, level, optname, optval, optlen);
}
int Dispatcher::inet_aton(const char* cp, in_addr* inp) {
return folly::netops::inet_aton(cp, inp);
}
int Dispatcher::listen(NetworkSocket s, int backlog) {
return folly::netops::listen(s, backlog);
}
int Dispatcher::poll(PollDescriptor fds[], nfds_t nfds, int timeout) {
return folly::netops::poll(fds, nfds, timeout);
}
ssize_t Dispatcher::recv(NetworkSocket s, void* buf, size_t len, int flags) {
return folly::netops::recv(s, buf, len, flags);
}
ssize_t Dispatcher::recvfrom(
NetworkSocket s,
void* buf,
size_t len,
int flags,
sockaddr* from,
socklen_t* fromlen) {
return folly::netops::recvfrom(s, buf, len, flags, from, fromlen);
}
ssize_t Dispatcher::recvmsg(NetworkSocket s, msghdr* message, int flags) {
return folly::netops::recvmsg(s, message, flags);
}
int Dispatcher::recvmmsg(
NetworkSocket s,
mmsghdr* msgvec,
unsigned int vlen,
unsigned int flags,
timespec* timeout) {
return folly::netops::recvmmsg(s, msgvec, vlen, flags, timeout);
}
ssize_t Dispatcher::send(
NetworkSocket s, const void* buf, size_t len, int flags) {
return folly::netops::send(s, buf, len, flags);
}
ssize_t Dispatcher::sendmsg(
NetworkSocket socket, const msghdr* message, int flags) {
return folly::netops::sendmsg(socket, message, flags);
}
int Dispatcher::sendmmsg(
NetworkSocket socket, mmsghdr* msgvec, unsigned int vlen, int flags) {
return folly::netops::sendmmsg(socket, msgvec, vlen, flags);
}
ssize_t Dispatcher::sendto(
NetworkSocket s,
const void* buf,
size_t len,
int flags,
const sockaddr* to,
socklen_t tolen) {
return folly::netops::sendto(s, buf, len, flags, to, tolen);
}
int Dispatcher::setsockopt(
NetworkSocket s,
int level,
int optname,
const void* optval,
socklen_t optlen) {
return folly::netops::setsockopt(s, level, optname, optval, optlen);
}
int Dispatcher::shutdown(NetworkSocket s, int how) {
return folly::netops::shutdown(s, how);
}
NetworkSocket Dispatcher::socket(int af, int type, int protocol) {
return folly::netops::socket(af, type, protocol);
}
int Dispatcher::socketpair(
int domain, int type, int protocol, NetworkSocket sv[2]) {
return folly::netops::socketpair(domain, type, protocol, sv);
}
int Dispatcher::set_socket_non_blocking(NetworkSocket s) {
return folly::netops::set_socket_non_blocking(s);
}
int Dispatcher::set_socket_close_on_exec(NetworkSocket s) {
return folly::netops::set_socket_close_on_exec(s);
}
} // namespace netops
} // namespace folly
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/net/NetOps.h>
#include <memory>
namespace folly {
namespace netops {
/**
* Dispatcher for netops:: calls.
*
* Using a Dispatcher instead of calling netops:: directly enables tests to
* mock netops:: calls.
*/
class Dispatcher {
public:
static Dispatcher* getDefaultInstance();
virtual NetworkSocket accept(
NetworkSocket s, sockaddr* addr, socklen_t* addrlen);
virtual int bind(NetworkSocket s, const sockaddr* name, socklen_t namelen);
virtual int close(NetworkSocket s);
virtual int connect(NetworkSocket s, const sockaddr* name, socklen_t namelen);
virtual int getpeername(NetworkSocket s, sockaddr* name, socklen_t* namelen);
virtual int getsockname(NetworkSocket s, sockaddr* name, socklen_t* namelen);
virtual int getsockopt(
NetworkSocket s, int level, int optname, void* optval, socklen_t* optlen);
virtual int inet_aton(const char* cp, in_addr* inp);
virtual int listen(NetworkSocket s, int backlog);
virtual int poll(PollDescriptor fds[], nfds_t nfds, int timeout);
virtual ssize_t recv(NetworkSocket s, void* buf, size_t len, int flags);
virtual ssize_t recvfrom(
NetworkSocket s,
void* buf,
size_t len,
int flags,
sockaddr* from,
socklen_t* fromlen);
virtual ssize_t recvmsg(NetworkSocket s, msghdr* message, int flags);
virtual int recvmmsg(
NetworkSocket s,
mmsghdr* msgvec,
unsigned int vlen,
unsigned int flags,
timespec* timeout);
virtual ssize_t send(NetworkSocket s, const void* buf, size_t len, int flags);
virtual ssize_t sendto(
NetworkSocket s,
const void* buf,
size_t len,
int flags,
const sockaddr* to,
socklen_t tolen);
virtual ssize_t sendmsg(
NetworkSocket socket, const msghdr* message, int flags);
virtual int sendmmsg(
NetworkSocket socket, mmsghdr* msgvec, unsigned int vlen, int flags);
virtual int setsockopt(
NetworkSocket s,
int level,
int optname,
const void* optval,
socklen_t optlen);
virtual int shutdown(NetworkSocket s, int how);
virtual NetworkSocket socket(int af, int type, int protocol);
virtual int socketpair(
int domain, int type, int protocol, NetworkSocket sv[2]);
virtual int set_socket_non_blocking(NetworkSocket s);
virtual int set_socket_close_on_exec(NetworkSocket s);
protected:
Dispatcher() = default;
virtual ~Dispatcher() = default;
};
/**
* Container for netops::Dispatcher.
*
* Enables override Dispatcher to be installed for tests and special cases.
* If no override installed, returns default Dispatcher instance.
*/
class DispatcherContainer {
public:
/**
* Returns Dispatcher.
*
* If no override installed, returns default Dispatcher instance.
*/
netops::Dispatcher* getDispatcher() const {
return overrideDispatcher_ ? overrideDispatcher_.get()
: Dispatcher::getDefaultInstance();
}
/**
* Returns Dispatcher.
*
* If no override installed, returns default Dispatcher instance.
*/
netops::Dispatcher* operator->() const { return getDispatcher(); }
/**
* Sets override Dispatcher. To remove override, pass empty shared_ptr.
*/
void setOverride(std::shared_ptr<netops::Dispatcher> dispatcher) {
overrideDispatcher_ = std::move(dispatcher);
}
/**
* If installed, returns shared_ptr to override Dispatcher, else empty ptr.
*/
std::shared_ptr<netops::Dispatcher> getOverride() const {
return overrideDispatcher_;
}
private:
std::shared_ptr<netops::Dispatcher> overrideDispatcher_;
};
} // namespace netops
} // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment