lidavidm commented on a change in pull request #11019: URL: https://github.com/apache/arrow/pull/11019#discussion_r701297612
########## File path: cpp/src/arrow/compute/api_vector.h ########## @@ -120,6 +120,29 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { std::vector<SortKey> sort_keys; }; +/// \brief SelectK options for TopK/BottomK +class ARROW_EXPORT SelectKOptions : public FunctionOptions { + public: + explicit SelectKOptions(int64_t k = -1, std::vector<std::string> keys = {}, + bool keep_duplicates = false, + SortOrder order = SortOrder::Ascending); + constexpr static char const kTypeName[] = "SelectKOptions"; + static SelectKOptions TopKDefault() { + return SelectKOptions{-1, {}, false, SortOrder::Descending}; + } + static SelectKOptions BottomKDefault() { + return SelectKOptions{-1, {}, false, SortOrder::Ascending}; + } + /// The index into the equivalent sorted array of the partition pivot element. Review comment: This could be more clearly described as just the number of elements to keep, right? ########## File path: cpp/src/arrow/compute/kernels/vector_sort.cc ########## @@ -1778,6 +1799,736 @@ class SortIndicesMetaFunction : public MetaFunction { } }; +// ---------------------------------------------------------------------- +// TopK/BottomK implementations + +using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>; +const auto kDefaultTopKOptions = SelectKOptions::TopKDefault(); +const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault(); + +const FunctionDoc top_k_doc( + "Returns the first k elements ordered by `options.keys` in ascending order", + ("This function computes the k largest elements in ascending order of the input\n" + "array, record batch or table specified in the column names (`options.keys`). The\n" + "columns that are not specified are returned as well, but not used for ordering.\n" + "Null values are considered greater than any other value and are therefore sorted\n" + "at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +const FunctionDoc bottom_k_doc( + "Returns the first k elements ordered by `options.keys` in descending order", + ("This function computes the k smallest elements in descending order of the input\n" + "array, record batch or table specified in the column names (`options.keys`). The\n" + "columns that are not specified are returned as well, but not used for ordering.\n" + "Null values are considered greater than any other value and are therefore sorted\n" + "at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +Result<std::shared_ptr<ArrayData>> MakeMutableArrayForFixedSizedType( + std::shared_ptr<DataType> out_type, int64_t length, MemoryPool* memory_pool) { + auto buffer_size = BitUtil::BytesForBits( + length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width()); + std::vector<std::shared_ptr<Buffer>> buffers(2); + ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size, memory_pool)); + auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0); + return out; +} + +template <SortOrder order> +class SelectKComparator { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval); +}; + +template <> +class SelectKComparator<SortOrder::Ascending> { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + return lval < rval; + } +}; + +template <> +class SelectKComparator<SortOrder::Descending> { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + return rval < lval; + } +}; + +template <SortOrder sort_order> +class ArraySelecter : public TypeVisitor { + public: + ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + array_(array), + options_(options), + physical_type_(GetPhysicalType(array.type())), + output_(output) {} + + Status Run() { return VisitTypeInline(*physical_type_, this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported type for ArraySelecter: ", type.ToString()); + } + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + + ArrayType arr(array_.data()); + std::vector<uint64_t> indices(arr.length()); + + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + if (options_.k > arr.length()) { + options_.k = arr.length(); + } + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + std::function<bool(uint64_t, uint64_t)> cmp; + SelectKComparator<sort_order> comparator; + cmp = [&arr, &comparator](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return comparator(lval, rval); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto lval = GetView::LogicalValue(arr.GetView(x_index)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.top())); + if (comparator(lval, rval)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + if (*iter != heap.top()) { + const auto lval = GetView::LogicalValue(arr.GetView(*iter)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.top())); + if (lval == rval) { + heap.Push(*iter); + } + } + } + } + + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, Take(array_, Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const Array& array_; + SelectKOptions options_; + const std::shared_ptr<DataType> physical_type_; + Datum* output_; +}; + +template <typename ArrayType> +struct TypedHeapItem { + uint64_t index; + uint64_t offset; + ArrayType* array; +}; + +template <SortOrder sort_order> +class ChunkedArraySelecter : public TypeVisitor { + public: + ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + chunked_array_(chunked_array), + physical_type_(GetPhysicalType(chunked_array.type())), + physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), + options_(options), + ctx_(ctx), + output_(output) {} + + Status Run() { return physical_type_->Accept(this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + using HeapItem = TypedHeapItem<ArrayType>; + + const auto num_chunks = chunked_array_.num_chunks(); + if (num_chunks == 0) { + return Status::OK(); + } + if (options_.k > chunked_array_.length()) { + options_.k = chunked_array_.length(); + } + std::function<bool(const HeapItem&, const HeapItem&)> cmp; + SelectKComparator<sort_order> comparator; + + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp); + std::vector<std::shared_ptr<ArrayType>> chunks_holder; + uint64_t offset = 0; + for (const auto& chunk : physical_chunks_) { + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>( + indices_begin, indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.size() < static_cast<size_t>(options_.k); ++iter) { + heap.Push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.ReplaceTop(HeapItem{x_index, offset, &arr}); + } + } + offset += chunk->length(); + } + + if (options_.keep_duplicates == true) { + offset = 0; + for (const auto& chunk : chunks_holder) { + ArrayType& arr = *chunk; + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto iter = indices_begin; + for (; iter != indices_end; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index + offset != top_item.index + top_item.offset) { + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (xval == top_value) { + heap.Push(HeapItem{x_index, offset, &arr}); + } + } + } + offset += chunk->length(); + } + } + + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, + Take(Datum(chunked_array_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + ARROW_ASSIGN_OR_RAISE( + auto select_k, + Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); + *output_ = Datum(select_k); + return Status::OK(); + } + + const ChunkedArray& chunked_array_; + const std::shared_ptr<DataType> physical_type_; + const ArrayVector physical_chunks_; + SelectKOptions options_; + ExecContext* ctx_; + Datum* output_; +}; + +template <SortOrder sort_order> +class RecordBatchSelecter : public TypeVisitor { + private: + using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey; + using Comparator = MultipleKeyComparator<ResolvedSortKey>; + + public: + RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + ctx_(ctx), + record_batch_(record_batch), + options_(options), + output_(output), + sort_keys_(ResolveSortKeys(record_batch, options.keys, options.order, &status_)), + comparator_(sort_keys_) {} + + Status Run() { + ARROW_RETURN_NOT_OK(status_); + return sort_keys_[0].type->Accept(this); + } + + protected: +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + static std::vector<ResolvedSortKey> ResolveSortKeys( + const RecordBatch& batch, const std::vector<std::string>& sort_keys, + SortOrder order, Status* status) { + std::vector<ResolvedSortKey> resolved; + for (const auto& key_name : sort_keys) { + auto array = batch.GetColumnByName(key_name); + if (!array) { + *status = Status::Invalid("Nonexistent sort key column: ", key_name); + break; + } + resolved.emplace_back(array, order); + } + return resolved; + } + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + const ArrayType& arr = checked_cast<const ArrayType&>(first_sort_key.array); + + const auto num_rows = record_batch_.num_rows(); + if (num_rows == 0) { + return Status::OK(); + } + if (options_.k > record_batch_.num_rows()) { + options_.k = record_batch_.num_rows(); + } + std::function<bool(const uint64_t&, const uint64_t&)> cmp; + SelectKComparator<sort_order> select_k_comparator; + cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + if (lval == rval) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + } + return select_k_comparator(lval, rval); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index != top_item) { + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + const auto& top_value = GetView::LogicalValue(arr.GetView(top_item)); + if (xval == top_value && comparator.Equals(x_index, top_item, 1)) { + heap.Push(x_index); + } + } + } + } + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, + Take(Datum(record_batch_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const RecordBatch& record_batch_; + SelectKOptions options_; + Datum* output_; + std::vector<ResolvedSortKey> sort_keys_; + Comparator comparator_; + Status status_; +}; + +template <SortOrder sort_order> +class TableSelecter : public TypeVisitor { + private: + using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey; + using Comparator = MultipleKeyComparator<ResolvedSortKey>; + + public: + TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + table_(table), + options_(options), + output_(output), + sort_keys_(ResolveSortKeys(table, options.keys, options.order, &status_)), + comparator_(sort_keys_) {} + + Status Run() { + ARROW_RETURN_NOT_OK(status_); + return sort_keys_[0].type->Accept(this); + } + + protected: +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + static std::vector<ResolvedSortKey> ResolveSortKeys( + const Table& table, const std::vector<std::string>& sort_keys, SortOrder order, + Status* status) { + std::vector<ResolvedSortKey> resolved; + for (const auto& key_name : sort_keys) { + auto chunked_array = table.GetColumnByName(key_name); + if (!chunked_array) { + *status = Status::Invalid("Nonexistent sort key column: ", key_name); + break; + } + resolved.emplace_back(*chunked_array, order); + } + return resolved; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For non-float types. + template <typename Type> + enable_if_t<!is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal( + uint64_t* indices_begin, uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits<Type>::ArrayType; + if (first_sort_key.null_count == 0) { + return indices_end; + } + StablePartitioner partitioner; + auto nulls_begin = + partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t index) { + const auto chunk = first_sort_key.GetChunk<ArrayType>((int64_t)index); + return !chunk.IsNull(); + }); + DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nulls_begin; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For float types. + template <typename Type> + enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal( + uint64_t* indices_begin, uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits<Type>::ArrayType; + StablePartitioner partitioner; + uint64_t* nulls_begin; + if (first_sort_key.null_count == 0) { + nulls_begin = indices_end; + } else { + nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk<ArrayType>(index); + return !chunk.IsNull(); + }); + } + DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); + uint64_t* nans_begin = partitioner(indices_begin, nulls_begin, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk<ArrayType>(index); + return !std::isnan(chunk.Value()); + }); + auto& comparator = comparator_; + // Sort all NaNs by the second and following sort keys. + std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + // Sort all nulls by the second and following sort keys. + std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nans_begin; + } + + template <typename InType> + Status SelectKthInternal() { + using ArrayType = typename TypeTraits<InType>::ArrayType; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + + const auto num_rows = table_.num_rows(); + if (num_rows == 0) { + return Status::OK(); + } + if (options_.k > table_.num_rows()) { + options_.k = table_.num_rows(); + } + std::function<bool(const uint64_t&, const uint64_t&)> cmp; + SelectKComparator<sort_order> select_k_comparator; + cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { + auto chunk_left = first_sort_key.GetChunk<ArrayType>(left); + auto chunk_right = first_sort_key.GetChunk<ArrayType>(right); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + return comparator.Compare(left, right, 1); + } + return select_k_comparator(value_left, value_right); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + + std::vector<uint64_t> indices(num_rows); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = + this->PartitionNullsInternal<InType>(indices_begin, indices_end, first_sort_key); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + uint64_t top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index != top_item) { + auto chunk_left = first_sort_key.GetChunk<ArrayType>(x_index); + auto chunk_right = first_sort_key.GetChunk<ArrayType>(top_item); + auto xval = chunk_left.Value(); + auto top_value = chunk_right.Value(); + if (xval == top_value && comparator.Equals(x_index, top_item, 1)) { + heap.Push(x_index); + } + } + } + } + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, Take(Datum(table_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const Table& table_; + SelectKOptions options_; + Datum* output_; + std::vector<ResolvedSortKey> sort_keys_; + Comparator comparator_; + Status status_; +}; Review comment: These selecters all have a relatively similar core structure. It might be worth considering how some templating (e.g. via [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern), or via a set of helper comparator/iteration templates) could let you factor out the core algorithm and the container-specific bits. It would then be easier to also try to share generated code between types with the same physical type (e.g. as mentioned, Int64, Timestamp, and Date64 should all use the same generated code underneath). ########## File path: cpp/src/arrow/compute/kernels/select_k_test.cc ########## @@ -0,0 +1,809 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <functional> +#include <iostream> +#include <limits> +#include <memory> +#include <string> +#include <vector> + +#include "arrow/array/array_decimal.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type_traits.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace compute { + +namespace { +template <typename ArrayType> +auto GetLogicalValue(const ArrayType& array, uint64_t index) + -> decltype(array.GetView(index)) { + return array.GetView(index); +} + +Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) { + return Decimal128(array.Value(index)); +} + +} // namespace + +template <typename ArrayType, SortOrder order> +class SelectKComparator { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + if (is_floating_type<typename ArrayType::TypeClass>::value) { + if (rval != rval) return true; + if (lval != lval) return false; + } + if (order == SortOrder::Ascending) { + return lval <= rval; + } else { + return rval <= lval; + } + } +}; + +template <SortOrder order> +Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) { + if (order == SortOrder::Descending) { + return TopK(values, k); + } else { + return BottomK(values, k); + } +} +template <SortOrder order> +Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) { + if (order == SortOrder::Descending) { + return TopK(values, k); + } else { + return BottomK(values, k); + } +} + +template <SortOrder order> +Result<std::shared_ptr<RecordBatch>> SelectK(const RecordBatch& values, + const SelectKOptions& options) { + if (order == SortOrder::Descending) { + ARROW_ASSIGN_OR_RAISE(auto out, TopK(Datum(values), options.k, options)); + return out.record_batch(); + } else { + ARROW_ASSIGN_OR_RAISE(auto out, BottomK(Datum(values), options.k, options)); + return out.record_batch(); + } +} + +template <SortOrder order> +Result<std::shared_ptr<Table>> SelectK(const Table& values, + const SelectKOptions& options) { + if (order == SortOrder::Descending) { + ARROW_ASSIGN_OR_RAISE(auto out, TopK(Datum(values), options.k, options)); + return out.table(); + } else { + ARROW_ASSIGN_OR_RAISE(auto out, BottomK(Datum(values), options.k, options)); + return out.table(); + } +} + +template <typename ArrowType> +class TestSelectKBase : public TestBase { + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + + protected: + void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder order) { + ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order)); + + // head(k) + auto head_k_indices = sorted_indices->Slice(0, select_k.length()); + + // sorted_indices + ASSERT_OK_AND_ASSIGN(Datum sorted_datum, + Take(array, head_k_indices, TakeOptions::NoBoundsCheck())); + std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array(); + + const ArrayType& sorted_array = *checked_pointer_cast<ArrayType>(sorted_array_out); + + if (k < array.length()) { + AssertArraysEqual(sorted_array, select_k); + } + } + template <SortOrder order> + void AssertSelectKArray(const std::shared_ptr<Array> values, int n) { + std::shared_ptr<Array> select_k; + ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n)); + ASSERT_EQ(select_k->data()->null_count, 0); + ValidateOutput(*select_k); + Validate(*checked_pointer_cast<ArrayType>(values), n, + *checked_pointer_cast<ArrayType>(select_k), order); + } + + void AssertTopKArray(const std::shared_ptr<Array> values, int n) { + AssertSelectKArray<SortOrder::Descending>(values, n); + } + void AssertBottomKArray(const std::shared_ptr<Array> values, int n) { + AssertSelectKArray<SortOrder::Descending>(values, n); + } + + void AssertSelectKJson(const std::string& values, int n) { + AssertTopKArray(ArrayFromJSON(type_singleton(), values), n); + AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n); + } + + virtual std::shared_ptr<DataType> type_singleton() = 0; +}; + +template <typename ArrowType> +class TestSelectK : public TestSelectKBase<ArrowType> { + protected: + std::shared_ptr<DataType> type_singleton() override { + return default_type_instance<ArrowType>(); + } +}; + +template <typename ArrowType> +class TestSelectKForReal : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes); + +template <typename ArrowType> +class TestSelectKForIntegral : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes); + +template <typename ArrowType> +class TestSelectKForBool : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>); + +template <typename ArrowType> +class TestSelectKForTemporal : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes); + +template <typename ArrowType> +class TestSelectKForDecimal : public TestSelectKBase<ArrowType> { + std::shared_ptr<DataType> type_singleton() override { + return std::make_shared<ArrowType>(5, 2); + } +}; +TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes); + +template <typename ArrowType> +class TestSelectKForStrings : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>); + +TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) { + auto input = ArrayFromJSON(this->type_singleton(), "[null, 1, 3.3, null, 2, 5.3]"); + ASSERT_RAISES(Invalid, CallFunction("top_k", {input})); +} + +TYPED_TEST(TestSelectKForReal, Real) { + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6); + + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); +} + +TYPED_TEST(TestSelectKForIntegral, Integral) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); + + this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5); +} + +TYPED_TEST(TestSelectKForBool, Bool) { + this->AssertSelectKJson("[null, false, true, null, false, true]", 0); + this->AssertSelectKJson("[null, false, true, null, false, true]", 2); + this->AssertSelectKJson("[null, false, true, null, false, true]", 5); + this->AssertSelectKJson("[null, false, true, null, false, true]", 6); +} + +TYPED_TEST(TestSelectKForTemporal, Temporal) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); +} + +TYPED_TEST(TestSelectKForDecimal, Decimal) { + const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])"; + this->AssertSelectKJson(values, 0); + this->AssertSelectKJson(values, 2); + this->AssertSelectKJson(values, 4); + this->AssertSelectKJson(values, 5); +} + +TYPED_TEST(TestSelectKForStrings, Strings) { + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 0); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 2); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 5); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 6); +} + +template <typename ArrowType> +class TestSelectKRandom : public TestSelectKBase<ArrowType> { + public: + std::shared_ptr<DataType> type_singleton() override { + EXPECT_TRUE(0) << "shouldn't be used"; + return nullptr; + } +}; + +using SelectKableNumericAndTemporal = + ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type, + Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type, + TimestampType, Time32Type, Time64Type>; + +using SelectKableTypes = + ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type, + Int32Type, Int64Type, FloatType, DoubleType, Decimal128Type, + StringType>; + +TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes); + +TYPED_TEST(TestSelectKRandom, RandomValues) { + Random<TypeParam> rand(0x61549225); + int length = 100; + for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { + // Try n from 0 to out of bound + for (int n = 0; n <= length; ++n) { + auto array = rand.Generate(length, null_probability); + this->AssertTopKArray(array, n); + this->AssertBottomKArray(array, n); + } + } +} + +template <SortOrder order> +struct TestSelectKWithChunkedArray : public ::testing::Test { + TestSelectKWithChunkedArray() + : sizes_({0, 1, 2, 4, 16, 31, 1234}), + null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {} + + void Check(const std::shared_ptr<DataType>& type, + const std::vector<std::string>& values, int64_t k, + const std::string& expected) { + std::shared_ptr<Array> actual; + ASSERT_OK(this->DoSelectK(type, values, k, &actual)); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual); + } + + void Check(const std::shared_ptr<DataType>& type, + const std::shared_ptr<ChunkedArray>& values, int64_t k, + const std::string& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, SelectK<order>(*values, k)); + ValidateOutput(actual); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual); + } + + Status DoSelectK(const std::shared_ptr<DataType>& type, + const std::vector<std::string>& values, int64_t k, + std::shared_ptr<Array>* out) { + ARROW_ASSIGN_OR_RAISE(*out, SelectK<order>(*(ChunkedArrayFromJSON(type, values)), k)); + ValidateOutput(*out); + return Status::OK(); + } + std::vector<int32_t> sizes_; + std::vector<double> null_probabilities_; +}; + +template <typename ArrowType> +struct TestTopKWithChunkedArray + : public TestSelectKWithChunkedArray<SortOrder::Descending> { + std::shared_ptr<DataType> type_singleton() { + return default_type_instance<ArrowType>(); + } +}; + +TYPED_TEST_SUITE(TestTopKWithChunkedArray, SelectKableNumericAndTemporal); Review comment: For chunked array tests, you could get it "for free" by slicing a non-chunked array and building a chunked array to test. This is what the scalar kernel tests do. That way, you don't have to duplicate test cases. ########## File path: cpp/src/arrow/compute/kernels/select_k_test.cc ########## @@ -0,0 +1,714 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <functional> +#include <iostream> +#include <limits> +#include <memory> +#include <string> +#include <vector> + +#include "arrow/array/array_decimal.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type_traits.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace compute { + +namespace { + +// Convert arrow::Type to arrow::DataType. If arrow::Type isn't +// parameter free, this returns an arrow::DataType with the default +// parameter. +template <typename ArrowType> +enable_if_t<TypeTraits<ArrowType>::is_parameter_free, std::shared_ptr<DataType>> +TypeToDataType() { + return TypeTraits<ArrowType>::type_singleton(); +} + +template <typename ArrowType> +enable_if_t<std::is_same<ArrowType, TimestampType>::value, std::shared_ptr<DataType>> +TypeToDataType() { + return timestamp(TimeUnit::MILLI); +} + +template <typename ArrowType> +enable_if_t<std::is_same<ArrowType, Time32Type>::value, std::shared_ptr<DataType>> +TypeToDataType() { + return time32(TimeUnit::MILLI); +} + +template <typename ArrowType> +enable_if_t<std::is_same<ArrowType, Time64Type>::value, std::shared_ptr<DataType>> +TypeToDataType() { + return time64(TimeUnit::NANO); +} + +// ---------------------------------------------------------------------- +// Tests for SelectK + +template <typename ArrayType> +auto GetLogicalValue(const ArrayType& array, uint64_t index) + -> decltype(array.GetView(index)) { + return array.GetView(index); +} + +Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) { + return Decimal128(array.Value(index)); +} + +Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) { + return Decimal256(array.Value(index)); +} + +} // namespace + +template <typename ArrayType, SortOrder order> +class SelectKComparator { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + if (is_floating_type<typename ArrayType::TypeClass>::value) { + // NaNs ordered after non-NaNs + if (rval != rval) return true; Review comment: Ok, I think I'm missing something - then why is it comparing rval to itself? ########## File path: cpp/src/arrow/compute/kernels/select_k_test.cc ########## @@ -0,0 +1,809 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <functional> +#include <iostream> +#include <limits> +#include <memory> +#include <string> +#include <vector> + +#include "arrow/array/array_decimal.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type_traits.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +namespace compute { + +namespace { +template <typename ArrayType> +auto GetLogicalValue(const ArrayType& array, uint64_t index) + -> decltype(array.GetView(index)) { + return array.GetView(index); +} + +Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) { + return Decimal128(array.Value(index)); +} + +} // namespace + +template <typename ArrayType, SortOrder order> +class SelectKComparator { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + if (is_floating_type<typename ArrayType::TypeClass>::value) { + if (rval != rval) return true; + if (lval != lval) return false; + } + if (order == SortOrder::Ascending) { + return lval <= rval; + } else { + return rval <= lval; + } + } +}; + +template <SortOrder order> +Result<std::shared_ptr<Array>> SelectK(const Array& values, int64_t k) { + if (order == SortOrder::Descending) { + return TopK(values, k); + } else { + return BottomK(values, k); + } +} +template <SortOrder order> +Result<std::shared_ptr<Array>> SelectK(const ChunkedArray& values, int64_t k) { + if (order == SortOrder::Descending) { + return TopK(values, k); + } else { + return BottomK(values, k); + } +} + +template <SortOrder order> +Result<std::shared_ptr<RecordBatch>> SelectK(const RecordBatch& values, + const SelectKOptions& options) { + if (order == SortOrder::Descending) { + ARROW_ASSIGN_OR_RAISE(auto out, TopK(Datum(values), options.k, options)); + return out.record_batch(); + } else { + ARROW_ASSIGN_OR_RAISE(auto out, BottomK(Datum(values), options.k, options)); + return out.record_batch(); + } +} + +template <SortOrder order> +Result<std::shared_ptr<Table>> SelectK(const Table& values, + const SelectKOptions& options) { + if (order == SortOrder::Descending) { + ARROW_ASSIGN_OR_RAISE(auto out, TopK(Datum(values), options.k, options)); + return out.table(); + } else { + ARROW_ASSIGN_OR_RAISE(auto out, BottomK(Datum(values), options.k, options)); + return out.table(); + } +} + +template <typename ArrowType> +class TestSelectKBase : public TestBase { + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + + protected: + void Validate(const ArrayType& array, int k, ArrayType& select_k, SortOrder order) { + ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(array, order)); + + // head(k) + auto head_k_indices = sorted_indices->Slice(0, select_k.length()); + + // sorted_indices + ASSERT_OK_AND_ASSIGN(Datum sorted_datum, + Take(array, head_k_indices, TakeOptions::NoBoundsCheck())); + std::shared_ptr<Array> sorted_array_out = sorted_datum.make_array(); + + const ArrayType& sorted_array = *checked_pointer_cast<ArrayType>(sorted_array_out); + + if (k < array.length()) { + AssertArraysEqual(sorted_array, select_k); + } + } + template <SortOrder order> + void AssertSelectKArray(const std::shared_ptr<Array> values, int n) { + std::shared_ptr<Array> select_k; + ASSERT_OK_AND_ASSIGN(select_k, SelectK<order>(*values, n)); + ASSERT_EQ(select_k->data()->null_count, 0); + ValidateOutput(*select_k); + Validate(*checked_pointer_cast<ArrayType>(values), n, + *checked_pointer_cast<ArrayType>(select_k), order); + } + + void AssertTopKArray(const std::shared_ptr<Array> values, int n) { + AssertSelectKArray<SortOrder::Descending>(values, n); + } + void AssertBottomKArray(const std::shared_ptr<Array> values, int n) { + AssertSelectKArray<SortOrder::Descending>(values, n); + } + + void AssertSelectKJson(const std::string& values, int n) { + AssertTopKArray(ArrayFromJSON(type_singleton(), values), n); + AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n); + } + + virtual std::shared_ptr<DataType> type_singleton() = 0; +}; + +template <typename ArrowType> +class TestSelectK : public TestSelectKBase<ArrowType> { + protected: + std::shared_ptr<DataType> type_singleton() override { + return default_type_instance<ArrowType>(); + } +}; + +template <typename ArrowType> +class TestSelectKForReal : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForReal, RealArrowTypes); + +template <typename ArrowType> +class TestSelectKForIntegral : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForIntegral, IntegralArrowTypes); + +template <typename ArrowType> +class TestSelectKForBool : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForBool, ::testing::Types<BooleanType>); + +template <typename ArrowType> +class TestSelectKForTemporal : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForTemporal, TemporalArrowTypes); + +template <typename ArrowType> +class TestSelectKForDecimal : public TestSelectKBase<ArrowType> { + std::shared_ptr<DataType> type_singleton() override { + return std::make_shared<ArrowType>(5, 2); + } +}; +TYPED_TEST_SUITE(TestSelectKForDecimal, DecimalArrowTypes); + +template <typename ArrowType> +class TestSelectKForStrings : public TestSelectK<ArrowType> {}; +TYPED_TEST_SUITE(TestSelectKForStrings, testing::Types<StringType>); + +TYPED_TEST(TestSelectKForReal, SelectKDoesNotProvideDefaultOptions) { + auto input = ArrayFromJSON(this->type_singleton(), "[null, 1, 3.3, null, 2, 5.3]"); + ASSERT_RAISES(Invalid, CallFunction("top_k", {input})); +} + +TYPED_TEST(TestSelectKForReal, Real) { + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 0); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 2); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 5); + this->AssertSelectKJson("[null, 1, 3.3, null, 2, 5.3]", 6); + + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 0); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); +} + +TYPED_TEST(TestSelectKForIntegral, Integral) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); + + this->AssertSelectKJson("[2, 4, 5, 7, 8, 0, 9, 1, 3]", 5); +} + +TYPED_TEST(TestSelectKForBool, Bool) { + this->AssertSelectKJson("[null, false, true, null, false, true]", 0); + this->AssertSelectKJson("[null, false, true, null, false, true]", 2); + this->AssertSelectKJson("[null, false, true, null, false, true]", 5); + this->AssertSelectKJson("[null, false, true, null, false, true]", 6); +} + +TYPED_TEST(TestSelectKForTemporal, Temporal) { + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 0); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 2); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 5); + this->AssertSelectKJson("[null, 1, 3, null, 2, 5]", 6); +} + +TYPED_TEST(TestSelectKForDecimal, Decimal) { + const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])"; + this->AssertSelectKJson(values, 0); + this->AssertSelectKJson(values, 2); + this->AssertSelectKJson(values, 4); + this->AssertSelectKJson(values, 5); +} + +TYPED_TEST(TestSelectKForStrings, Strings) { + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 0); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 2); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 5); + this->AssertSelectKJson(R"(["testing", null, "nth", "for", null, "strings"])", 6); +} + +template <typename ArrowType> +class TestSelectKRandom : public TestSelectKBase<ArrowType> { + public: + std::shared_ptr<DataType> type_singleton() override { + EXPECT_TRUE(0) << "shouldn't be used"; + return nullptr; + } +}; + +using SelectKableNumericAndTemporal = + ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type, + Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type, + TimestampType, Time32Type, Time64Type>; + +using SelectKableTypes = + ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type, + Int32Type, Int64Type, FloatType, DoubleType, Decimal128Type, + StringType>; + +TYPED_TEST_SUITE(TestSelectKRandom, SelectKableTypes); + +TYPED_TEST(TestSelectKRandom, RandomValues) { + Random<TypeParam> rand(0x61549225); + int length = 100; + for (auto null_probability : {0.0, 0.1, 0.5, 1.0}) { + // Try n from 0 to out of bound + for (int n = 0; n <= length; ++n) { + auto array = rand.Generate(length, null_probability); + this->AssertTopKArray(array, n); + this->AssertBottomKArray(array, n); + } + } +} + +template <SortOrder order> +struct TestSelectKWithChunkedArray : public ::testing::Test { + TestSelectKWithChunkedArray() + : sizes_({0, 1, 2, 4, 16, 31, 1234}), + null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {} + + void Check(const std::shared_ptr<DataType>& type, + const std::vector<std::string>& values, int64_t k, + const std::string& expected) { + std::shared_ptr<Array> actual; + ASSERT_OK(this->DoSelectK(type, values, k, &actual)); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual); + } + + void Check(const std::shared_ptr<DataType>& type, + const std::shared_ptr<ChunkedArray>& values, int64_t k, + const std::string& expected) { + ASSERT_OK_AND_ASSIGN(auto actual, SelectK<order>(*values, k)); + ValidateOutput(actual); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected), *actual); + } + + Status DoSelectK(const std::shared_ptr<DataType>& type, + const std::vector<std::string>& values, int64_t k, + std::shared_ptr<Array>* out) { + ARROW_ASSIGN_OR_RAISE(*out, SelectK<order>(*(ChunkedArrayFromJSON(type, values)), k)); + ValidateOutput(*out); + return Status::OK(); + } + std::vector<int32_t> sizes_; + std::vector<double> null_probabilities_; +}; + +template <typename ArrowType> +struct TestTopKWithChunkedArray + : public TestSelectKWithChunkedArray<SortOrder::Descending> { + std::shared_ptr<DataType> type_singleton() { + return default_type_instance<ArrowType>(); + } +}; + +TYPED_TEST_SUITE(TestTopKWithChunkedArray, SelectKableNumericAndTemporal); Review comment: You could imagine doing the same thing for tables/record batches. ########## File path: cpp/src/arrow/compute/kernels/vector_sort.cc ########## @@ -1778,6 +1799,736 @@ class SortIndicesMetaFunction : public MetaFunction { } }; +// ---------------------------------------------------------------------- +// TopK/BottomK implementations + +using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>; +const auto kDefaultTopKOptions = SelectKOptions::TopKDefault(); +const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault(); + +const FunctionDoc top_k_doc( + "Returns the first k elements ordered by `options.keys` in ascending order", + ("This function computes the k largest elements in ascending order of the input\n" + "array, record batch or table specified in the column names (`options.keys`). The\n" + "columns that are not specified are returned as well, but not used for ordering.\n" + "Null values are considered greater than any other value and are therefore sorted\n" + "at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +const FunctionDoc bottom_k_doc( + "Returns the first k elements ordered by `options.keys` in descending order", + ("This function computes the k smallest elements in descending order of the input\n" + "array, record batch or table specified in the column names (`options.keys`). The\n" + "columns that are not specified are returned as well, but not used for ordering.\n" + "Null values are considered greater than any other value and are therefore sorted\n" + "at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +Result<std::shared_ptr<ArrayData>> MakeMutableArrayForFixedSizedType( + std::shared_ptr<DataType> out_type, int64_t length, MemoryPool* memory_pool) { + auto buffer_size = BitUtil::BytesForBits( + length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width()); + std::vector<std::shared_ptr<Buffer>> buffers(2); + ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size, memory_pool)); + auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0); + return out; +} + +template <SortOrder order> +class SelectKComparator { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval); +}; + +template <> +class SelectKComparator<SortOrder::Ascending> { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + return lval < rval; + } +}; + +template <> +class SelectKComparator<SortOrder::Descending> { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + return rval < lval; + } +}; + +template <SortOrder sort_order> +class ArraySelecter : public TypeVisitor { + public: + ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + array_(array), + options_(options), + physical_type_(GetPhysicalType(array.type())), + output_(output) {} + + Status Run() { return VisitTypeInline(*physical_type_, this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported type for ArraySelecter: ", type.ToString()); + } + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + + ArrayType arr(array_.data()); + std::vector<uint64_t> indices(arr.length()); + + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + if (options_.k > arr.length()) { + options_.k = arr.length(); + } + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + std::function<bool(uint64_t, uint64_t)> cmp; + SelectKComparator<sort_order> comparator; + cmp = [&arr, &comparator](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return comparator(lval, rval); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto lval = GetView::LogicalValue(arr.GetView(x_index)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.top())); + if (comparator(lval, rval)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + if (*iter != heap.top()) { + const auto lval = GetView::LogicalValue(arr.GetView(*iter)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.top())); + if (lval == rval) { + heap.Push(*iter); + } + } + } + } + + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, Take(array_, Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const Array& array_; + SelectKOptions options_; + const std::shared_ptr<DataType> physical_type_; + Datum* output_; +}; + +template <typename ArrayType> +struct TypedHeapItem { + uint64_t index; + uint64_t offset; + ArrayType* array; +}; + +template <SortOrder sort_order> +class ChunkedArraySelecter : public TypeVisitor { + public: + ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + chunked_array_(chunked_array), + physical_type_(GetPhysicalType(chunked_array.type())), + physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), + options_(options), + ctx_(ctx), + output_(output) {} + + Status Run() { return physical_type_->Accept(this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + using HeapItem = TypedHeapItem<ArrayType>; + + const auto num_chunks = chunked_array_.num_chunks(); + if (num_chunks == 0) { + return Status::OK(); + } + if (options_.k > chunked_array_.length()) { + options_.k = chunked_array_.length(); + } + std::function<bool(const HeapItem&, const HeapItem&)> cmp; + SelectKComparator<sort_order> comparator; + + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp); + std::vector<std::shared_ptr<ArrayType>> chunks_holder; + uint64_t offset = 0; + for (const auto& chunk : physical_chunks_) { + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>( + indices_begin, indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.size() < static_cast<size_t>(options_.k); ++iter) { + heap.Push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.ReplaceTop(HeapItem{x_index, offset, &arr}); + } + } + offset += chunk->length(); + } + + if (options_.keep_duplicates == true) { + offset = 0; + for (const auto& chunk : chunks_holder) { + ArrayType& arr = *chunk; + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto iter = indices_begin; + for (; iter != indices_end; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index + offset != top_item.index + top_item.offset) { + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (xval == top_value) { + heap.Push(HeapItem{x_index, offset, &arr}); + } + } + } + offset += chunk->length(); + } + } + + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, + Take(Datum(chunked_array_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + ARROW_ASSIGN_OR_RAISE( + auto select_k, + Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); Review comment: Why do we concatenate here? If the input is a chunked array I would expect the output to also be a chunked array. ########## File path: cpp/src/arrow/compute/kernels/vector_sort.cc ########## @@ -1778,6 +1799,736 @@ class SortIndicesMetaFunction : public MetaFunction { } }; +// ---------------------------------------------------------------------- +// TopK/BottomK implementations + +using SelectKOptionsState = internal::OptionsWrapper<SelectKOptions>; +const auto kDefaultTopKOptions = SelectKOptions::TopKDefault(); +const auto kDefaultBottomKOptions = SelectKOptions::BottomKDefault(); + +const FunctionDoc top_k_doc( + "Returns the first k elements ordered by `options.keys` in ascending order", + ("This function computes the k largest elements in ascending order of the input\n" + "array, record batch or table specified in the column names (`options.keys`). The\n" + "columns that are not specified are returned as well, but not used for ordering.\n" + "Null values are considered greater than any other value and are therefore sorted\n" + "at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +const FunctionDoc bottom_k_doc( + "Returns the first k elements ordered by `options.keys` in descending order", + ("This function computes the k smallest elements in descending order of the input\n" + "array, record batch or table specified in the column names (`options.keys`). The\n" + "columns that are not specified are returned as well, but not used for ordering.\n" + "Null values are considered greater than any other value and are therefore sorted\n" + "at the end of the array.\n" + "For floating-point types, NaNs are considered greater than any\n" + "other non-null value, but smaller than null values."), + {"input"}, "SelectKOptions"); + +Result<std::shared_ptr<ArrayData>> MakeMutableArrayForFixedSizedType( + std::shared_ptr<DataType> out_type, int64_t length, MemoryPool* memory_pool) { + auto buffer_size = BitUtil::BytesForBits( + length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width()); + std::vector<std::shared_ptr<Buffer>> buffers(2); + ARROW_ASSIGN_OR_RAISE(buffers[1], AllocateResizableBuffer(buffer_size, memory_pool)); + auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0); + return out; +} + +template <SortOrder order> +class SelectKComparator { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval); +}; + +template <> +class SelectKComparator<SortOrder::Ascending> { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + return lval < rval; + } +}; + +template <> +class SelectKComparator<SortOrder::Descending> { + public: + template <typename Type> + bool operator()(const Type& lval, const Type& rval) { + return rval < lval; + } +}; + +template <SortOrder sort_order> +class ArraySelecter : public TypeVisitor { + public: + ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + array_(array), + options_(options), + physical_type_(GetPhysicalType(array.type())), + output_(output) {} + + Status Run() { return VisitTypeInline(*physical_type_, this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + Status Visit(const DataType& type) { + return Status::TypeError("Unsupported type for ArraySelecter: ", type.ToString()); + } + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + + ArrayType arr(array_.data()); + std::vector<uint64_t> indices(arr.length()); + + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + if (options_.k > arr.length()) { + options_.k = arr.length(); + } + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + std::function<bool(uint64_t, uint64_t)> cmp; + SelectKComparator<sort_order> comparator; + cmp = [&arr, &comparator](uint64_t left, uint64_t right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return comparator(lval, rval); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto lval = GetView::LogicalValue(arr.GetView(x_index)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.top())); + if (comparator(lval, rval)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + if (*iter != heap.top()) { + const auto lval = GetView::LogicalValue(arr.GetView(*iter)); + const auto rval = GetView::LogicalValue(arr.GetView(heap.top())); + if (lval == rval) { + heap.Push(*iter); + } + } + } + } + + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, Take(array_, Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const Array& array_; + SelectKOptions options_; + const std::shared_ptr<DataType> physical_type_; + Datum* output_; +}; + +template <typename ArrayType> +struct TypedHeapItem { + uint64_t index; + uint64_t offset; + ArrayType* array; +}; + +template <SortOrder sort_order> +class ChunkedArraySelecter : public TypeVisitor { + public: + ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + chunked_array_(chunked_array), + physical_type_(GetPhysicalType(chunked_array.type())), + physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), + options_(options), + ctx_(ctx), + output_(output) {} + + Status Run() { return physical_type_->Accept(this); } + +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + using HeapItem = TypedHeapItem<ArrayType>; + + const auto num_chunks = chunked_array_.num_chunks(); + if (num_chunks == 0) { + return Status::OK(); + } + if (options_.k > chunked_array_.length()) { + options_.k = chunked_array_.length(); + } + std::function<bool(const HeapItem&, const HeapItem&)> cmp; + SelectKComparator<sort_order> comparator; + + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + arrow::internal::Heap<HeapItem, decltype(cmp)> heap(cmp); + std::vector<std::shared_ptr<ArrayType>> chunks_holder; + uint64_t offset = 0; + for (const auto& chunk : physical_chunks_) { + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared<ArrayType>(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>( + indices_begin, indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.size() < static_cast<size_t>(options_.k); ++iter) { + heap.Push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.ReplaceTop(HeapItem{x_index, offset, &arr}); + } + } + offset += chunk->length(); + } + + if (options_.keep_duplicates == true) { + offset = 0; + for (const auto& chunk : chunks_holder) { + ArrayType& arr = *chunk; + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto iter = indices_begin; + for (; iter != indices_end; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index + offset != top_item.index + top_item.offset) { + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (xval == top_value) { + heap.Push(HeapItem{x_index, offset, &arr}); + } + } + } + offset += chunk->length(); + } + } + + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(auto chunked_select_k, + Take(Datum(chunked_array_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + ARROW_ASSIGN_OR_RAISE( + auto select_k, + Concatenate(chunked_select_k.chunked_array()->chunks(), ctx_->memory_pool())); + *output_ = Datum(select_k); + return Status::OK(); + } + + const ChunkedArray& chunked_array_; + const std::shared_ptr<DataType> physical_type_; + const ArrayVector physical_chunks_; + SelectKOptions options_; + ExecContext* ctx_; + Datum* output_; +}; + +template <SortOrder sort_order> +class RecordBatchSelecter : public TypeVisitor { + private: + using ResolvedSortKey = MultipleKeyRecordBatchSorter::ResolvedSortKey; + using Comparator = MultipleKeyComparator<ResolvedSortKey>; + + public: + RecordBatchSelecter(ExecContext* ctx, const RecordBatch& record_batch, + const SelectKOptions& options, Datum* output) + : TypeVisitor(), + ctx_(ctx), + record_batch_(record_batch), + options_(options), + output_(output), + sort_keys_(ResolveSortKeys(record_batch, options.keys, options.order, &status_)), + comparator_(sort_keys_) {} + + Status Run() { + ARROW_RETURN_NOT_OK(status_); + return sort_keys_[0].type->Accept(this); + } + + protected: +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + static std::vector<ResolvedSortKey> ResolveSortKeys( + const RecordBatch& batch, const std::vector<std::string>& sort_keys, + SortOrder order, Status* status) { + std::vector<ResolvedSortKey> resolved; + for (const auto& key_name : sort_keys) { + auto array = batch.GetColumnByName(key_name); + if (!array) { + *status = Status::Invalid("Nonexistent sort key column: ", key_name); + break; + } + resolved.emplace_back(array, order); + } + return resolved; + } + + template <typename InType> + Status SelectKthInternal() { + using GetView = GetViewType<InType>; + using ArrayType = typename TypeTraits<InType>::ArrayType; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + const ArrayType& arr = checked_cast<const ArrayType&>(first_sort_key.array); + + const auto num_rows = record_batch_.num_rows(); + if (num_rows == 0) { + return Status::OK(); + } + if (options_.k > record_batch_.num_rows()) { + options_.k = record_batch_.num_rows(); + } + std::function<bool(const uint64_t&, const uint64_t&)> cmp; + SelectKComparator<sort_order> select_k_comparator; + cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + if (lval == rval) { + // If the left value equals to the right value, + // we need to compare the second and following + // sort keys. + return comparator.Compare(left, right, 1); + } + return select_k_comparator(lval, rval); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + + std::vector<uint64_t> indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = PartitionNulls<ArrayType, NonStablePartitioner>(indices_begin, + indices_end, arr, 0); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index != top_item) { + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + const auto& top_value = GetView::LogicalValue(arr.GetView(top_item)); + if (xval == top_value && comparator.Equals(x_index, top_item, 1)) { + heap.Push(x_index); + } + } + } + } + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, + Take(Datum(record_batch_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const RecordBatch& record_batch_; + SelectKOptions options_; + Datum* output_; + std::vector<ResolvedSortKey> sort_keys_; + Comparator comparator_; + Status status_; +}; + +template <SortOrder sort_order> +class TableSelecter : public TypeVisitor { + private: + using ResolvedSortKey = MultipleKeyTableSorter::ResolvedSortKey; + using Comparator = MultipleKeyComparator<ResolvedSortKey>; + + public: + TableSelecter(ExecContext* ctx, const Table& table, const SelectKOptions& options, + Datum* output) + : TypeVisitor(), + ctx_(ctx), + table_(table), + options_(options), + output_(output), + sort_keys_(ResolveSortKeys(table, options.keys, options.order, &status_)), + comparator_(sort_keys_) {} + + Status Run() { + ARROW_RETURN_NOT_OK(status_); + return sort_keys_[0].type->Accept(this); + } + + protected: +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return SelectKthInternal<TYPE>(); } + + VISIT_PHYSICAL_TYPES(VISIT) + +#undef VISIT + + static std::vector<ResolvedSortKey> ResolveSortKeys( + const Table& table, const std::vector<std::string>& sort_keys, SortOrder order, + Status* status) { + std::vector<ResolvedSortKey> resolved; + for (const auto& key_name : sort_keys) { + auto chunked_array = table.GetColumnByName(key_name); + if (!chunked_array) { + *status = Status::Invalid("Nonexistent sort key column: ", key_name); + break; + } + resolved.emplace_back(*chunked_array, order); + } + return resolved; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For non-float types. + template <typename Type> + enable_if_t<!is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal( + uint64_t* indices_begin, uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits<Type>::ArrayType; + if (first_sort_key.null_count == 0) { + return indices_end; + } + StablePartitioner partitioner; + auto nulls_begin = + partitioner(indices_begin, indices_end, [&first_sort_key](uint64_t index) { + const auto chunk = first_sort_key.GetChunk<ArrayType>((int64_t)index); + return !chunk.IsNull(); + }); + DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nulls_begin; + } + + // Behaves like PatitionNulls() but this supports multiple sort keys. + // + // For float types. + template <typename Type> + enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal( + uint64_t* indices_begin, uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits<Type>::ArrayType; + StablePartitioner partitioner; + uint64_t* nulls_begin; + if (first_sort_key.null_count == 0) { + nulls_begin = indices_end; + } else { + nulls_begin = partitioner(indices_begin, indices_end, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk<ArrayType>(index); + return !chunk.IsNull(); + }); + } + DCHECK_EQ(indices_end - nulls_begin, first_sort_key.null_count); + uint64_t* nans_begin = partitioner(indices_begin, nulls_begin, [&](uint64_t index) { + const auto chunk = first_sort_key.GetChunk<ArrayType>(index); + return !std::isnan(chunk.Value()); + }); + auto& comparator = comparator_; + // Sort all NaNs by the second and following sort keys. + std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + // Sort all nulls by the second and following sort keys. + std::stable_sort(nulls_begin, indices_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + return nans_begin; + } + + template <typename InType> + Status SelectKthInternal() { + using ArrayType = typename TypeTraits<InType>::ArrayType; + auto& comparator = comparator_; + const auto& first_sort_key = sort_keys_[0]; + + const auto num_rows = table_.num_rows(); + if (num_rows == 0) { + return Status::OK(); + } + if (options_.k > table_.num_rows()) { + options_.k = table_.num_rows(); + } + std::function<bool(const uint64_t&, const uint64_t&)> cmp; + SelectKComparator<sort_order> select_k_comparator; + cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { + auto chunk_left = first_sort_key.GetChunk<ArrayType>(left); + auto chunk_right = first_sort_key.GetChunk<ArrayType>(right); + auto value_left = chunk_left.Value(); + auto value_right = chunk_right.Value(); + if (value_left == value_right) { + return comparator.Compare(left, right, 1); + } + return select_k_comparator(value_left, value_right); + }; + arrow::internal::Heap<uint64_t, decltype(cmp)> heap(cmp); + + std::vector<uint64_t> indices(num_rows); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + auto end_iter = + this->PartitionNullsInternal<InType>(indices_begin, indices_end, first_sort_key); + auto kth_begin = indices_begin + options_.k; + + if (kth_begin > end_iter) { + kth_begin = end_iter; + } + uint64_t* iter = indices_begin; + for (; iter != kth_begin; ++iter) { + heap.Push(*iter); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + uint64_t top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.ReplaceTop(x_index); + } + } + if (options_.keep_duplicates == true) { + iter = indices_begin; + for (; iter != end_iter; ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (x_index != top_item) { + auto chunk_left = first_sort_key.GetChunk<ArrayType>(x_index); + auto chunk_right = first_sort_key.GetChunk<ArrayType>(top_item); + auto xval = chunk_left.Value(); + auto top_value = chunk_right.Value(); + if (xval == top_value && comparator.Equals(x_index, top_item, 1)) { + heap.Push(x_index); + } + } + } + } + int64_t out_size = static_cast<int64_t>(heap.size()); + ARROW_ASSIGN_OR_RAISE( + auto take_indices, + MakeMutableArrayForFixedSizedType(uint64(), out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues<uint64_t>(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.Pop(); + --out_cbegin; + } + ARROW_ASSIGN_OR_RAISE(*output_, Take(Datum(table_), Datum(std::move(take_indices)), + TakeOptions::NoBoundsCheck(), ctx_)); + return Status::OK(); + } + + ExecContext* ctx_; + const Table& table_; + SelectKOptions options_; + Datum* output_; + std::vector<ResolvedSortKey> sort_keys_; + Comparator comparator_; + Status status_; +}; + +template <SortOrder sort_order> Review comment: Another thought: instead of a single SortOrder, why not accept SortKeys like the sort kernels? Then you could emulate a query like `... ORDER BY foo ASC, bar DESC LIMIT 5` which this kernel can't currently do. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org