Commit 7728c4b3 authored by Philip Pronin's avatar Philip Pronin Committed by Facebook Github Bot

return range from AsyncIO::cancel(), fix test

Summary:
Return not just number of cancelled ops, but all of them as well.
Test was incorrectly assuming `wait(1)` will return exactly one operation, fix
that as well.

Reviewed By: ot

Differential Revision: D5054684

fbshipit-source-id: 1c53c3f7ba855d1fcfeac8b1b27f90f0872d2c21
parent f0daf647
...@@ -193,15 +193,13 @@ Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) { ...@@ -193,15 +193,13 @@ Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) {
CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object"; CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object";
auto p = pending_.load(std::memory_order_acquire); auto p = pending_.load(std::memory_order_acquire);
CHECK_LE(minRequests, p); CHECK_LE(minRequests, p);
doWait(WaitType::COMPLETE, minRequests, p, &completed_); return doWait(WaitType::COMPLETE, minRequests, p, completed_);
return Range<Op**>(completed_.data(), completed_.size());
} }
size_t AsyncIO::cancel() { Range<AsyncIO::Op**> AsyncIO::cancel() {
CHECK(ctx_); CHECK(ctx_);
auto p = pending_.load(std::memory_order_acquire); auto p = pending_.load(std::memory_order_acquire);
doWait(WaitType::CANCEL, p, p, nullptr); return doWait(WaitType::CANCEL, p, p, canceled_);
return p;
} }
Range<AsyncIO::Op**> AsyncIO::pollCompleted() { Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
...@@ -224,15 +222,14 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() { ...@@ -224,15 +222,14 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
DCHECK_LE(numEvents, pending_); DCHECK_LE(numEvents, pending_);
// Don't reap more than numEvents, as we've just reset the counter to 0. // Don't reap more than numEvents, as we've just reset the counter to 0.
doWait(WaitType::COMPLETE, numEvents, numEvents, &completed_); return doWait(WaitType::COMPLETE, numEvents, numEvents, completed_);
return Range<Op**>(completed_.data(), completed_.size());
} }
void AsyncIO::doWait( Range<AsyncIO::Op**> AsyncIO::doWait(
WaitType type, WaitType type,
size_t minRequests, size_t minRequests,
size_t maxRequests, size_t maxRequests,
std::vector<Op*>* result) { std::vector<Op*>& result) {
io_event events[maxRequests]; io_event events[maxRequests];
// Unfortunately, Linux AIO doesn't implement io_cancel, so even for // Unfortunately, Linux AIO doesn't implement io_cancel, so even for
...@@ -257,9 +254,7 @@ void AsyncIO::doWait( ...@@ -257,9 +254,7 @@ void AsyncIO::doWait(
} while (count < minRequests); } while (count < minRequests);
DCHECK_LE(count, maxRequests); DCHECK_LE(count, maxRequests);
if (result != nullptr) { result.clear();
result->clear();
}
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
DCHECK(events[i].obj); DCHECK(events[i].obj);
Op* op = boost::intrusive::get_parent_from_member( Op* op = boost::intrusive::get_parent_from_member(
...@@ -273,10 +268,10 @@ void AsyncIO::doWait( ...@@ -273,10 +268,10 @@ void AsyncIO::doWait(
op->cancel(); op->cancel();
break; break;
} }
if (result != nullptr) { result.push_back(op);
result->push_back(op);
}
} }
return range(result);
} }
AsyncIOQueue::AsyncIOQueue(AsyncIO* asyncIO) AsyncIOQueue::AsyncIOQueue(AsyncIO* asyncIO)
......
...@@ -156,9 +156,10 @@ class AsyncIO : private boost::noncopyable { ...@@ -156,9 +156,10 @@ class AsyncIO : private boost::noncopyable {
Range<Op**> wait(size_t minRequests); Range<Op**> wait(size_t minRequests);
/** /**
* Cancel all pending requests and return their number. * Cancel all pending requests and return them; the returned range is
* valid until the next call to cancel().
*/ */
size_t cancel(); Range<Op**> cancel();
/** /**
* Return the number of pending requests. * Return the number of pending requests.
...@@ -201,11 +202,11 @@ class AsyncIO : private boost::noncopyable { ...@@ -201,11 +202,11 @@ class AsyncIO : private boost::noncopyable {
void initializeContext(); void initializeContext();
enum class WaitType { COMPLETE, CANCEL }; enum class WaitType { COMPLETE, CANCEL };
void doWait( Range<AsyncIO::Op**> doWait(
WaitType type, WaitType type,
size_t minRequests, size_t minRequests,
size_t maxRequests, size_t maxRequests,
std::vector<Op*>* result); std::vector<Op*>& result);
io_context_t ctx_{nullptr}; io_context_t ctx_{nullptr};
std::atomic<bool> ctxSet_{false}; std::atomic<bool> ctxSet_{false};
...@@ -216,6 +217,7 @@ class AsyncIO : private boost::noncopyable { ...@@ -216,6 +217,7 @@ class AsyncIO : private boost::noncopyable {
const size_t capacity_; const size_t capacity_;
int pollFd_{-1}; int pollFd_{-1};
std::vector<Op*> completed_; std::vector<Op*> completed_;
std::vector<Op*> canceled_;
}; };
/** /**
......
...@@ -393,48 +393,63 @@ TEST(AsyncIO, NonBlockingWait) { ...@@ -393,48 +393,63 @@ TEST(AsyncIO, NonBlockingWait) {
} }
TEST(AsyncIO, Cancel) { TEST(AsyncIO, Cancel) {
constexpr size_t kNumOps = 10; constexpr size_t kNumOpsBatch1 = 10;
constexpr size_t kNumOpsBatch2 = 10;
AsyncIO aioReader(kNumOps, AsyncIO::NOT_POLLABLE); AsyncIO aioReader(kNumOpsBatch1 + kNumOpsBatch2, AsyncIO::NOT_POLLABLE);
int fd = ::open(tempFile.path().c_str(), O_DIRECT | O_RDONLY); int fd = ::open(tempFile.path().c_str(), O_DIRECT | O_RDONLY);
PCHECK(fd != -1); PCHECK(fd != -1);
SCOPE_EXIT { SCOPE_EXIT {
::close(fd); ::close(fd);
}; };
std::vector<AsyncIO::Op> ops(kNumOps);
std::vector<ManagedBuffer> bufs;
size_t completed = 0; size_t completed = 0;
for (auto& op : ops) {
std::vector<std::unique_ptr<AsyncIO::Op>> ops;
std::vector<ManagedBuffer> bufs;
const auto schedule = [&](size_t n) {
for (size_t i = 0; i < n; ++i) {
const size_t size = 2 * kAlign; const size_t size = 2 * kAlign;
bufs.push_back(allocateAligned(size)); bufs.push_back(allocateAligned(size));
ops.push_back(std::make_unique<AsyncIO::Op>());
auto& op = *ops.back();
op.setNotificationCallback([&](AsyncIOOp*) { ++completed; }); op.setNotificationCallback([&](AsyncIOOp*) { ++completed; });
op.pread(fd, bufs.back().get(), size, 0); op.pread(fd, bufs.back().get(), size, 0);
aioReader.submit(&op); aioReader.submit(&op);
} }
};
// Mix completed and canceled operations for this test.
// In order to achieve that, schedule in two batches and do partial
// wait() after the first one.
EXPECT_EQ(aioReader.pending(), kNumOps); schedule(kNumOpsBatch1);
EXPECT_EQ(aioReader.pending(), kNumOpsBatch1);
EXPECT_EQ(completed, 0); EXPECT_EQ(completed, 0);
{
auto result = aioReader.wait(1); auto result = aioReader.wait(1);
EXPECT_EQ(result.size(), 1); EXPECT_GE(result.size(), 1);
} EXPECT_EQ(completed, result.size());
EXPECT_EQ(completed, 1); EXPECT_EQ(aioReader.pending(), kNumOpsBatch1 - result.size());
EXPECT_EQ(aioReader.pending(), kNumOps - 1);
schedule(kNumOpsBatch2);
EXPECT_EQ(aioReader.pending(), ops.size() - result.size());
EXPECT_EQ(completed, result.size());
EXPECT_EQ(aioReader.cancel(), kNumOps - 1); auto canceled = aioReader.cancel();
EXPECT_EQ(canceled.size(), ops.size() - result.size());
EXPECT_EQ(aioReader.pending(), 0); EXPECT_EQ(aioReader.pending(), 0);
EXPECT_EQ(completed, 1); EXPECT_EQ(completed, result.size());
completed = 0; size_t foundCompleted = 0;
for (auto& op : ops) { for (auto& op : ops) {
if (op.state() == AsyncIOOp::State::COMPLETED) { if (op->state() == AsyncIOOp::State::COMPLETED) {
++completed; ++foundCompleted;
} else { } else {
EXPECT_TRUE(op.state() == AsyncIOOp::State::CANCELED) << op; EXPECT_TRUE(op->state() == AsyncIOOp::State::CANCELED) << *op;
} }
} }
EXPECT_EQ(completed, 1); EXPECT_EQ(foundCompleted, completed);
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment