rtpsw commented on code in PR #35874:
URL: https://github.com/apache/arrow/pull/35874#discussion_r1231400648


##########
cpp/src/arrow/acero/asof_join_node_test.cc:
##########
@@ -1381,36 +1587,85 @@ void TestBackpressure(BatchesMaker maker, int 
num_batches, int batch_size,
   ASSERT_OK_AND_ASSIGN(auto r0_batches, make_shift(r0_schema, 1));
   ASSERT_OK_AND_ASSIGN(auto r1_batches, make_shift(r1_schema, 2));
 
-  Declaration l_src = {
-      "source", SourceNodeOptions(
-                    l_schema, MakeDelayedGen(l_batches, "0:fast", fast_delay, 
noisy))};
-  Declaration r0_src = {
-      "source", SourceNodeOptions(
-                    r0_schema, MakeDelayedGen(r0_batches, "1:slow", 
slow_delay, noisy))};
-  Declaration r1_src = {
-      "source", SourceNodeOptions(
-                    r1_schema, MakeDelayedGen(r1_batches, "2:fast", 
fast_delay, noisy))};
+  BackpressureCountingNode::Register();
+  GatedNode::Register();
 
-  Declaration asofjoin = {
-      "asofjoin", {l_src, r0_src, r1_src}, GetRepeatedOptions(3, "time", 
{"key"}, 1000)};
+  struct BackpressureSourceConfig {
+    std::string name_prefix;
+    bool is_gated;
+    std::shared_ptr<Schema> schema;
+    decltype(l_batches) batches;
 
-  ASSERT_OK_AND_ASSIGN(std::unique_ptr<RecordBatchReader> batch_reader,
-                       DeclarationToReader(asofjoin, /*use_threads=*/false));
+    std::string name() const {
+      return name_prefix + ";" + (is_gated ? "gated" : "ungated");
+    }
+  };
+
+  Gate gate;
+  GatedNodeOptions gate_options(&gate);
+
+  // Two ungated and one gated
+  std::vector<BackpressureSourceConfig> source_configs = {
+      {"0", false, l_schema, l_batches},
+      {"1", true, r0_schema, r0_batches},
+      {"2", false, r1_schema, r1_batches},
+  };
 
-  int64_t total_length = 0;
-  for (;;) {
-    ASSERT_OK_AND_ASSIGN(auto batch, batch_reader->Next());
-    if (!batch) {
-      break;
+  std::vector<BackpressureCounters> bp_counters(source_configs.size());
+  std::vector<Declaration> src_decls;
+  std::vector<std::shared_ptr<BackpressureCountingNodeOptions>> bp_options;
+  std::vector<Declaration::Input> bp_decls;
+  for (size_t i = 0; i < source_configs.size(); i++) {
+    const auto& config = source_configs[i];
+
+    src_decls.emplace_back("source",
+                           SourceNodeOptions(config.schema, 
GetGen(config.batches)));
+    bp_options.push_back(
+        std::make_shared<BackpressureCountingNodeOptions>(&bp_counters[i]));
+    std::shared_ptr<ExecNodeOptions> options = bp_options.back();
+    std::vector<Declaration::Input> bp_in = {src_decls.back()};
+    Declaration bp_decl = {BackpressureCountingNode::kFactoryName, bp_in,
+                           std::move(options)};
+    if (config.is_gated) {
+      bp_decl = {GatedNode::kFactoryName, {bp_decl}, gate_options};
     }
-    total_length += batch->num_rows();
+    bp_decls.push_back(bp_decl);
+  }
+
+  Declaration asofjoin = {"asofjoin", bp_decls,
+                          GetRepeatedOptions(source_configs.size(), "time", 
{"key"}, 0)};
+
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<internal::ThreadPool> tpool,
+                       internal::ThreadPool::Make(1));
+  ExecContext exec_ctx(default_memory_pool(), tpool.get());
+  Future<BatchesWithCommonSchema> batches_fut =
+      DeclarationToExecBatchesAsync(asofjoin, exec_ctx);
+
+  auto has_bp_been_applied = [&] {
+    int total_paused = 0;
+    for (const auto& counters : bp_counters) {
+      total_paused += counters.pause_count;
+    }
+    // One of the inputs is gated.  The other two will eventually be paused by 
the asof
+    // join node
+    return total_paused >= 2;
+  };
+
+  BusyWait(10.0, has_bp_been_applied);
+  ASSERT_TRUE(has_bp_been_applied());
+
+  gate.ReleaseAllBatches();
+  ASSERT_FINISHES_OK_AND_ASSIGN(BatchesWithCommonSchema batches, batches_fut);
+
+  size_t total_resumed = 0;
+  for (const auto& counters : bp_counters) {
+    total_resumed += counters.resume_count;
   }
-  ASSERT_EQ(static_cast<int64_t>(num_batches * batch_size), total_length);
+  ASSERT_GE(total_resumed, 2);

Review Comment:
   See [this 
post](https://github.com/apache/arrow/pull/35874#discussion_r1231399976).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to