Commit 7dbfd2f8 authored by James Sedgwick's avatar James Sedgwick Committed by Sara Golemon

AsyncSocket::writeRequest() and its first user wangle::FileRegion

Summary: similar to D2050808, but move the functionality into AsyncSocket itself so that you have a consistent interface and contiguous writes for a single file

Test Plan: added unit, will hook this up to a file server example next

Reviewed By: davejwatson@fb.com

Subscribers: fugalh, net-systems@, folly-diffs@, jsedgwick, yfeldblum, chalfant

FB internal diff: D2084452

Signature: t1:2084452:1433181933:175158618966706db00bf6620cc86ae145d04ecf
parent 7cae5640
......@@ -282,6 +282,7 @@ nobase_follyinclude_HEADERS = \
wangle/bootstrap/ClientBootstrap.h \
wangle/channel/AsyncSocketHandler.h \
wangle/channel/EventBaseHandler.h \
wangle/channel/FileRegion.h \
wangle/channel/Handler.h \
wangle/channel/HandlerContext.h \
wangle/channel/HandlerContext-inl.h \
......
......@@ -17,6 +17,8 @@
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/EventHandler.h>
#include <folly/Singleton.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
......@@ -24,6 +26,7 @@
#include <errno.h>
#include <limits.h>
#include <unistd.h>
#include <thread>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
......@@ -43,7 +46,7 @@ const AsyncSocketException socketClosedLocallyEx(
const AsyncSocketException socketShutdownForWritesEx(
AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
// TODO: It might help performance to provide a version of WriteRequest that
// TODO: It might help performance to provide a version of BytesWriteRequest that
// users could derive from, so we can avoid the extra allocation for each call
// to write()/writev(). We could templatize TFramedAsyncChannel just like the
// protocols are currently templatized for transports.
......@@ -52,53 +55,6 @@ const AsyncSocketException socketShutdownForWritesEx(
// storage space, and only our internal version would allocate it at the end of
// the WriteRequest.
/**
* A WriteRequest object tracks information about a pending write operation.
*/
class AsyncSocket::WriteRequest {
public:
WriteRequest(AsyncSocket* socket,
WriteRequest* next,
WriteCallback* callback,
uint32_t totalBytesWritten) :
socket_(socket), next_(next), callback_(callback),
totalBytesWritten_(totalBytesWritten) {}
virtual void destroy() = 0;
virtual bool performWrite() = 0;
virtual void consume() = 0;
virtual bool isComplete() = 0;
WriteRequest* getNext() const {
return next_;
}
WriteCallback* getCallback() const {
return callback_;
}
uint32_t getTotalBytesWritten() const {
return totalBytesWritten_;
}
void append(WriteRequest* next) {
assert(next_ == nullptr);
next_ = next;
}
protected:
// protected destructor, to ensure callers use destroy()
virtual ~WriteRequest() {}
AsyncSocket* socket_; ///< parent socket
WriteRequest* next_; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t totalBytesWritten_; ///< total bytes written
};
/* The default WriteRequest implementation, used for write(), writev() and
* writeChain()
*
......@@ -181,7 +137,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags)
: AsyncSocket::WriteRequest(socket, nullptr, callback, 0)
: AsyncSocket::WriteRequest(socket, callback)
, opCount_(opCount)
, opIndex_(0)
, flags_(flags)
......@@ -773,6 +729,17 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
}
}
void AsyncSocket::writeRequest(WriteRequest* req) {
if (writeReqTail_ == nullptr) {
assert(writeReqHead_ == nullptr);
writeReqHead_ = writeReqTail_ = req;
req->start();
} else {
writeReqTail_->append(req);
writeReqTail_ = req;
}
}
void AsyncSocket::close() {
VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
<< ", state=" << state_ << ", shutdownFlags="
......
......@@ -334,6 +334,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) override;
class WriteRequest;
virtual void writeRequest(WriteRequest* req);
void writeRequestReady() {
handleWrite();
}
// Methods inherited from AsyncTransport
void close() override;
void closeNow() override;
......@@ -477,6 +483,60 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
ERROR
};
/**
* A WriteRequest object tracks information about a pending write operation.
*/
class WriteRequest {
public:
WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
socket_(socket), callback_(callback) {}
virtual void start() {};
virtual void destroy() = 0;
virtual bool performWrite() = 0;
virtual void consume() = 0;
virtual bool isComplete() = 0;
WriteRequest* getNext() const {
return next_;
}
WriteCallback* getCallback() const {
return callback_;
}
uint32_t getTotalBytesWritten() const {
return totalBytesWritten_;
}
void append(WriteRequest* next) {
assert(next_ == nullptr);
next_ = next;
}
void fail(const char* fn, const AsyncSocketException& ex) {
socket_->failWrite(fn, ex);
}
void bytesWritten(size_t count) {
totalBytesWritten_ += count;
socket_->appBytesWritten_ += count;
}
protected:
// protected destructor, to ensure callers use destroy()
virtual ~WriteRequest() {}
AsyncSocket* socket_; ///< parent socket
WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t totalBytesWritten_{0}; ///< total bytes written
};
protected:
enum ReadResultEnum {
READ_EOF = 0,
......@@ -516,7 +576,6 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
SHUT_READ = 0x04,
};
class WriteRequest;
class BytesWriteRequest;
class WriteTimeout : public AsyncTimeout {
......
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <boost/scoped_array.hpp>
#include <poll.h>
// This is a test-only header
/* using override */
using namespace folly;
enum StateEnum {
STATE_WAITING,
STATE_SUCCEEDED,
STATE_FAILED
};
typedef std::function<void()> VoidCallback;
class ConnCallback : public AsyncSocket::ConnectCallback {
public:
ConnCallback()
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none") {}
void connectSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void connectErr(const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
StateEnum state;
AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
class WriteCallback : public AsyncTransportWrapper::WriteCallback {
public:
WriteCallback()
: state(STATE_WAITING)
, bytesWritten(0)
, exception(AsyncSocketException::UNKNOWN, "none") {}
void writeSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void writeErr(size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
this->bytesWritten = bytesWritten;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
class ReadCallback : public AsyncTransportWrapper::ReadCallback {
public:
ReadCallback()
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none")
, buffers() {}
~ReadCallback() {
for (std::vector<Buffer>::iterator it = buffers.begin();
it != buffers.end();
++it) {
it->free();
}
currentBuffer.free();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(4096);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
}
void readDataAvailable(size_t len) noexcept override {
currentBuffer.length = len;
buffers.push_back(currentBuffer);
currentBuffer.reset();
if (dataAvailableCallback) {
dataAvailableCallback();
}
}
void readEOF() noexcept override {
state = STATE_SUCCEEDED;
}
void readErr(const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
}
void verifyData(const char* expected, size_t expectedLen) const {
size_t offset = 0;
for (size_t idx = 0; idx < buffers.size(); ++idx) {
const auto& buf = buffers[idx];
size_t cmpLen = std::min(buf.length, expectedLen - offset);
CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
CHECK_EQ(cmpLen, buf.length);
offset += cmpLen;
}
CHECK_EQ(offset, expectedLen);
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
void reset() {
buffer = nullptr;
length = 0;
}
void allocate(size_t length) {
assert(buffer == nullptr);
this->buffer = static_cast<char*>(malloc(length));
this->length = length;
}
void free() {
::free(buffer);
reset();
}
char* buffer;
size_t length;
};
StateEnum state;
AsyncSocketException exception;
std::vector<Buffer> buffers;
Buffer currentBuffer;
VoidCallback dataAvailableCallback;
};
class ReadVerifier {
};
class TestServer {
public:
// Create a TestServer.
// This immediately starts listening on an ephemeral port.
TestServer()
: fd_(-1) {
fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
if (fd_ < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to create test server socket", errno);
}
if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to put test server socket in "
"non-blocking mode", errno);
}
if (listen(fd_, 10) != 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to listen on test server socket",
errno);
}
address_.setFromLocalAddress(fd_);
// The local address will contain 0.0.0.0.
// Change it to 127.0.0.1, so it can be used to connect to the server
address_.setFromIpPort("127.0.0.1", address_.getPort());
}
// Get the address for connecting to the server
const folly::SocketAddress& getAddress() const {
return address_;
}
int acceptFD(int timeout=50) {
struct pollfd pfd;
pfd.fd = fd_;
pfd.events = POLLIN;
int ret = poll(&pfd, 1, timeout);
if (ret == 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() timed out");
} else if (ret < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() poll failed", errno);
}
int acceptedFd = ::accept(fd_, nullptr, nullptr);
if (acceptedFd < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() failed", errno);
}
return acceptedFd;
}
std::shared_ptr<BlockingSocket> accept(int timeout=50) {
int fd = acceptFD(timeout);
return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
}
std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
int fd = acceptFD(timeout);
return AsyncSocket::newSocket(evb, fd);
}
/**
* Accept a connection, read data from it, and verify that it matches the
* data in the specified buffer.
*/
void verifyConnection(const char* buf, size_t len) {
// accept a connection
std::shared_ptr<BlockingSocket> acceptedSocket = accept();
// read the data and compare it to the specified buffer
boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
acceptedSocket->readAll(readbuf.get(), len);
CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
// make sure we get EOF next
uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
CHECK_EQ(bytesRead, 0);
}
private:
int fd_;
folly::SocketAddress address_;
};
......@@ -20,7 +20,7 @@
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/Util.h>
#include <gtest/gtest.h>
......@@ -47,246 +47,6 @@ using boost::scoped_array;
using namespace folly;
enum StateEnum {
STATE_WAITING,
STATE_SUCCEEDED,
STATE_FAILED
};
typedef std::function<void()> VoidCallback;
class ConnCallback : public AsyncSocket::ConnectCallback {
public:
ConnCallback()
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none") {}
void connectSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void connectErr(const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
StateEnum state;
AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
class WriteCallback : public AsyncTransportWrapper::WriteCallback {
public:
WriteCallback()
: state(STATE_WAITING)
, bytesWritten(0)
, exception(AsyncSocketException::UNKNOWN, "none") {}
void writeSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void writeErr(size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
this->bytesWritten = bytesWritten;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
class ReadCallback : public AsyncTransportWrapper::ReadCallback {
public:
ReadCallback()
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none")
, buffers() {}
~ReadCallback() {
for (vector<Buffer>::iterator it = buffers.begin();
it != buffers.end();
++it) {
it->free();
}
currentBuffer.free();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(4096);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
}
void readDataAvailable(size_t len) noexcept override {
currentBuffer.length = len;
buffers.push_back(currentBuffer);
currentBuffer.reset();
if (dataAvailableCallback) {
dataAvailableCallback();
}
}
void readEOF() noexcept override {
state = STATE_SUCCEEDED;
}
void readErr(const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
}
void verifyData(const char* expected, size_t expectedLen) const {
size_t offset = 0;
for (size_t idx = 0; idx < buffers.size(); ++idx) {
const auto& buf = buffers[idx];
size_t cmpLen = std::min(buf.length, expectedLen - offset);
CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
CHECK_EQ(cmpLen, buf.length);
offset += cmpLen;
}
CHECK_EQ(offset, expectedLen);
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
void reset() {
buffer = nullptr;
length = 0;
}
void allocate(size_t length) {
assert(buffer == nullptr);
this->buffer = static_cast<char*>(malloc(length));
this->length = length;
}
void free() {
::free(buffer);
reset();
}
char* buffer;
size_t length;
};
StateEnum state;
AsyncSocketException exception;
vector<Buffer> buffers;
Buffer currentBuffer;
VoidCallback dataAvailableCallback;
};
class ReadVerifier {
};
class TestServer {
public:
// Create a TestServer.
// This immediately starts listening on an ephemeral port.
TestServer()
: fd_(-1) {
fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
if (fd_ < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to create test server socket", errno);
}
if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to put test server socket in "
"non-blocking mode", errno);
}
if (listen(fd_, 10) != 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to listen on test server socket",
errno);
}
address_.setFromLocalAddress(fd_);
// The local address will contain 0.0.0.0.
// Change it to 127.0.0.1, so it can be used to connect to the server
address_.setFromIpPort("127.0.0.1", address_.getPort());
}
// Get the address for connecting to the server
const folly::SocketAddress& getAddress() const {
return address_;
}
int acceptFD(int timeout=50) {
struct pollfd pfd;
pfd.fd = fd_;
pfd.events = POLLIN;
int ret = poll(&pfd, 1, timeout);
if (ret == 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() timed out");
} else if (ret < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() poll failed", errno);
}
int acceptedFd = ::accept(fd_, nullptr, nullptr);
if (acceptedFd < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() failed", errno);
}
return acceptedFd;
}
std::shared_ptr<BlockingSocket> accept(int timeout=50) {
int fd = acceptFD(timeout);
return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
}
std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
int fd = acceptFD(timeout);
return AsyncSocket::newSocket(evb, fd);
}
/**
* Accept a connection, read data from it, and verify that it matches the
* data in the specified buffer.
*/
void verifyConnection(const char* buf, size_t len) {
// accept a connection
std::shared_ptr<BlockingSocket> acceptedSocket = accept();
// read the data and compare it to the specified buffer
scoped_array<uint8_t> readbuf(new uint8_t[len]);
acceptedSocket->readAll(readbuf.get(), len);
CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
// make sure we get EOF next
uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
CHECK_EQ(bytesRead, 0);
}
private:
int fd_;
folly::SocketAddress address_;
};
class DelayedWrite: public AsyncTimeout {
public:
DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
......
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/wangle/channel/FileRegion.h>
using namespace folly;
using namespace folly::wangle;
namespace {
struct FileRegionReadPool {};
Singleton<IOThreadPoolExecutor, FileRegionReadPool> readPool(
[]{
return new IOThreadPoolExecutor(
sysconf(_SC_NPROCESSORS_ONLN),
std::make_shared<NamedThreadFactory>("FileRegionReadPool"));
});
}
namespace folly { namespace wangle {
FileRegion::FileWriteRequest::FileWriteRequest(AsyncSocket* socket,
WriteCallback* callback, int fd, off_t offset, size_t count)
: WriteRequest(socket, callback),
readFd_(fd), offset_(offset), count_(count) {
}
void FileRegion::FileWriteRequest::destroy() {
readBase_->runInEventBaseThread([this]{
delete this;
});
}
bool FileRegion::FileWriteRequest::performWrite() {
if (!started_) {
start();
return true;
}
int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE;
ssize_t spliced = ::splice(pipe_out_, nullptr,
socket_->getFd(), nullptr,
bytesInPipe_, flags);
if (spliced == -1) {
if (errno == EAGAIN) {
return true;
}
return false;
}
bytesInPipe_ -= spliced;
bytesWritten(spliced);
return true;
}
void FileRegion::FileWriteRequest::consume() {
// do nothing
}
bool FileRegion::FileWriteRequest::isComplete() {
return totalBytesWritten_ == count_;
}
void FileRegion::FileWriteRequest::messageAvailable(size_t&& count) {
bool shouldWrite = bytesInPipe_ == 0;
bytesInPipe_ += count;
if (shouldWrite) {
socket_->writeRequestReady();
}
}
#ifdef __GLIBC__
# if (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 9))
# define GLIBC_AT_LEAST_2_9 1
# endif
#endif
void FileRegion::FileWriteRequest::start() {
started_ = true;
readBase_ = readPool.get()->getEventBase();
readBase_->runInEventBaseThread([this]{
auto flags = fcntl(readFd_, F_GETFL);
if (flags == -1) {
fail(__func__, AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
"fcntl F_GETFL failed", errno));
return;
}
flags &= O_ACCMODE;
if (flags == O_WRONLY) {
fail(__func__, AsyncSocketException(
AsyncSocketException::BAD_ARGS, "file not open for reading"));
return;
}
#ifndef GLIBC_AT_LEAST_2_9
fail(__func__, AsyncSocketException(
AsyncSocketException::NOT_SUPPORTED,
"writeFile unsupported on glibc < 2.9"));
return;
#else
int pipeFds[2];
if (::pipe2(pipeFds, O_NONBLOCK) == -1) {
fail(__func__, AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
"pipe2 failed", errno));
return;
}
// Max size for unprevileged processes as set in /proc/sys/fs/pipe-max-size
// Ignore failures and just roll with it
// TODO maybe read max size from /proc?
fcntl(pipeFds[0], F_SETPIPE_SZ, 1048576);
fcntl(pipeFds[1], F_SETPIPE_SZ, 1048576);
pipe_out_ = pipeFds[0];
socket_->getEventBase()->runInEventBaseThreadAndWait([&]{
startConsuming(socket_->getEventBase(), &queue_);
});
readHandler_ = folly::make_unique<FileReadHandler>(
this, pipeFds[1], count_);
#endif
});
}
FileRegion::FileWriteRequest::~FileWriteRequest() {
CHECK(readBase_->isInEventBaseThread());
socket_->getEventBase()->runInEventBaseThreadAndWait([&]{
stopConsuming();
if (pipe_out_ > -1) {
::close(pipe_out_);
}
});
}
void FileRegion::FileWriteRequest::fail(
const char* fn,
const AsyncSocketException& ex) {
socket_->getEventBase()->runInEventBaseThread([=]{
WriteRequest::fail(fn, ex);
});
}
FileRegion::FileWriteRequest::FileReadHandler::FileReadHandler(
FileWriteRequest* req, int pipe_in, size_t bytesToRead)
: req_(req), pipe_in_(pipe_in), bytesToRead_(bytesToRead) {
CHECK(req_->readBase_->isInEventBaseThread());
initHandler(req_->readBase_, pipe_in);
if (!registerHandler(EventFlags::WRITE | EventFlags::PERSIST)) {
req_->fail(__func__, AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
"registerHandler failed"));
}
}
FileRegion::FileWriteRequest::FileReadHandler::~FileReadHandler() {
CHECK(req_->readBase_->isInEventBaseThread());
unregisterHandler();
::close(pipe_in_);
}
void FileRegion::FileWriteRequest::FileReadHandler::handlerReady(
uint16_t events) noexcept {
CHECK(events & EventHandler::WRITE);
if (bytesToRead_ == 0) {
unregisterHandler();
return;
}
int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE;
ssize_t spliced = ::splice(req_->readFd_, &req_->offset_,
pipe_in_, nullptr,
bytesToRead_, flags);
if (spliced == -1) {
if (errno == EAGAIN) {
return;
} else {
req_->fail(__func__, AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
"splice failed", errno));
return;
}
}
if (spliced > 0) {
bytesToRead_ -= spliced;
try {
req_->queue_.putMessage(static_cast<size_t>(spliced));
} catch (...) {
req_->fail(__func__, AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
"putMessage failed"));
return;
}
}
}
}} // folly::wangle
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Singleton.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/NotificationQueue.h>
#include <folly/futures/Future.h>
#include <folly/futures/Promise.h>
#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
namespace folly { namespace wangle {
class FileRegion {
public:
FileRegion(int fd, off_t offset, size_t count)
: fd_(fd), offset_(offset), count_(count) {}
Future<void> transferTo(std::shared_ptr<AsyncTransport> transport) {
auto socket = std::dynamic_pointer_cast<AsyncSocket>(
transport);
CHECK(socket);
auto cb = new WriteCallback();
auto f = cb->promise_.getFuture();
auto req = new FileWriteRequest(socket.get(), cb, fd_, offset_, count_);
socket->writeRequest(req);
return f;
}
private:
class WriteCallback : private AsyncSocket::WriteCallback {
void writeSuccess() noexcept override {
promise_.setValue();
delete this;
}
void writeErr(size_t bytesWritten,
const AsyncSocketException& ex)
noexcept override {
promise_.setException(ex);
delete this;
}
friend class FileRegion;
folly::Promise<void> promise_;
};
const int fd_;
const off_t offset_;
const size_t count_;
class FileWriteRequest : public AsyncSocket::WriteRequest,
public NotificationQueue<size_t>::Consumer {
public:
FileWriteRequest(AsyncSocket* socket, WriteCallback* callback,
int fd, off_t offset, size_t count);
void destroy() override;
bool performWrite() override;
void consume() override;
bool isComplete() override;
void messageAvailable(size_t&& count) override;
void start() override;
class FileReadHandler : public folly::EventHandler {
public:
FileReadHandler(FileWriteRequest* req, int pipe_in, size_t bytesToRead);
~FileReadHandler();
void handlerReady(uint16_t events) noexcept override;
private:
FileWriteRequest* req_;
int pipe_in_;
size_t bytesToRead_;
};
private:
~FileWriteRequest();
void fail(const char* fn, const AsyncSocketException& ex);
const int readFd_;
off_t offset_;
const size_t count_;
bool started_{false};
int pipe_out_{-1};
size_t bytesInPipe_{0};
folly::EventBase* readBase_;
folly::NotificationQueue<size_t> queue_;
std::unique_ptr<FileReadHandler> readHandler_;
};
};
}} // folly::wangle
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/wangle/channel/FileRegion.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <gtest/gtest.h>
using namespace folly;
using namespace folly::wangle;
using namespace testing;
struct FileRegionTest : public Test {
FileRegionTest() {
// Connect
socket = AsyncSocket::newSocket(&evb);
socket->connect(&ccb, server.getAddress(), 30);
// Accept the connection
acceptedSocket = server.acceptAsync(&evb);
acceptedSocket->setReadCB(&rcb);
// Create temp file
char path[] = "/tmp/AsyncSocketTest.WriteFile.XXXXXX";
fd = mkostemp(path, O_RDWR);
EXPECT_TRUE(fd > 0);
EXPECT_EQ(0, unlink(path));
}
~FileRegionTest() {
// Close up shop
close(fd);
acceptedSocket->close();
socket->close();
}
TestServer server;
EventBase evb;
std::shared_ptr<AsyncSocket> socket;
std::shared_ptr<AsyncSocket> acceptedSocket;
ConnCallback ccb;
ReadCallback rcb;
int fd;
};
TEST_F(FileRegionTest, Basic) {
size_t count = 1000000000; // 1 GB
void* zeroBuf = calloc(1, count);
write(fd, zeroBuf, count);
FileRegion fileRegion(fd, 0, count);
auto f = fileRegion.transferTo(socket);
try {
f.getVia(&evb);
} catch (std::exception& e) {
LOG(FATAL) << exceptionStr(e);
}
// Let the reads run to completion
socket->shutdownWrite();
evb.loop();
ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
size_t receivedBytes = 0;
for (auto& buf : rcb.buffers) {
receivedBytes += buf.length;
ASSERT_EQ(memcmp(buf.buffer, zeroBuf, buf.length), 0);
}
ASSERT_EQ(receivedBytes, count);
}
TEST_F(FileRegionTest, Repeated) {
size_t count = 1000000;
void* zeroBuf = calloc(1, count);
write(fd, zeroBuf, count);
int sendCount = 1000;
FileRegion fileRegion(fd, 0, count);
std::vector<Future<void>> fs;
for (int i = 0; i < sendCount; i++) {
fs.push_back(fileRegion.transferTo(socket));
}
auto f = collect(fs);
ASSERT_NO_THROW(f.getVia(&evb));
// Let the reads run to completion
socket->shutdownWrite();
evb.loop();
ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
size_t receivedBytes = 0;
for (auto& buf : rcb.buffers) {
receivedBytes += buf.length;
}
ASSERT_EQ(receivedBytes, sendCount*count);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment