This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 2eb2573790 ARROW-15590: [C++] Add support for joins to the Substrait 
consumer (#13078)
2eb2573790 is described below

commit 2eb257379062a5c349017397fd64a2b27bae3056
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Wed Jun 1 06:20:41 2022 +0530

    ARROW-15590: [C++] Add support for joins to the Substrait consumer (#13078)
    
    Initial Version of Substrait Join Support
    
    This PR doesn't support the complete join functionality, but it include the 
following features.
    This will be a followed by a set of PRs to solve the remaining features [1].
    
    Features included
    
    - [X] Only Support Inner Join (A follow up PR would include the support for 
other join types)
    - [X] Support Join operations with a single call-expression of types 
"equal" and "is_not_distinct_from"
    - [X] Test cases to check the basic functionality and limitations
    
    Todo:
    - [x] Fix the Windows CI Issue
    
    [1]. https://issues.apache.org/jira/browse/ARROW-16485
    
    Lead-authored-by: Vibhatha Abeykoon <[email protected]>
    Co-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/src/arrow/compute/exec/options.h               |  13 +
 cpp/src/arrow/engine/substrait/extension_set.cc    |   2 +
 .../arrow/engine/substrait/relation_internal.cc    |  82 +++++
 cpp/src/arrow/engine/substrait/serde_test.cc       | 372 +++++++++++++++++++++
 4 files changed, 469 insertions(+)

diff --git a/cpp/src/arrow/compute/exec/options.h 
b/cpp/src/arrow/compute/exec/options.h
index 48cbf9d371..31c910c25b 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -296,6 +296,19 @@ class ARROW_EXPORT HashJoinNodeOptions : public 
ExecNodeOptions {
       this->key_cmp[i] = JoinKeyCmp::EQ;
     }
   }
+  HashJoinNodeOptions(std::vector<FieldRef> in_left_keys,
+                      std::vector<FieldRef> in_right_keys)
+      : left_keys(std::move(in_left_keys)), 
right_keys(std::move(in_right_keys)) {
+    this->join_type = JoinType::INNER;
+    this->output_all = true;
+    this->output_suffix_for_left = default_output_suffix_for_left;
+    this->output_suffix_for_right = default_output_suffix_for_right;
+    this->key_cmp.resize(this->left_keys.size());
+    for (size_t i = 0; i < this->left_keys.size(); ++i) {
+      this->key_cmp[i] = JoinKeyCmp::EQ;
+    }
+    this->filter = literal(true);
+  }
   HashJoinNodeOptions(
       JoinType join_type, std::vector<FieldRef> left_keys,
       std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc 
b/cpp/src/arrow/engine/substrait/extension_set.cc
index b7d2f87b74..cd85678a72 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -240,6 +240,8 @@ ExtensionIdRegistry* default_extension_id_registry() {
       // ARROW-15535.
       for (util::string_view name : {
                "add",
+               "equal",
+               "is_not_distinct_from",
            }) {
         DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, 
name.to_string()));
       }
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc 
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 723edfe2ec..89ab7ca4dc 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -225,6 +225,88 @@ Result<compute::Declaration> FromProto(const 
substrait::Rel& rel,
       });
     }
 
