Commit 118a3923 authored by Cameron Pickett's avatar Cameron Pickett Committed by Facebook GitHub Bot

Do not swallow child exceptions on cancellation

Summary:
Aligns `collectAll` and `collectAny` behaviour and simplifies the logic for propagating child exceptions on cancellation.

Prior to this change, we were conditionally setting `firstException` dependent on whether cancellation was requested or not. Additionally, `collectAny` was unconditionally returning `co_cancelled` even if all the child tasks completed successfully.

Now, the behaviour is consistent between general errors and cancellation: If any child task completes in error (including folly::OperationCancelled), then the parent `collectAll`/`collectAny` will propagate that failure. Otherwise, if no child fails, then `collectAll`/`collectAny` will propagate that success.

(Note: this ignores all push blocking failures!)

Reviewed By: iahs

Differential Revision: D31266586

fbshipit-source-id: b6eba6ab2a0a3634b112318b1810819d7916acdb
parent ff7c3177
...@@ -135,7 +135,6 @@ auto collectAllImpl( ...@@ -135,7 +135,6 @@ auto collectAllImpl(
CancellationToken::merge(parentCancelToken, cancelSource.getToken()); CancellationToken::merge(parentCancelToken, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures{false};
auto makeTask = [&](auto&& awaitable, auto& result) -> BarrierTask { auto makeTask = [&](auto&& awaitable, auto& result) -> BarrierTask {
using await_result = semi_await_result_t<decltype(awaitable)>; using await_result = semi_await_result_t<decltype(awaitable)>;
...@@ -153,9 +152,7 @@ auto collectAllImpl( ...@@ -153,9 +152,7 @@ auto collectAllImpl(
cancelToken, static_cast<decltype(awaitable)>(awaitable)))); cancelToken, static_cast<decltype(awaitable)>(awaitable))));
} }
} catch (...) { } catch (...) {
anyFailures.store(true, std::memory_order_relaxed); if (!cancelSource.requestCancellation()) {
if (!cancelSource.requestCancellation() &&
!parentCancelToken.isCancellationRequested()) {
// This was the first failure, remember its error. // This was the first failure, remember its error.
firstException = exception_wrapper{std::current_exception()}; firstException = exception_wrapper{std::current_exception()};
} }
...@@ -194,17 +191,10 @@ auto collectAllImpl( ...@@ -194,17 +191,10 @@ auto collectAllImpl(
// the use of co_viaIfAsync() within makeBarrierTask(). // the use of co_viaIfAsync() within makeBarrierTask().
co_await UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()}; co_await UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) { if (firstException) {
co_yield co_error(std::move(firstException)); co_yield co_error(std::move(firstException));
} }
// Parent task was cancelled before any child tasks failed.
// Complete with the OperationCancelled error instead of the
// child task's errors.
co_yield co_cancelled;
}
co_return std::tuple<collect_all_component_t<SemiAwaitables>...>{ co_return std::tuple<collect_all_component_t<SemiAwaitables>...>{
getValueOrUnit(std::get<Indices>(std::move(results)))...}; getValueOrUnit(std::get<Indices>(std::move(results)))...};
} }
...@@ -311,10 +301,6 @@ auto collectAnyImpl( ...@@ -311,10 +301,6 @@ auto collectAnyImpl(
} }
}))...); }))...);
if (parentCancelToken.isCancellationRequested()) {
co_yield co_cancelled;
}
co_return firstCompletion; co_return firstCompletion;
} }
...@@ -377,7 +363,6 @@ auto collectAllRange(InputRange awaitables) ...@@ -377,7 +363,6 @@ auto collectAllRange(InputRange awaitables)
tryResults; tryResults;
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>; using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>;
auto makeTask = [&](awaitable_type semiAwaitable, auto makeTask = [&](awaitable_type semiAwaitable,
...@@ -389,7 +374,6 @@ auto collectAllRange(InputRange awaitables) ...@@ -389,7 +374,6 @@ auto collectAllRange(InputRange awaitables)
executor.get_alias(), executor.get_alias(),
co_withCancellation(cancelToken, std::move(semiAwaitable)))); co_withCancellation(cancelToken, std::move(semiAwaitable))));
} catch (...) { } catch (...) {
anyFailures.store(true, std::memory_order_relaxed);
if (!cancelSource.requestCancellation()) { if (!cancelSource.requestCancellation()) {
firstException = exception_wrapper{std::current_exception()}; firstException = exception_wrapper{std::current_exception()};
} }
...@@ -429,16 +413,10 @@ auto collectAllRange(InputRange awaitables) ...@@ -429,16 +413,10 @@ auto collectAllRange(InputRange awaitables)
} }
// Check if there were any exceptions and rethrow the first one. // Check if there were any exceptions and rethrow the first one.
if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) { if (firstException) {
co_yield co_error(std::move(firstException)); co_yield co_error(std::move(firstException));
} }
// Cancellation was requested of the parent Task before any of the
// child tasks failed.
co_yield co_cancelled;
}
std::vector<detail::collect_all_range_component_t< std::vector<detail::collect_all_range_component_t<
detail::range_reference_t<InputRange>>> detail::range_reference_t<InputRange>>>
results; results;
...@@ -463,7 +441,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -463,7 +441,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
co_await co_current_cancellation_token, cancelSource.getToken()); co_await co_current_cancellation_token, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>; using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>;
auto makeTask = [&](awaitable_type semiAwaitable) -> detail::BarrierTask { auto makeTask = [&](awaitable_type semiAwaitable) -> detail::BarrierTask {
...@@ -472,7 +449,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -472,7 +449,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
executor.get_alias(), executor.get_alias(),
co_withCancellation(cancelToken, std::move(semiAwaitable))); co_withCancellation(cancelToken, std::move(semiAwaitable)));
} catch (...) { } catch (...) {
anyFailures.store(true, std::memory_order_relaxed);
if (!cancelSource.requestCancellation()) { if (!cancelSource.requestCancellation()) {
firstException = exception_wrapper{std::current_exception()}; firstException = exception_wrapper{std::current_exception()};
} }
...@@ -509,11 +485,9 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -509,11 +485,9 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
} }
// Check if there were any exceptions and rethrow the first one. // Check if there were any exceptions and rethrow the first one.
if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) { if (firstException) {
co_yield co_error(std::move(firstException)); co_yield co_error(std::move(firstException));
} }
}
} }
template <typename InputRange> template <typename InputRange>
...@@ -607,10 +581,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -607,10 +581,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
co_await co_current_cancellation_token, cancelSource.getToken()); co_await co_current_cancellation_token, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
const auto trySetFirstException = [&](exception_wrapper&& e) noexcept { const auto trySetFirstException = [&](exception_wrapper&& e) noexcept {
anyFailures.store(true, std::memory_order_relaxed);
if (!cancelSource.requestCancellation()) { if (!cancelSource.requestCancellation()) {
// This is first entity to request cancellation. // This is first entity to request cancellation.
firstException = std::move(e); firstException = std::move(e);
...@@ -700,14 +672,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -700,14 +672,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()}; co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
if (iterationException) { if (auto& ex = iterationException ? iterationException : firstException) {
co_yield co_error(std::move(iterationException)); co_yield co_error(std::move(ex));
} else if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) {
co_yield co_error(std::move(firstException));
}
co_yield co_cancelled;
} }
} }
...@@ -731,12 +697,9 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -731,12 +697,9 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
CancellationToken::merge(parentCancelToken, cancelSource.getToken()); CancellationToken::merge(parentCancelToken, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
auto trySetFirstException = [&](exception_wrapper&& e) noexcept { auto trySetFirstException = [&](exception_wrapper&& e) noexcept {
anyFailures.store(true, std::memory_order_relaxed); if (!cancelSource.requestCancellation()) {
if (!cancelSource.requestCancellation() &&
!parentCancelToken.isCancellationRequested()) {
// This is first entity to request cancellation. // This is first entity to request cancellation.
firstException = std::move(e); firstException = std::move(e);
} }
...@@ -846,16 +809,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -846,16 +809,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()}; co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
if (iterationException) { if (auto& ex = iterationException ? iterationException : firstException) {
co_yield co_error(std::move(iterationException)); co_yield co_error(std::move(ex));
} else if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) {
co_yield co_error(std::move(firstException));
}
// Otherwise, cancellation was requested before any of the child tasks
// failed so complete with the OperationCancelled error.
co_yield co_cancelled;
} }
std::vector<detail::collect_all_range_component_t< std::vector<detail::collect_all_range_component_t<
...@@ -1100,10 +1055,6 @@ auto collectAnyRange(InputRange awaitables) -> folly::coro::Task<std::pair< ...@@ -1100,10 +1055,6 @@ auto collectAnyRange(InputRange awaitables) -> folly::coro::Task<std::pair<
co_await folly::coro::co_withCancellation( co_await folly::coro::co_withCancellation(
cancelToken, folly::coro::collectAllRange(tasks | ranges::views::move)); cancelToken, folly::coro::collectAllRange(tasks | ranges::views::move));
if (parentCancelToken.isCancellationRequested()) {
co_yield co_cancelled;
}
co_return firstCompletion; co_return firstCompletion;
} }
......
...@@ -2039,7 +2039,6 @@ TEST_F(CollectAnyTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) { ...@@ -2039,7 +2039,6 @@ TEST_F(CollectAnyTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> { folly::coro::blockingWait([]() -> folly::coro::Task<void> {
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
folly::CancellationSource cancelSource; folly::CancellationSource cancelSource;
try {
auto [index, result] = co_await folly::coro::co_withCancellation( auto [index, result] = co_await folly::coro::co_withCancellation(
cancelSource.getToken(), cancelSource.getToken(),
folly::coro::collectAny( folly::coro::collectAny(
...@@ -2059,11 +2058,8 @@ TEST_F(CollectAnyTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) { ...@@ -2059,11 +2058,8 @@ TEST_F(CollectAnyTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) {
co_await sleepThatShouldBeCancelled(15s); co_await sleepThatShouldBeCancelled(15s);
co_return 123; co_return 123;
}())); }()));
ADD_FAILURE() << "Hit unexpected codepath";
} catch (const folly::OperationCancelled&) {
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
EXPECT_LT(end - start, 1s); EXPECT_LT(end - start, 1s);
}
}()); }());
} }
...@@ -2530,7 +2526,6 @@ TEST_F(CollectAnyRangeTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) { ...@@ -2530,7 +2526,6 @@ TEST_F(CollectAnyRangeTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> { folly::coro::blockingWait([]() -> folly::coro::Task<void> {
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
folly::CancellationSource cancelSource; folly::CancellationSource cancelSource;
try {
auto generateTasks = auto generateTasks =
[&]() -> folly::coro::Generator<folly::coro::Task<int>&&> { [&]() -> folly::coro::Generator<folly::coro::Task<int>&&> {
co_yield [&]() -> folly::coro::Task<int> { co_yield [&]() -> folly::coro::Task<int> {
...@@ -2551,13 +2546,9 @@ TEST_F(CollectAnyRangeTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) { ...@@ -2551,13 +2546,9 @@ TEST_F(CollectAnyRangeTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) {
}(); }();
}; };
auto [index, result] = co_await folly::coro::co_withCancellation( auto [index, result] = co_await folly::coro::co_withCancellation(
cancelSource.getToken(), cancelSource.getToken(), folly::coro::collectAnyRange(generateTasks()));
folly::coro::collectAnyRange(generateTasks()));
ADD_FAILURE() << "Hit unexpected codepath";
} catch (const folly::OperationCancelled&) {
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
EXPECT_LT(end - start, 1s); EXPECT_LT(end - start, 1s);
}
}()); }());
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment