This is an automated email from the ASF dual-hosted git repository. xyz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/pulsar-client-cpp.git
The following commit(s) were added to refs/heads/main by this push: new 20f6fa0 Fix the buggy Future and Promise implementations (#299) 20f6fa0 is described below commit 20f6fa0a72929f8e2668f45790f44e6a00390a44 Author: Yunze Xu <xyzinfern...@163.com> AuthorDate: Wed Jul 5 14:49:31 2023 +0800 Fix the buggy Future and Promise implementations (#299) Fixes https://github.com/apache/pulsar-client-cpp/issues/298 ### Motivation Currently the `Future` and `Promise` are implemented manually by managing conditional variables. However, the conditional variable sometimes behaviors incorrectly on macOS, while the existing `future` and `promise` from the C++ standard library works well. ### Modifications Redesign `Future` and `Promise` based on the utilities in the standard `<future>` header. In addition, fix the possible race condition when `addListener` is called after `setValue` or `setFailed`: - Thread 1: call `setValue`, switch existing listeners and call them one by one out of the lock. - Thread 2: call `addListener`, detect `complete_` is true and call the listener directly. Now, the previous listeners and the new listener are called concurrently in thread 1 and 2. This patch fixes the problem by adding a future to wait all listeners that were added before completing are done. ### Verifications Run the reproduce code in #298 for 10 times and found it never failed or hang. Co-authored-by: Zike Yang <z...@apache.org> --------- Co-authored-by: Zike Yang <z...@apache.org> --- lib/BinaryProtoLookupService.cc | 2 +- lib/Future.h | 201 +++++++++++++++++----------------------- lib/stats/ProducerStatsImpl.cc | 6 +- tests/BasicEndToEndTest.cc | 4 +- tests/PromiseTest.cc | 27 ++++++ 5 files changed, 119 insertions(+), 121 deletions(-) diff --git a/lib/BinaryProtoLookupService.cc b/lib/BinaryProtoLookupService.cc index f563f63..dfa3cab 100644 --- a/lib/BinaryProtoLookupService.cc +++ b/lib/BinaryProtoLookupService.cc @@ -146,7 +146,7 @@ void BinaryProtoLookupService::handlePartitionMetadataLookup(const std::string& } uint64_t BinaryProtoLookupService::newRequestId() { - Lock lock(mutex_); + std::lock_guard<std::mutex> lock(mutex_); return ++requestIdGenerator_; } diff --git a/lib/Future.h b/lib/Future.h index 3593057..03e93e4 100644 --- a/lib/Future.h +++ b/lib/Future.h @@ -19,162 +19,133 @@ #ifndef LIB_FUTURE_H_ #define LIB_FUTURE_H_ -#include <condition_variable> +#include <atomic> +#include <chrono> #include <functional> +#include <future> #include <list> #include <memory> #include <mutex> - -using Lock = std::unique_lock<std::mutex>; +#include <thread> +#include <utility> namespace pulsar { template <typename Result, typename Type> -struct InternalState { - std::mutex mutex; - std::condition_variable condition; - Result result; - Type value; - bool complete; - - std::list<typename std::function<void(Result, const Type&)> > listeners; -}; - -template <typename Result, typename Type> -class Future { +class InternalState { public: - typedef std::function<void(Result, const Type&)> ListenerCallback; - - Future& addListener(ListenerCallback callback) { - InternalState<Result, Type>* state = state_.get(); - Lock lock(state->mutex); - - if (state->complete) { - lock.unlock(); - callback(state->result, state->value); - } else { - state->listeners.push_back(callback); - } + using Listener = std::function<void(Result, const Type &)>; + using Pair = std::pair<Result, Type>; + using Lock = std::unique_lock<std::mutex>; - return *this; - } + // NOTE: Add the constructor explicitly just to be compatible with GCC 4.8 + InternalState() {} - Result get(Type& result) { - InternalState<Result, Type>* state = state_.get(); - Lock lock(state->mutex); + void addListener(Listener listener) { + Lock lock{mutex_}; + listeners_.emplace_back(listener); + lock.unlock(); - if (!state->complete) { - // Wait for result - while (!state->complete) { - state->condition.wait(lock); - } + if (completed()) { + Type value; + Result result = get(value); + triggerListeners(result, value); } - - result = state->value; - return state->result; } - template <typename Duration> - bool get(Result& res, Type& value, Duration d) { - InternalState<Result, Type>* state = state_.get(); - Lock lock(state->mutex); - - if (!state->complete) { - // Wait for result - while (!state->complete) { - if (!state->condition.wait_for(lock, d, [&state] { return state->complete; })) { - // Timeout while waiting for the future to complete - return false; - } - } + bool complete(Result result, const Type &value) { + bool expected = false; + if (!completed_.compare_exchange_strong(expected, true)) { + return false; } - - value = state->value; - res = state->result; + triggerListeners(result, value); + promise_.set_value(std::make_pair(result, value)); return true; } - private: - typedef std::shared_ptr<InternalState<Result, Type> > InternalStatePtr; - Future(InternalStatePtr state) : state_(state) {} + bool completed() const noexcept { return completed_; } - std::shared_ptr<InternalState<Result, Type> > state_; - - template <typename U, typename V> - friend class Promise; -}; + Result get(Type &result) { + const auto &pair = future_.get(); + result = pair.second; + return pair.first; + } -template <typename Result, typename Type> -class Promise { - public: - Promise() : state_(std::make_shared<InternalState<Result, Type> >()) {} + // Only public for test + void triggerListeners(Result result, const Type &value) { + while (true) { + Lock lock{mutex_}; + if (listeners_.empty()) { + return; + } - bool setValue(const Type& value) const { - static Result DEFAULT_RESULT; - InternalState<Result, Type>* state = state_.get(); - Lock lock(state->mutex); + bool expected = false; + if (!listenerRunning_.compare_exchange_strong(expected, true)) { + // There is another thread that polled a listener that is running, skip polling and release + // the lock. Here we wait for some time to avoid busy waiting. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + auto listener = std::move(listeners_.front()); + listeners_.pop_front(); + lock.unlock(); - if (state->complete) { - return false; + listener(result, value); + listenerRunning_ = false; } + } - state->value = value; - state->result = DEFAULT_RESULT; - state->complete = true; + private: + std::atomic_bool completed_{false}; + std::promise<Pair> promise_; + std::shared_future<Pair> future_{promise_.get_future()}; - decltype(state->listeners) listeners; - listeners.swap(state->listeners); + std::list<Listener> listeners_; + mutable std::mutex mutex_; + std::atomic_bool listenerRunning_{false}; +}; - lock.unlock(); +template <typename Result, typename Type> +using InternalStatePtr = std::shared_ptr<InternalState<Result, Type>>; - for (auto& callback : listeners) { - callback(DEFAULT_RESULT, value); - } +template <typename Result, typename Type> +class Future { + public: + using Listener = typename InternalState<Result, Type>::Listener; - state->condition.notify_all(); - return true; + Future &addListener(Listener listener) { + state_->addListener(listener); + return *this; } - bool setFailed(Result result) const { - static Type DEFAULT_VALUE; - InternalState<Result, Type>* state = state_.get(); - Lock lock(state->mutex); + Result get(Type &result) { return state_->get(result); } - if (state->complete) { - return false; - } + private: + InternalStatePtr<Result, Type> state_; - state->result = result; - state->complete = true; + Future(InternalStatePtr<Result, Type> state) : state_(state) {} - decltype(state->listeners) listeners; - listeners.swap(state->listeners); + template <typename U, typename V> + friend class Promise; +}; - lock.unlock(); +template <typename Result, typename Type> +class Promise { + public: + Promise() : state_(std::make_shared<InternalState<Result, Type>>()) {} - for (auto& callback : listeners) { - callback(result, DEFAULT_VALUE); - } + bool setValue(const Type &value) const { return state_->complete({}, value); } - state->condition.notify_all(); - return true; - } + bool setFailed(Result result) const { return state_->complete(result, {}); } - bool isComplete() const { - InternalState<Result, Type>* state = state_.get(); - Lock lock(state->mutex); - return state->complete; - } + bool isComplete() const { return state_->completed(); } - Future<Result, Type> getFuture() const { return Future<Result, Type>(state_); } + Future<Result, Type> getFuture() const { return Future<Result, Type>{state_}; } private: - typedef std::function<void(Result, const Type&)> ListenerCallback; - std::shared_ptr<InternalState<Result, Type> > state_; + const InternalStatePtr<Result, Type> state_; }; -class Void {}; - -} /* namespace pulsar */ +} // namespace pulsar -#endif /* LIB_FUTURE_H_ */ +#endif diff --git a/lib/stats/ProducerStatsImpl.cc b/lib/stats/ProducerStatsImpl.cc index 9b0f7e6..3d3629d 100644 --- a/lib/stats/ProducerStatsImpl.cc +++ b/lib/stats/ProducerStatsImpl.cc @@ -71,7 +71,7 @@ void ProducerStatsImpl::flushAndReset(const boost::system::error_code& ec) { return; } - Lock lock(mutex_); + std::unique_lock<std::mutex> lock(mutex_); std::ostringstream oss; oss << *this; numMsgsSent_ = 0; @@ -86,7 +86,7 @@ void ProducerStatsImpl::flushAndReset(const boost::system::error_code& ec) { } void ProducerStatsImpl::messageSent(const Message& msg) { - Lock lock(mutex_); + std::lock_guard<std::mutex> lock(mutex_); numMsgsSent_++; totalMsgsSent_++; numBytesSent_ += msg.getLength(); @@ -96,7 +96,7 @@ void ProducerStatsImpl::messageSent(const Message& msg) { void ProducerStatsImpl::messageReceived(Result res, const boost::posix_time::ptime& publishTime) { boost::posix_time::ptime currentTime = boost::posix_time::microsec_clock::universal_time(); double diffInMicros = (currentTime - publishTime).total_microseconds(); - Lock lock(mutex_); + std::lock_guard<std::mutex> lock(mutex_); totalLatencyAccumulator_(diffInMicros); latencyAccumulator_(diffInMicros); sendMap_[res] += 1; // Value will automatically be initialized to 0 in the constructor diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc index 8599b92..9ca2ab0 100644 --- a/tests/BasicEndToEndTest.cc +++ b/tests/BasicEndToEndTest.cc @@ -191,7 +191,7 @@ TEST(BasicEndToEndTest, testBatchMessages) { } void resendMessage(Result r, const MessageId msgId, Producer producer) { - Lock lock(mutex_); + std::unique_lock<std::mutex> lock(mutex_); if (r != ResultOk) { LOG_DEBUG("globalResendMessageCount" << globalResendMessageCount); if (++globalResendMessageCount >= 3) { @@ -993,7 +993,7 @@ TEST(BasicEndToEndTest, testResendViaSendCallback) { // 3 seconds std::this_thread::sleep_for(std::chrono::microseconds(3 * 1000 * 1000)); producer.close(); - Lock lock(mutex_); + std::lock_guard<std::mutex> lock(mutex_); ASSERT_GE(globalResendMessageCount, 3); } diff --git a/tests/PromiseTest.cc b/tests/PromiseTest.cc index 25b6b72..29ee2a3 100644 --- a/tests/PromiseTest.cc +++ b/tests/PromiseTest.cc @@ -24,6 +24,9 @@ #include <vector> #include "lib/Future.h" +#include "lib/LogUtils.h" + +DECLARE_LOG_OBJECT() using namespace pulsar; @@ -84,3 +87,27 @@ TEST(PromiseTest, testListeners) { ASSERT_EQ(results, (std::vector<int>(2, 0))); ASSERT_EQ(values, (std::vector<std::string>(2, "hello"))); } + +TEST(PromiseTest, testTriggerListeners) { + InternalState<int, int> state; + state.addListener([](int, const int&) { + LOG_INFO("Start task 1..."); + std::this_thread::sleep_for(std::chrono::seconds(1)); + LOG_INFO("Finish task 1..."); + }); + state.addListener([](int, const int&) { + LOG_INFO("Start task 2..."); + std::this_thread::sleep_for(std::chrono::seconds(1)); + LOG_INFO("Finish task 2..."); + }); + + auto start = std::chrono::high_resolution_clock::now(); + auto future1 = std::async(std::launch::async, [&state] { state.triggerListeners(0, 0); }); + auto future2 = std::async(std::launch::async, [&state] { state.triggerListeners(0, 0); }); + future1.wait(); + future2.wait(); + auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::high_resolution_clock::now() - start) + .count(); + ASSERT_TRUE(elapsed > 2000) << "elapsed: " << elapsed << "ms"; +}