bkietz commented on a change in pull request #9528: URL: https://github.com/apache/arrow/pull/9528#discussion_r587606004
########## File path: cpp/src/arrow/util/thread_pool_test.cc ########## @@ -192,32 +208,50 @@ TEST_F(TestThreadPool, StressSpawnThreaded) { TEST_F(TestThreadPool, SpawnSlow) { // This checks that Shutdown() waits for all tasks to finish auto pool = this->MakeThreadPool(2); - SpawnAdds(pool.get(), 7, [](int x, int y, int* out) { - return task_slow_add(0.02 /* seconds */, x, y, out); - }); + SpawnAdds(pool.get(), 7, task_slow_add<int>{0.02 /* seconds */}); Review comment: ```suggestion SpawnAdds(pool.get(), 7, task_slow_add<int>{ /*seconds=*/0.02}); ``` ########## File path: cpp/src/arrow/util/io_util.cc ########## @@ -1608,6 +1640,39 @@ Result<SignalHandler> SetSignalHandler(int signum, const SignalHandler& handler) return Status::OK(); } +void ReinstateSignalHandler(int signum, SignalHandler::Callback handler) { +#if !ARROW_HAVE_SIGACTION + // Cannot report any errors from signal() (but there shouldn't be any) + signal(signum, handler); +#endif +} + +Status SendSignal(int signum) { + if (raise(signum) == 0) { + return Status::OK(); + } + if (errno == EINVAL) { + return Status::Invalid("Invalid signal number"); + } + return IOErrorFromErrno(errno, "Failed to raise signal"); +} + +Status SendSignalToThread(int signum, uint64_t thread_id) { +#ifdef _WIN32 + return Status::NotImplemented("Cannot send signal to specific thread on Windows"); +#else + // Have to use a C-style cast because pthread_t can be a pointer *or* integer type + int r = pthread_kill((pthread_t)thread_id, signum); + if (r == 0) { + return Status::OK(); + } + if (r == EINVAL) { + return Status::Invalid("Invalid signal number"); Review comment: ```suggestion return Status::Invalid("Invalid signal number ", signum); ``` ########## File path: cpp/src/arrow/csv/reader.cc ########## @@ -833,9 +843,10 @@ class AsyncThreadedTableReader AsyncThreadedTableReader(MemoryPool* pool, std::shared_ptr<io::InputStream> input, const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options, Executor* cpu_executor, - Executor* io_executor) - : BaseTableReader(pool, input, read_options, parse_options, convert_options), + const ConvertOptions& convert_options, StopToken stop_token, + Executor* cpu_executor, Executor* io_executor) Review comment: Seems we might prefer to rewrite this constructor to take an IOContext ########## File path: cpp/src/arrow/util/io_util_test.cc ########## @@ -623,5 +655,46 @@ TEST(FileUtils, LongPaths) { } #endif +static std::atomic<int> signal_received; + +static void handle_signal(int signum) { + ReinstateSignalHandler(signum, &handle_signal); + signal_received.store(signum); +} + +TEST(SendSignal, Generic) { + signal_received.store(0); + SignalHandlerGuard guard(SIGINT, &handle_signal); + + ASSERT_EQ(signal_received.load(), 0); + ASSERT_OK(SendSignal(SIGINT)); + BusyWait(1.0, [&]() { return signal_received.load() != 0; }); + ASSERT_EQ(signal_received.load(), SIGINT); + + // Re-try (exercise ReinstateSignalHandler) + signal_received.store(0); + ASSERT_OK(SendSignal(SIGINT)); + BusyWait(1.0, [&]() { return signal_received.load() != 0; }); + ASSERT_EQ(signal_received.load(), SIGINT); +} + +TEST(SendSignal, ToThread) { +#ifdef _WIN32 + uint64_t dummy_thread_id = 42; + ASSERT_RAISES(NotImplemented, SendSignalToThread(SIGINT, dummy_thread_id)); +#else + // Have to use a C-style cast because pthread_t can be a pointer *or* integer type + uint64_t thread_id = (uint64_t)(pthread_self()); Review comment: ```suggestion uint64_t thread_id = (uint64_t)(pthread_self()); // NOLINT readability-casting ``` ########## File path: cpp/src/arrow/util/cancel.cc ########## @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/cancel.h" + +#include <atomic> +#include <mutex> +#include <sstream> +#include <utility> + +#include "arrow/result.h" +#include "arrow/util/atomic_shared_ptr.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +#if ATOMIC_INT_LOCK_FREE != 2 +#error Lock-free atomic int required for signal safety +#endif + +using internal::ReinstateSignalHandler; +using internal::SetSignalHandler; +using internal::SignalHandler; + +// NOTE: We care mainly about the making the common case (not cancelled) fast. + +struct StopSourceImpl { + std::atomic<int> requested_{0}; // will be -1 or signal number if requested + std::mutex mutex_; + Status cancel_error_; +}; + +StopSource::StopSource() : impl_(new StopSourceImpl) {} + +StopSource::~StopSource() = default; + +void StopSource::RequestStop() { RequestStop(Status::Cancelled("Operation cancelled")); } + +void StopSource::RequestStop(Status st) { + std::lock_guard<std::mutex> lock(impl_->mutex_); + DCHECK(!st.ok()); + if (!impl_->requested_) { + impl_->requested_ = -1; + impl_->cancel_error_ = std::move(st); + } +} + +void StopSource::RequestStopFromSignal(int signum) { + // Only async-signal-safe code allowed here + impl_->requested_.store(signum); +} + +StopToken StopSource::token() { return StopToken(impl_); } + +bool StopToken::IsStopRequested() { + if (!impl_) { + return false; + } + return impl_->requested_.load() != 0; +} + +Status StopToken::Poll() { + if (!impl_) { + return Status::OK(); + } + if (!impl_->requested_.load()) { + return Status::OK(); + } + + std::lock_guard<std::mutex> lock(impl_->mutex_); + if (impl_->cancel_error_.ok()) { + auto signum = impl_->requested_.load(); + DCHECK_GT(signum, 0); + impl_->cancel_error_ = internal::CancelledFromSignal(signum, "Operation cancelled"); + } Review comment: ```suggestion } else { DCHECK_EQ(impl_->requested_.load(), -1); } ``` ########## File path: cpp/src/arrow/util/task_group.cc ########## @@ -67,34 +75,54 @@ class SerialTaskGroup : public TaskGroup { class ThreadedTaskGroup : public TaskGroup { public: - explicit ThreadedTaskGroup(Executor* executor) - : executor_(executor), nremaining_(0), ok_(true) {} + explicit ThreadedTaskGroup(Executor* executor, StopToken stop_token) Review comment: ```suggestion ThreadedTaskGroup(Executor* executor, StopToken stop_token) ``` ########## File path: cpp/src/arrow/util/cancel_test.cc ########## @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <atomic> +#include <cmath> +#include <sstream> +#include <string> +#include <thread> +#include <utility> +#include <vector> + +#include <gtest/gtest.h> + +#include <signal.h> +#ifndef _WIN32 +#include <sys/time.h> // for setitimer() +#endif + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/cancel.h" +#include "arrow/util/future.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" + +namespace arrow { + +class CancelTest : public ::testing::Test {}; + +TEST_F(CancelTest, StopBasics) { + { + StopSource source; + StopToken token = source.token(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + source.RequestStop(); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + } + { + StopSource source; + StopToken token = source.token(); + source.RequestStop(Status::IOError("Operation cancelled")); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(IOError, token.Poll()); + } +} + +TEST_F(CancelTest, StopTokenCopy) { + StopSource source; + StopToken token = source.token(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + source.RequestStop(); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + + StopToken new_token = token; + ASSERT_TRUE(new_token.IsStopRequested()); + ASSERT_RAISES(Cancelled, new_token.Poll()); +} + +TEST_F(CancelTest, RequestStopTwice) { + StopSource source; + StopToken token = source.token(); + source.RequestStop(); + // Second RequestStop() call is ignored + source.RequestStop(Status::IOError("Operation cancelled")); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); +} + +TEST_F(CancelTest, Unstoppable) { + StopToken token = StopToken::Unstoppable(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); +} + +TEST_F(CancelTest, SourceVanishes) { + { + util::optional<StopSource> source{StopSource()}; + StopToken token = source->token(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + source.reset(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + } + { + util::optional<StopSource> source{StopSource()}; + StopToken token = source->token(); + source->RequestStop(); + + source.reset(); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + } +} + +static void noop_signal_handler(int signum) { + internal::ReinstateSignalHandler(signum, &noop_signal_handler); +} + +#ifndef _WIN32 +static util::optional<StopSource> signal_stop_source; + +static void signal_handler(int signum) { + signal_stop_source->RequestStopFromSignal(signum); +} + +static void SetITimer(double seconds) { Review comment: ```suggestion // SIGALRM will be received once after the specified wait static void SetITimer(double seconds) { ``` ########## File path: cpp/src/arrow/util/task_group_test.cc ########## @@ -114,6 +114,49 @@ void TestTaskGroupErrors(std::shared_ptr<TaskGroup> task_group) { ASSERT_RAISES(Invalid, task_group->Finish()); } +void TestTaskGroupCancel(std::shared_ptr<TaskGroup> task_group, StopSource* stop_source) { + const int NSUCCESSES = 2; + const int NCANCELS = 20; + + std::atomic<int> count(0); + + auto task_group_was_ok = false; + task_group->Append([&]() -> Status { + for (int i = 0; i < NSUCCESSES; ++i) { + task_group->Append([&]() { + count++; + return Status::OK(); + }); + } + task_group_was_ok = task_group->ok(); + for (int i = 0; i < NCANCELS; ++i) { + task_group->Append([&]() { + SleepFor(1e-2); + stop_source->RequestStop(); + count++; + return Status::OK(); + }); + } + + return Status::OK(); + }); + + // Cancellation is propagated + ASSERT_RAISES(Cancelled, task_group->Finish()); + ASSERT_TRUE(task_group_was_ok); + ASSERT_FALSE(task_group->ok()); + if (task_group->parallelism() == 1) { + // Serial: exactly three successes + ASSERT_EQ(count.load(), 3); + } else { + // Parallel: at least three successes + ASSERT_GE(count.load(), 3); + ASSERT_LE(count.load(), 2 * task_group->parallelism()); Review comment: ```suggestion ASSERT_LE(count.load(), NSUCCESSES + task_group->parallelism()); ``` ########## File path: cpp/src/arrow/util/cancel.cc ########## @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/cancel.h" + +#include <atomic> +#include <mutex> +#include <sstream> +#include <utility> + +#include "arrow/result.h" +#include "arrow/util/atomic_shared_ptr.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +#if ATOMIC_INT_LOCK_FREE != 2 +#error Lock-free atomic int required for signal safety +#endif + +using internal::ReinstateSignalHandler; +using internal::SetSignalHandler; +using internal::SignalHandler; + +// NOTE: We care mainly about the making the common case (not cancelled) fast. + +struct StopSourceImpl { + std::atomic<int> requested_{0}; // will be -1 or signal number if requested + std::mutex mutex_; + Status cancel_error_; +}; + +StopSource::StopSource() : impl_(new StopSourceImpl) {} + +StopSource::~StopSource() = default; + +void StopSource::RequestStop() { RequestStop(Status::Cancelled("Operation cancelled")); } + +void StopSource::RequestStop(Status st) { + std::lock_guard<std::mutex> lock(impl_->mutex_); + DCHECK(!st.ok()); + if (!impl_->requested_) { + impl_->requested_ = -1; + impl_->cancel_error_ = std::move(st); + } +} + +void StopSource::RequestStopFromSignal(int signum) { + // Only async-signal-safe code allowed here + impl_->requested_.store(signum); +} + +StopToken StopSource::token() { return StopToken(impl_); } + +bool StopToken::IsStopRequested() { + if (!impl_) { + return false; + } + return impl_->requested_.load() != 0; +} + +Status StopToken::Poll() { + if (!impl_) { + return Status::OK(); + } + if (!impl_->requested_.load()) { + return Status::OK(); + } + + std::lock_guard<std::mutex> lock(impl_->mutex_); + if (impl_->cancel_error_.ok()) { + auto signum = impl_->requested_.load(); + DCHECK_GT(signum, 0); + impl_->cancel_error_ = internal::CancelledFromSignal(signum, "Operation cancelled"); + } + return impl_->cancel_error_; +} + +namespace { + +void HandleSignal(int signum); + +struct SignalStopState { + struct SavedSignalHandler { + int signum; + SignalHandler handler; + }; + + Status RegisterHandlers(const std::vector<int>& signals) { + if (!saved_handlers.empty()) { + return Status::Invalid("Signal handlers already registered"); + } + for (int signum : signals) { + ARROW_ASSIGN_OR_RAISE(auto handler, + SetSignalHandler(signum, SignalHandler{&HandleSignal})); + saved_handlers.push_back({signum, handler}); + } + return Status::OK(); + } + + void UnregisterHandlers() { + auto handlers = std::move(saved_handlers); + for (const auto& h : handlers) { + ARROW_CHECK_OK(SetSignalHandler(h.signum, h.handler).status()); + } + } + + ~SignalStopState() { UnregisterHandlers(); } + + StopSource stop_source; + std::vector<SavedSignalHandler> saved_handlers; +}; + +std::shared_ptr<SignalStopState> g_signal_stop_state; + +void HandleSignal(int signum) { + ReinstateSignalHandler(signum, &HandleSignal); + std::shared_ptr<SignalStopState> state = internal::atomic_load(&g_signal_stop_state); + if (state) { + state->stop_source.RequestStopFromSignal(signum); + } +} + +} // namespace + +Result<StopSource*> SetSignalStopSource() { + if (g_signal_stop_state) { + return Status::Invalid("Signal stop source already set up"); + } + internal::atomic_store(&g_signal_stop_state, std::make_shared<SignalStopState>()); + return &g_signal_stop_state->stop_source; +} + +void ResetSignalStopSource() { + internal::atomic_store(&g_signal_stop_state, std::shared_ptr<SignalStopState>{}); Review comment: Is this intended to be idempotent? If not, ```suggestion DCHECK_NE(g_signal_stop_state, nullptr); internal::atomic_store(&g_signal_stop_state, std::shared_ptr<SignalStopState>{}); ``` ########## File path: cpp/src/arrow/util/thread_pool_test.cc ########## @@ -137,30 +136,47 @@ class TestThreadPool : public ::testing::Test { return *ThreadPool::Make(threads); } - void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func) { - AddTester add_tester(nadds); + void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func, + StopToken stop_token = StopToken::Unstoppable(), + StopSource* stop_source = nullptr) { Review comment: ```suggestion void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func, StopToken stop_token = StopToken::Unstoppable()) { SpawnAddsImpl(pool, nadds, add_func, stop_token, nullptr); } void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func, StopSource* stop_source) { SpawnAddsImpl(pool, nadds, add_func, stop_source->token(), stop_source); } void SpawnAddsImpl(ThreadPool* pool, int nadds, AddTaskFunc add_func, StopToken stop_token, StopSource* stop_source) { ``` I'm not sure this is correct for all the cases below; there are a few which use a StopSource's token but don't pass the StopSource. If this *isn't* correct, I think that it'd be best to separate the may-be-cancelled and the won't-be-cancelled cases into separate overloads of SpawnAdds, with comments clarifying the intent of each. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org