Commit d2e690a6 authored by Yair Gottdenker's avatar Yair Gottdenker Committed by Facebook GitHub Bot

moving CoroSocket to folly/experimental/coro

Summary: This is the second attempt, the first one was D22958650. Decided to do a different diff as some affected files were moved from experimental/afrind/coro/h2proxy to proxygen/facebook/lib/experimental/coro/ which created some confusion while arc pulling

Reviewed By: yfeldblum

Differential Revision: D25432869

fbshipit-source-id: a183898302a79084d890548b9b7ecc4409f501d2
parent be76ab69
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/Baton.h>
#include <folly/io/coro/ServerSocket.h>
using namespace folly::coro;
namespace {
class AcceptCallback : public folly::AsyncServerSocket::AcceptCallback {
public:
explicit AcceptCallback(
Baton& baton, std::shared_ptr<folly::AsyncServerSocket> socket)
: baton_{baton}, socket_(std::move(socket)) {}
~AcceptCallback() override = default;
int acceptFd{-1};
folly::exception_wrapper error;
private:
// to notify the caller of the result
Baton& baton_;
// the server socket
std::shared_ptr<folly::AsyncServerSocket> socket_;
//
// AcceptCallback methods
//
void connectionAccepted(
folly::NetworkSocket fdNetworkSocket,
const folly::SocketAddress& clientAddr) noexcept override {
VLOG(5) << "Connection accepted from: " << clientAddr.describe();
// unregister handlers while in the callback
socket_->pauseAccepting();
socket_->removeAcceptCallback(this, nullptr);
acceptFd = fdNetworkSocket.toFd();
baton_.post();
}
void acceptError(folly::exception_wrapper ex) noexcept override {
VLOG(5) << "acceptError";
// unregister handlers while in the callback
socket_->pauseAccepting();
socket_->removeAcceptCallback(this, nullptr);
error = std::move(ex);
acceptFd = -1;
baton_.post();
}
void acceptStarted() noexcept override { VLOG(5) << "acceptStarted"; }
void acceptStopped() noexcept override { VLOG(5) << "acceptStopped"; }
};
} // namespace
namespace folly {
namespace coro {
ServerSocket::ServerSocket(
std::shared_ptr<AsyncServerSocket> socket,
std::optional<SocketAddress> bindAddr,
uint32_t listenQueueDepth)
: socket_{socket} {
socket_->setReusePortEnabled(true);
if (bindAddr.has_value()) {
VLOG(1) << "ServerSocket binds on IP: " << bindAddr->describe();
socket_->bind(*bindAddr);
} else {
VLOG(1) << "ServerSocket binds on any addr, random port";
socket_->bind(0);
}
socket_->listen(listenQueueDepth);
}
Task<std::unique_ptr<Socket>> ServerSocket::accept() {
VLOG(5) << "accept() called";
co_await folly::coro::co_safe_point;
Baton baton;
AcceptCallback cb(baton, socket_);
socket_->addAcceptCallback(&cb, nullptr);
socket_->startAccepting();
auto cancelToken = co_await folly::coro::co_current_cancellation_token;
CancellationCallback cancellationCallback{cancelToken, [&baton, this] {
this->socket_->stopAccepting();
baton.post();
}};
co_await baton;
co_await folly::coro::co_safe_point;
if (cb.error) {
co_yield co_error(std::move(cb.error));
}
co_return std::make_unique<Socket>(AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
}
} // namespace coro
} // namespace folly
#endif // FOLLY_HAS_COROUTINES
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/ExceptionWrapper.h>
#include <folly/Expected.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/coro/Socket.h>
#include <optional>
namespace folly {
namespace coro {
//
// This server socket will accept connections on the
// same event base as the socket itself
//
class ServerSocket {
public:
ServerSocket(
std::shared_ptr<AsyncServerSocket> socket,
std::optional<SocketAddress> bindAddr,
uint32_t listenQueueDepth);
ServerSocket(ServerSocket&&) = default;
ServerSocket& operator=(ServerSocket&&) = default;
Task<std::unique_ptr<Socket>> accept();
void close() noexcept {
if (socket_) {
socket_->stopAccepting();
}
}
const AsyncServerSocket* getAsyncServerSocket() const {
return socket_.get();
}
private:
// non-copyable
ServerSocket(const ServerSocket&) = delete;
ServerSocket& operator=(const ServerSocket&) = delete;
std::shared_ptr<AsyncServerSocket> socket_;
};
} // namespace coro
} // namespace folly
This diff is collapsed.
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Range.h>
#include <folly/SocketAddress.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/Utils.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
namespace folly {
namespace coro {
class Transport {
public:
using ErrorCode = AsyncSocketException::AsyncSocketExceptionType;
// on write error, report the issue and how many bytes were written
virtual ~Transport() = default;
virtual EventBase* getEventBase() noexcept = 0;
virtual Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) = 0;
Task<size_t> read(
void* buf, size_t buflen, std::chrono::milliseconds timeout) {
return read(MutableByteRange((unsigned char*)buf, buflen), timeout);
}
virtual Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) = 0;
struct WriteInfo {
size_t bytesWritten{0};
};
virtual Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual Task<Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual SocketAddress getLocalAddress() const noexcept = 0;
virtual SocketAddress getPeerAddress() const noexcept = 0;
virtual void close() = 0;
virtual void shutdownWrite() = 0;
virtual void closeWithReset() = 0;
virtual folly::AsyncTransport* getTransport() const = 0;
virtual const AsyncTransportCertificate* getPeerCertificate() const = 0;
};
class Socket : public Transport {
public:
explicit Socket(std::shared_ptr<AsyncSocket> socket)
: socket_(std::move(socket)) {}
Socket(Socket&&) = default;
Socket& operator=(Socket&&) = default;
static Task<Socket> connect(
EventBase* evb,
const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout);
virtual EventBase* getEventBase() noexcept override {
return socket_->getEventBase();
}
Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override;
Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) override;
Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
Task<folly::Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr;
socket_->getLocalAddress(&addr);
return addr;
}
folly::AsyncTransport* getTransport() const override { return socket_.get(); }
SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr;
socket_->getPeerAddress(&addr);
return addr;
}
void shutdownWrite() noexcept override {
if (socket_) {
socket_->shutdownWrite();
}
}
void close() noexcept override {
if (socket_) {
socket_->close();
}
}
void closeWithReset() noexcept override {
if (socket_) {
socket_->closeWithReset();
}
}
std::shared_ptr<AsyncSocket> getAsyncSocket() { return socket_; }
const AsyncTransportCertificate* getPeerCertificate() const override {
return socket_->getPeerCertificate();
}
private:
// non-copyable
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;
std::shared_ptr<AsyncSocket> socket_;
bool deferredReadEOF_{false};
};
} // namespace coro
} // namespace folly
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/ScopedBoundPort.h>
#include <folly/io/coro/ServerSocket.h>
#include <folly/io/coro/Socket.h>
#include <folly/portability/GTest.h>
using namespace std::chrono_literals;
using namespace folly;
using namespace folly::coro;
class SocketTest : public testing::Test {
public:
template <typename F>
void run(F f) {
blockingWait(co_invoke(std::move(f)), &evb);
}
folly::coro::Task<> requestCancellation() {
cancelSource.requestCancellation();
co_return;
}
EventBase evb;
CancellationSource cancelSource;
};
class ServerSocketTest : public SocketTest {
public:
folly::coro::Task<Socket> connect() {
co_return co_await Socket::connect(&evb, srv.getAddress(), 0ms);
}
TestServer srv;
};
TEST_F(SocketTest, ConnectFailure) {
run([&]() -> Task<> {
ScopedBoundPort ph;
auto serverAddr = ph.getAddress();
EXPECT_THROW(
co_await Socket::connect(&evb, serverAddr, 0ms), AsyncSocketException);
});
}
TEST_F(ServerSocketTest, ConnectSuccess) {
run([&]() -> Task<> {
auto cs = co_await connect();
EXPECT_EQ(srv.getAddress(), cs.getPeerAddress());
});
}
TEST_F(ServerSocketTest, ConnectCancelled) {
run([&]() -> Task<> {
co_await folly::coro::collectAll(
// token would be cancelled while waiting on connect
[&]() -> Task<> {
EXPECT_THROW(
co_await co_withCancellation(cancelSource.getToken(), connect()),
OperationCancelled);
}(),
requestCancellation());
// token was cancelled before read was called
EXPECT_THROW(
co_await co_withCancellation(
cancelSource.getToken(),
Socket::connect(&evb, srv.getAddress(), 0ms)),
OperationCancelled);
});
}
TEST_F(ServerSocketTest, SimpleRead) {
run([&]() -> Task<> {
constexpr auto kBufSize = 65536;
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
ss->write(sndBuf.data(), sndBuf.size());
// read using coroutines
std::array<uint8_t, kBufSize> rcvBuf;
auto reader = [&rcvBuf, &cs]() -> Task<Unit> {
int totalBytes{0};
while (totalBytes < kBufSize) {
auto bytesRead = co_await cs.read(
MutableByteRange(
rcvBuf.data() + totalBytes,
(rcvBuf.data() + rcvBuf.size() - totalBytes)),
0ms);
totalBytes += bytesRead;
}
co_return unit;
};
co_await reader();
EXPECT_EQ(0, memcmp(sndBuf.data(), rcvBuf.data(), rcvBuf.size()));
});
}
TEST_F(ServerSocketTest, SimpleIOBufRead) {
run([&]() -> Task<> {
// Exactly fills a buffer mid-loop and triggers deferredReadEOF handling
constexpr auto kBufSize = 55 * 1184;
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
ss->write(sndBuf.data(), sndBuf.size());
ss->close();
// read using coroutines
IOBufQueue rcvBuf(IOBufQueue::cacheChainLength());
int totalBytes{0};
while (totalBytes < kBufSize) {
auto bytesRead = co_await cs.read(rcvBuf, 1000, 1000, 0ms);
totalBytes += bytesRead;
}
auto bytesRead = co_await cs.read(rcvBuf, 1000, 1000, 50ms);
EXPECT_EQ(bytesRead, 0); // closed
auto data = rcvBuf.move();
data->coalesce();
EXPECT_EQ(0, memcmp(sndBuf.data(), data->data(), data->length()));
});
}
TEST_F(ServerSocketTest, ReadCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
auto reader = [&cs]() -> Task<Unit> {
std::array<uint8_t, 1024> rcvBuf;
EXPECT_THROW(
co_await cs.read(
MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
0ms),
OperationCancelled);
co_return unit;
};
co_await co_withCancellation(
cancelSource.getToken(),
folly::coro::collectAll(requestCancellation(), reader()));
// token was cancelled before read was called
co_await co_withCancellation(cancelSource.getToken(), reader());
});
}
TEST_F(ServerSocketTest, ReadTimeout) {
run([&]() -> Task<> {
auto cs = co_await connect();
std::array<uint8_t, 1024> rcvBuf;
EXPECT_THROW(
co_await cs.read(
MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
50ms),
AsyncSocketException);
});
}
TEST_F(ServerSocketTest, ReadError) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
ss->closeWithReset();
std::array<uint8_t, 1024> rcvBuf;
EXPECT_THROW(
co_await cs.read(
MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
50ms),
AsyncSocketException);
});
}
TEST_F(ServerSocketTest, SimpleWrite) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
constexpr auto kBufSize = 65536;
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
// write use co-routine
co_await cs.write(ByteRange(sndBuf.data(), sndBuf.data() + sndBuf.size()));
// read on server side
std::array<uint8_t, kBufSize> rcvBuf;
ss->readAll(rcvBuf.data(), rcvBuf.size());
EXPECT_EQ(0, memcmp(sndBuf.data(), rcvBuf.data(), rcvBuf.size()));
});
}
TEST_F(ServerSocketTest, SimpleWritev) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
IOBufQueue sndBuf;
constexpr auto kBufSize = 65536;
std::array<uint8_t, kBufSize> bufA;
std::memset(bufA.data(), 'a', bufA.size());
std::array<uint8_t, kBufSize> bufB;
std::memset(bufB.data(), 'b', bufB.size());
sndBuf.append(bufA.data(), bufA.size());
sndBuf.append(bufB.data(), bufB.size());
// write use co-routine
co_await cs.write(sndBuf);
// read on server side
std::array<uint8_t, kBufSize> rcvBufA;
ss->readAll(rcvBufA.data(), rcvBufA.size());
EXPECT_EQ(0, memcmp(bufA.data(), rcvBufA.data(), rcvBufA.size()));
std::array<uint8_t, kBufSize> rcvBufB;
ss->readAll(rcvBufB.data(), rcvBufB.size());
EXPECT_EQ(0, memcmp(bufB.data(), rcvBufB.data(), rcvBufB.size()));
});
}
TEST_F(ServerSocketTest, WriteCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
// reduce the send buffer size so the write wouldn't complete immediately
EXPECT_EQ(cs.getAsyncSocket()->setSendBufSize(4096), 0);
// produces blocking socket
auto ss = srv.accept(-1);
constexpr auto kBufSize = 65536;
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
// write use co-routine
auto writter = [&]() -> Task<> {
EXPECT_THROW(
co_await co_withCancellation(
cancelSource.getToken(),
cs.write(
ByteRange(sndBuf.data(), sndBuf.data() + sndBuf.size()))),
OperationCancelled);
};
co_await folly::coro::collectAll(requestCancellation(), writter());
co_await co_withCancellation(cancelSource.getToken(), writter());
});
}
TEST_F(SocketTest, SimpleAccept) {
run([&]() -> Task<> {
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
auto serverAddr = css.getAsyncServerSocket()->getAddress();
co_await folly::coro::collectAll(
css.accept(), Socket::connect(&evb, serverAddr, 0ms));
});
}
TEST_F(SocketTest, AcceptCancelled) {
run([&]() -> Task<> {
co_await folly::coro::collectAll(requestCancellation(), [&]() -> Task<> {
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
EXPECT_THROW(
co_await co_withCancellation(cancelSource.getToken(), css.accept()),
OperationCancelled);
}());
});
}
TEST_F(SocketTest, AsyncClientAndServer) {
run([&]() -> Task<> {
constexpr int kSize = 128;
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
auto serverAddr = css.getAsyncServerSocket()->getAddress();
auto cs = co_await Socket::connect(&evb, serverAddr, 0ms);
co_await folly::coro::collectAll(
[&css]() -> Task<> {
auto sock = co_await css.accept();
std::array<uint8_t, kSize> buf;
memset(buf.data(), 'a', kSize);
co_await sock->write(ByteRange(buf.begin(), buf.end()));
css.close();
}(),
[&cs]() -> Task<> {
std::array<uint8_t, kSize> buf;
// For fun, shutdown the write half -- we don't need it
cs.shutdownWrite();
auto len =
co_await cs.read(MutableByteRange(buf.begin(), buf.end()), 0ms);
cs.close();
EXPECT_TRUE(len == buf.size());
}());
});
}
#endif // FOLLY_HAS_COROUTINES
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment