Commit 8b4086af authored by Alan Frindell's avatar Alan Frindell Committed by Facebook GitHub Bot

rename coro::Socket to coro::Transport

Summary: This also renames the file and creates a Socket.h shim for compatabilty.

Reviewed By: yairgott

Differential Revision: D26837068

fbshipit-source-id: 47c8f7b410540d1ad0aa03f4bc96816f8a93ef1c
parent c3156ea8
......@@ -95,7 +95,7 @@ ServerSocket::ServerSocket(
socket_->listen(listenQueueDepth);
}
Task<std::unique_ptr<Socket>> ServerSocket::accept() {
Task<std::unique_ptr<Transport>> ServerSocket::accept() {
VLOG(5) << "accept() called";
co_await folly::coro::co_safe_point;
......@@ -114,7 +114,7 @@ Task<std::unique_ptr<Socket>> ServerSocket::accept() {
if (cb.error) {
co_yield co_error(std::move(cb.error));
}
co_return std::make_unique<Socket>(
co_return std::make_unique<Transport>(
socket_->getEventBase(),
AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
......
......@@ -22,7 +22,7 @@
#include <folly/Expected.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/coro/Socket.h>
#include <folly/io/coro/Transport.h>
#if FOLLY_HAS_COROUTINES
......@@ -43,7 +43,7 @@ class ServerSocket {
ServerSocket(ServerSocket&&) = default;
ServerSocket& operator=(ServerSocket&&) = default;
Task<std::unique_ptr<Socket>> accept();
Task<std::unique_ptr<Transport>> accept();
void close() noexcept {
if (socket_) {
......
......@@ -16,138 +16,32 @@
#pragma once
#include <folly/Range.h>
#include <folly/SocketAddress.h>
#include <folly/experimental/coro/Task.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSocketException.h>
#include <folly/io/coro/Transport.h>
#if FOLLY_HAS_COROUTINES
namespace folly {
namespace coro {
class Transport {
public:
using ErrorCode = AsyncSocketException::AsyncSocketExceptionType;
// on write error, report the issue and how many bytes were written
virtual ~Transport() = default;
virtual EventBase* getEventBase() noexcept = 0;
virtual Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) = 0;
Task<size_t> read(
void* buf, size_t buflen, std::chrono::milliseconds timeout) {
return read(MutableByteRange((unsigned char*)buf, buflen), timeout);
}
virtual Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) = 0;
struct WriteInfo {
size_t bytesWritten{0};
};
virtual Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual Task<Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual SocketAddress getLocalAddress() const noexcept = 0;
virtual SocketAddress getPeerAddress() const noexcept = 0;
virtual void close() = 0;
virtual void shutdownWrite() = 0;
virtual void closeWithReset() = 0;
virtual folly::AsyncTransport* getTransport() const = 0;
virtual const AsyncTransportCertificate* getPeerCertificate() const = 0;
};
// Shim class -- will remove
class Socket : public Transport {
public:
explicit Socket(AsyncSocket::UniquePtr socket)
: eventBase_(socket->getEventBase()), transport_(std::move(socket)) {}
: Socket(socket->getEventBase(), std::move(socket)) {}
Socket(
folly::EventBase* eventBase, folly::AsyncTransport::UniquePtr transport)
: eventBase_(eventBase), transport_(std::move(transport)) {}
: Transport(eventBase, std::move(transport)) {}
Socket(Socket&&) = default;
Socket& operator=(Socket&&) = default;
Socket(Transport&& transport) : Transport(std::move(transport)) {}
static Task<Socket> connect(
EventBase* evb,
const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout);
virtual EventBase* getEventBase() noexcept override { return eventBase_; }
Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override;
Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) override;
Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
Task<folly::Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
AsyncTransport* getTransport() const override { return transport_.get(); }
SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr;
transport_->getLocalAddress(&addr);
return addr;
std::chrono::milliseconds connectTimeout) {
auto transport = co_await newConnectedSocket(evb, destAddr, connectTimeout);
co_return Socket(std::move(transport));
}
SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr;
transport_->getPeerAddress(&addr);
return addr;
}
void shutdownWrite() noexcept override {
if (transport_) {
transport_->shutdownWrite();
}
}
void close() noexcept override {
if (transport_) {
transport_->close();
}
}
void closeWithReset() noexcept override {
if (transport_) {
transport_->closeWithReset();
}
}
const AsyncTransportCertificate* getPeerCertificate() const override {
return transport_->getPeerCertificate();
}
private:
// non-copyable
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;
EventBase* eventBase_;
AsyncTransport::UniquePtr transport_;
bool deferredReadEOF_{false};
};
} // namespace coro
......
......@@ -19,7 +19,7 @@
#include <functional>
#include <folly/experimental/coro/Baton.h>
#include <folly/io/coro/Socket.h>
#include <folly/io/coro/Transport.h>
#if FOLLY_HAS_COROUTINES
......@@ -263,7 +263,7 @@ class WriteCallback : public CallbackBase,
namespace folly {
namespace coro {
Task<Socket> Socket::connect(
Task<Transport> Transport::newConnectedSocket(
folly::EventBase* evb,
const folly::SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout) {
......@@ -280,16 +280,16 @@ Task<Socket> Socket::connect(
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
co_return Socket(evb, std::move(socket));
co_return Transport(evb, std::move(socket));
}
Task<size_t> Socket::read(
Task<size_t> Transport::read(
folly::MutableByteRange buf, std::chrono::milliseconds timeout) {
if (deferredReadEOF_) {
deferredReadEOF_ = false;
co_return 0;
}
VLOG(5) << "Socket::read(), expecting max len " << buf.size();
VLOG(5) << "Transport::read(), expecting max len " << buf.size();
ReadCallback cb{eventBase_->timer(), *transport_, buf, timeout};
transport_->setReadCB(&cb);
......@@ -307,7 +307,7 @@ Task<size_t> Socket::read(
co_return cb.length;
}
Task<size_t> Socket::read(
Task<size_t> Transport::read(
folly::IOBufQueue& readBuf,
std::size_t minReadSize,
std::size_t newAllocationSize,
......@@ -316,7 +316,7 @@ Task<size_t> Socket::read(
deferredReadEOF_ = false;
co_return 0;
}
VLOG(5) << "Socket::read(), expecting minReadSize=" << minReadSize;
VLOG(5) << "Transport::read(), expecting minReadSize=" << minReadSize;
ReadCallback cb{
eventBase_->timer(),
......@@ -339,7 +339,7 @@ Task<size_t> Socket::read(
co_return cb.length;
}
Task<folly::Unit> Socket::write(
Task<folly::Unit> Transport::write(
folly::ByteRange buf,
std::chrono::milliseconds timeout,
WriteInfo* writeInfo) {
......@@ -364,7 +364,7 @@ Task<folly::Unit> Socket::write(
co_return unit;
}
Task<folly::Unit> Socket::write(
Task<folly::Unit> Transport::write(
folly::IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout,
WriteInfo* writeInfo) {
......
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Range.h>
#include <folly/SocketAddress.h>
#include <folly/experimental/coro/Task.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSocketException.h>
#if FOLLY_HAS_COROUTINES
namespace folly {
namespace coro {
class TransportIf {
public:
using ErrorCode = AsyncSocketException::AsyncSocketExceptionType;
// on write error, report the issue and how many bytes were written
virtual ~TransportIf() = default;
virtual EventBase* getEventBase() noexcept = 0;
virtual Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) = 0;
Task<size_t> read(
void* buf, size_t buflen, std::chrono::milliseconds timeout) {
return read(MutableByteRange((unsigned char*)buf, buflen), timeout);
}
virtual Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) = 0;
struct WriteInfo {
size_t bytesWritten{0};
};
virtual Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual Task<Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual SocketAddress getLocalAddress() const noexcept = 0;
virtual SocketAddress getPeerAddress() const noexcept = 0;
virtual void close() = 0;
virtual void shutdownWrite() = 0;
virtual void closeWithReset() = 0;
virtual folly::AsyncTransport* getTransport() const = 0;
virtual const AsyncTransportCertificate* getPeerCertificate() const = 0;
};
class Transport : public TransportIf {
public:
Transport(
folly::EventBase* eventBase, folly::AsyncTransport::UniquePtr transport)
: eventBase_(eventBase), transport_(std::move(transport)) {}
Transport(Transport&&) = default;
Transport& operator=(Transport&&) = default;
// Establish a TCP connection to the given address and return a Transport
// That wraps that socket
static Task<Transport> newConnectedSocket(
EventBase* evb,
const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout);
virtual EventBase* getEventBase() noexcept override { return eventBase_; }
Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override;
Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) override;
Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
Task<folly::Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
AsyncTransport* getTransport() const override { return transport_.get(); }
SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr;
transport_->getLocalAddress(&addr);
return addr;
}
SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr;
transport_->getPeerAddress(&addr);
return addr;
}
void shutdownWrite() noexcept override {
if (transport_) {
transport_->shutdownWrite();
}
}
void close() noexcept override {
if (transport_) {
transport_->close();
}
}
void closeWithReset() noexcept override {
if (transport_) {
transport_->closeWithReset();
}
}
const AsyncTransportCertificate* getPeerCertificate() const override {
return transport_->getPeerCertificate();
}
private:
// non-copyable
Transport(const Transport&) = delete;
Transport& operator=(const Transport&) = delete;
EventBase* eventBase_;
AsyncTransport::UniquePtr transport_;
bool deferredReadEOF_{false};
};
} // namespace coro
} // namespace folly
#endif // FOLLY_HAS_COROUTINES
......@@ -21,7 +21,7 @@
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/ScopedBoundPort.h>
#include <folly/io/coro/ServerSocket.h>
#include <folly/io/coro/Socket.h>
#include <folly/io/coro/Transport.h>
#include <folly/portability/GTest.h>
#if FOLLY_HAS_COROUTINES
......@@ -30,7 +30,7 @@ using namespace std::chrono_literals;
using namespace folly;
using namespace folly::coro;
class SocketTest : public testing::Test {
class TransportTest : public testing::Test {
public:
template <typename F>
void run(F f) {
......@@ -46,32 +46,34 @@ class SocketTest : public testing::Test {
CancellationSource cancelSource;
};
class ServerSocketTest : public SocketTest {
class ServerTransportTest : public TransportTest {
public:
folly::coro::Task<Socket> connect() {
co_return co_await Socket::connect(&evb, srv.getAddress(), 0ms);
folly::coro::Task<Transport> connect() {
co_return co_await Transport::newConnectedSocket(
&evb, srv.getAddress(), 0ms);
}
TestServer srv;
};
TEST_F(SocketTest, ConnectFailure) {
TEST_F(TransportTest, ConnectFailure) {
run([&]() -> Task<> {
ScopedBoundPort ph;
auto serverAddr = ph.getAddress();
EXPECT_THROW(
co_await Socket::connect(&evb, serverAddr, 0ms), AsyncSocketException);
co_await Transport::newConnectedSocket(&evb, serverAddr, 0ms),
AsyncSocketException);
});
}
TEST_F(ServerSocketTest, ConnectSuccess) {
TEST_F(ServerTransportTest, ConnectSuccess) {
run([&]() -> Task<> {
auto cs = co_await connect();
EXPECT_EQ(srv.getAddress(), cs.getPeerAddress());
});
}
TEST_F(ServerSocketTest, ConnectCancelled) {
TEST_F(ServerTransportTest, ConnectCancelled) {
run([&]() -> Task<> {
co_await folly::coro::collectAll(
// token would be cancelled while waiting on connect
......@@ -85,12 +87,12 @@ TEST_F(ServerSocketTest, ConnectCancelled) {
EXPECT_THROW(
co_await co_withCancellation(
cancelSource.getToken(),
Socket::connect(&evb, srv.getAddress(), 0ms)),
Transport::newConnectedSocket(&evb, srv.getAddress(), 0ms)),
OperationCancelled);
});
}
TEST_F(ServerSocketTest, SimpleRead) {
TEST_F(ServerTransportTest, SimpleRead) {
run([&]() -> Task<> {
constexpr auto kBufSize = 65536;
auto cs = co_await connect();
......@@ -121,7 +123,7 @@ TEST_F(ServerSocketTest, SimpleRead) {
});
}
TEST_F(ServerSocketTest, SimpleIOBufRead) {
TEST_F(ServerTransportTest, SimpleIOBufRead) {
run([&]() -> Task<> {
// Exactly fills a buffer mid-loop and triggers deferredReadEOF handling
constexpr auto kBufSize = 55 * 1184;
......@@ -151,7 +153,7 @@ TEST_F(ServerSocketTest, SimpleIOBufRead) {
});
}
TEST_F(ServerSocketTest, ReadCancelled) {
TEST_F(ServerTransportTest, ReadCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
auto reader = [&cs]() -> Task<Unit> {
......@@ -172,7 +174,7 @@ TEST_F(ServerSocketTest, ReadCancelled) {
});
}
TEST_F(ServerSocketTest, ReadTimeout) {
TEST_F(ServerTransportTest, ReadTimeout) {
run([&]() -> Task<> {
auto cs = co_await connect();
std::array<uint8_t, 1024> rcvBuf;
......@@ -184,7 +186,7 @@ TEST_F(ServerSocketTest, ReadTimeout) {
});
}
TEST_F(ServerSocketTest, ReadError) {
TEST_F(ServerTransportTest, ReadError) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
......@@ -200,7 +202,7 @@ TEST_F(ServerSocketTest, ReadError) {
});
}
TEST_F(ServerSocketTest, SimpleWrite) {
TEST_F(ServerTransportTest, SimpleWrite) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
......@@ -220,7 +222,7 @@ TEST_F(ServerSocketTest, SimpleWrite) {
});
}
TEST_F(ServerSocketTest, SimpleWritev) {
TEST_F(ServerTransportTest, SimpleWritev) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
......@@ -248,7 +250,7 @@ TEST_F(ServerSocketTest, SimpleWritev) {
});
}
TEST_F(ServerSocketTest, WriteCancelled) {
TEST_F(ServerTransportTest, WriteCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
// reduce the send buffer size so the write wouldn't complete immediately
......@@ -276,17 +278,17 @@ TEST_F(ServerSocketTest, WriteCancelled) {
});
}
TEST_F(SocketTest, SimpleAccept) {
TEST_F(TransportTest, SimpleAccept) {
run([&]() -> Task<> {
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
auto serverAddr = css.getAsyncServerSocket()->getAddress();
co_await folly::coro::collectAll(
css.accept(), Socket::connect(&evb, serverAddr, 0ms));
css.accept(), Transport::newConnectedSocket(&evb, serverAddr, 0ms));
});
}
TEST_F(SocketTest, AcceptCancelled) {
TEST_F(TransportTest, AcceptCancelled) {
run([&]() -> Task<> {
co_await folly::coro::collectAll(requestCancellation(), [&]() -> Task<> {
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
......@@ -297,12 +299,12 @@ TEST_F(SocketTest, AcceptCancelled) {
});
}
TEST_F(SocketTest, AsyncClientAndServer) {
TEST_F(TransportTest, AsyncClientAndServer) {
run([&]() -> Task<> {
constexpr int kSize = 128;
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
auto serverAddr = css.getAsyncServerSocket()->getAddress();
auto cs = co_await Socket::connect(&evb, serverAddr, 0ms);
auto cs = co_await Transport::newConnectedSocket(&evb, serverAddr, 0ms);
co_await folly::coro::collectAll(
[&css]() -> Task<> {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment