nirandaperera commented on a change in pull request #10845:
URL: https://github.com/apache/arrow/pull/10845#discussion_r697534271



##########
File path: cpp/src/arrow/compute/exec/hash_join_node.cc
##########
@@ -0,0 +1,555 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <mutex>
+
+#include "arrow/api.h"
+#include "arrow/compute/api.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/util/bitmap_ops.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/thread_pool.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute {
+
+namespace {
+Status ValidateJoinInputs(const std::shared_ptr<Schema>& left_schema,
+                          const std::shared_ptr<Schema>& right_schema,
+                          const std::vector<int>& left_keys,
+                          const std::vector<int>& right_keys) {
+  if (left_keys.size() != right_keys.size()) {
+    return Status::Invalid("left and right key sizes do not match");
+  }
+
+  for (size_t i = 0; i < left_keys.size(); i++) {
+    auto l_type = left_schema->field(left_keys[i])->type();
+    auto r_type = right_schema->field(right_keys[i])->type();
+
+    if (!l_type->Equals(r_type)) {
+      return Status::Invalid("build and probe types do not match: " + 
l_type->ToString() +
+                             "!=" + r_type->ToString());
+    }
+  }
+
+  return Status::OK();
+}
+
+Result<std::vector<int>> PopulateKeys(const Schema& schema,
+                                      const std::vector<FieldRef>& keys) {
+  std::vector<int> key_field_ids(keys.size());
+  // Find input field indices for left key fields
+  for (size_t i = 0; i < keys.size(); ++i) {
+    ARROW_ASSIGN_OR_RAISE(auto match, keys[i].FindOne(schema));
+    key_field_ids[i] = match[0];
+  }
+  return key_field_ids;
+}
+}  // namespace
+
+template <bool anti_join = false>
+struct HashSemiJoinNode : ExecNode {
+  HashSemiJoinNode(ExecNode* build_input, ExecNode* probe_input, ExecContext* 
ctx,
+                   const std::vector<int>&& build_index_field_ids,
+                   const std::vector<int>&& probe_index_field_ids)
+      : ExecNode(build_input->plan(), {build_input, probe_input},
+                 {"hash_join_build", "hash_join_probe"}, 
probe_input->output_schema(),
+                 /*num_outputs=*/1),
+        ctx_(ctx),
+        build_index_field_ids_(build_index_field_ids),
+        probe_index_field_ids_(probe_index_field_ids),
+        build_result_index(-1),
+        hash_table_built_(false),
+        cached_probe_batches_consumed(false) {}
+
+ private:
+  struct ThreadLocalState;
+
+ public:
+  const char* kind_name() override { return "HashSemiJoinNode"; }
+
+  Status InitLocalStateIfNeeded(ThreadLocalState* state) {
+    ARROW_LOG(DEBUG) << "init state";
+
+    // Get input schema
+    auto build_schema = inputs_[0]->output_schema();
+
+    if (state->grouper != nullptr) return Status::OK();
+
+    // Build vector of key field data types
+    std::vector<ValueDescr> key_descrs(build_index_field_ids_.size());
+    for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+      auto build_type = build_schema->field(build_index_field_ids_[i])->type();
+      key_descrs[i] = ValueDescr(build_type);
+    }
+
+    // Construct grouper
+    ARROW_ASSIGN_OR_RAISE(state->grouper, internal::Grouper::Make(key_descrs, 
ctx_));
+
+    return Status::OK();
+  }
+
+  // Finds an appropriate index which could accumulate all build indices (i.e. 
the grouper
+  // which has the highest # of groups)
+  void CalculateBuildResultIndex() {
+    int32_t curr_max = -1;
+    for (int i = 0; i < static_cast<int>(local_states_.size()); i++) {
+      auto* state = &local_states_[i];
+      ARROW_DCHECK(state);
+      if (state->grouper &&
+          curr_max < static_cast<int32_t>(state->grouper->num_groups())) {
+        curr_max = static_cast<int32_t>(state->grouper->num_groups());
+        build_result_index = i;
+      }
+    }
+    ARROW_DCHECK(build_result_index > -1);
+    ARROW_LOG(DEBUG) << "build_result_index " << build_result_index;
+  }
+
+  // Performs the housekeeping work after the build-side is completed.
+  // Note: this method is not thread safe, and hence should be guaranteed that 
it is
+  // not accessed concurrently!
+  Status BuildSideCompleted() {
+    ARROW_LOG(DEBUG) << "build side merge";
+
+    // if the hash table has already been built, return
+    if (hash_table_built_) return Status::OK();
+
+    CalculateBuildResultIndex();
+
+    // merge every group into the build_result_index grouper
+    ThreadLocalState* result_state = &local_states_[build_result_index];
+    for (int i = 0; i < static_cast<int>(local_states_.size()); ++i) {
+      ThreadLocalState* state = &local_states_[i];
+      ARROW_DCHECK(state);
+      if (i == build_result_index || !state->grouper) {
+        continue;
+      }
+      ARROW_ASSIGN_OR_RAISE(ExecBatch other_keys, 
state->grouper->GetUniques());
+
+      // TODO(niranda) replace with void consume method
+      ARROW_ASSIGN_OR_RAISE(Datum _, 
result_state->grouper->Consume(other_keys));
+      state->grouper.reset();
+    }
+
+    // enable flag that build side is completed
+    hash_table_built_ = true;
+
+    // since the build side is completed, consume cached probe batches
+    RETURN_NOT_OK(ConsumeCachedProbeBatches());
+
+    return Status::OK();
+  }
+
+  // consumes a build batch and increments the build_batches count. if the 
build batches
+  // total reached at the end of consumption, all the local states will be 
merged, before
+  // incrementing the total batches
+  Status ConsumeBuildBatch(ExecBatch batch) {
+    size_t thread_index = get_thread_index_();
+    ARROW_DCHECK(thread_index < local_states_.size());
+
+    ARROW_LOG(DEBUG) << "ConsumeBuildBatch tid:" << thread_index
+                     << " len:" << batch.length;
+
+    auto state = &local_states_[thread_index];
+    RETURN_NOT_OK(InitLocalStateIfNeeded(state));
+
+    // Create a batch with key columns
+    std::vector<Datum> keys(build_index_field_ids_.size());
+    for (size_t i = 0; i < build_index_field_ids_.size(); ++i) {
+      keys[i] = batch.values[build_index_field_ids_[i]];
+    }
+    ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+    // Create a batch with group ids
+    // TODO(niranda) replace with void consume method
+    ARROW_ASSIGN_OR_RAISE(Datum _, state->grouper->Consume(key_batch));
+
+    if (build_counter_.Increment()) {
+      // only one thread would get inside this block!
+      // while incrementing, if the total is reached, call BuildSideCompleted.
+      RETURN_NOT_OK(BuildSideCompleted());
+    }
+
+    return Status::OK();
+  }
+
+  // consumes cached probe batches by invoking executor::Spawn.
+  Status ConsumeCachedProbeBatches() {
+    ARROW_LOG(DEBUG) << "ConsumeCachedProbeBatches tid:" << get_thread_index_()
+                     << " len:" << cached_probe_batches.size();
+
+    // acquire the mutex to access cached_probe_batches, because while 
consuming, other
+    // batches should not be cached!
+    std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+
+    if (!cached_probe_batches_consumed) {
+      auto executor = ctx_->executor();
+
+      while (!cached_probe_batches.empty()) {
+        ExecBatch cached = std::move(cached_probe_batches.back());
+        cached_probe_batches.pop_back();
+
+        if (executor) {
+          RETURN_NOT_OK(executor->Spawn(
+              // since cached will be going out-of-scope, it needs to be 
copied into the
+              // capture list
+              [&, cached]() mutable {
+                // since batch consumption is done asynchronously, a failed 
status would
+                // have to be propagated then and there!
+                ErrorIfNotOk(ConsumeProbeBatch(std::move(cached)));
+              }));
+        } else {
+          RETURN_NOT_OK(ConsumeProbeBatch(std::move(cached)));
+        }
+      }
+    }
+
+    // set flag
+    cached_probe_batches_consumed = true;
+    return Status::OK();
+  }
+
+  Status GenerateOutput(const ArrayData& group_ids_data, ExecBatch batch) {
+    if (group_ids_data.GetNullCount() == batch.length) {
+      // All NULLS! hence, there are no valid outputs!
+      ARROW_LOG(DEBUG) << "output seq:";
+      outputs_[0]->InputReceived(this, batch.Slice(0, 0));
+    } else if (group_ids_data.MayHaveNulls()) {  // values need to be filtered
+      auto filter_arr =
+          std::make_shared<BooleanArray>(group_ids_data.length, 
group_ids_data.buffers[0],
+                                         /*null_bitmap=*/nullptr, 
/*null_count=*/0,
+                                         /*offset=*/group_ids_data.offset);
+      ARROW_ASSIGN_OR_RAISE(auto rec_batch,
+                            batch.ToRecordBatch(output_schema_, 
ctx_->memory_pool()));
+      ARROW_ASSIGN_OR_RAISE(
+          auto filtered,
+          Filter(rec_batch, filter_arr,
+                 /* null_selection = DROP*/ FilterOptions::Defaults(), ctx_));
+      auto out_batch = ExecBatch(*filtered.record_batch());
+      ARROW_LOG(DEBUG) << "output seq:"
+                       << " " << out_batch.length;
+      outputs_[0]->InputReceived(this, std::move(out_batch));
+    } else {  // all values are valid for output
+      ARROW_LOG(DEBUG) << "output seq:"
+                       << " " << batch.length;
+      outputs_[0]->InputReceived(this, std::move(batch));
+    }
+
+    return Status::OK();
+  }
+
+  // consumes a probe batch and increment probe batches count. Probing would 
query the
+  // grouper[build_result_index] which have been merged with all others.
+  Status ConsumeProbeBatch(ExecBatch batch) {
+    ARROW_LOG(DEBUG) << "ConsumeProbeBatch seq:";
+
+    auto& final_grouper = *local_states_[build_result_index].grouper;
+
+    // Create a batch with key columns
+    std::vector<Datum> keys(probe_index_field_ids_.size());
+    for (size_t i = 0; i < probe_index_field_ids_.size(); ++i) {
+      keys[i] = batch.values[probe_index_field_ids_[i]];
+    }
+    ARROW_ASSIGN_OR_RAISE(ExecBatch key_batch, ExecBatch::Make(keys));
+
+    // Query the grouper with key_batch. If no match was found, returning 
group_ids would
+    // have null.
+    ARROW_ASSIGN_OR_RAISE(Datum group_ids, final_grouper.Find(key_batch));
+    auto group_ids_data = *group_ids.array();
+
+    RETURN_NOT_OK(GenerateOutput(group_ids_data, std::move(batch)));
+
+    if (out_counter_.Increment()) {
+      finished_.MarkFinished();
+    }
+    return Status::OK();
+  }
+
+  // Attempt to cache a probe batch. If it is not cached, return false.
+  // if cached_probe_batches_consumed is true, by the time a thread acquires
+  // cached_probe_batches_mutex, it should no longer be cached! instead, it 
can be
+  //  directly consumed!
+  bool AttemptToCacheProbeBatch(ExecBatch* batch) {
+    ARROW_LOG(DEBUG) << "cache tid:" << get_thread_index_() << " len:" << 
batch->length;
+    std::lock_guard<std::mutex> lck(cached_probe_batches_mutex);
+    if (cached_probe_batches_consumed) {
+      return false;
+    }
+    cached_probe_batches.push_back(std::move(*batch));
+    return true;
+  }
+
+  inline bool IsBuildInput(ExecNode* input) { return input == inputs_[0]; }
+
+  // If all build side batches received? continue streaming using probing
+  // else cache the batches in thread-local state
+  void InputReceived(ExecNode* input, ExecBatch batch) override {
+    ARROW_LOG(DEBUG) << "input received input:" << (IsBuildInput(input) ? "b" 
: "p")
+                     << " seq:" << 0 << " len:" << batch.length;
+
+    ARROW_DCHECK(input == inputs_[0] || input == inputs_[1]);
+
+    if (finished_.is_finished()) {
+      return;
+    }
+
+    if (IsBuildInput(input)) {  // build input batch is received
+      // if a build input is received when build side is completed, 
something's wrong!
+      ARROW_DCHECK(!hash_table_built_);
+
+      ErrorIfNotOk(ConsumeBuildBatch(std::move(batch)));
+    } else {  // probe input batch is received
+      if (hash_table_built_) {

Review comment:
       hmm... yes, you are correct. `hash_table_built_` only changes from 
`false`--> `true`. And `cached_probe_batches_mutex` would guarantee that only a 
single thread would cache/ consume the `cached_probe_batches` vector. If an 
attempt to caching fails, that means that vector has already consumed, and that 
probe batch can be consumed. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to