rtpsw commented on code in PR #35874:
URL: https://github.com/apache/arrow/pull/35874#discussion_r1231405359
##########
cpp/src/arrow/acero/asof_join_node_test.cc:
##########
@@ -1360,9 +1366,209 @@ TRACED_TEST(AsofJoinTest, TestUnorderedOnKey, {
schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())}));
})
+struct BackpressureCounters {
+ std::atomic<int32_t> pause_count = 0;
+ std::atomic<int32_t> resume_count = 0;
+};
+
+struct BackpressureCountingNodeOptions : public ExecNodeOptions {
+ BackpressureCountingNodeOptions(BackpressureCounters* counters) :
counters(counters) {}
+
+ BackpressureCounters* counters;
+};
+
+struct BackpressureCountingNode : public MapNode {
+ static constexpr const char* kKindName = "BackpressureCountingNode";
+ static constexpr const char* kFactoryName = "backpressure_count";
+
+ static void Register() {
+ auto exec_reg = default_exec_factory_registry();
+ if (!exec_reg->GetFactory(kFactoryName).ok()) {
+ ASSERT_OK(exec_reg->AddFactory(kFactoryName,
BackpressureCountingNode::Make));
+ }
+ }
+
+ BackpressureCountingNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema,
+ const BackpressureCountingNodeOptions& options)
+ : MapNode(plan, inputs, output_schema), counters(options.counters) {}
+
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ const ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, kKindName));
+ auto bp_options = static_cast<const
BackpressureCountingNodeOptions&>(options);
+ return plan->EmplaceNode<BackpressureCountingNode>(
+ plan, inputs, inputs[0]->output_schema(), bp_options);
+ }
+
+ const char* kind_name() const override { return kKindName; }
+ Result<ExecBatch> ProcessBatch(ExecBatch batch) override { return batch; }
+
+ void PauseProducing(ExecNode* output, int32_t counter) override {
+ ++counters->pause_count;
+ inputs()[0]->PauseProducing(this, counter);
+ }
+ void ResumeProducing(ExecNode* output, int32_t counter) override {
+ ++counters->resume_count;
+ inputs()[0]->ResumeProducing(this, counter);
+ }
+
+ BackpressureCounters* counters;
+};
+
+class Gate {
+ public:
+ void ReleaseAllBatches() {
+ std::lock_guard lg(mutex_);
+ num_allowed_batches_ = -1;
+ NotifyAll();
+ }
+
+ void ReleaseOneBatch() {
+ std::lock_guard lg(mutex_);
+ DCHECK_GE(num_allowed_batches_, 0)
+ << "you can't call ReleaseOneBatch() after calling
ReleaseAllBatches()";
+ num_allowed_batches_++;
+ NotifyAll();
+ }
+
+ Future<> WaitForNextReleasedBatch() {
+ std::lock_guard lg(mutex_);
+ if (current_waiter_.is_valid()) {
+ return current_waiter_;
+ }
+ Future<> fut;
+ if (num_allowed_batches_ < 0 || num_released_batches_ <
num_allowed_batches_) {
+ num_released_batches_++;
+ return Future<>::MakeFinished();
+ }
+
+ current_waiter_ = Future<>::Make();
+ return current_waiter_;
+ }
+
+ private:
+ void NotifyAll() {
+ if (current_waiter_.is_valid()) {
+ Future<> to_unlock = current_waiter_;
+ current_waiter_ = {};
+ to_unlock.MarkFinished();
+ }
+ }
+
+ Future<> current_waiter_;
+ int num_released_batches_ = 0;
+ int num_allowed_batches_ = 0;
+ std::mutex mutex_;
+};
+
+struct GatedNodeOptions : public ExecNodeOptions {
+ explicit GatedNodeOptions(Gate* gate) : gate(gate) {}
+ Gate* gate;
+};
+
+struct GatedNode : public ExecNode, public TracedNode {
Review Comment:
Done.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]