rtpsw commented on code in PR #35874:
URL: https://github.com/apache/arrow/pull/35874#discussion_r1233197482
##########
cpp/src/arrow/acero/asof_join_node_test.cc:
##########
@@ -1381,36 +1443,93 @@ void TestBackpressure(BatchesMaker maker, int
num_batches, int batch_size,
ASSERT_OK_AND_ASSIGN(auto r0_batches, make_shift(r0_schema, 1));
ASSERT_OK_AND_ASSIGN(auto r1_batches, make_shift(r1_schema, 2));
- Declaration l_src = {
- "source", SourceNodeOptions(
- l_schema, MakeDelayedGen(l_batches, "0:fast", fast_delay,
noisy))};
- Declaration r0_src = {
- "source", SourceNodeOptions(
- r0_schema, MakeDelayedGen(r0_batches, "1:slow",
slow_delay, noisy))};
- Declaration r1_src = {
- "source", SourceNodeOptions(
- r1_schema, MakeDelayedGen(r1_batches, "2:fast",
fast_delay, noisy))};
+ BackpressureCountingNode::Register();
+ RegisterTestNodes(); // for GatedNode
- Declaration asofjoin = {
- "asofjoin", {l_src, r0_src, r1_src}, GetRepeatedOptions(3, "time",
{"key"}, 1000)};
+ struct BackpressureSourceConfig {
+ std::string name_prefix;
+ bool is_gated;
+ std::shared_ptr<Schema> schema;
+ decltype(l_batches) batches;
+
+ std::string name() const {
+ return name_prefix + ";" + (is_gated ? "gated" : "ungated");
+ }
+ };
+
+ auto gate_ptr = Gate::Make();
+ auto& gate = *gate_ptr;
+ GatedNodeOptions gate_options(gate_ptr.get());
+
+ // Two ungated and one gated
+ std::vector<BackpressureSourceConfig> source_configs = {
+ {"0", false, l_schema, l_batches},
+ {"1", true, r0_schema, r0_batches},
+ {"2", false, r1_schema, r1_batches},
+ };
+
+ std::vector<BackpressureCounters> bp_counters(source_configs.size());
+ std::vector<Declaration> src_decls;
+ std::vector<std::shared_ptr<BackpressureCountingNodeOptions>> bp_options;
+ std::vector<Declaration::Input> bp_decls;
+ for (size_t i = 0; i < source_configs.size(); i++) {
+ const auto& config = source_configs[i];
+
+ src_decls.emplace_back("source",
+ SourceNodeOptions(config.schema,
GetGen(config.batches)));
+ bp_options.push_back(
+ std::make_shared<BackpressureCountingNodeOptions>(&bp_counters[i]));
+ std::shared_ptr<ExecNodeOptions> options = bp_options.back();
+ std::vector<Declaration::Input> bp_in = {src_decls.back()};
+ Declaration bp_decl = {BackpressureCountingNode::kFactoryName, bp_in,
+ std::move(options)};
+ if (config.is_gated) {
+ bp_decl = {std::string{GatedNodeOptions::kName}, {bp_decl},
gate_options};
+ }
+ bp_decls.push_back(bp_decl);
+ }
+
+ Declaration asofjoin = {"asofjoin", bp_decls,
+ GetRepeatedOptions(source_configs.size(), "time",
{"key"}, 0)};
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<internal::ThreadPool> tpool,
+ internal::ThreadPool::Make(1));
+ ExecContext exec_ctx(default_memory_pool(), tpool.get());
+ Future<BatchesWithCommonSchema> batches_fut =
+ DeclarationToExecBatchesAsync(asofjoin, exec_ctx);
+
+ auto has_bp_been_applied = [&] {
+ // One of the inputs is gated. The other two will eventually be paused by
the asof
+ // join node
+ for (size_t i = 0; i < source_configs.size(); i++) {
+ const auto& counters = bp_counters[i];
+ if (source_configs[i].is_gated) {
+ if (counters.pause_count > 0) return false;
Review Comment:
The logic around here checks the following expectations of correct
application of backpressure: a gated node should not have been paused (checked
by the current line) whereas an non-gated node should have been paused once
(checked by the next if-statement).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]