Commit 7749a469 authored by Yang Chi's avatar Yang Chi Committed by facebook-github-bot-9

Add a buffer callback to AsyncSocket

Summary: This is probably easier than D2612490. The idea is just to add a callback to write, writev and writeChain in AsyncSocket, so upper layer can know when data starts to buffer up

Reviewed By: mzlee

Differential Revision: D2623385

fb-gh-sync-id: 98d32ca83871aaa4f6c75a769b5f1bf0b5d62c3e
parent 596d6f6d
......@@ -148,7 +148,8 @@ void AsyncPipeWriter::write(unique_ptr<folly::IOBuf> buf,
void AsyncPipeWriter::writeChain(folly::AsyncWriter::WriteCallback* callback,
std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags) {
WriteFlags,
BufferCallback*) {
write(std::move(buf), callback);
}
......
......@@ -148,16 +148,19 @@ class AsyncPipeWriter : public EventHandler,
// AsyncWriter methods
void write(folly::AsyncWriter::WriteCallback* callback, const void* buf,
size_t bytes, WriteFlags flags = WriteFlags::NONE) override {
writeChain(callback, IOBuf::wrapBuffer(buf, bytes), flags);
size_t bytes, WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override {
writeChain(callback, IOBuf::wrapBuffer(buf, bytes), flags, bufCallback);
}
void writev(folly::AsyncWriter::WriteCallback*, const iovec*,
size_t, WriteFlags = WriteFlags::NONE) override {
size_t, WriteFlags = WriteFlags::NONE,
BufferCallback* = nullptr) override {
throw std::runtime_error("writev is not supported. Please use writeChain.");
}
void writeChain(folly::AsyncWriter::WriteCallback* callback,
std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) override;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override;
private:
void handlerReady(uint16_t events) noexcept override;
......
......@@ -63,14 +63,16 @@ const AsyncSocketException socketShutdownForWritesEx(
*/
class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
public:
static BytesWriteRequest* newRequest(AsyncSocket* socket,
WriteCallback* callback,
const iovec* ops,
uint32_t opCount,
uint32_t partialWritten,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags) {
static BytesWriteRequest* newRequest(
AsyncSocket* socket,
WriteCallback* callback,
const iovec* ops,
uint32_t opCount,
uint32_t partialWritten,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags,
BufferCallback* bufferCallback = nullptr) {
assert(opCount > 0);
// Since we put a variable size iovec array at the end
// of each BytesWriteRequest, we have to manually allocate the memory.
......@@ -82,7 +84,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
partialWritten, bytesWritten,
std::move(ioBuf), flags);
std::move(ioBuf), flags, bufferCallback);
}
void destroy() override {
......@@ -136,8 +138,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
uint32_t partialBytes,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags)
: AsyncSocket::WriteRequest(socket, callback)
WriteFlags flags,
BufferCallback* bufferCallback = nullptr)
: AsyncSocket::WriteRequest(socket, callback, bufferCallback)
, opCount_(opCount)
, opIndex_(0)
, flags_(flags)
......@@ -608,43 +611,46 @@ AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
}
void AsyncSocket::write(WriteCallback* callback,
const void* buf, size_t bytes, WriteFlags flags) {
const void* buf, size_t bytes, WriteFlags flags,
BufferCallback* bufCallback) {
iovec op;
op.iov_base = const_cast<void*>(buf);
op.iov_len = bytes;
writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags, bufCallback);
}
void AsyncSocket::writev(WriteCallback* callback,
const iovec* vec,
size_t count,
WriteFlags flags) {
writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
WriteFlags flags,
BufferCallback* bufCallback) {
writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags, bufCallback);
}
void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
WriteFlags flags) {
WriteFlags flags, BufferCallback* bufCallback) {
constexpr size_t kSmallSizeMax = 64;
size_t count = buf->countChainElements();
if (count <= kSmallSizeMax) {
iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
writeChainImpl(callback, vec, count, std::move(buf), flags);
writeChainImpl(callback, vec, count, std::move(buf), flags, bufCallback);
} else {
iovec* vec = new iovec[count];
writeChainImpl(callback, vec, count, std::move(buf), flags);
writeChainImpl(callback, vec, count, std::move(buf), flags, bufCallback);
delete[] vec;
}
}
void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags,
BufferCallback* bufCallback) {
size_t veclen = buf->fillIov(vec, count);
writeImpl(callback, vec, veclen, std::move(buf), flags);
writeImpl(callback, vec, veclen, std::move(buf), flags, bufCallback);
}
void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
size_t count, unique_ptr<IOBuf>&& buf,
WriteFlags flags) {
WriteFlags flags, BufferCallback* bufCallback) {
VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", count=" << count
<< ", state=" << state_;
......@@ -688,7 +694,11 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
callback->writeSuccess();
}
return;
} // else { continue writing the next writeReq }
} else { // continue writing the next writeReq
if (bufCallback) {
bufCallback->onEgressBuffered();
}
}
mustRegister = true;
}
} else if (!connecting()) {
......@@ -701,7 +711,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
try {
req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
count - countWritten, partialWritten,
bytesWritten, std::move(ioBuf), flags);
bytesWritten, std::move(ioBuf), flags,
bufCallback);
} catch (const std::exception& ex) {
// we mainly expect to catch std::bad_alloc here
AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
......@@ -1473,6 +1484,11 @@ void AsyncSocket::handleWrite() noexcept {
}
// We'll continue around the loop, trying to write another request
} else {
// Notify BufferCallback:
BufferCallback* bufferCallback = writeReqHead_->getBufferCallback();
if (bufferCallback) {
bufferCallback->onEgressBuffered();
}
// Partial write.
writeReqHead_->consume();
// Stop after a partial write; it's highly likely that a subsequent write
......
......@@ -328,12 +328,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
ReadCallback* getReadCallback() const override;
void write(WriteCallback* callback, const void* buf, size_t bytes,
WriteFlags flags = WriteFlags::NONE) override;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override;
void writev(WriteCallback* callback, const iovec* vec, size_t count,
WriteFlags flags = WriteFlags::NONE) override;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override;
void writeChain(WriteCallback* callback,
std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) override;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override;
class WriteRequest;
virtual void writeRequest(WriteRequest* req);
......@@ -507,8 +510,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
class WriteRequest {
public:
WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
socket_(socket), callback_(callback) {}
WriteRequest(
AsyncSocket* socket,
WriteCallback* callback,
BufferCallback* bufferCallback = nullptr) :
socket_(socket), callback_(callback), bufferCallback_(bufferCallback) {}
virtual void start() {};
......@@ -546,6 +552,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
socket_->appBytesWritten_ += count;
}
BufferCallback* getBufferCallback() const {
return bufferCallback_;
}
protected:
// protected destructor, to ensure callers use destroy()
virtual ~WriteRequest() {}
......@@ -554,6 +564,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t totalBytesWritten_{0}; ///< total bytes written
BufferCallback* bufferCallback_{nullptr};
};
protected:
......@@ -677,36 +688,39 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
/**
* Populate an iovec array from an IOBuf and attempt to write it.
*
* @param callback Write completion/error callback.
* @param vec Target iovec array; caller retains ownership.
* @param count Number of IOBufs to write, beginning at start of buf.
* @param buf Chain of iovecs.
* @param flags set of flags for the underlying write calls, like cork
* @param callback Write completion/error callback.
* @param vec Target iovec array; caller retains ownership.
* @param count Number of IOBufs to write, beginning at start of buf.
* @param buf Chain of iovecs.
* @param flags set of flags for the underlying write calls, like cork
* @param bufCallback Callback when egress data begins to buffer
*/
void writeChainImpl(WriteCallback* callback, iovec* vec,
size_t count, std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags flags);
WriteFlags flags, BufferCallback* bufCallback = nullptr);
/**
* Write as much data as possible to the socket without blocking,
* and queue up any leftover data to send when the socket can
* handle writes again.
*
* @param callback The callback to invoke when the write is completed.
* @param vec Array of buffers to write; this method will make a
* copy of the vector (but not the buffers themselves)
* if the write has to be completed asynchronously.
* @param count Number of elements in vec.
* @param buf The IOBuf that manages the buffers referenced by
* vec, or a pointer to nullptr if the buffers are not
* associated with an IOBuf. Note that ownership of
* the IOBuf is transferred here; upon completion of
* the write, the AsyncSocket deletes the IOBuf.
* @param flags Set of write flags.
* @param callback The callback to invoke when the write is completed.
* @param vec Array of buffers to write; this method will make a
* copy of the vector (but not the buffers themselves)
* if the write has to be completed asynchronously.
* @param count Number of elements in vec.
* @param buf The IOBuf that manages the buffers referenced by
* vec, or a pointer to nullptr if the buffers are not
* associated with an IOBuf. Note that ownership of
* the IOBuf is transferred here; upon completion of
* the write, the AsyncSocket deletes the IOBuf.
* @param flags Set of write flags.
* @param bufCallback Callback when egress data buffers up
*/
void writeImpl(WriteCallback* callback, const iovec* vec, size_t count,
std::unique_ptr<folly::IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE);
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr);
/**
* Attempt to write to the socket.
......
......@@ -464,6 +464,12 @@ class AsyncReader {
class AsyncWriter {
public:
class BufferCallback {
public:
virtual ~BufferCallback() {}
virtual void onEgressBuffered() = 0;
};
class WriteCallback {
public:
virtual ~WriteCallback() = default;
......@@ -493,12 +499,15 @@ class AsyncWriter {
// Write methods that aren't part of AsyncTransport
virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
WriteFlags flags = WriteFlags::NONE) = 0;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) = 0;
virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
WriteFlags flags = WriteFlags::NONE) = 0;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) = 0;
virtual void writeChain(WriteCallback* callback,
std::unique_ptr<IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) = 0;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) = 0;
protected:
virtual ~AsyncWriter() = default;
......@@ -516,15 +525,19 @@ class AsyncTransportWrapper : virtual public AsyncTransport,
// to keep compatibility.
using ReadCallback = AsyncReader::ReadCallback;
using WriteCallback = AsyncWriter::WriteCallback;
using BufferCallback = AsyncWriter::BufferCallback;
virtual void setReadCB(ReadCallback* callback) override = 0;
virtual ReadCallback* getReadCallback() const override = 0;
virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
WriteFlags flags = WriteFlags::NONE) override = 0;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override = 0;
virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
WriteFlags flags = WriteFlags::NONE) override = 0;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override = 0;
virtual void writeChain(WriteCallback* callback,
std::unique_ptr<IOBuf>&& buf,
WriteFlags flags = WriteFlags::NONE) override = 0;
WriteFlags flags = WriteFlags::NONE,
BufferCallback* bufCallback = nullptr) override = 0;
/**
* The transport wrapper may wrap another transport. This returns the
* transport that is wrapped. It returns nullptr if there is no wrapped
......
......@@ -60,6 +60,23 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
VoidCallback errorCallback;
};
class BufferCallback : public AsyncTransportWrapper::BufferCallback {
public:
BufferCallback()
: buffered_(false) {}
void onEgressBuffered() override {
buffered_ = true;
}
bool hasBuffered() const {
return buffered_;
}
private:
bool buffered_{false};
};
class WriteCallback : public AsyncTransportWrapper::WriteCallback {
public:
WriteCallback()
......
......@@ -2238,3 +2238,32 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
eventBase.loop();
}
TEST(AsyncSocketTest, BufferTest) {
TestServer server;
EventBase evb;
AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30, option);
char buf[100 * 1024];
memset(buf, 'c', sizeof(buf));
WriteCallback wcb;
BufferCallback bcb;
socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE, &bcb);
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
ASSERT_TRUE(bcb.hasBuffered());
socket->close();
server.verifyConnection(buf, sizeof(buf));
ASSERT_TRUE(socket->isClosedBySelf());
ASSERT_FALSE(socket->isClosedByPeer());
}
......@@ -27,23 +27,31 @@ class MockAsyncTransport: public AsyncTransportWrapper {
MOCK_METHOD1(setReadCB, void(ReadCallback*));
MOCK_CONST_METHOD0(getReadCallback, ReadCallback*());
MOCK_CONST_METHOD0(getReadCB, ReadCallback*());
MOCK_METHOD4(write, void(WriteCallback*,
MOCK_METHOD5(write, void(WriteCallback*,
const void*, size_t,
WriteFlags));
MOCK_METHOD4(writev, void(WriteCallback*,
WriteFlags,
BufferCallback*));
MOCK_METHOD5(writev, void(WriteCallback*,
const iovec*, size_t,
WriteFlags));
MOCK_METHOD3(writeChain,
WriteFlags,
BufferCallback*));
MOCK_METHOD4(writeChain,
void(WriteCallback*,
std::shared_ptr<folly::IOBuf>,
WriteFlags));
WriteFlags,
BufferCallback*));
void writeChain(WriteCallback* callback,
std::unique_ptr<folly::IOBuf>&& iob,
WriteFlags flags =
WriteFlags::NONE) override {
writeChain(callback, std::shared_ptr<folly::IOBuf>(iob.release()), flags);
WriteFlags::NONE,
BufferCallback* bufCB = nullptr) override {
writeChain(
callback,
std::shared_ptr<folly::IOBuf>(iob.release()),
flags,
bufCB);
}
MOCK_METHOD0(close, void());
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment