This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new af4db7731b ARROW-16807: [C++][R] count distinct incorrectly merges state (#13583) af4db7731b is described below commit af4db7731b1f857e78221c53c2d8221849b1eeec Author: octalene <octalene....@pm.me> AuthorDate: Sat Jul 16 14:45:27 2022 -0700 ARROW-16807: [C++][R] count distinct incorrectly merges state (#13583) This addresses a bug where the `count_distinct` function simply added counts when merging state. The correct logic would be to return the number of distinct elements after both states have been merged. State for count_distinct is backed by a MemoTable, which is then backed by a HashTable. To properly merge state, this PR adds 2 functions to each MemoTable: `MaybeInsert` and `MergeTable`. The MaybeInsert function handles simplified logic for inserting an element into the MemoTable. The MergeTable function handles iteration over elements in the MemoTable _to be merged_. This PR also adds an R test and a C++ test. The R test mirrors what was provided in ARROW-16807. The C++ test, `AllChunkedArrayTypesWithNulls`, mirrors another C++ test, `AllArrayTypesWithNulls`, but uses chunked arrays for test data. Lead-authored-by: Aldrin Montana <octalene....@pm.me> Co-authored-by: Aldrin M <octalene....@pm.me> Co-authored-by: Wes McKinney <w...@apache.org> Signed-off-by: Wes McKinney <w...@apache.org> --- cpp/src/arrow/compute/kernels/aggregate_basic.cc | 17 ++++-- cpp/src/arrow/compute/kernels/aggregate_test.cc | 72 ++++++++++++++++++++++++ cpp/src/arrow/compute/kernels/codegen_internal.h | 2 +- cpp/src/arrow/util/hashing.h | 32 +++++++++++ r/tests/testthat/test-dplyr-summarize.R | 9 +++ 5 files changed, 126 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 57cee87f00..fec483318e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -136,27 +136,34 @@ struct CountDistinctImpl : public ScalarAggregator { Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { const ArrayData& arr = *batch[0].array(); + this->has_nulls = arr.GetNullCount() > 0; + auto visit_null = []() { return Status::OK(); }; auto visit_value = [&](VisitorArgType arg) { - int y; + int32_t y; return memo_table_->GetOrInsert(arg, &y); }; RETURN_NOT_OK(VisitArraySpanInline<Type>(arr, visit_value, visit_null)); - this->non_nulls += memo_table_->size(); - this->has_nulls = arr.GetNullCount() > 0; + } else { const Scalar& input = *batch[0].scalar(); this->has_nulls = !input.is_valid; + if (input.is_valid) { - this->non_nulls += batch.length; + int32_t unused; + RETURN_NOT_OK(memo_table_->GetOrInsert(UnboxScalar<Type>::Unbox(input), &unused)); } } + + this->non_nulls = memo_table_->size(); + return Status::OK(); } Status MergeFrom(KernelContext*, KernelState&& src) override { const auto& other_state = checked_cast<const CountDistinctImpl&>(src); - this->non_nulls += other_state.non_nulls; + RETURN_NOT_OK(this->memo_table_->MergeTable(*(other_state.memo_table_))); + this->non_nulls = this->memo_table_->size(); this->has_nulls = this->has_nulls || other_state.has_nulls; return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index aa54fe5f3e..abd5b5210a 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -962,11 +962,83 @@ class TestCountDistinctKernel : public ::testing::Test { EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); } + void CheckChunkedArr(const std::shared_ptr<DataType>& type, + const std::vector<std::string>& json, int64_t expected_all, + bool has_nulls = true) { + Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls); + } + CountOptions only_valid{CountOptions::ONLY_VALID}; CountOptions only_null{CountOptions::ONLY_NULL}; CountOptions all{CountOptions::ALL}; }; +TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) { + // Boolean + CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false); + CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", "[true]"}, 3); + + // Number + for (auto ty : NumericTypes()) { + CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6, 8, 9, 10]"}, + 8); + CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7, + /*has_nulls=*/false); + } + + // Date + CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"}, 4); + CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"}, 3); + + // Time + CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14, null]"}, 4); + CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000, 11000]"}, 3); + + CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null, 84203999999]", "[0]"}, + 3); + CheckChunkedArr(time64(TimeUnit::NANO), + {"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3); + + // Timestamp & Duration + for (auto u : TimeUnit::values()) { + CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789, null]"}, + 3); + + CheckChunkedArr(duration(u), + {"[123456789, 987654321, 123456789, 123456789]", "[123456789]"}, 2, + /*has_nulls=*/false); + + auto ts = + std::vector<std::string>{R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])", + R"(["2020-01-01", null])", R"(["2020-01-01", null])"}; + CheckChunkedArr(timestamp(u), ts, 3); + CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3); + } + + // Interval + CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null, 9012]"}, + 3); + CheckChunkedArr(day_time_interval(), + {"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3); + CheckChunkedArr(month_day_nano_interval(), + {"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2); + + // Binary & String & Fixed binary + auto samples = std::vector<std::string>{ + R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba", null])"}; + + CheckChunkedArr(binary(), samples, 4); + CheckChunkedArr(large_binary(), samples, 4); + CheckChunkedArr(utf8(), samples, 4); + CheckChunkedArr(large_utf8(), samples, 4); + CheckChunkedArr(fixed_size_binary(3), samples, 4); + + // Decimal + samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"}; + CheckChunkedArr(decimal128(21, 3), samples, 3); + CheckChunkedArr(decimal256(13, 3), samples, 3); +} + TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) { // Boolean Check(boolean(), "[]", 0, /*has_nulls=*/false); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 1d5f5dd9bd..f008314e8b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -343,7 +343,7 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> { using T = util::string_view; static T Unbox(const Scalar& val) { if (!val.is_valid) return util::string_view(); - return util::string_view(*checked_cast<const BaseBinaryScalar&>(val).value); + return checked_cast<const ::arrow::internal::PrimitiveScalarBase&>(val).view(); } }; diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index d2c0178b00..ca5a6c766b 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -485,6 +485,20 @@ class ScalarMemoTable : public MemoTable { hash_t ComputeHash(const Scalar& value) const { return ScalarHelper<Scalar, 0>::ComputeHash(value); } + + public: + // defined here so that `HashTableType` is visible + // Merge entries from `other_table` into `this->hash_table_`. + Status MergeTable(const ScalarMemoTable& other_table) { + const HashTableType& other_hashtable = other_table.hash_table_; + + other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { + int32_t unused; + DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused)); + }); + // TODO: ARROW-17074 - implement proper error handling + return Status::OK(); + } }; // ---------------------------------------------------------------------- @@ -568,6 +582,15 @@ class SmallScalarMemoTable : public MemoTable { // (which is also 1 + the largest memo index) int32_t size() const override { return static_cast<int32_t>(index_to_value_.size()); } + // Merge entries from `other_table` into `this`. + Status MergeTable(const SmallScalarMemoTable& other_table) { + for (const Scalar& other_val : other_table.index_to_value_) { + int32_t unused; + RETURN_NOT_OK(this->GetOrInsert(other_val, &unused)); + } + return Status::OK(); + } + // Copy values starting from index `start` into `out_data` void CopyValues(int32_t start, Scalar* out_data) const { DCHECK_GE(start, 0); @@ -824,6 +847,15 @@ class BinaryMemoTable : public MemoTable { }; return hash_table_.Lookup(h, cmp_func); } + + public: + Status MergeTable(const BinaryMemoTable& other_table) { + other_table.VisitValues(0, [this](const util::string_view& other_value) { + int32_t unused; + DCHECK_OK(this->GetOrInsert(other_value, &unused)); + }); + return Status::OK(); + } }; template <typename T, typename Enable = void> diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index c2207a1f27..3711b49975 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -236,6 +236,15 @@ test_that("Group by any/all", { ) }) +test_that("n_distinct() with many batches", { + tf <- tempfile() + write_parquet(dplyr::starwars, tf, chunk_size = 20) + + ds <- open_dataset(tf) + expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), + ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))) +}) + test_that("n_distinct() on dataset", { # With group_by compare_dplyr_binding(