Commit c3156ea8 authored by Alan Frindell's avatar Alan Frindell Committed by Facebook GitHub Bot

coro::Socket wraps AsyncTransport

Summary: This allows coro::Socket to be used with non-socket AsyncTransports, notably Fizz.  A subsequent diff renames the class and file, and creates a shim for callers.

Reviewed By: yairgott

Differential Revision: D26610473

fbshipit-source-id: d64597ef0de3c90ab084249b20e85d97b34857a6
parent 2f7a1264
......@@ -114,8 +114,10 @@ Task<std::unique_ptr<Socket>> ServerSocket::accept() {
if (cb.error) {
co_yield co_error(std::move(cb.error));
}
co_return std::make_unique<Socket>(AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
co_return std::make_unique<Socket>(
socket_->getEventBase(),
AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
}
} // namespace coro
......
......@@ -33,8 +33,8 @@ namespace {
class CallbackBase {
public:
explicit CallbackBase(std::shared_ptr<folly::AsyncSocket> socket)
: socket_{std::move(socket)} {}
explicit CallbackBase(folly::AsyncTransport& transport)
: transport_{transport} {}
virtual ~CallbackBase() noexcept = default;
......@@ -65,10 +65,10 @@ class CallbackBase {
protected:
// we use this to notify the other side of completion
Baton baton_;
// needed to modify AsyncSocket state, e.g. cacncel callbacks
const std::shared_ptr<folly::AsyncSocket> socket_;
// needed to modify AsyncTransport state, e.g. cacncel callbacks
folly::AsyncTransport& transport_;
// to wrap AsyncSocket errors
// to wrap AsyncTransport errors
folly::exception_wrapper error_;
private:
......@@ -82,11 +82,11 @@ class CallbackBase {
class ConnectCallback : public CallbackBase,
public folly::AsyncSocket::ConnectCallback {
public:
explicit ConnectCallback(std::shared_ptr<folly::AsyncSocket> socket)
: CallbackBase(std::move(socket)) {}
explicit ConnectCallback(folly::AsyncSocket& socket)
: CallbackBase(socket), socket_(socket) {}
private:
void cancel() noexcept override { socket_->cancelConnect(); }
void cancel() noexcept override { socket_.cancelConnect(); }
void connectSuccess() noexcept override { post(); }
......@@ -94,14 +94,15 @@ class ConnectCallback : public CallbackBase,
error_ = folly::exception_wrapper(ex);
post();
}
folly::AsyncSocket& socket_;
};
//
// Handle data read for AsyncSocket
// Handle data read for AsyncTransport
//
class ReadCallback : public CallbackBase,
public folly::AsyncSocket::ReadCallback,
public folly::AsyncTransport::ReadCallback,
public folly::HHWheelTimer::Callback {
public:
// we need to pass the socket into ReadCallback so we can clear the callback
......@@ -111,27 +112,29 @@ class ReadCallback : public CallbackBase,
// socket to call readDataAvailable and readEOF in sequence, causing the
// promise to be fulfilled twice (oops!)
ReadCallback(
std::shared_ptr<folly::AsyncSocket> socket,
folly::HHWheelTimer& timer,
folly::AsyncTransport& transport,
folly::MutableByteRange buf,
std::chrono::milliseconds timeout)
: CallbackBase(socket), buf_{buf} {
: CallbackBase(transport), buf_{buf} {
if (timeout.count() > 0) {
socket->getEventBase()->timer().scheduleTimeout(this, timeout);
timer.scheduleTimeout(this, timeout);
}
}
ReadCallback(
std::shared_ptr<folly::AsyncSocket> socket,
folly::HHWheelTimer& timer,
folly::AsyncTransport& transport,
folly::IOBufQueue* readBuf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout)
: CallbackBase(socket),
: CallbackBase(transport),
readBuf_(readBuf),
minReadSize_(minReadSize),
newAllocationSize_(newAllocationSize) {
if (timeout.count() > 0) {
socket->getEventBase()->timer().scheduleTimeout(this, timeout);
timer.scheduleTimeout(this, timeout);
}
}
......@@ -147,7 +150,7 @@ class ReadCallback : public CallbackBase,
size_t newAllocationSize_{0};
void cancel() noexcept override {
socket_->setReadCB(nullptr);
transport_.setReadCB(nullptr);
cancelTimeout();
}
......@@ -176,7 +179,7 @@ class ReadCallback : public CallbackBase,
if (readBuf_) {
readBuf_->postallocate(len);
} else if (length == buf_.size()) {
socket_->setReadCB(nullptr);
transport_.setReadCB(nullptr);
cancelTimeout();
}
post();
......@@ -185,7 +188,7 @@ class ReadCallback : public CallbackBase,
void readEOF() noexcept override {
VLOG(5) << "readEOF()";
// disable callbacks
socket_->setReadCB(nullptr);
transport_.setReadCB(nullptr);
cancelTimeout();
eof = true;
post();
......@@ -194,7 +197,7 @@ class ReadCallback : public CallbackBase,
void readErr(const folly::AsyncSocketException& ex) noexcept override {
VLOG(5) << "readErr()";
// disable callbacks
socket_->setReadCB(nullptr);
transport_.setReadCB(nullptr);
cancelTimeout();
error_ = folly::exception_wrapper(ex);
post();
......@@ -210,7 +213,7 @@ class ReadCallback : public CallbackBase,
using Error = folly::AsyncSocketException::AsyncSocketExceptionType;
// uninstall read callback. it takes another read to bring it back.
socket_->setReadCB(nullptr);
transport_.setReadCB(nullptr);
// If the timeout fires but this ReadCallback did get some data, ignore it.
// post() has already happend from readDataAvailable.
if (length == 0) {
......@@ -222,21 +225,21 @@ class ReadCallback : public CallbackBase,
};
//
// Handle data write for AsyncSocket
// Handle data write for AsyncTransport
//
class WriteCallback : public CallbackBase,
public folly::AsyncSocket::WriteCallback {
public folly::AsyncTransport::WriteCallback {
public:
explicit WriteCallback(std::shared_ptr<folly::AsyncSocket> socket)
: CallbackBase(socket) {}
explicit WriteCallback(folly::AsyncTransport& transport)
: CallbackBase(transport) {}
~WriteCallback() override = default;
size_t bytesWritten{0};
std::optional<folly::AsyncSocketException> error;
private:
void cancel() noexcept override { socket_->closeWithReset(); }
void cancel() noexcept override { transport_.closeWithReset(); }
//
// Methods of WriteCallback
//
......@@ -264,10 +267,10 @@ Task<Socket> Socket::connect(
folly::EventBase* evb,
const folly::SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout) {
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(evb);
auto socket = AsyncSocket::newSocket(evb);
socket->setReadCB(nullptr);
ConnectCallback cb{socket};
ConnectCallback cb{*socket};
socket->connect(&cb, destAddr, connectTimeout.count());
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
......@@ -277,7 +280,7 @@ Task<Socket> Socket::connect(
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
co_return Socket(socket);
co_return Socket(evb, std::move(socket));
}
Task<size_t> Socket::read(
......@@ -288,8 +291,8 @@ Task<size_t> Socket::read(
}
VLOG(5) << "Socket::read(), expecting max len " << buf.size();
ReadCallback cb{socket_, buf, timeout};
socket_->setReadCB(&cb);
ReadCallback cb{eventBase_->timer(), *transport_, buf, timeout};
transport_->setReadCB(&cb);
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
......@@ -299,7 +302,7 @@ Task<size_t> Socket::read(
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
socket_->setReadCB(nullptr);
transport_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length;
}
......@@ -315,8 +318,14 @@ Task<size_t> Socket::read(
}
VLOG(5) << "Socket::read(), expecting minReadSize=" << minReadSize;
ReadCallback cb{socket_, &readBuf, minReadSize, newAllocationSize, timeout};
socket_->setReadCB(&cb);
ReadCallback cb{
eventBase_->timer(),
*transport_,
&readBuf,
minReadSize,
newAllocationSize,
timeout};
transport_->setReadCB(&cb);
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
......@@ -325,7 +334,7 @@ Task<size_t> Socket::read(
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
socket_->setReadCB(nullptr);
transport_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length;
}
......@@ -334,9 +343,9 @@ Task<folly::Unit> Socket::write(
folly::ByteRange buf,
std::chrono::milliseconds timeout,
WriteInfo* writeInfo) {
socket_->setSendTimeout(timeout.count());
WriteCallback cb{socket_};
socket_->write(&cb, buf.begin(), buf.size());
transport_->setSendTimeout(timeout.count());
WriteCallback cb{*transport_};
transport_->write(&cb, buf.begin(), buf.size());
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
......@@ -359,10 +368,10 @@ Task<folly::Unit> Socket::write(
folly::IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout,
WriteInfo* writeInfo) {
socket_->setSendTimeout(timeout.count());
WriteCallback cb{socket_};
transport_->setSendTimeout(timeout.count());
WriteCallback cb{*transport_};
auto iovec = ioBufQueue.front()->getIov();
socket_->writev(&cb, iovec.data(), iovec.size());
transport_->writev(&cb, iovec.data(), iovec.size());
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
......
......@@ -21,7 +21,7 @@
#include <folly/experimental/coro/Task.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncSocketException.h>
#if FOLLY_HAS_COROUTINES
......@@ -71,11 +71,12 @@ class Transport {
class Socket : public Transport {
public:
explicit Socket(std::shared_ptr<AsyncSocket> socket)
: socket_(std::move(socket)) {}
explicit Socket(AsyncSocket::UniquePtr socket)
: socket_(socket.release(), AsyncSocket::Destructor()) {}
: eventBase_(socket->getEventBase()), transport_(std::move(socket)) {}
Socket(
folly::EventBase* eventBase, folly::AsyncTransport::UniquePtr transport)
: eventBase_(eventBase), transport_(std::move(transport)) {}
Socket(Socket&&) = default;
Socket& operator=(Socket&&) = default;
......@@ -84,9 +85,7 @@ class Socket : public Transport {
EventBase* evb,
const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout);
virtual EventBase* getEventBase() noexcept override {
return socket_->getEventBase();
}
virtual EventBase* getEventBase() noexcept override { return eventBase_; }
Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override;
......@@ -105,42 +104,40 @@ class Socket : public Transport {
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
AsyncTransport* getTransport() const override { return transport_.get(); }
SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr;
socket_->getLocalAddress(&addr);
transport_->getLocalAddress(&addr);
return addr;
}
folly::AsyncTransport* getTransport() const override { return socket_.get(); }
SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr;
socket_->getPeerAddress(&addr);
transport_->getPeerAddress(&addr);
return addr;
}
void shutdownWrite() noexcept override {
if (socket_) {
socket_->shutdownWrite();
if (transport_) {
transport_->shutdownWrite();
}
}
void close() noexcept override {
if (socket_) {
socket_->close();
if (transport_) {
transport_->close();
}
}
void closeWithReset() noexcept override {
if (socket_) {
socket_->closeWithReset();
if (transport_) {
transport_->closeWithReset();
}
}
std::shared_ptr<AsyncSocket> getAsyncSocket() { return socket_; }
const AsyncTransportCertificate* getPeerCertificate() const override {
return socket_->getPeerCertificate();
return transport_->getPeerCertificate();
}
private:
......@@ -148,7 +145,8 @@ class Socket : public Transport {
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;
std::shared_ptr<AsyncSocket> socket_;
EventBase* eventBase_;
AsyncTransport::UniquePtr transport_;
bool deferredReadEOF_{false};
};
......
......@@ -252,7 +252,9 @@ TEST_F(ServerSocketTest, WriteCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
// reduce the send buffer size so the write wouldn't complete immediately
EXPECT_EQ(cs.getAsyncSocket()->setSendBufSize(4096), 0);
auto asyncSocket = dynamic_cast<folly::AsyncSocket*>(cs.getTransport());
CHECK(asyncSocket);
EXPECT_EQ(asyncSocket->setSendBufSize(4096), 0);
// produces blocking socket
auto ss = srv.accept(-1);
constexpr auto kBufSize = 65536;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment