westonpace commented on code in PR #12228:
URL: https://github.com/apache/arrow/pull/12228#discussion_r855748930


##########
cpp/src/arrow/compute/exec/sink_node.cc:
##########
@@ -46,31 +46,81 @@ using internal::checked_cast;
 namespace compute {
 namespace {
 
+class BackpressureResevoir : public BackpressureMonitor {
+ public:
+  BackpressureResevoir(uint64_t resume_if_below, uint64_t pause_if_above)
+      : bytes_used_(0),
+        state_change_counter_(0),
+        resume_if_below_(resume_if_below),
+        pause_if_above_(pause_if_above) {}
+
+  uint64_t bytes_in_use() const override { return bytes_used_; }
+  bool is_paused() const override { return state_change_counter_ % 2 == 1; }
+  bool enabled() const { return pause_if_above_ > 0; }
+
+  int32_t RecordProduced(uint64_t num_bytes) {
+    std::lock_guard<std::mutex> lg(mutex_);
+    bool was_under = bytes_used_ <= pause_if_above_;
+    bytes_used_ += num_bytes;
+    if (was_under && bytes_used_ > pause_if_above_) {
+      return ++state_change_counter_;
+    }
+    return -1;
+  }
+
+  int32_t RecordConsumed(uint64_t num_bytes) {
+    std::lock_guard<std::mutex> lg(mutex_);
+    bool was_over = bytes_used_ >= resume_if_below_;
+    bytes_used_ -= num_bytes;
+    if (was_over && bytes_used_ < resume_if_below_) {
+      return ++state_change_counter_;
+    }
+    return -1;
+  }
+
+ private:
+  std::mutex mutex_;
+  uint64_t bytes_used_;
+  int32_t state_change_counter_;
+  const uint64_t resume_if_below_;
+  const uint64_t pause_if_above_;
+};
+
 class SinkNode : public ExecNode {
  public:
   SinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
            AsyncGenerator<util::optional<ExecBatch>>* generator,
-           util::BackpressureOptions backpressure)
+           BackpressureOptions backpressure,
+           std::shared_ptr<BackpressureMonitor>* backpressure_monitor_out)
       : ExecNode(plan, std::move(inputs), {"collected"}, {},
                  /*num_outputs=*/0),
-        producer_(MakeProducer(generator, std::move(backpressure))) {}
+        backpressure_queue_(std::make_shared<BackpressureResevoir>(
+            backpressure.resume_if_below, backpressure.pause_if_above)),
+        push_gen_(),
+        producer_(push_gen_.producer()) {
+    if (backpressure_monitor_out) {
+      *backpressure_monitor_out = backpressure_queue_;
+    }
+    AsyncGenerator<util::optional<ExecBatch>> captured_gen = push_gen_;

Review Comment:
   No, I've changed this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to