+    case substrait::Rel::RelTypeCase::kJoin: {
+      const auto& join = rel.join();
+      RETURN_NOT_OK(CheckRelCommon(join));
+
+      if (!join.has_left()) {
+        return Status::Invalid("substrait::JoinRel with no left relation");
+      }
+
+      if (!join.has_right()) {
+        return Status::Invalid("substrait::JoinRel with no right relation");
+      }
+
+      compute::JoinType join_type;
+      switch (join.type()) {
+        case substrait::JoinRel::JOIN_TYPE_UNSPECIFIED:
+          return Status::NotImplemented("Unspecified join type is not 
supported");
+        case substrait::JoinRel::JOIN_TYPE_INNER:
+          join_type = compute::JoinType::INNER;
+          break;
+        case substrait::JoinRel::JOIN_TYPE_OUTER:
+          join_type = compute::JoinType::FULL_OUTER;
+          break;
+        case substrait::JoinRel::JOIN_TYPE_LEFT:
+          join_type = compute::JoinType::LEFT_OUTER;
+          break;
+        case substrait::JoinRel::JOIN_TYPE_RIGHT:
+          join_type = compute::JoinType::RIGHT_OUTER;
+          break;
+        case substrait::JoinRel::JOIN_TYPE_SEMI:
+          join_type = compute::JoinType::LEFT_SEMI;
+          break;
+        case substrait::JoinRel::JOIN_TYPE_ANTI:
+          join_type = compute::JoinType::LEFT_ANTI;
+          break;
+        default:
+          return Status::Invalid("Unsupported join type");
+      }
+
+      ARROW_ASSIGN_OR_RAISE(auto left, FromProto(join.left(), ext_set));
+      ARROW_ASSIGN_OR_RAISE(auto right, FromProto(join.right(), ext_set));
+
+      if (!join.has_expression()) {
+        return Status::Invalid("substrait::JoinRel with no expression");
+      }
+
+      ARROW_ASSIGN_OR_RAISE(auto expression, FromProto(join.expression(), 
ext_set));
+
+      const auto* callptr = expression.call();
+      if (!callptr) {
+        return Status::Invalid(
+            "A join rel's expression must be a simple equality between keys 
but got ",
+            expression.ToString());
+      }
+
+      compute::JoinKeyCmp join_key_cmp;
+      if (callptr->function_name == "equal") {
+        join_key_cmp = compute::JoinKeyCmp::EQ;
+      } else if (callptr->function_name == "is_not_distinct_from") {
+        join_key_cmp = compute::JoinKeyCmp::IS;
+      } else {
+        return Status::Invalid(
+            "Only `equal` or `is_not_distinct_from` are supported for join key 
"
+            "comparison but got ",
+            callptr->function_name);
+      }
+
+      // TODO: ARROW-166241 Add Suffix support for Substrait
+      const auto* left_keys = callptr->arguments[0].field_ref();
+      const auto* right_keys = callptr->arguments[1].field_ref();
+      if (!left_keys || !right_keys) {
+        return Status::Invalid("Left keys for join cannot be null");
+      }
+      compute::HashJoinNodeOptions join_options{{std::move(*left_keys)},
+                                                {std::move(*right_keys)}};
+      join_options.join_type = join_type;
+      join_options.key_cmp = {join_key_cmp};
+      compute::Declaration join_dec{"hashjoin", std::move(join_options)};
+      join_dec.inputs.emplace_back(std::move(left));
+      join_dec.inputs.emplace_back(std::move(right));
+      return std::move(join_dec);
+    }
+
     default:
       break;
   }
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc 
b/cpp/src/arrow/engine/substrait/serde_test.cc
index fae23f200d..9a0e93fc7a 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -801,5 +801,377 @@ TEST(Substrait, InvalidPlan) {
   ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf));
 }
 
