This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 54eedb95ec GH-33960: [C++] Add DeclarationToSchema and 
DeclarationToString helper methods. (#34013)
54eedb95ec is described below

commit 54eedb95ec504a715d557e71139ae4df9657fde6
Author: Weston Pace <[email protected]>
AuthorDate: Fri Feb 3 12:25:19 2023 -0800

    GH-33960: [C++] Add DeclarationToSchema and DeclarationToString helper 
methods. (#34013)
    
    Also cleans up the Acero server example to use the DeclarationToXyz methods
    * Closes: #33960
    
    Authored-by: Weston Pace <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/src/arrow/compute/exec/exec_plan.cc          |  54 ++++++++-
 cpp/src/arrow/compute/exec/exec_plan.h           |  30 +++++
 cpp/src/arrow/compute/exec/plan_test.cc          | 116 +++++++++++--------
 cpp/src/arrow/flight/sql/example/acero_server.cc | 137 ++---------------------
 4 files changed, 160 insertions(+), 177 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc 
b/cpp/src/arrow/compute/exec/exec_plan.cc
index 896eafb58c..a187d4346f 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.cc
+++ b/cpp/src/arrow/compute/exec/exec_plan.cc
@@ -575,6 +575,52 @@ bool Declaration::IsValid(ExecFactoryRegistry* registry) 
const {
   return !this->factory_name.empty() && this->options != nullptr;
 }
 
+namespace {
+
+Result<ExecNode*> EnsureSink(ExecNode* last_node, ExecPlan* plan) {
+  if (!last_node->is_sink()) {
+    Declaration null_sink =
+        Declaration("consuming_sink", {last_node},
+                    ConsumingSinkNodeOptions(NullSinkNodeConsumer::Make()));
+    return null_sink.AddToPlan(plan);
+  }
+  return last_node;
+}
+
+}  // namespace
+
+Result<std::shared_ptr<Schema>> DeclarationToSchema(const Declaration& 
declaration,
+                                                    FunctionRegistry* 
function_registry) {
+  // We pass in the default memory pool and the CPU executor but nothing we 
are doing
+  // should be starting new thread tasks or making large allocations.
+  ExecContext exec_context(default_memory_pool(), 
::arrow::internal::GetCpuThreadPool(),
+                           function_registry);
+  ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> exec_plan,
+                        ExecPlan::Make(exec_context));
+  ARROW_ASSIGN_OR_RAISE(ExecNode * last_node, 
declaration.AddToPlan(exec_plan.get()));
+  ARROW_ASSIGN_OR_RAISE(last_node, EnsureSink(last_node, exec_plan.get()));
+  ARROW_RETURN_NOT_OK(exec_plan->Validate());
+  if (last_node->inputs().size() != 1) {
+    // Every sink node today has exactly one input
+    return Status::Invalid("Unexpected sink node with more than one input");
+  }
+  return last_node->inputs()[0]->output_schema();
+}
+
+Result<std::string> DeclarationToString(const Declaration& declaration,
+                                        FunctionRegistry* function_registry) {
+  // We pass in the default memory pool and the CPU executor but nothing we 
are doing
+  // should be starting new thread tasks or making large allocations.
+  ExecContext exec_context(default_memory_pool(), 
::arrow::internal::GetCpuThreadPool(),
+                           function_registry);
+  ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ExecPlan> exec_plan,
+                        ExecPlan::Make(exec_context));
+  ARROW_ASSIGN_OR_RAISE(ExecNode * last_node, 
declaration.AddToPlan(exec_plan.get()));
+  ARROW_ASSIGN_OR_RAISE(last_node, EnsureSink(last_node, exec_plan.get()));
+  ARROW_RETURN_NOT_OK(exec_plan->Validate());
+  return exec_plan->ToString();
+}
+
 Future<std::shared_ptr<Table>> DeclarationToTableAsync(Declaration declaration,
                                                        ExecContext 
exec_context) {
   std::shared_ptr<std::shared_ptr<Table>> output_table =
@@ -817,11 +863,17 @@ Result<std::unique_ptr<RecordBatchReader>> 
DeclarationToReader(
     std::shared_ptr<Schema> schema() const override { return schema_; }
 
     Status ReadNext(std::shared_ptr<RecordBatch>* record_batch) override {
-      DCHECK(!!iterator_) << "call to ReadNext on already closed reader";
+      if (!iterator_) {
+        return Status::Invalid("call to ReadNext on already closed reader");
+      }
       return iterator_->Next().Value(record_batch);
     }
 
     Status Close() override {
+      if (!iterator_) {
+        // Already closed
+        return Status::OK();
+      }
       // End plan and read from generator until finished
       std::shared_ptr<RecordBatch> batch;
       do {
diff --git a/cpp/src/arrow/compute/exec/exec_plan.h 
b/cpp/src/arrow/compute/exec/exec_plan.h
index 0fcfb36754..dc875ef479 100644
--- a/cpp/src/arrow/compute/exec/exec_plan.h
+++ b/cpp/src/arrow/compute/exec/exec_plan.h
@@ -408,6 +408,36 @@ struct ARROW_EXPORT Declaration {
   std::string label;
 };
 
+/// \brief Calculate the output schema of a declaration
+///
+/// This does not actually execute the plan.  This operation may fail if the
+/// declaration represents an invalid plan (e.g. a project node with multiple 
inputs)
+///
+/// \param declaration A declaration describing an execution plan
+/// \param function_registry The function registry to use for function 
execution.  If null
+///                          then the default function registry will be used.
+///
+/// \return the schema that batches would have after going through the 
execution plan
+ARROW_EXPORT Result<std::shared_ptr<Schema>> DeclarationToSchema(
+    const Declaration& declaration, FunctionRegistry* function_registry = 
NULLPTR);
+
+/// \brief Create a string representation of a plan
+///
+/// This representation is for debug purposes only.
+///
+/// Conversion to a string may fail if the declaration represents an
+/// invalid plan.
+///
+/// Use Substrait for complete serialization of plans
+///
+/// \param declaration A declaration describing an execution plan
+/// \param function_registry The function registry to use for function 
execution.  If null
+///                          then the default function registry will be used.
+///
+/// \return a string representation of the plan suitable for debugging output
+ARROW_EXPORT Result<std::string> DeclarationToString(
+    const Declaration& declaration, FunctionRegistry* function_registry = 
NULLPTR);
+
 /// \brief Utility method to run a declaration and collect the results into a 
table
 ///
 /// \param declaration A declaration describing the plan to run
diff --git a/cpp/src/arrow/compute/exec/plan_test.cc 
b/cpp/src/arrow/compute/exec/plan_test.cc
index 497b719625..5b2af718df 100644
--- a/cpp/src/arrow/compute/exec/plan_test.cc
+++ b/cpp/src/arrow/compute/exec/plan_test.cc
@@ -440,6 +440,8 @@ TEST(ExecPlan, ToString) {
   auto basic_data = MakeBasicBatches();
   AsyncGenerator<std::optional<ExecBatch>> sink_gen;
 
+  // Cannot test the following mini-plans with DeclarationToString since 
validation
+  // would fail (no sink)
   ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
   ASSERT_OK(Declaration::Sequence(
                 {
@@ -456,40 +458,36 @@ TEST(ExecPlan, ToString) {
   :SourceNode{}
 )");
 
-  ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
   std::shared_ptr<CountOptions> options =
       std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
-  ASSERT_OK(
-      Declaration::Sequence(
-          {
-              {"source",
-               SourceNodeOptions{basic_data.schema,
-                                 basic_data.gen(/*parallel=*/false, 
/*slow=*/false)},
-               "custom_source_label"},
-              {"filter", FilterNodeOptions{greater_equal(field_ref("i32"), 
literal(0))}},
-              {"project", ProjectNodeOptions{{
-                              field_ref("bool"),
-                              call("multiply", {field_ref("i32"), literal(2)}),
-                          }}},
-              {"aggregate",
-               AggregateNodeOptions{
-                   /*aggregates=*/{
-                       {"hash_sum", nullptr, "multiply(i32, 2)", 
"sum(multiply(i32, 2))"},
-                       {"hash_count", options, "multiply(i32, 2)",
-                        "count(multiply(i32, 2))"},
-                       {"hash_count_all", "count(*)"},
-                   },
-                   /*keys=*/{"bool"}}},
-              {"filter", 
FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"),
-                                                   literal(10))}},
-              {"order_by_sink",
-               OrderBySinkNodeOptions{
-                   SortOptions({SortKey{"sum(multiply(i32, 2))", 
SortOrder::Ascending}}),
-                   &sink_gen},
-               "custom_sink_label"},
-          })
-          .AddToPlan(plan.get()));
-  EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 6 nodes:
+  Declaration declaration = Declaration::Sequence({
+      {"source",
+       SourceNodeOptions{basic_data.schema,
+                         basic_data.gen(/*parallel=*/false, /*slow=*/false)},
+       "custom_source_label"},
+      {"filter", FilterNodeOptions{greater_equal(field_ref("i32"), 
literal(0))}},
+      {"project", ProjectNodeOptions{{
+                      field_ref("bool"),
+                      call("multiply", {field_ref("i32"), literal(2)}),
+                  }}},
+      {"aggregate",
+       AggregateNodeOptions{
+           /*aggregates=*/{
+               {"hash_sum", nullptr, "multiply(i32, 2)", "sum(multiply(i32, 
2))"},
+               {"hash_count", options, "multiply(i32, 2)", 
"count(multiply(i32, 2))"},
+               {"hash_count_all", "count(*)"},
+           },
+           /*keys=*/{"bool"}}},
+      {"filter",
+       FilterNodeOptions{greater(field_ref("sum(multiply(i32, 2))"), 
literal(10))}},
+      {"order_by_sink",
+       OrderBySinkNodeOptions{
+           SortOptions({SortKey{"sum(multiply(i32, 2))", 
SortOrder::Ascending}}),
+           &sink_gen},
+       "custom_sink_label"},
+  });
+  ASSERT_OK_AND_ASSIGN(std::string plan_str, DeclarationToString(declaration));
+  EXPECT_EQ(plan_str, R"a(ExecPlan with 6 nodes:
 
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
 2))) ASC], null_placement=AtEnd}}
   :FilterNode{filter=(sum(multiply(i32, 2)) > 10)}
     :GroupByNode{keys=["bool"], aggregates=[
@@ -502,8 +500,6 @@ 
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
           custom_source_label:SourceNode{}
 )a");
 
-  ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make());
-
   Declaration union_node{"union", ExecNodeOptions{}};
   Declaration lhs{"source",
                   SourceNodeOptions{basic_data.schema,
@@ -515,19 +511,17 @@ 
custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32,
   rhs.label = "rhs";
   union_node.inputs.emplace_back(lhs);
   union_node.inputs.emplace_back(rhs);
-  ASSERT_OK(Declaration::Sequence(
-                {
-                    union_node,
-                    {"aggregate",
-                     AggregateNodeOptions{/*aggregates=*/{
-                                              {"count", options, "i32", 
"count(i32)"},
-                                              {"count_all", "count(*)"},
-                                          },
-                                          /*keys=*/{}}},
-                    {"sink", SinkNodeOptions{&sink_gen}},
-                })
-                .AddToPlan(plan.get()));
-  EXPECT_EQ(plan->ToString(), R"a(ExecPlan with 5 nodes:
+  declaration = Declaration::Sequence({
+      union_node,
+      {"aggregate", AggregateNodeOptions{/*aggregates=*/{
+                                             {"count", options, "i32", 
"count(i32)"},
+                                             {"count_all", "count(*)"},
+                                         },
+                                         /*keys=*/{}}},
+      {"sink", SinkNodeOptions{&sink_gen}},
+  });
+  ASSERT_OK_AND_ASSIGN(plan_str, DeclarationToString(declaration));
+  EXPECT_EQ(plan_str, R"a(ExecPlan with 5 nodes:
 :SinkNode{}
   :ScalarAggregateNode{aggregates=[
        count(i32, {mode=NON_NULL}),
@@ -674,6 +668,34 @@ TEST(ExecPlanExecution, SourceTableConsumingSink) {
   }
 }
 
+TEST(ExecPlanExecution, DeclarationToSchema) {
+  auto basic_data = MakeBasicBatches();
+  auto plan = Declaration::Sequence(
+      {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, 
false))},
+       {"aggregate", AggregateNodeOptions({{"hash_sum", "i32", "int32_sum"}}, 
{"bool"})},
+       {"project",
+        ProjectNodeOptions({field_ref("int32_sum"),
+                            call("multiply", {field_ref("int32_sum"), 
literal(2)})})}});
+  auto expected_out_schema =
+      schema({field("int32_sum", int64()), field("multiply(int32_sum, 2)", 
int64())});
+  ASSERT_OK_AND_ASSIGN(auto actual_out_schema, 
DeclarationToSchema(std::move(plan)));
+  AssertSchemaEqual(expected_out_schema, actual_out_schema);
+}
+
+TEST(ExecPlanExecution, DeclarationToReader) {
+  auto basic_data = MakeBasicBatches();
+  auto plan = Declaration::Sequence(
+      {{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, 
false))}});
+  ASSERT_OK_AND_ASSIGN(std::unique_ptr<RecordBatchReader> reader,
+                       DeclarationToReader(plan));
+
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> out, reader->ToTable());
+  ASSERT_EQ(5, out->num_rows());
+  ASSERT_OK(reader->Close());
+  EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("already closed reader"),
+                                  reader->Next());
+}
+
 TEST(ExecPlanExecution, ConsumingSinkNames) {
   struct SchemaKeepingConsumer : public SinkNodeConsumer {
     std::shared_ptr<Schema> schema_;
diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc 
b/cpp/src/arrow/flight/sql/example/acero_server.cc
index 43b69d669f..ed5422e81f 100644
--- a/cpp/src/arrow/flight/sql/example/acero_server.cc
+++ b/cpp/src/arrow/flight/sql/example/acero_server.cc
@@ -35,108 +35,6 @@ namespace sql {
 namespace acero_example {
 
 namespace {
-/// \brief A SinkNodeConsumer that saves the schema as given to it by
-///   the ExecPlan. Used to retrieve the schema of a Substrait plan to
-///   fulfill the Flight SQL API contract.
-class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {
- public:
-  Status Init(const std::shared_ptr<Schema>& schema, 
compute::BackpressureControl*,
-              compute::ExecPlan* plan) override {
-    schema_ = schema;
-    return Status::OK();
-  }
-  Status Consume(compute::ExecBatch exec_batch) override { return 
Status::OK(); }
-  Future<> Finish() override { return Status::OK(); }
-
-  const std::shared_ptr<Schema>& schema() const { return schema_; }
-
- private:
-  std::shared_ptr<Schema> schema_;
-};
-
-/// \brief A SinkNodeConsumer that internally saves batches into a
-///   queue, so that it can be read from a RecordBatchReader. In other
-///   words, this bridges a push-based interface (ExecPlan) to a
-///   pull-based interface (RecordBatchReader).
-class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer {
- public:
-  QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {}
-
-  Status Init(const std::shared_ptr<Schema>& schema, 
compute::BackpressureControl*,
-              compute::ExecPlan* plan) override {
-    schema_ = schema;
-    return Status::OK();
-  }
-
-  Status Consume(compute::ExecBatch exec_batch) override {
-    {
-      std::lock_guard<std::mutex> guard(mutex_);
-      batches_.push_back(std::move(exec_batch));
-      batches_added_.notify_all();
-    }
-
-    return Status::OK();
-  }
-
-  Future<> Finish() override {
-    {
-      std::lock_guard<std::mutex> guard(mutex_);
-      finished_ = true;
-      batches_added_.notify_all();
-    }
-
-    return Status::OK();
-  }
-
-  const std::shared_ptr<Schema>& schema() const { return schema_; }
-
-  arrow::Result<std::shared_ptr<RecordBatch>> Next() {
-    compute::ExecBatch batch;
-    {
-      std::unique_lock<std::mutex> guard(mutex_);
-      batches_added_.wait(guard, [this] { return !batches_.empty() || 
finished_; });
-
-      if (finished_ && batches_.empty()) {
-        return nullptr;
-      }
-      batch = std::move(batches_.front());
-      batches_.pop_front();
-    }
-
-    return batch.ToRecordBatch(schema_);
-  }
-
- private:
-  std::mutex mutex_;
-  std::condition_variable batches_added_;
-  std::deque<compute::ExecBatch> batches_;
-  std::shared_ptr<Schema> schema_;
-  bool finished_;
-};
-
-/// \brief A RecordBatchReader that pulls from the
-///   QueuingSinkNodeConsumer above, blocking until results are
-///   available as necessary.
-class ConsumerBasedRecordBatchReader : public RecordBatchReader {
- public:
-  explicit ConsumerBasedRecordBatchReader(
-      std::shared_ptr<compute::ExecPlan> plan,
-      std::shared_ptr<QueuingSinkNodeConsumer> consumer)
-      : plan_(std::move(plan)), consumer_(std::move(consumer)) {}
-
-  std::shared_ptr<Schema> schema() const override { return 
consumer_->schema(); }
-
-  Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
-    return consumer_->Next().Value(batch);
-  }
-
-  // TODO(ARROW-17242): FlightDataStream needs to call Close()
-  Status Close() override { return plan_->finished().status(); }
-
- private:
-  std::shared_ptr<compute::ExecPlan> plan_;
-  std::shared_ptr<QueuingSinkNodeConsumer> consumer_;
-};
 
 /// \brief An implementation of a Flight SQL service backed by Acero.
 class AceroFlightSqlServer : public FlightSqlServerBase {
@@ -193,18 +91,14 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
     // GetFlightInfoSubstraitPlan encodes the plan into the ticket
     std::shared_ptr<Buffer> serialized_plan =
         Buffer::FromString(command.statement_handle);
-    std::shared_ptr<QueuingSinkNodeConsumer> consumer =
-        std::make_shared<QueuingSinkNodeConsumer>();
-    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<compute::ExecPlan> plan,
-                          engine::DeserializePlan(*serialized_plan, consumer));
-
-    ARROW_LOG(INFO) << "DoGetStatement: executing plan " << plan->ToString();
+    ARROW_ASSIGN_OR_RAISE(compute::Declaration plan,
+                          engine::DeserializePlan(*serialized_plan));
 
-    plan->StartProducing();
+    ARROW_LOG(INFO) << "DoGetStatement: executing plan "
+                    << compute::DeclarationToString(plan).ValueOr("Invalid 
plan");
 
-    auto reader = 
std::make_shared<ConsumerBasedRecordBatchReader>(std::move(plan),
-                                                                   
std::move(consumer));
-    return std::make_unique<RecordBatchStream>(reader);
+    ARROW_ASSIGN_OR_RAISE(auto reader, compute::DeclarationToReader(plan));
+    return std::make_unique<RecordBatchStream>(std::move(reader));
   }
 
   arrow::Result<int64_t> DoPutCommandSubstraitPlan(
@@ -263,23 +157,8 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
   arrow::Result<std::shared_ptr<arrow::Schema>> GetPlanSchema(
       const std::string& serialized_plan) {
     std::shared_ptr<Buffer> plan_buf = Buffer::FromString(serialized_plan);
-    std::shared_ptr<GetSchemaSinkNodeConsumer> consumer =
-        std::make_shared<GetSchemaSinkNodeConsumer>();
-    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<compute::ExecPlan> plan,
-                          engine::DeserializePlan(*plan_buf, consumer));
-    std::shared_ptr<Schema> output_schema;
-    for (compute::ExecNode* possible_sink : plan->nodes()) {
-      if (possible_sink->is_sink()) {
-        // Force SinkNodeConsumer::Init to be called
-        ARROW_RETURN_NOT_OK(possible_sink->StartProducing());
-        output_schema = consumer->schema();
-        break;
-      }
-    }
-    if (!output_schema) {
-      return Status::Invalid("Could not infer output schema");
-    }
-    return output_schema;
+    ARROW_ASSIGN_OR_RAISE(compute::Declaration plan, 
engine::DeserializePlan(*plan_buf));
+    return compute::DeclarationToSchema(plan);
   }
 
   arrow::Result<std::unique_ptr<FlightInfo>> MakeFlightInfo(

Reply via email to