bkietz commented on a change in pull request #11579:
URL: https://github.com/apache/arrow/pull/11579#discussion_r745800870



##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -255,17 +284,83 @@ std::shared_ptr<Schema> HashJoinSchema::MakeOutputSchema(
   return std::make_shared<Schema>(std::move(fields));
 }
 
+Result<Expression> HashJoinSchema::BindFilter(Expression filter,
+                                              const Schema& left_schema,
+                                              const Schema& right_schema) {
+  if (filter.IsBound()) {
+    return std::move(filter);
+  }
+  if (!filter.IsEmpty()) {
+    FieldVector fields;
+    auto left = proj_maps[0].map(HashJoinProjection::FILTER, 
HashJoinProjection::INPUT);
+    auto right = proj_maps[1].map(HashJoinProjection::FILTER, 
HashJoinProjection::INPUT);
+
+    auto AppendFieldsInMap = [&fields](const SchemaProjectionMap& map,
+                                       const Schema& schema) {
+      for (int i = 0; i < map.num_cols; i++) {
+        int input_idx = map.get(i);
+        fields.push_back(schema.fields()[input_idx]);
+      }
+    };
+    AppendFieldsInMap(left, left_schema);
+    AppendFieldsInMap(right, right_schema);
+    Schema filter_schema(fields);
+    ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema));
+    if (filter.type()->id() != Type::BOOL) {
+      return Status::TypeError("Filter expression must evaluate to bool, but ",
+                               filter.ToString(), " evaluates to ",
+                               filter.type()->ToString());
+    }
+    return std::move(filter);
+  }
+  return Expression();
+}
+
+Result<std::vector<FieldRef>> HashJoinSchema::CollectFilterColumns(
+    const Expression& filter, const Schema& schema) {
+  std::vector<FieldRef> nonunique_refs;
+  RETURN_NOT_OK(TraverseExpression(nonunique_refs, filter, schema));
+
+  std::vector<FieldRef> result;
+  std::unordered_set<int> seen_paths;
+  for (auto ref : nonunique_refs) {
+    ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema));
+    if (seen_paths.find(match[0]) == seen_paths.end()) {
+      seen_paths.insert(match[0]);
+      result.push_back(ref);
+    }
+  }
+  return result;

Review comment:
       Isn't it the case that the output schema of a hash join node is always
   ```python
   schema({
     **left_payload,
     **left_keys,
     **right_payload,
     **right_keys,
   })
   ```
   In light of this, I think it's possible to map predictably from an index in 
that schema to a field in one of the inputs.
   
   Addendum to one of the unit tests making this explicit:
   ```diff
   diff --git a/cpp/src/arrow/compute/exec/plan_test.cc 
b/cpp/src/arrow/compute/exec/plan_test.cc
   index 437a93f9e..c250b305d 100644
   --- a/cpp/src/arrow/compute/exec/plan_test.cc
   +++ b/cpp/src/arrow/compute/exec/plan_test.cc
   @@ -1046,6 +1046,13 @@ TEST(ExecPlanExecution, SelfInnerHashJoinSink) {
            auto hashjoin,
            MakeExecNode("hashjoin", plan.get(), {left_filter, right_filter}, 
join_opts));
   
   +    ASSERT_EQ(*hashjoin->output_schema(), Schema({
   +                                              field("l_i32", int32()),
   +                                              field("l_str", utf8()),
   +                                              field("r_i32", int32()),
   +                                              field("r_str", utf8()),
   +                                          }));
   +
        ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), 
{hashjoin},
                                                       
SinkNodeOptions{&sink_gen}));
   
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to