icexelloss commented on code in PR #13028:
URL: https://github.com/apache/arrow/pull/13028#discussion_r887278231


##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -0,0 +1,800 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <iostream>
+#include <unordered_map>
+
+#include <arrow/api.h>
+#include <arrow/compute/api.h>
+#include <arrow/util/optional.h>
+#include "arrow/compute/exec/asof_join.h"
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
+#include "arrow/util/make_unique.h"
+
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+
+namespace arrow {
+namespace compute {
+
+/**
+ * Simple implementation for an unbound concurrent queue
+ */
+template <class T>
+class ConcurrentQueue {
+ public:
+  T pop() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    cond_.wait(lock, [&] { return !queue_.empty(); });
+    auto item = queue_.front();
+    queue_.pop();
+    return item;
+  }
+
+  void push(const T& item) {
+    std::unique_lock<std::mutex> lock(mutex_);
+    queue_.push(item);
+    cond_.notify_one();
+  }
+
+  util::optional<T> try_pop() {
+    // Try to pop the oldest value from the queue (or return nullopt if none)
+    std::unique_lock<std::mutex> lock(mutex_);
+    if (queue_.empty()) {
+      return util::nullopt;
+    } else {
+      auto item = queue_.front();
+      queue_.pop();
+      return item;
+    }
+  }
+
+  bool empty() const {
+    std::unique_lock<std::mutex> lock(mutex_);
+    return queue_.empty();
+  }
+
+  // Un-synchronized access to front
+  // For this to be "safe":
+  // 1) the caller logically guarantees that queue is not empty
+  // 2) pop/try_pop cannot be called concurrently with this
+  const T& unsync_front() const { return queue_.front(); }
+
+ private:
+  std::queue<T> queue_;
+  mutable std::mutex mutex_;
+  std::condition_variable cond_;
+};
+
+struct MemoStore {
+  // Stores last known values for all the keys
+
+  struct Entry {
+    // Timestamp associated with the entry
+    int64_t _time;
+
+    // Batch associated with the entry (perf is probably OK for this; batches 
change
+    // rarely)
+    std::shared_ptr<arrow::RecordBatch> _batch;
+
+    // Row associated with the entry
+    row_index_t _row;
+  };
+
+  std::unordered_map<KeyType, Entry> _entries;
+
+  void store(const std::shared_ptr<RecordBatch>& batch, row_index_t row, 
int64_t time,
+             KeyType key) {
+    auto& e = _entries[key];
+    // that we can do this assignment optionally, is why we
+    // can get array with using shared_ptr above (the batch
+    // shouldn't change that often)
+    if (e._batch != batch) e._batch = batch;
+    e._row = row;
+    e._time = time;
+  }
+
+  util::optional<const Entry*> get_entry_for_key(KeyType key) const {
+    auto e = _entries.find(key);
+    if (_entries.end() == e) return util::nullopt;
+    return util::optional<const Entry*>(&e->second);
+  }
+
+  void remove_entries_with_lesser_time(int64_t ts) {
+    size_t dbg_size0 = _entries.size();
+    for (auto e = _entries.begin(); e != _entries.end();)
+      if (e->second._time < ts)
+        e = _entries.erase(e);
+      else
+        ++e;
+    size_t dbg_size1 = _entries.size();
+    if (dbg_size1 < dbg_size0) {
+      // cerr << "Removed " << dbg_size0-dbg_size1 << " memo entries.\n";
+    }
+  }
+};
+
+class InputState {
+  // InputState correponds to an input
+  // Input record batches are queued up in InputState until processed and
+  // turned into output record batches.
+
+ public:
+  InputState(const std::shared_ptr<arrow::Schema>& schema,
+             const std::string& time_col_name, const std::string& key_col_name,
+             util::optional<KeyType> wildcard_key)
+      : queue_(),
+        wildcard_key_(wildcard_key),
+        schema_(schema),
+        time_col_index_(
+            schema->GetFieldIndex(time_col_name)),  // TODO: handle missing 
field name
+        key_col_index_(schema->GetFieldIndex(key_col_name)) {}
+
+  col_index_t init_src_to_dst_mapping(col_index_t dst_offset,
+                                      bool skip_time_and_key_fields) {
+    src_to_dst_.resize(schema_->num_fields());
+    for (int i = 0; i < schema_->num_fields(); ++i)
+      if (!(skip_time_and_key_fields && is_time_or_key_column(i)))
+        src_to_dst_[i] = dst_offset++;
+    return dst_offset;
+  }
+
+  const util::optional<col_index_t>& map_src_to_dst(col_index_t src) const {
+    return src_to_dst_[src];
+  }
+
+  bool is_time_or_key_column(col_index_t i) const {
+    assert(i < schema_->num_fields());
+    return (i == time_col_index_) || (i == key_col_index_);
+  }
+
+  // Gets the latest row index,  assuming the queue isn't empty
+  row_index_t get_latest_row() const { return latest_ref_row_; }
+
+  bool empty() const {
+    if (latest_ref_row_ > 0)
+      return false;  // cannot be empty if ref row is >0 -- can avoid slow 
queue lock
+                     // below
+    return queue_.empty();
+  }
+
+  int countbatches_processed_() const { return batches_processed_; }
+  int count_total_batches() const { return total_batches_; }
+
+  // Gets latest batch (precondition: must not be empty)
+  const std::shared_ptr<arrow::RecordBatch>& get_latest_batch() const {
+    return queue_.unsync_front();
+  }
+  KeyType get_latest_key() const {
+    return queue_.unsync_front()
+        ->column_data(key_col_index_)
+        ->GetValues<KeyType>(1)[latest_ref_row_];
+  }
+  int64_t get_latest_time() const {
+    return queue_.unsync_front()
+        ->column_data(time_col_index_)
+        ->GetValues<int64_t>(1)[latest_ref_row_];
+  }
+
+  bool finished() const { return batches_processed_ == total_batches_; }
+
+  bool advance() {
+    // Returns true if able to advance, false if not.
+
+    bool have_active_batch =
+        (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || 
!queue_.empty();
+    if (have_active_batch) {
+      // If we have an active batch
+      if (++latest_ref_row_ >= (row_index_t)queue_.unsync_front()->num_rows()) 
{
+        // hit the end of the batch, need to get the next batch if possible.
+        ++batches_processed_;
+        latest_ref_row_ = 0;
+        have_active_batch &= !queue_.try_pop();
+        if (have_active_batch)
+          assert(queue_.unsync_front()->num_rows() > 0);  // empty batches 
disallowed
+      }
+    }
+    return have_active_batch;
+  }
+
+  // Advance the data to be immediately past the specified TS, updating latest 
and
+  // latest_ref_row to the latest data prior to that immediate just past 
Returns true if
+  // updates were made, false if not.
+  bool advance_and_memoize(int64_t ts) {
+    // Advance the right side row index until we reach the latest right row 
(for each key)
+    // for the given left timestamp.
+
+    // Check if already updated for TS (or if there is no latest)
+    if (empty()) return false;  // can't advance if empty
+    auto latest_time = get_latest_time();
+    if (latest_time > ts) return false;  // already advanced
+
+    // Not updated.  Try to update and possibly advance.
+    bool updated = false;
+    do {
+      latest_time = get_latest_time();
+      // if advance() returns true, then the latest_ts must also be valid
+      // Keep advancing right table until we hit the latest row that has
+      // timestamp <= ts. This is because we only need the latest row for the
+      // match given a left ts.
+      if (latest_time <= ts) {
+        memo_.store(get_latest_batch(), latest_ref_row_, latest_time, 
get_latest_key());
+      } else {
+        break;  // hit a future timestamp -- done updating for now
+      }
+      updated = true;
+    } while (advance());
+    return updated;
+  }
+
+  void push(const std::shared_ptr<arrow::RecordBatch>& rb) {
+    if (rb->num_rows() > 0) {
+      queue_.push(rb);
+    } else {
+      ++batches_processed_;  // don't enqueue empty batches, just record as 
processed
+    }
+  }
+
+  util::optional<const MemoStore::Entry*> get_memo_entry_for_key(KeyType key) {
+    auto r = memo_.get_entry_for_key(key);
+    if (r.has_value()) return r;
+    if (wildcard_key_.has_value()) r = memo_.get_entry_for_key(*wildcard_key_);
+    return r;
+  }
+
+  util::optional<int64_t> get_memo_time_for_key(KeyType key) {
+    auto r = get_memo_entry_for_key(key);
+    return r.has_value() ? util::make_optional((*r)->_time) : util::nullopt;

Review Comment:
   Update (slightly different)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to