This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new c544a8bb02 ARROW-15779: [Python] Create python bindings for Substrait
consumer
c544a8bb02 is described below
commit c544a8bb025d5fb1226637780275c8753f81dcb3
Author: Vibhatha Abeykoon <[email protected]>
AuthorDate: Fri May 20 13:50:01 2022 -1000
ARROW-15779: [Python] Create python bindings for Substrait consumer
The PR includes the initial integration of Substrait to Python
- [x] Adding a util API for consuming Substrait
- [x] Adding a C++ example for using Substrait with Util API
- [x] Adding Python Bindings for Substrait using the Util API
- [x] Adding CMake changes to integrate `engine` module (experimental) :
comments and suggestions are much appreciated to improve this component
- [x] Adding a Python example to consume a Substrait plan (experimental)
Closes #12672 from vibhatha/arrow-15779
Authored-by: Vibhatha Abeykoon <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
.github/workflows/python.yml | 1 +
ci/scripts/python_build.sh | 1 +
cpp/src/arrow/engine/CMakeLists.txt | 3 +-
cpp/src/arrow/engine/substrait/serde_test.cc | 62 +++++++++++
cpp/src/arrow/engine/substrait/type_internal.cc | 137 ++++++++++++------------
cpp/src/arrow/engine/substrait/util.cc | 130 ++++++++++++++++++++++
cpp/src/arrow/engine/substrait/util.h | 44 ++++++++
python/CMakeLists.txt | 20 ++++
python/pyarrow/_substrait.pyx | 79 ++++++++++++++
python/pyarrow/includes/libarrow_substrait.pxd | 26 +++++
python/pyarrow/substrait.py | 20 ++++
python/pyarrow/tests/conftest.py | 8 ++
python/pyarrow/tests/test_substrait.py | 93 ++++++++++++++++
python/setup.py | 9 ++
14 files changed, 566 insertions(+), 67 deletions(-)
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 5ababa3295..b14559d12a 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -130,6 +130,7 @@ jobs:
ARROW_PLASMA: ON
ARROW_PYTHON: ON
ARROW_S3: ON
+ ARROW_SUBSTRAIT: ON
ARROW_WITH_ZLIB: ON
ARROW_WITH_LZ4: ON
ARROW_WITH_BZ2: ON
diff --git a/ci/scripts/python_build.sh b/ci/scripts/python_build.sh
index e87117ce87..b90321643c 100755
--- a/ci/scripts/python_build.sh
+++ b/ci/scripts/python_build.sh
@@ -64,6 +64,7 @@ export PYARROW_WITH_PLASMA=${ARROW_PLASMA:-OFF}
export PYARROW_WITH_PARQUET=${ARROW_PARQUET:-OFF}
export PYARROW_WITH_PARQUET_ENCRYPTION=${PARQUET_REQUIRE_ENCRYPTION:-ON}
export PYARROW_WITH_S3=${ARROW_S3:-OFF}
+export PYARROW_WITH_SUBSTRAIT=${ARROW_SUBSTRAIT:-OFF}
export PYARROW_PARALLEL=${n_jobs}
diff --git a/cpp/src/arrow/engine/CMakeLists.txt
b/cpp/src/arrow/engine/CMakeLists.txt
index d09b8819fb..ea9797ea1d 100644
--- a/cpp/src/arrow/engine/CMakeLists.txt
+++ b/cpp/src/arrow/engine/CMakeLists.txt
@@ -26,7 +26,8 @@ set(ARROW_SUBSTRAIT_SRCS
substrait/serde.cc
substrait/plan_internal.cc
substrait/relation_internal.cc
- substrait/type_internal.cc)
+ substrait/type_internal.cc
+ substrait/util.cc)
add_arrow_lib(arrow_substrait
CMAKE_PACKAGE_NAME
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc
b/cpp/src/arrow/engine/substrait/serde_test.cc
index 775e2520e0..deee2d1445 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/engine/substrait/serde.h"
+#include "arrow/engine/substrait/util.h"
#include <google/protobuf/descriptor.h>
#include <google/protobuf/util/json_util.h>
@@ -752,5 +753,66 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) {
&ext_set));
}
+Result<std::string> GetSubstraitJSON() {
+ ARROW_ASSIGN_OR_RAISE(std::string dir_string,
+ arrow::internal::GetEnvVar("PARQUET_TEST_DATA"));
+ auto file_name =
+
arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet");
+ auto file_path = file_name->ToString();
+ std::string substrait_json = R"({
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [
+ {"binary": {}}
+ ]
+ },
+ "names": [
+ "foo"
+ ]
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file://FILENAME_PLACEHOLDER",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ }}
+ ]
+ })";
+ std::string filename_placeholder = "FILENAME_PLACEHOLDER";
+ substrait_json.replace(substrait_json.find(filename_placeholder),
+ filename_placeholder.size(), file_path);
+ return substrait_json;
+}
+
+TEST(Substrait, GetRecordBatchReader) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#else
+ ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON());
+ ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
+ ASSERT_OK_AND_ASSIGN(auto reader, substrait::ExecuteSerializedPlan(*buf));
+ ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get()));
+ // Note: assuming the binary.parquet file contains fixed amount of records
+ // in case of a test failure, re-evalaute the content in the file
+ EXPECT_EQ(table->num_rows(), 12);
+#endif
+}
+
+TEST(Substrait, InvalidPlan) {
+ std::string substrait_json = R"({
+ "relations": [
+ ]
+ })";
+ ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json));
+ ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf));
+}
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc
b/cpp/src/arrow/engine/substrait/type_internal.cc
index c1dac97b68..c7b94b4104 100644
--- a/cpp/src/arrow/engine/substrait/type_internal.cc
+++ b/cpp/src/arrow/engine/substrait/type_internal.cc
@@ -46,7 +46,7 @@ Status CheckVariation(const TypeMessage& type) {
template <typename TypeMessage>
bool IsNullable(const TypeMessage& type) {
// FIXME what can we do with NULLABILITY_UNSPECIFIED
- return type.nullability() != substrait::Type::NULLABILITY_REQUIRED;
+ return type.nullability() != ::substrait::Type::NULLABILITY_REQUIRED;
}
template <typename ArrowType, typename TypeMessage, typename... A>
@@ -99,66 +99,66 @@ Result<FieldVector> FieldsFromProto(int size, const Types&
types,
} // namespace
Result<std::pair<std::shared_ptr<DataType>, bool>> FromProto(
- const substrait::Type& type, const ExtensionSet& ext_set) {
+ const ::substrait::Type& type, const ExtensionSet& ext_set) {
switch (type.kind_case()) {
- case substrait::Type::kBool:
+ case ::substrait::Type::kBool:
return FromProtoImpl<BooleanType>(type.bool_());
- case substrait::Type::kI8:
+ case ::substrait::Type::kI8:
return FromProtoImpl<Int8Type>(type.i8());
- case substrait::Type::kI16:
+ case ::substrait::Type::kI16:
return FromProtoImpl<Int16Type>(type.i16());
- case substrait::Type::kI32:
+ case ::substrait::Type::kI32:
return FromProtoImpl<Int32Type>(type.i32());
- case substrait::Type::kI64:
+ case ::substrait::Type::kI64:
return FromProtoImpl<Int64Type>(type.i64());
- case substrait::Type::kFp32:
+ case ::substrait::Type::kFp32:
return FromProtoImpl<FloatType>(type.fp32());
- case substrait::Type::kFp64:
+ case ::substrait::Type::kFp64:
return FromProtoImpl<DoubleType>(type.fp64());
- case substrait::Type::kString:
+ case ::substrait::Type::kString:
return FromProtoImpl<StringType>(type.string());
- case substrait::Type::kBinary:
+ case ::substrait::Type::kBinary:
return FromProtoImpl<BinaryType>(type.binary());
- case substrait::Type::kTimestamp:
+ case ::substrait::Type::kTimestamp:
return FromProtoImpl<TimestampType>(type.timestamp(), TimeUnit::MICRO);
- case substrait::Type::kTimestampTz:
+ case ::substrait::Type::kTimestampTz:
return FromProtoImpl<TimestampType>(type.timestamp_tz(), TimeUnit::MICRO,
TimestampTzTimezoneString());
- case substrait::Type::kDate:
+ case ::substrait::Type::kDate:
return FromProtoImpl<Date32Type>(type.date());
- case substrait::Type::kTime:
+ case ::substrait::Type::kTime:
return FromProtoImpl<Time64Type>(type.time(), TimeUnit::MICRO);
- case substrait::Type::kIntervalYear:
+ case ::substrait::Type::kIntervalYear:
return FromProtoImpl(type.interval_year(), interval_year);
- case substrait::Type::kIntervalDay:
+ case ::substrait::Type::kIntervalDay:
return FromProtoImpl(type.interval_day(), interval_day);
- case substrait::Type::kUuid:
+ case ::substrait::Type::kUuid:
return FromProtoImpl(type.uuid(), uuid);
- case substrait::Type::kFixedChar:
+ case ::substrait::Type::kFixedChar:
return FromProtoImpl(type.fixed_char(), fixed_char,
type.fixed_char().length());
- case substrait::Type::kVarchar:
+ case ::substrait::Type::kVarchar:
return FromProtoImpl(type.varchar(), varchar, type.varchar().length());
- case substrait::Type::kFixedBinary:
+ case ::substrait::Type::kFixedBinary:
return FromProtoImpl<FixedSizeBinaryType>(type.fixed_binary(),
type.fixed_binary().length());
- case substrait::Type::kDecimal: {
+ case ::substrait::Type::kDecimal: {
const auto& decimal = type.decimal();
return FromProtoImpl<Decimal128Type>(decimal, decimal.precision(),
decimal.scale());
}
- case substrait::Type::kStruct: {
+ case ::substrait::Type::kStruct: {
const auto& struct_ = type.struct_();
ARROW_ASSIGN_OR_RAISE(auto fields, FieldsFromProto(
@@ -168,7 +168,7 @@ Result<std::pair<std::shared_ptr<DataType>, bool>>
FromProto(
return FromProtoImpl<StructType>(struct_, std::move(fields));
}
- case substrait::Type::kList: {
+ case ::substrait::Type::kList: {
const auto& list = type.list();
if (!list.has_type()) {
@@ -182,7 +182,7 @@ Result<std::pair<std::shared_ptr<DataType>, bool>>
FromProto(
list, field("item", std::move(type_nullable.first),
type_nullable.second));
}
- case substrait::Type::kMap: {
+ case ::substrait::Type::kMap: {
const auto& map = type.map();
static const std::array<char const*, 4> kMissing = {"key and value",
"value", "key",
@@ -206,7 +206,7 @@ Result<std::pair<std::shared_ptr<DataType>, bool>>
FromProto(
field("value", std::move(value_nullable.first),
value_nullable.second));
}
- case substrait::Type::kUserDefinedTypeReference: {
+ case ::substrait::Type::kUserDefinedTypeReference: {
uint32_t anchor = type.user_defined_type_reference();
ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(anchor));
return std::make_pair(std::move(type_record.type), true);
@@ -226,18 +226,20 @@ struct DataTypeToProtoImpl {
Status Visit(const NullType& t) { return EncodeUserDefined(t); }
Status Visit(const BooleanType& t) {
- return SetWith(&substrait::Type::set_allocated_bool_);
+ return SetWith(&::substrait::Type::set_allocated_bool_);
}
- Status Visit(const Int8Type& t) { return
SetWith(&substrait::Type::set_allocated_i8); }
+ Status Visit(const Int8Type& t) {
+ return SetWith(&::substrait::Type::set_allocated_i8);
+ }
Status Visit(const Int16Type& t) {
- return SetWith(&substrait::Type::set_allocated_i16);
+ return SetWith(&::substrait::Type::set_allocated_i16);
}
Status Visit(const Int32Type& t) {
- return SetWith(&substrait::Type::set_allocated_i32);
+ return SetWith(&::substrait::Type::set_allocated_i32);
}
Status Visit(const Int64Type& t) {
- return SetWith(&substrait::Type::set_allocated_i64);
+ return SetWith(&::substrait::Type::set_allocated_i64);
}
Status Visit(const UInt8Type& t) { return EncodeUserDefined(t); }
@@ -247,26 +249,27 @@ struct DataTypeToProtoImpl {
Status Visit(const HalfFloatType& t) { return EncodeUserDefined(t); }
Status Visit(const FloatType& t) {
- return SetWith(&substrait::Type::set_allocated_fp32);
+ return SetWith(&::substrait::Type::set_allocated_fp32);
}
Status Visit(const DoubleType& t) {
- return SetWith(&substrait::Type::set_allocated_fp64);
+ return SetWith(&::substrait::Type::set_allocated_fp64);
}
Status Visit(const StringType& t) {
- return SetWith(&substrait::Type::set_allocated_string);
+ return SetWith(&::substrait::Type::set_allocated_string);
}
Status Visit(const BinaryType& t) {
- return SetWith(&substrait::Type::set_allocated_binary);
+ return SetWith(&::substrait::Type::set_allocated_binary);
}
Status Visit(const FixedSizeBinaryType& t) {
-
SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width());
+ SetWithThen(&::substrait::Type::set_allocated_fixed_binary)
+ ->set_length(t.byte_width());
return Status::OK();
}
Status Visit(const Date32Type& t) {
- return SetWith(&substrait::Type::set_allocated_date);
+ return SetWith(&::substrait::Type::set_allocated_date);
}
Status Visit(const Date64Type& t) { return NotImplemented(t); }
@@ -274,10 +277,10 @@ struct DataTypeToProtoImpl {
if (t.unit() != TimeUnit::MICRO) return NotImplemented(t);
if (t.timezone() == "") {
- return SetWith(&substrait::Type::set_allocated_timestamp);
+ return SetWith(&::substrait::Type::set_allocated_timestamp);
}
if (t.timezone() == TimestampTzTimezoneString()) {
- return SetWith(&substrait::Type::set_allocated_timestamp_tz);
+ return SetWith(&::substrait::Type::set_allocated_timestamp_tz);
}
return NotImplemented(t);
@@ -286,14 +289,14 @@ struct DataTypeToProtoImpl {
Status Visit(const Time32Type& t) { return NotImplemented(t); }
Status Visit(const Time64Type& t) {
if (t.unit() != TimeUnit::MICRO) return NotImplemented(t);
- return SetWith(&substrait::Type::set_allocated_time);
+ return SetWith(&::substrait::Type::set_allocated_time);
}
Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); }
Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); }
Status Visit(const Decimal128Type& t) {
- auto dec = SetWithThen(&substrait::Type::set_allocated_decimal);
+ auto dec = SetWithThen(&::substrait::Type::set_allocated_decimal);
dec->set_precision(t.precision());
dec->set_scale(t.scale());
return Status::OK();
@@ -304,18 +307,20 @@ struct DataTypeToProtoImpl {
// FIXME assert default field name; custom ones won't roundtrip
ARROW_ASSIGN_OR_RAISE(
auto type, ToProto(*t.value_type(), t.value_field()->nullable(),
ext_set_));
-
SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release());
+ SetWithThen(&::substrait::Type::set_allocated_list)
+ ->set_allocated_type(type.release());
return Status::OK();
}
Status Visit(const StructType& t) {
- auto types =
SetWithThen(&substrait::Type::set_allocated_struct_)->mutable_types();
+ auto types =
SetWithThen(&::substrait::Type::set_allocated_struct_)->mutable_types();
types->Reserve(t.num_fields());
for (const auto& field : t.fields()) {
if (field->metadata() != nullptr) {
- return Status::Invalid("substrait::Type::Struct does not support field
metadata");
+ return Status::Invalid(
+ "::substrait::Type::Struct does not support field metadata");
}
ARROW_ASSIGN_OR_RAISE(auto type,
ToProto(*field->type(), field->nullable(),
ext_set_));
@@ -330,7 +335,7 @@ struct DataTypeToProtoImpl {
Status Visit(const MapType& t) {
// FIXME assert default field names; custom ones won't roundtrip
- auto map = SetWithThen(&substrait::Type::set_allocated_map);
+ auto map = SetWithThen(&::substrait::Type::set_allocated_map);
ARROW_ASSIGN_OR_RAISE(auto key, ToProto(*t.key_type(), /*nullable=*/false,
ext_set_));
map->set_allocated_key(key.release());
@@ -344,25 +349,25 @@ struct DataTypeToProtoImpl {
Status Visit(const ExtensionType& t) {
if (UnwrapUuid(t)) {
- return SetWith(&substrait::Type::set_allocated_uuid);
+ return SetWith(&::substrait::Type::set_allocated_uuid);
}
if (auto length = UnwrapFixedChar(t)) {
-
SetWithThen(&substrait::Type::set_allocated_fixed_char)->set_length(*length);
+
SetWithThen(&::substrait::Type::set_allocated_fixed_char)->set_length(*length);
return Status::OK();
}
if (auto length = UnwrapVarChar(t)) {
-
SetWithThen(&substrait::Type::set_allocated_varchar)->set_length(*length);
+
SetWithThen(&::substrait::Type::set_allocated_varchar)->set_length(*length);
return Status::OK();
}
if (UnwrapIntervalYear(t)) {
- return SetWith(&substrait::Type::set_allocated_interval_year);
+ return SetWith(&::substrait::Type::set_allocated_interval_year);
}
if (UnwrapIntervalDay(t)) {
- return SetWith(&substrait::Type::set_allocated_interval_day);
+ return SetWith(&::substrait::Type::set_allocated_interval_day);
}
return NotImplemented(t);
@@ -376,10 +381,10 @@ struct DataTypeToProtoImpl {
Status Visit(const MonthDayNanoIntervalType& t) { return
EncodeUserDefined(t); }
template <typename Sub>
- Sub* SetWithThen(void (substrait::Type::*set_allocated_sub)(Sub*)) {
+ Sub* SetWithThen(void (::substrait::Type::*set_allocated_sub)(Sub*)) {
auto sub = internal::make_unique<Sub>();
- sub->set_nullability(nullable_ ? substrait::Type::NULLABILITY_NULLABLE
- : substrait::Type::NULLABILITY_REQUIRED);
+ sub->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE
+ : ::substrait::Type::NULLABILITY_REQUIRED);
auto out = sub.get();
(type_->*set_allocated_sub)(sub.release());
@@ -387,7 +392,7 @@ struct DataTypeToProtoImpl {
}
template <typename Sub>
- Status SetWith(void (substrait::Type::*set_allocated_sub)(Sub*)) {
+ Status SetWith(void (::substrait::Type::*set_allocated_sub)(Sub*)) {
return SetWithThen(set_allocated_sub), Status::OK();
}
@@ -399,25 +404,25 @@ struct DataTypeToProtoImpl {
}
Status NotImplemented(const DataType& t) {
- return Status::NotImplemented("conversion to substrait::Type from ",
t.ToString());
+ return Status::NotImplemented("conversion to ::substrait::Type from ",
t.ToString());
}
Status operator()(const DataType& type) { return VisitTypeInline(type,
this); }
- substrait::Type* type_;
+ ::substrait::Type* type_;
bool nullable_;
ExtensionSet* ext_set_;
};
} // namespace
-Result<std::unique_ptr<substrait::Type>> ToProto(const DataType& type, bool
nullable,
- ExtensionSet* ext_set) {
- auto out = internal::make_unique<substrait::Type>();
+Result<std::unique_ptr<::substrait::Type>> ToProto(const DataType& type, bool
nullable,
+ ExtensionSet* ext_set) {
+ auto out = internal::make_unique<::substrait::Type>();
RETURN_NOT_OK((DataTypeToProtoImpl{out.get(), nullable, ext_set})(type));
return std::move(out);
}
-Result<std::shared_ptr<Schema>> FromProto(const substrait::NamedStruct&
named_struct,
+Result<std::shared_ptr<Schema>> FromProto(const ::substrait::NamedStruct&
named_struct,
const ExtensionSet& ext_set) {
if (!named_struct.has_struct_()) {
return Status::Invalid("While converting ", named_struct.DebugString(),
@@ -461,25 +466,25 @@ void ToProtoGetDepthFirstNames(const FieldVector& fields,
}
} // namespace
-Result<std::unique_ptr<substrait::NamedStruct>> ToProto(const Schema& schema,
- ExtensionSet* ext_set)
{
+Result<std::unique_ptr<::substrait::NamedStruct>> ToProto(const Schema& schema,
+ ExtensionSet*
ext_set) {
if (schema.metadata()) {
- return Status::Invalid("substrait::NamedStruct does not support schema
metadata");
+ return Status::Invalid("::substrait::NamedStruct does not support schema
metadata");
}
- auto named_struct = internal::make_unique<substrait::NamedStruct>();
+ auto named_struct = internal::make_unique<::substrait::NamedStruct>();
auto names = named_struct->mutable_names();
names->Reserve(schema.num_fields());
ToProtoGetDepthFirstNames(schema.fields(), names);
- auto struct_ = internal::make_unique<substrait::Type::Struct>();
+ auto struct_ = internal::make_unique<::substrait::Type::Struct>();
auto types = struct_->mutable_types();
types->Reserve(schema.num_fields());
for (const auto& field : schema.fields()) {
if (field->metadata() != nullptr) {
- return Status::Invalid("substrait::NamedStruct does not support field
metadata");
+ return Status::Invalid("::substrait::NamedStruct does not support field
metadata");
}
ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*field->type(),
field->nullable(), ext_set));
diff --git a/cpp/src/arrow/engine/substrait/util.cc
b/cpp/src/arrow/engine/substrait/util.cc
new file mode 100644
index 0000000000..bc2aa36856
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/util.cc
@@ -0,0 +1,130 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/engine/substrait/util.h"
+#include "arrow/util/async_generator.h"
+#include "arrow/util/async_util.h"
+
+namespace arrow {
+
+namespace engine {
+
+namespace substrait {
+
+namespace {
+
+/// \brief A SinkNodeConsumer specialized to output ExecBatches via
PushGenerator
+class SubstraitSinkConsumer : public compute::SinkNodeConsumer {
+ public:
+ explicit SubstraitSinkConsumer(
+ arrow::PushGenerator<util::optional<compute::ExecBatch>>::Producer
producer)
+ : producer_(std::move(producer)) {}
+
+ Status Consume(compute::ExecBatch batch) override {
+ // Consume a batch of data
+ bool did_push = producer_.Push(batch);
+ if (!did_push) return Status::Invalid("Producer closed already");
+ return Status::OK();
+ }
+
+ Status Init(const std::shared_ptr<Schema>& schema,
+ compute::BackpressureControl* backpressure_control) override {
+ schema_ = schema;
+ return Status::OK();
+ }
+
+ Future<> Finish() override {
+ ARROW_UNUSED(producer_.Close());
+ return Future<>::MakeFinished();
+ }
+
+ std::shared_ptr<Schema> schema() { return schema_; }
+
+ private:
+ arrow::PushGenerator<util::optional<compute::ExecBatch>>::Producer producer_;
+ std::shared_ptr<Schema> schema_;
+};
+
+/// \brief An executor to run a Substrait Query
+/// This interface is provided as a utility when creating language
+/// bindings for consuming a Substrait plan.
+class SubstraitExecutor {
+ public:
+ explicit SubstraitExecutor(std::shared_ptr<compute::ExecPlan> plan,
+ compute::ExecContext exec_context)
+ : plan_(std::move(plan)), exec_context_(exec_context) {}
+
+ ~SubstraitExecutor() { ARROW_CHECK_OK(this->Close()); }
+
+ Result<std::shared_ptr<RecordBatchReader>> Execute() {
+ for (const compute::Declaration& decl : declarations_) {
+ RETURN_NOT_OK(decl.AddToPlan(plan_.get()).status());
+ }
+ RETURN_NOT_OK(plan_->Validate());
+ RETURN_NOT_OK(plan_->StartProducing());
+ auto schema = sink_consumer_->schema();
+ std::shared_ptr<RecordBatchReader> sink_reader =
compute::MakeGeneratorReader(
+ std::move(schema), std::move(generator_), exec_context_.memory_pool());
+ return sink_reader;
+ }
+
+ Status Close() { return plan_->finished().status(); }
+
+ Status Init(const Buffer& substrait_buffer) {
+ if (substrait_buffer.size() == 0) {
+ return Status::Invalid("Empty substrait plan is passed.");
+ }
+ sink_consumer_ =
std::make_shared<SubstraitSinkConsumer>(generator_.producer());
+ std::function<std::shared_ptr<compute::SinkNodeConsumer>()>
consumer_factory = [&] {
+ return sink_consumer_;
+ };
+ ARROW_ASSIGN_OR_RAISE(declarations_,
+ engine::DeserializePlans(substrait_buffer,
consumer_factory));
+ return Status::OK();
+ }
+
+ private:
+ arrow::PushGenerator<util::optional<compute::ExecBatch>> generator_;
+ std::vector<compute::Declaration> declarations_;
+ std::shared_ptr<compute::ExecPlan> plan_;
+ compute::ExecContext exec_context_;
+ std::shared_ptr<SubstraitSinkConsumer> sink_consumer_;
+};
+
+} // namespace
+
+Result<std::shared_ptr<RecordBatchReader>> ExecuteSerializedPlan(
+ const Buffer& substrait_buffer) {
+ ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make());
+ // TODO(ARROW-15732)
+ compute::ExecContext exec_context(arrow::default_memory_pool(),
+ ::arrow::internal::GetCpuThreadPool());
+ SubstraitExecutor executor(std::move(plan), exec_context);
+ RETURN_NOT_OK(executor.Init(substrait_buffer));
+ ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute());
+ return sink_reader;
+}
+
+Result<std::shared_ptr<Buffer>> SerializeJsonPlan(const std::string&
substrait_json) {
+ return engine::internal::SubstraitFromJSON("Plan", substrait_json);
+}
+
+} // namespace substrait
+
+} // namespace engine
+
+} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/util.h
b/cpp/src/arrow/engine/substrait/util.h
new file mode 100644
index 0000000000..860a459da2
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/util.h
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include "arrow/engine/substrait/api.h"
+#include "arrow/util/iterator.h"
+#include "arrow/util/optional.h"
+
+namespace arrow {
+
+namespace engine {
+
+namespace substrait {
+
+/// \brief Retrieve a RecordBatchReader from a Substrait plan.
+ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>>
ExecuteSerializedPlan(
+ const Buffer& substrait_buffer);
+
+/// \brief Get a Serialized Plan from a Substrait JSON plan.
+/// This is a helper method for Python tests.
+ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan(
+ const std::string& substrait_json);
+
+} // namespace substrait
+
+} // namespace engine
+
+} // namespace arrow
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index d17c76e288..3e08253f32 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -69,6 +69,7 @@ endif()
if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
option(PYARROW_BUILD_CUDA "Build the PyArrow CUDA support" OFF)
option(PYARROW_BUILD_FLIGHT "Build the PyArrow Flight integration" OFF)
+ option(PYARROW_BUILD_SUBSTRAIT "Build the PyArrow Substrait integration" OFF)
option(PYARROW_BUILD_DATASET "Build the PyArrow Dataset integration" OFF)
option(PYARROW_BUILD_GANDIVA "Build the PyArrow Gandiva integration" OFF)
option(PYARROW_BUILD_PARQUET "Build the PyArrow Parquet integration" OFF)
@@ -227,6 +228,10 @@ if(PYARROW_BUILD_FLIGHT)
set(ARROW_FLIGHT TRUE)
endif()
+if(PYARROW_BUILD_SUBSTRAIT)
+ set(ARROW_SUBSTRAIT TRUE)
+endif()
+
# Arrow
find_package(ArrowPython REQUIRED)
include_directories(SYSTEM ${ARROW_INCLUDE_DIR})
@@ -535,6 +540,17 @@ if(PYARROW_BUILD_FLIGHT)
set(CYTHON_EXTENSIONS ${CYTHON_EXTENSIONS} _flight)
endif()
+# Engine
+if(PYARROW_BUILD_SUBSTRAIT)
+ find_package(ArrowSubstrait REQUIRED)
+ if(PYARROW_BUNDLE_ARROW_CPP)
+ bundle_arrow_lib(ARROW_SUBSTRAIT_SHARED_LIB SO_VERSION ${ARROW_SO_VERSION})
+ endif()
+
+ set(SUBSTRAIT_LINK_LIBS arrow_substrait_shared)
+ set(CYTHON_EXTENSIONS ${CYTHON_EXTENSIONS} _substrait)
+endif()
+
# Gandiva
if(PYARROW_BUILD_GANDIVA)
find_package(Gandiva REQUIRED)
@@ -625,6 +641,10 @@ if(PYARROW_BUILD_FLIGHT)
target_link_libraries(_flight PRIVATE ${FLIGHT_LINK_LIBS})
endif()
+if(PYARROW_BUILD_SUBSTRAIT)
+ target_link_libraries(_substrait PRIVATE ${SUBSTRAIT_LINK_LIBS})
+endif()
+
if(PYARROW_BUILD_DATASET)
target_link_libraries(_dataset PRIVATE ${DATASET_LINK_LIBS})
target_link_libraries(_exec_plan PRIVATE ${DATASET_LINK_LIBS})
diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx
new file mode 100644
index 0000000000..7f079fb717
--- /dev/null
+++ b/python/pyarrow/_substrait.pyx
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+from cython.operator cimport dereference as deref
+
+from pyarrow import Buffer
+from pyarrow.lib cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_substrait cimport *
+
+
+def run_query(plan):
+ """
+ Execute a Substrait plan and read the results as a RecordBatchReader.
+
+ Parameters
+ ----------
+ plan : Buffer
+ The serialized Substrait plan to execute.
+ """
+
+ cdef:
+ CResult[shared_ptr[CRecordBatchReader]] c_res_reader
+ shared_ptr[CRecordBatchReader] c_reader
+ RecordBatchReader reader
+ c_string c_str_plan
+ shared_ptr[CBuffer] c_buf_plan
+
+ c_buf_plan = pyarrow_unwrap_buffer(plan)
+ with nogil:
+ c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan))
+
+ c_reader = GetResultValue(c_res_reader)
+
+ reader = RecordBatchReader.__new__(RecordBatchReader)
+ reader.reader = c_reader
+ return reader
+
+
+def _parse_json_plan(plan):
+ """
+ Parse a JSON plan into equivalent serialized Protobuf.
+
+ Parameters
+ ----------
+ plan: bytes
+ Substrait plan in JSON.
+
+ Returns
+ -------
+ Buffer
+ A buffer containing the serialized Protobuf plan.
+ """
+
+ cdef:
+ CResult[shared_ptr[CBuffer]] c_res_buffer
+ c_string c_str_plan
+ shared_ptr[CBuffer] c_buf_plan
+
+ c_str_plan = plan
+ c_res_buffer = SerializeJsonPlan(c_str_plan)
+ with nogil:
+ c_buf_plan = GetResultValue(c_res_buffer)
+ return pyarrow_wrap_buffer(c_buf_plan)
diff --git a/python/pyarrow/includes/libarrow_substrait.pxd
b/python/pyarrow/includes/libarrow_substrait.pxd
new file mode 100644
index 0000000000..2e1a17b06b
--- /dev/null
+++ b/python/pyarrow/includes/libarrow_substrait.pxd
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+
+cdef extern from "arrow/engine/substrait/util.h" namespace
"arrow::engine::substrait" nogil:
+ CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const
CBuffer& substrait_buffer)
+ CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string&
substrait_json)
diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py
new file mode 100644
index 0000000000..e3ff28f4eb
--- /dev/null
+++ b/python/pyarrow/substrait.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyarrow._substrait import ( # noqa
+ run_query,
+)
diff --git a/python/pyarrow/tests/conftest.py b/python/pyarrow/tests/conftest.py
index 466b1647fd..a5aae6f634 100644
--- a/python/pyarrow/tests/conftest.py
+++ b/python/pyarrow/tests/conftest.py
@@ -66,6 +66,7 @@ groups = [
'plasma',
's3',
'snappy',
+ 'substrait',
'tensorflow',
'flight',
'slow',
@@ -98,6 +99,7 @@ defaults = {
's3': False,
'slow': False,
'snappy': Codec.is_available('snappy'),
+ 'substrait': False,
'tensorflow': False,
'zstd': Codec.is_available('zstd'),
}
@@ -181,6 +183,12 @@ try:
except ImportError:
pass
+try:
+ import pyarrow.substrait # noqa
+ defaults['substrait'] = True
+except ImportError:
+ pass
+
def pytest_addoption(parser):
# Create options to selectively enable test groups
diff --git a/python/pyarrow/tests/test_substrait.py
b/python/pyarrow/tests/test_substrait.py
new file mode 100644
index 0000000000..8df35bbba4
--- /dev/null
+++ b/python/pyarrow/tests/test_substrait.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+import pytest
+
+import pyarrow as pa
+from pyarrow.lib import tobytes
+from pyarrow.lib import ArrowInvalid
+
+try:
+ import pyarrow.substrait as substrait
+except ImportError:
+ substrait = None
+
+# Marks all of the tests in this module
+# Ignore these with pytest ... -m 'not substrait'
+pytestmark = [pytest.mark.dataset, pytest.mark.substrait]
+
+
[email protected](sys.platform == 'win32',
+ reason="ARROW-16392: file based URI is" +
+ " not fully supported for Windows")
+def test_run_serialized_query(tmpdir):
+ substrait_query = """
+ {
+ "relations": [
+ {"rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [
+ {"i64": {}}
+ ]
+ },
+ "names": [
+ "foo"
+ ]
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file://FILENAME_PLACEHOLDER"
+ }
+ ]
+ }
+ }
+ }}
+ ]
+ }
+ """
+ # TODO: replace with ipc when the support is finalized in C++
+ path = os.path.join(str(tmpdir), 'substrait_data.arrow')
+ table = pa.table([[1, 2, 3, 4, 5]], names=['foo'])
+ with pa.ipc.RecordBatchFileWriter(path, schema=table.schema) as writer:
+ writer.write_table(table)
+
+ query = tobytes(substrait_query.replace("FILENAME_PLACEHOLDER", path))
+
+ buf = pa._substrait._parse_json_plan(query)
+
+ reader = substrait.run_query(buf)
+ res_tb = reader.read_all()
+
+ assert table.select(["foo"]) == res_tb.select(["foo"])
+
+
+def test_invalid_plan():
+ query = """
+ {
+ "relations": [
+ ]
+ }
+ """
+ buf = pa._substrait._parse_json_plan(tobytes(query))
+ exec_message = "Empty substrait plan is passed."
+ with pytest.raises(ArrowInvalid, match=exec_message):
+ substrait.run_query(buf)
diff --git a/python/setup.py b/python/setup.py
index 1189357b23..79ec3c8447 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -108,6 +108,7 @@ class build_ext(_build_ext):
'namespace of boost (default: boost)'),
('with-cuda', None, 'build the Cuda extension'),
('with-flight', None, 'build the Flight extension'),
+ ('with-substrait', None, 'build the Substrait extension'),
('with-dataset', None, 'build the Dataset extension'),
('with-parquet', None, 'build the Parquet extension'),
('with-parquet-encryption', None,
@@ -160,6 +161,8 @@ class build_ext(_build_ext):
os.environ.get('PYARROW_WITH_HDFS', '0'))
self.with_cuda = strtobool(
os.environ.get('PYARROW_WITH_CUDA', '0'))
+ self.with_substrait = strtobool(
+ os.environ.get('PYARROW_WITH_SUBSTRAIT', '0'))
self.with_flight = strtobool(
os.environ.get('PYARROW_WITH_FLIGHT', '0'))
self.with_dataset = strtobool(
@@ -214,6 +217,7 @@ class build_ext(_build_ext):
'_orc',
'_plasma',
'_s3fs',
+ '_substrait',
'_hdfs',
'_hdfsio',
'gandiva']
@@ -268,6 +272,7 @@ class build_ext(_build_ext):
cmake_options += ['-G', self.cmake_generator]
append_cmake_bool(self.with_cuda, 'PYARROW_BUILD_CUDA')
+ append_cmake_bool(self.with_substrait, 'PYARROW_BUILD_SUBSTRAIT')
append_cmake_bool(self.with_flight, 'PYARROW_BUILD_FLIGHT')
append_cmake_bool(self.with_gandiva, 'PYARROW_BUILD_GANDIVA')
append_cmake_bool(self.with_dataset, 'PYARROW_BUILD_DATASET')
@@ -393,6 +398,8 @@ class build_ext(_build_ext):
move_shared_libs(build_prefix, build_lib, "arrow_python")
if self.with_cuda:
move_shared_libs(build_prefix, build_lib, "arrow_cuda")
+ if self.with_substrait:
+ move_shared_libs(build_prefix, build_lib, "arrow_substrait")
if self.with_flight:
move_shared_libs(build_prefix, build_lib, "arrow_flight")
move_shared_libs(build_prefix, build_lib,
@@ -438,6 +445,8 @@ class build_ext(_build_ext):
return True
if name == '_flight' and not self.with_flight:
return True
+ if name == '_substrait' and not self.with_substrait:
+ return True
if name == '_s3fs' and not self.with_s3:
return True
if name == '_hdfs' and not self.with_hdfs: