Alex-PLACET commented on code in PR #49679:
URL: https://github.com/apache/arrow/pull/49679#discussion_r3137441161


##########
cpp/src/arrow/compute/kernels/vector_search_sorted.cc:
##########
@@ -0,0 +1,935 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_vector.h"
+
+#include <algorithm>
+#include <memory>
+#include <optional>
+#include <ranges>
+#include <span>
+#include <type_traits>
+#include <utility>
+
+#include "arrow/array/array_primitive.h"
+#include "arrow/array/array_run_end.h"
+#include "arrow/array/builder_primitive.h"
+#include "arrow/array/util.h"
+#include "arrow/chunk_resolver.h"
+#include "arrow/compute/function.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/vector_sort_internal.h"
+#include "arrow/compute/registry.h"
+#include "arrow/compute/registry_internal.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging_internal.h"
+#include "arrow/util/ree_util.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace compute::internal {
+namespace {
+
+const SearchSortedOptions* GetDefaultSearchSortedOptions() {
+  static const auto kDefaultSearchSortedOptions = 
SearchSortedOptions::Defaults();
+  return &kDefaultSearchSortedOptions;
+}
+
+const FunctionDoc search_sorted_doc(
+    "Find insertion indices for sorted input",
+    ("Return the index where each needle should be inserted in a sorted input 
array\n"
+     "to maintain ascending order.\n"
+     "\n"
+     "With side='left', returns the first suitable index (lower bound).\n"
+     "With side='right', returns the last suitable index (upper bound).\n"
+     "\n"
+     "The searched values may be provided as an array or chunked array and 
must\n"
+     "already be sorted in ascending order. Null values in the searched array 
are\n"
+     "supported when clustered entirely at the start or\n"
+     "entirely at the end. Non-null needles are matched only against the 
non-null\n"
+     "portion of the searched array. Needles may be a scalar, array, or 
chunked\n"
+     "array. Null needles emit nulls in the output."),
+    {"values", "needles"}, "SearchSortedOptions");
+
+// This file implements search_sorted as a small pipeline that first normalizes
+// Arrow input shapes and then runs one typed binary-search core on logical
+// values.
+//
+// Plain arrays, run-end encoded arrays, and scalar needles are all
+// adapted into the same accessor and visitor model so the search logic does
+// not care about physical layout.
+//
+// After validation, the kernel isolates the contiguous non-null window of the 
searched
+// values, because nulls are only supported when clustered at one end.
+// Needles are then visited either as single values or as logical runs, and 
each non-null
+// needle is resolved with a lower-bound or upper-bound binary search over the 
sorted
+// non-null range.
+//
+// Output materialization is split by null handling: non-null-only needles 
write directly
+// into a preallocated uint64 buffer, while nullable needles append null and 
non-null
+// spans through a UInt64Builder. That builder path is optimized for repeated 
runs by
+// bulk-filling reserved memory instead of appending one insertion index at a 
time.
+//
+// High-level flow:
+//
+//   values datum
+//       |
+//       +--> ValidateSortedValuesInput
+//       |
+//       +--> LogicalType / FindNonNullValuesRange
+//       |
+//       +--> VisitValuesAccessor
+//             |
+//             +--> PlainArrayAccessor
+//             |
+//             `--> RunEndEncodedValuesAccessor
+//
+//   needles datum
+//       |
+//       +--> ValidateNeedleInput
+//       |
+//       +--> DatumHasNulls
+//       |
+//       `--> VisitNeedles
+//             |
+//             +--> scalar needle -> one logical span
+//             |
+//             +--> plain array   -> one span per element
+//             |
+//             `--> REE array     -> one span per logical run
+//
+//   normalized values accessor + normalized needle spans
+//       |
+//       `--> FindInsertionPoint<T>
+//             |
+//             +--> side = left  -> lower_bound semantics
+//             |
+//             `--> side = right -> upper_bound semantics
+//
+//   result materialization
+//       |
+//       +--> no needle nulls
+//       |     `--> MakeMutableUInt64Array
+//       |           `--> fill output buffer directly
+//       |
+//       `--> nullable needles
+//             `--> UInt64Builder
+//                   +--> AppendNulls for null runs
+//                   `--> bulk fill + UnsafeAdvance for repeated indices
+//
+// A rough map of the file:
+//
+//   [validation + type helpers]
+//           |
+//   [value accessors]
+//           |
+//   [needle visitors]
+//           |
+//   [typed search + output helpers]
+//           |
+//   [meta-function dispatch]
+//
+
+#define VISIT_SEARCH_SORTED_TYPES(VISIT) \
+  VISIT(BooleanType)                     \
+  VISIT(Int8Type)                        \
+  VISIT(Int16Type)                       \
+  VISIT(Int32Type)                       \
+  VISIT(Int64Type)                       \
+  VISIT(UInt8Type)                       \
+  VISIT(UInt16Type)                      \
+  VISIT(UInt32Type)                      \
+  VISIT(UInt64Type)                      \
+  VISIT(FloatType)                       \
+  VISIT(DoubleType)                      \
+  VISIT(Date32Type)                      \
+  VISIT(Date64Type)                      \
+  VISIT(Time32Type)                      \
+  VISIT(Time64Type)                      \
+  VISIT(TimestampType)                   \
+  VISIT(DurationType)                    \
+  VISIT(BinaryType)                      \
+  VISIT(StringType)                      \
+  VISIT(LargeBinaryType)                 \
+  VISIT(LargeStringType)                 \
+  VISIT(BinaryViewType)                  \
+  VISIT(StringViewType)
+
+template <typename ArrowType>
+using SearchValue = typename GetViewType<ArrowType>::T;
+
+template <typename ReturnType, typename Visitor>
+ReturnType DispatchRunEndEncodedByRunEndType(const RunEndEncodedArray& array,
+                                             const char* argument_name,
+                                             Visitor&& visitor);
+
+/// Comparator implementing Arrow's ascending-order semantics for supported 
types.
+template <typename ArrowType>
+struct SearchSortedCompare {
+  using ValueType = SearchValue<ArrowType>;
+
+  int operator()(const ValueType& left, const ValueType& right) const {
+    return CompareTypeValues<ArrowType>(left, right, SortOrder::Ascending,
+                                        NullPlacement::AtEnd);
+  }
+};
+
+/// Access logical values from a plain Arrow array.
+template <typename ArrowType>
+class PlainArrayAccessor {
+ public:
+  using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+  using ValueType = SearchValue<ArrowType>;
+
+  /// Build a typed accessor over a plain array payload.
+  explicit PlainArrayAccessor(const std::shared_ptr<ArrayData>& array_data)
+      : array_(array_data) {}
+
+  /// Return the logical length of the searched values.
+  int64_t length() const { return array_.length(); }
+
+  /// Return the logical value at the given logical position.
+  ValueType Value(int64_t index) const {
+    return GetViewType<ArrowType>::LogicalValue(array_.GetView(index));
+  }
+
+ private:
+  ArrayType array_;
+};
+
+/// Access logical values from a run-end encoded Arrow array.
+template <typename ArrowType, typename RunEndCType>
+class RunEndEncodedValuesAccessor {
+ public:
+  using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+  using ValueType = SearchValue<ArrowType>;
+
+  /// Build a typed accessor over a run-end encoded payload.
+  explicit RunEndEncodedValuesAccessor(const RunEndEncodedArray& array)
+      : values_(array.values()->data()), array_span_(*array.data()), 
span_(array_span_) {}
+
+  /// Return the logical length of the searched values.
+  int64_t length() const { return span_.length(); }
+
+  /// Return the logical value at the given logical position.
+  ValueType Value(int64_t index) const {
+    const auto physical_index = span_.PhysicalIndex(index);
+    return 
GetViewType<ArrowType>::LogicalValue(values_.GetView(physical_index));
+  }
+
+ private:
+  ArrayType values_;
+  ArraySpan array_span_;
+  ::arrow::ree_util::RunEndEncodedArraySpan<RunEndCType> span_;
+};
+
+/// Access logical values from a chunked Arrow array without combining chunks.
+template <typename ArrowType>
+class ChunkedArrayAccessor {
+ public:
+  using ValueType = SearchValue<ArrowType>;
+
+  explicit ChunkedArrayAccessor(const ChunkedArray& chunked_array)
+      : chunked_array_(chunked_array), resolver_(chunked_array.chunks()) {}
+
+  int64_t length() const { return chunked_array_.length(); }
+
+  ValueType Value(int64_t index) const {
+    const auto location = resolver_.Resolve(index);
+    DCHECK_LT(location.chunk_index, chunked_array_.num_chunks());
+    return 
ReadChunkValue(chunked_array_.chunk(static_cast<int>(location.chunk_index)),
+                          location.index_in_chunk);
+  }
+
+ private:
+  static ValueType ReadChunkValue(const std::shared_ptr<Array>& chunk, int64_t 
index) {
+    if (chunk->type_id() == Type::RUN_END_ENCODED) {
+      const auto& ree_chunk = checked_cast<const RunEndEncodedArray&>(*chunk);
+      const auto& ree_type = checked_cast<const 
RunEndEncodedType&>(*ree_chunk.type());
+      switch (ree_type.run_end_type()->id()) {
+        case Type::INT16: {
+          RunEndEncodedValuesAccessor<ArrowType, int16_t> 
values_accessor(ree_chunk);
+          return values_accessor.Value(index);
+        }
+        case Type::INT32: {
+          RunEndEncodedValuesAccessor<ArrowType, int32_t> 
values_accessor(ree_chunk);
+          return values_accessor.Value(index);
+        }
+        case Type::INT64: {
+          RunEndEncodedValuesAccessor<ArrowType, int64_t> 
values_accessor(ree_chunk);
+          return values_accessor.Value(index);
+        }
+        default:
+          DCHECK(false) << "Unexpected run-end type for search_sorted values: "
+                        << ree_chunk.type()->ToString();
+          return ValueType{};
+      }
+    }
+
+    PlainArrayAccessor<ArrowType> values_accessor(chunk->data());
+    return values_accessor.Value(index);
+  }
+
+  const ChunkedArray& chunked_array_;
+  ChunkResolver resolver_;
+};
+
+struct NonNullValuesRange {
+  int64_t offset = 0;
+  int64_t length = 0;
+
+  /// Return whether the range spans the full searched values input.
+  bool is_identity(int64_t full_length) const {
+    return (offset == 0) && (length == full_length);
+  }
+};
+
+constexpr std::string_view kClusteredNullValuesError =
+    "search_sorted values with nulls must be clustered at the start or end.";
+
+template <typename IsNullAt>
+int64_t CountLeadingNulls(int64_t length, IsNullAt&& is_null_at) {
+  auto indices = std::views::iota(int64_t{0}, length);
+  auto first_non_null =
+      std::ranges::find_if_not(indices, std::forward<IsNullAt>(is_null_at));
+  return std::ranges::distance(indices.begin(), first_non_null);
+}
+
+template <typename IsNullAt>
+int64_t CountTrailingNulls(int64_t length, IsNullAt&& is_null_at) {
+  auto indices = std::views::iota(int64_t{0}, length) | std::views::reverse;
+  auto first_non_null =
+      std::ranges::find_if_not(indices, std::forward<IsNullAt>(is_null_at));
+  return std::ranges::distance(indices.begin(), first_non_null);
+}
+
+template <typename ChunkRange, typename CountPartialNulls>
+int64_t CountEdgeNullsInChunks(ChunkRange&& chunks,
+                               CountPartialNulls&& count_partial_nulls) {
+  auto non_empty_chunks = std::forward<ChunkRange>(chunks) |
+                          std::views::filter([](const std::shared_ptr<Array>& 
chunk) {
+                            return chunk->length() != 0;
+                          });
+
+  auto first_not_all_null =
+      std::ranges::find_if(non_empty_chunks, [](const std::shared_ptr<Array>& 
chunk) {
+        return chunk->null_count() != chunk->length();
+      });
+
+  int64_t edge_null_count = 0;
+  for (auto it = non_empty_chunks.begin(); it != first_not_all_null; ++it) {
+    edge_null_count += (*it)->length();
+  }
+
+  if (first_not_all_null != non_empty_chunks.end() &&
+      (*first_not_all_null)->null_count() != 0) {
+    edge_null_count += count_partial_nulls(**first_not_all_null);
+  }
+
+  return edge_null_count;
+}
+
+inline Result<NonNullValuesRange> MakeNonNullValuesRange(int64_t full_length,
+                                                         int64_t null_count,
+                                                         int64_t 
leading_null_count,
+                                                         int64_t 
trailing_null_count) {
+  NonNullValuesRange non_null_values_range{.offset = 0, .length = full_length};
+
+  if (leading_null_count == full_length) {
+    non_null_values_range.offset = full_length;
+    non_null_values_range.length = 0;
+    return non_null_values_range;
+  }
+
+  if (leading_null_count > 0) {
+    if (leading_null_count != null_count) {
+      return Status::Invalid(kClusteredNullValuesError);
+    }
+    non_null_values_range.offset = leading_null_count;
+    non_null_values_range.length = full_length - leading_null_count;
+    return non_null_values_range;
+  }
+
+  if (trailing_null_count == 0 || trailing_null_count != null_count) {
+    return Status::Invalid(kClusteredNullValuesError);
+  }
+
+  non_null_values_range.length = full_length - trailing_null_count;
+  return non_null_values_range;
+}
+
+/// Present a contiguous non-null slice of the searched values through the same
+/// accessor interface as the original values container.
+template <typename ValuesAccessor>
+class NonNullValuesAccessor {
+ public:
+  /// Wrap the original accessor with the discovered non-null subrange.
+  explicit NonNullValuesAccessor(const ValuesAccessor& values,
+                                 const NonNullValuesRange& 
non_null_values_range)
+      : values_(values),
+        offset_(non_null_values_range.offset),
+        length_(non_null_values_range.length) {}
+
+  /// Return the number of accessible non-null values.
+  int64_t length() const noexcept { return length_; }
+
+  /// Return the value at the given index within the non-null subrange.
+  auto Value(int64_t index) const { return values_.Value(offset_ + index); }
+
+ private:
+  const ValuesAccessor& values_;
+  int64_t offset_;
+  int64_t length_;
+};
+
+/// Return the logical type of a datum, unwrapping run-end encoding when 
present.
+inline const DataType& LogicalType(const Datum& datum) {
+  const auto& type = *datum.type();
+  if (type.id() == Type::RUN_END_ENCODED) {
+    return *checked_cast<const RunEndEncodedType&>(type).value_type();
+  }
+  return type;
+}
+
+/// Return whether a scalar or array needle input contains any logical nulls.
+inline bool DatumHasNulls(const Datum& datum) {
+  if (datum.is_scalar()) {
+    return !datum.scalar()->is_valid;
+  }
+
+  if (datum.is_chunked_array()) {
+    const auto& chunked_array = *datum.chunked_array();
+    if (chunked_array.null_count() > 0) {
+      return true;
+    }
+    if (chunked_array.type()->id() != Type::RUN_END_ENCODED) {
+      return false;
+    }
+    return std::ranges::any_of(
+        chunked_array.chunks(), [](const std::shared_ptr<Array>& chunk) {
+          const auto& ree_chunk = checked_cast<const 
RunEndEncodedArray&>(*chunk);
+          return ree_chunk.values()->null_count() != 0;
+        });
+  }
+
+  const auto& array_data = datum.array();
+  const bool has_nulls = array_data->GetNullCount() > 0;
+  if (array_data->type->id() == Type::RUN_END_ENCODED) {
+    RunEndEncodedArray run_end_encoded(array_data);
+    return has_nulls || (run_end_encoded.values()->null_count() != 0);
+  }
+  return has_nulls;
+}
+
+/// Reject nested run-end encoded values. TODO: Support this case in the 
future if there
+/// is demand for it.
+inline Status ValidateRunEndEncodedLogicalValueType(const DataType& type,
+                                                    const char* name) {
+  const auto& ree_type = checked_cast<const RunEndEncodedType&>(type);
+  if (ree_type.value_type()->id() == Type::RUN_END_ENCODED) {
+    return Status::TypeError("Nested run-end encoded ", name, " are not 
supported");
+  }
+  return Status::OK();
+}
+
+/// Compute the contiguous non-null window of the searched values.
+///
+inline Result<NonNullValuesRange> FindNonNullValuesRange(const ArrayData& 
values) {
+  NonNullValuesRange non_null_values_range{.offset = 0, .length = 
values.length};
+
+  const auto null_count = values.GetNullCount();
+  if (null_count == 0) {
+    return non_null_values_range;
+  }
+
+  const int64_t leading_null_count = CountLeadingNulls(
+      values.length, [&](int64_t index) { return values.IsNull(index); });
+  const int64_t trailing_null_count =
+      leading_null_count > 0 ? 0 : CountTrailingNulls(values.length, 
[&](int64_t index) {
+        return values.IsNull(index);
+      });
+

Review Comment:
   You are right, I simplified and refactored it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to