This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 2eb2573790 ARROW-15590: [C++] Add support for joins to the Substrait
consumer (#13078)
2eb2573790 is described below
commit 2eb257379062a5c349017397fd64a2b27bae3056
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Wed Jun 1 06:20:41 2022 +0530
ARROW-15590: [C++] Add support for joins to the Substrait consumer (#13078)
Initial Version of Substrait Join Support
This PR doesn't support the complete join functionality, but it include the
following features.
This will be a followed by a set of PRs to solve the remaining features [1].
Features included
- [X] Only Support Inner Join (A follow up PR would include the support for
other join types)
- [X] Support Join operations with a single call-expression of types
"equal" and "is_not_distinct_from"
- [X] Test cases to check the basic functionality and limitations
Todo:
- [x] Fix the Windows CI Issue
[1]. https://issues.apache.org/jira/browse/ARROW-16485
Lead-authored-by: Vibhatha Abeykoon <[email protected]>
Co-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/src/arrow/compute/exec/options.h | 13 +
cpp/src/arrow/engine/substrait/extension_set.cc | 2 +
.../arrow/engine/substrait/relation_internal.cc | 82 +++++
cpp/src/arrow/engine/substrait/serde_test.cc | 372 +++++++++++++++++++++
4 files changed, 469 insertions(+)
diff --git a/cpp/src/arrow/compute/exec/options.h
b/cpp/src/arrow/compute/exec/options.h
index 48cbf9d371..31c910c25b 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -296,6 +296,19 @@ class ARROW_EXPORT HashJoinNodeOptions : public
ExecNodeOptions {
this->key_cmp[i] = JoinKeyCmp::EQ;
}
}
+ HashJoinNodeOptions(std::vector<FieldRef> in_left_keys,
+ std::vector<FieldRef> in_right_keys)
+ : left_keys(std::move(in_left_keys)),
right_keys(std::move(in_right_keys)) {
+ this->join_type = JoinType::INNER;
+ this->output_all = true;
+ this->output_suffix_for_left = default_output_suffix_for_left;
+ this->output_suffix_for_right = default_output_suffix_for_right;
+ this->key_cmp.resize(this->left_keys.size());
+ for (size_t i = 0; i < this->left_keys.size(); ++i) {
+ this->key_cmp[i] = JoinKeyCmp::EQ;
+ }
+ this->filter = literal(true);
+ }
HashJoinNodeOptions(
JoinType join_type, std::vector<FieldRef> left_keys,
std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc
b/cpp/src/arrow/engine/substrait/extension_set.cc
index b7d2f87b74..cd85678a72 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -240,6 +240,8 @@ ExtensionIdRegistry* default_extension_id_registry() {
// ARROW-15535.
for (util::string_view name : {
"add",
+ "equal",
+ "is_not_distinct_from",
}) {
DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name},
name.to_string()));
}
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 723edfe2ec..89ab7ca4dc 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -225,6 +225,88 @@ Result<compute::Declaration> FromProto(const
substrait::Rel& rel,
});
}
+ case substrait::Rel::RelTypeCase::kJoin: {
+ const auto& join = rel.join();
+ RETURN_NOT_OK(CheckRelCommon(join));
+
+ if (!join.has_left()) {
+ return Status::Invalid("substrait::JoinRel with no left relation");
+ }
+
+ if (!join.has_right()) {
+ return Status::Invalid("substrait::JoinRel with no right relation");
+ }
+
+ compute::JoinType join_type;
+ switch (join.type()) {
+ case substrait::JoinRel::JOIN_TYPE_UNSPECIFIED:
+ return Status::NotImplemented("Unspecified join type is not
supported");
+ case substrait::JoinRel::JOIN_TYPE_INNER:
+ join_type = compute::JoinType::INNER;
+ break;
+ case substrait::JoinRel::JOIN_TYPE_OUTER:
+ join_type = compute::JoinType::FULL_OUTER;
+ break;
+ case substrait::JoinRel::JOIN_TYPE_LEFT:
+ join_type = compute::JoinType::LEFT_OUTER;
+ break;
+ case substrait::JoinRel::JOIN_TYPE_RIGHT:
+ join_type = compute::JoinType::RIGHT_OUTER;
+ break;
+ case substrait::JoinRel::JOIN_TYPE_SEMI:
+ join_type = compute::JoinType::LEFT_SEMI;
+ break;
+ case substrait::JoinRel::JOIN_TYPE_ANTI:
+ join_type = compute::JoinType::LEFT_ANTI;
+ break;
+ default:
+ return Status::Invalid("Unsupported join type");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto left, FromProto(join.left(), ext_set));
+ ARROW_ASSIGN_OR_RAISE(auto right, FromProto(join.right(), ext_set));
+
+ if (!join.has_expression()) {
+ return Status::Invalid("substrait::JoinRel with no expression");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto expression, FromProto(join.expression(),
ext_set));
+
+ const auto* callptr = expression.call();
+ if (!callptr) {
+ return Status::Invalid(
+ "A join rel's expression must be a simple equality between keys
but got ",
+ expression.ToString());
+ }
+
+ compute::JoinKeyCmp join_key_cmp;
+ if (callptr->function_name == "equal") {
+ join_key_cmp = compute::JoinKeyCmp::EQ;
+ } else if (callptr->function_name == "is_not_distinct_from") {
+ join_key_cmp = compute::JoinKeyCmp::IS;
+ } else {
+ return Status::Invalid(
+ "Only `equal` or `is_not_distinct_from` are supported for join key
"
+ "comparison but got ",
+ callptr->function_name);
+ }
+
+ // TODO: ARROW-166241 Add Suffix support for Substrait
+ const auto* left_keys = callptr->arguments[0].field_ref();
+ const auto* right_keys = callptr->arguments[1].field_ref();
+ if (!left_keys || !right_keys) {
+ return Status::Invalid("Left keys for join cannot be null");
+ }
+ compute::HashJoinNodeOptions join_options{{std::move(*left_keys)},
+ {std::move(*right_keys)}};
+ join_options.join_type = join_type;
+ join_options.key_cmp = {join_key_cmp};
+ compute::Declaration join_dec{"hashjoin", std::move(join_options)};
+ join_dec.inputs.emplace_back(std::move(left));
+ join_dec.inputs.emplace_back(std::move(right));
+ return std::move(join_dec);
+ }
+
default:
break;
}
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc
b/cpp/src/arrow/engine/substrait/serde_test.cc
index fae23f200d..9a0e93fc7a 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -801,5 +801,377 @@ TEST(Substrait, InvalidPlan) {
ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf));
}
+TEST(Substrait, JoinPlanBasic) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "join": {
+ "left": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat1.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "right": {
+ "read": {
+ "base_schema": {
+ "names": ["X", "Y", "A"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat2.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "expression": {
+ "scalarFunction": {
+ "functionReference": 0,
+ "args": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "rootReference": {
+ }
+ }
+ }, {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 5
+ }
+ },
+ "rootReference": {
+ }
+ }
+ }]
+ }
+ },
+ "type": "JOIN_TYPE_INNER"
+ }
+ }
+ }],
+ "extension_uris": [
+ {
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }
+ ],
+ "extensions": [
+ {"extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "equal"
+ }}
+ ]
+ })"));
+ ExtensionSet ext_set;
+ ASSERT_OK_AND_ASSIGN(
+ auto sink_decls,
+ DeserializePlans(
+ *buf, [] { return
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+ &ext_set));
+
+ auto join_decl = sink_decls[0].inputs[0];
+
+ const auto& join_rel = join_decl.get<compute::Declaration>();
+
+ const auto& join_options =
+ checked_cast<const compute::HashJoinNodeOptions&>(*join_rel->options);
+
+ EXPECT_EQ(join_rel->factory_name, "hashjoin");
+ EXPECT_EQ(join_options.join_type, compute::JoinType::INNER);
+
+ const auto& left_rel = join_rel->inputs[0].get<compute::Declaration>();
+ const auto& right_rel = join_rel->inputs[1].get<compute::Declaration>();
+
+ const auto& l_options =
+ checked_cast<const dataset::ScanNodeOptions&>(*left_rel->options);
+ const auto& r_options =
+ checked_cast<const dataset::ScanNodeOptions&>(*right_rel->options);
+
+ AssertSchemaEqual(
+ l_options.dataset->schema(),
+ schema({field("A", int32()), field("B", int32()), field("C", int32())}));
+ AssertSchemaEqual(
+ r_options.dataset->schema(),
+ schema({field("X", int32()), field("Y", int32()), field("A", int32())}));
+
+ EXPECT_EQ(join_options.key_cmp[0], compute::JoinKeyCmp::EQ);
+}
+
+TEST(Substrait, JoinPlanInvalidKeyCmp) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "join": {
+ "left": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat1.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "right": {
+ "read": {
+ "base_schema": {
+ "names": ["X", "Y", "A"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat2.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "expression": {
+ "scalarFunction": {
+ "functionReference": 0,
+ "args": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "rootReference": {
+ }
+ }
+ }, {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 5
+ }
+ },
+ "rootReference": {
+ }
+ }
+ }]
+ }
+ },
+ "type": "JOIN_TYPE_INNER"
+ }
+ }
+ }],
+ "extension_uris": [
+ {
+ "extension_uri_anchor": 0,
+ "uri":
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"
+ }
+ ],
+ "extensions": [
+ {"extension_function": {
+ "extension_uri_reference": 0,
+ "function_anchor": 0,
+ "name": "add"
+ }}
+ ]
+ })"));
+ ExtensionSet ext_set;
+ ASSERT_RAISES(
+ Invalid,
+ DeserializePlans(
+ *buf, [] { return
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+ &ext_set));
+}
+
+TEST(Substrait, JoinPlanInvalidExpression) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "join": {
+ "left": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat1.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "right": {
+ "read": {
+ "base_schema": {
+ "names": ["X", "Y", "A"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat2.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "expression": {"literal": {"list": {"values": []}}},
+ "type": "JOIN_TYPE_INNER"
+ }
+ }
+ }]
+ })"));
+ ExtensionSet ext_set;
+ ASSERT_RAISES(
+ Invalid,
+ DeserializePlans(
+ *buf, [] { return
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+ &ext_set));
+}
+
+TEST(Substrait, JoinPlanInvalidKeys) {
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+ "relations": [{
+ "rel": {
+ "join": {
+ "left": {
+ "read": {
+ "base_schema": {
+ "names": ["A", "B", "C"],
+ "struct": {
+ "types": [{
+ "i32": {}
+ }, {
+ "i32": {}
+ }, {
+ "i32": {}
+ }]
+ }
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "file:///tmp/dat1.parquet",
+ "format": "FILE_FORMAT_PARQUET"
+ }
+ ]
+ }
+ }
+ },
+ "expression": {
+ "scalarFunction": {
+ "functionReference": 0,
+ "args": [{
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0
+ }
+ },
+ "rootReference": {
+ }
+ }
+ }, {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 5
+ }
+ },
+ "rootReference": {
+ }
+ }
+ }]
+ }
+ },
+ "type": "JOIN_TYPE_INNER"
+ }
+ }
+ }]
+ })"));
+ ExtensionSet ext_set;
+ ASSERT_RAISES(
+ Invalid,
+ DeserializePlans(
+ *buf, [] { return
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+ &ext_set));
+}
+
} // namespace engine
} // namespace arrow