This is an automated email from the ASF dual-hosted git repository.

pitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 293730128c GH-45193: [C++][Compute] Treat NaNs and nulls as distinct 
values in rank tie-breaking (#49304)
293730128c is described below

commit 293730128c3dbe7b59fa4ca0de8acb8bc3bbf018
Author: Abhishek Bansal <[email protected]>
AuthorDate: Tue Apr 21 13:18:12 2026 +0530

    GH-45193: [C++][Compute] Treat NaNs and nulls as distinct values in rank 
tie-breaking (#49304)
    
    ### Rationale for this change
    The rank kernel incorrectly treated NaNs and Nulls as ties. This fix 
ensures they are treated as distinct values according to Arrow's sorting 
conventions.
    
    ### What changes are included in this PR?
    Updated the internal MarkDuplicates helper in vector_rank.cc to distinguish 
between NaNs and Nulls.
    
    ### Are these changes tested?
    Yes. Added a regression test TestRank.NaNsAndNulls in vector_sort_test.cc 
and verified all compute tests pass.
    
    ### Are there any user-facing changes?
    The output of the rank function will now correctly differentiate between 
NaNs and Nulls instead of ranking them as ties. Fixes incorrect/invalid ranking 
results for datasets containing both NaNs and Nulls.
    
    * GitHub Issue: #45193
    
    Authored-by: Abhishek Bansal <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/arrow/compute/kernels/vector_rank.cc      | 30 ++++++++++++++++++-----
 cpp/src/arrow/compute/kernels/vector_sort_test.cc | 30 +++++++++++++++++++++++
 2 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc 
b/cpp/src/arrow/compute/kernels/vector_rank.cc
index adac794902..cb8c79953e 100644
--- a/cpp/src/arrow/compute/kernels/vector_rank.cc
+++ b/cpp/src/arrow/compute/kernels/vector_rank.cc
@@ -36,8 +36,9 @@ namespace {
 // is the same as the value at the previous sort index.
 constexpr uint64_t kDuplicateMask = 1ULL << 63;
 
-template <typename ValueSelector>
-void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& 
value_selector) {
+template <typename ValueSelector, typename IsNullSelector>
+void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&& 
value_selector,
+                    IsNullSelector&& is_null_selector) {
   using T = decltype(value_selector(int64_t{}));
 
   // Process non-nulls
@@ -55,10 +56,14 @@ void MarkDuplicates(const NullPartitionResult& sorted, 
ValueSelector&& value_sel
 
   // Process nulls
   if (sorted.nulls_end != sorted.nulls_begin) {
-    // TODO this should be able to distinguish between NaNs and real nulls 
(GH-45193)
     auto it = sorted.nulls_begin;
+    bool prev_is_null = is_null_selector(*it);
     while (++it < sorted.nulls_end) {
-      *it |= kDuplicateMask;
+      bool curr_is_null = is_null_selector(*it);
+      if (curr_is_null == prev_is_null) {
+        *it |= kDuplicateMask;
+      }
+      prev_is_null = curr_is_null;
     }
   }
 }
@@ -82,7 +87,12 @@ Result<NullPartitionResult> DoSortAndMarkDuplicate(
     auto value_selector = [&array](int64_t index) {
       return GetView::LogicalValue(array.GetView(index));
     };
-    MarkDuplicates(sorted, value_selector);
+    if constexpr (has_null_like_values<ArrowType>()) {
+      auto is_null_selector = [&array](int64_t index) { return 
array.IsNull(index); };
+      MarkDuplicates(sorted, value_selector, is_null_selector);
+    } else {
+      MarkDuplicates(sorted, value_selector, [](int64_t) { return true; });
+    }
   }
   return sorted;
 }
@@ -105,7 +115,15 @@ Result<NullPartitionResult> DoSortAndMarkDuplicate(
                                
ChunkedArrayResolver(std::span(arrays))](int64_t index) {
       return resolver.Resolve(index).Value<ArrowType>();
     };
-    MarkDuplicates(sorted, value_selector);
+    if constexpr (has_null_like_values<ArrowType>()) {
+      auto is_null_selector =
+          [resolver = ChunkedArrayResolver(std::span(arrays))](int64_t index) {
+            return resolver.Resolve(index).IsNull();
+          };
+      MarkDuplicates(sorted, value_selector, is_null_selector);
+    } else {
+      MarkDuplicates(sorted, value_selector, [](int64_t) { return true; });
+    }
   }
   return sorted;
 }
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc 
b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
index e18fcf3771..cd0baff3dd 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
@@ -2362,6 +2362,36 @@ TEST_F(TestRank, Real) {
   }
 }
 
+TEST_F(TestRank, NaNsAndNulls) {
+  auto type = float64();
+  auto array = ArrayFromJSON(type, "[1.0, null, NaN, 2.0, NaN, null]");
+  SetInput(array);
+
+  // Sorted order (at_end): [1.0, 2.0, NaN, NaN, null, null]
+  // Ranks (min): [1, 5, 3, 2, 3, 5]
+  auto expected_at_end = ArrayFromJSON(uint64(), "[1, 5, 3, 2, 3, 5]");
+  AssertRank(SortOrder::Ascending, NullPlacement::AtEnd, RankOptions::Min,
+             expected_at_end);
+
+  // Sorted order (at_start): [null, null, NaN, NaN, 1.0, 2.0]
+  // Ranks (min): [5, 1, 3, 6, 3, 1]
+  auto expected_at_start = ArrayFromJSON(uint64(), "[5, 1, 3, 6, 3, 1]");
+  AssertRank(SortOrder::Ascending, NullPlacement::AtStart, RankOptions::Min,
+             expected_at_start);
+
+  // Sorted order (descending, at_end): [2.0, 1.0, NaN, NaN, null, null]
+  // Ranks (min): [2, 5, 3, 1, 3, 5]
+  auto expected_desc_at_end = ArrayFromJSON(uint64(), "[2, 5, 3, 1, 3, 5]");
+  AssertRank(SortOrder::Descending, NullPlacement::AtEnd, RankOptions::Min,
+             expected_desc_at_end);
+
+  // Sorted order (descending, at_start): [null, null, NaN, NaN, 2.0, 1.0]
+  // Ranks (min): [6, 1, 3, 5, 3, 1]
+  auto expected_desc_at_start = ArrayFromJSON(uint64(), "[6, 1, 3, 5, 3, 1]");
+  AssertRank(SortOrder::Descending, NullPlacement::AtStart, RankOptions::Min,
+             expected_desc_at_start);
+}
+
 TEST_F(TestRank, Integral) {
   for (auto integer_type : ::arrow::IntTypes()) {
     SetInput(ArrayFromJSON(integer_type, "[2, 3, 1, 0, 5]"));

Reply via email to