+TEST(Substrait, JoinPlanBasic) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+  "relations": [{
+    "rel": {
+      "join": {
+        "left": {
+          "read": {
+            "base_schema": {
+              "names": ["A", "B", "C"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat1.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ] 
+            }
+          }
+        },
+        "right": {
+          "read": {
+            "base_schema": {
+              "names": ["X", "Y", "A"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat2.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ]
+            }
+          }
+        },
+        "expression": {
+          "scalarFunction": {
+            "functionReference": 0,
+            "args": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                },
+                "rootReference": {
+                }
+              }
+            }, {
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 5
+                  }
+                },
+                "rootReference": {
+                }
+              }
+            }]
+          }
+        },
+        "type": "JOIN_TYPE_INNER"
+      }
+    }
+  }],
+  "extension_uris": [
+      {
+        "extension_uri_anchor": 0,
+        "uri": 
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml";
+      }
+    ],
+    "extensions": [
+      {"extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "equal"
+      }}
+    ]
+  })"));
+  ExtensionSet ext_set;
+  ASSERT_OK_AND_ASSIGN(
+      auto sink_decls,
+      DeserializePlans(
+          *buf, [] { return 
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+          &ext_set));
+
+  auto join_decl = sink_decls[0].inputs[0];
+
+  const auto& join_rel = join_decl.get<compute::Declaration>();
+
+  const auto& join_options =
+      checked_cast<const compute::HashJoinNodeOptions&>(*join_rel->options);
+
+  EXPECT_EQ(join_rel->factory_name, "hashjoin");
+  EXPECT_EQ(join_options.join_type, compute::JoinType::INNER);
+
+  const auto& left_rel = join_rel->inputs[0].get<compute::Declaration>();
+  const auto& right_rel = join_rel->inputs[1].get<compute::Declaration>();
+
+  const auto& l_options =
+      checked_cast<const dataset::ScanNodeOptions&>(*left_rel->options);
+  const auto& r_options =
+      checked_cast<const dataset::ScanNodeOptions&>(*right_rel->options);
+
+  AssertSchemaEqual(
+      l_options.dataset->schema(),
+      schema({field("A", int32()), field("B", int32()), field("C", int32())}));
+  AssertSchemaEqual(
+      r_options.dataset->schema(),
+      schema({field("X", int32()), field("Y", int32()), field("A", int32())}));
+
+  EXPECT_EQ(join_options.key_cmp[0], compute::JoinKeyCmp::EQ);
+}
+
+TEST(Substrait, JoinPlanInvalidKeyCmp) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+  "relations": [{
+    "rel": {
+      "join": {
+        "left": {
+          "read": {
+            "base_schema": {
+              "names": ["A", "B", "C"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat1.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ] 
+            }
+          }
+        },
+        "right": {
+          "read": {
+            "base_schema": {
+              "names": ["X", "Y", "A"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat2.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ]
+            }
+          }
+        },
+        "expression": {
+          "scalarFunction": {
+            "functionReference": 0,
+            "args": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                },
+                "rootReference": {
+                }
+              }
+            }, {
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 5
+                  }
+                },
+                "rootReference": {
+                }
+              }
+            }]
+          }
+        },
+        "type": "JOIN_TYPE_INNER"
+      }
+    }
+  }],
+  "extension_uris": [
+      {
+        "extension_uri_anchor": 0,
+        "uri": 
"https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml";
+      }
+    ],
+    "extensions": [
+      {"extension_function": {
+        "extension_uri_reference": 0,
+        "function_anchor": 0,
+        "name": "add"
+      }}
+    ]
+  })"));
+  ExtensionSet ext_set;
+  ASSERT_RAISES(
+      Invalid,
+      DeserializePlans(
+          *buf, [] { return 
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+          &ext_set));
+}
+
+TEST(Substrait, JoinPlanInvalidExpression) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+  "relations": [{
+    "rel": {
+      "join": {
+        "left": {
+          "read": {
+            "base_schema": {
+              "names": ["A", "B", "C"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat1.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ] 
+            }
+          }
+        },
+        "right": {
+          "read": {
+            "base_schema": {
+              "names": ["X", "Y", "A"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat2.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ]
+            }
+          }
+        },
+        "expression": {"literal": {"list": {"values": []}}},
+        "type": "JOIN_TYPE_INNER"
+      }
+    }
+  }]
+  })"));
+  ExtensionSet ext_set;
+  ASSERT_RAISES(
+      Invalid,
+      DeserializePlans(
+          *buf, [] { return 
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+          &ext_set));
+}
+
+TEST(Substrait, JoinPlanInvalidKeys) {
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({
+  "relations": [{
+    "rel": {
+      "join": {
+        "left": {
+          "read": {
+            "base_schema": {
+              "names": ["A", "B", "C"],
+              "struct": {
+                "types": [{
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }, {
+                  "i32": {}
+                }]
+              }
+            },
+            "local_files": { 
+              "items": [
+                {
+                  "uri_file": "file:///tmp/dat1.parquet",
+                  "format": "FILE_FORMAT_PARQUET"
+                }
+              ] 
+            }
+          }
+        },
+        "expression": {
+          "scalarFunction": {
+            "functionReference": 0,
+            "args": [{
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 0
+                  }
+                },
+                "rootReference": {
+                }
+              }
+            }, {
+              "selection": {
+                "directReference": {
+                  "structField": {
+                    "field": 5
+                  }
+                },
+                "rootReference": {
+                }
+              }
+            }]
+          }
+        },
+        "type": "JOIN_TYPE_INNER"
+      }
+    }
+  }]
+  })"));
+  ExtensionSet ext_set;
+  ASSERT_RAISES(
+      Invalid,
+      DeserializePlans(
+          *buf, [] { return 
std::shared_ptr<compute::SinkNodeConsumer>{nullptr}; },
+          &ext_set));
+}
+
 }  // namespace engine
 }  // namespace arrow

Reply via email to