pitrou commented on a change in pull request #11257: URL: https://github.com/apache/arrow/pull/11257#discussion_r718328803
########## File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc ########## @@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> CountInit(KernelContext*, static_cast<const CountOptions&>(*args.options)); } +// ---------------------------------------------------------------------- +// Distinct Count implementation + +template <typename Type> +struct CountDistinctImpl : public ScalarAggregator { + using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType; + + explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options) + : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 0)) {} + + Status Consume(KernelContext*, const ExecBatch& batch) override { + if (batch[0].is_array()) { + const ArrayData& arr = *batch[0].array(); + auto visit_null = [&]() { + if (this->nulls > 0) return Status::OK(); + ++this->nulls; + return Status::OK(); + }; + auto visit_value = [&](typename Type::c_type arg) { + int y; + RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y)); + return Status::OK(); + }; + RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null)); + this->non_nulls += this->memo_table_->size(); + } else { + const Scalar& input = *batch[0].scalar(); + this->nulls += !input.is_valid * batch.length; + this->non_nulls += input.is_valid * batch.length; + } + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + const auto& other_state = checked_cast<const CountDistinctImpl&>(src); + this->non_nulls += other_state.non_nulls; + this->nulls += other_state.nulls; + return Status::OK(); + } + + Status Finalize(KernelContext* ctx, Datum* out) override { + const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state()); + switch (state.options.mode) { + case CountOptions::ONLY_VALID: + *out = Datum(state.non_nulls); + break; + case CountOptions::ALL: + *out = Datum(state.non_nulls + state.nulls); + break; + case CountOptions::ONLY_NULL: + *out = Datum(state.nulls); + break; + default: + DCHECK(false) << "unreachable"; + } + return Status::OK(); + } + + CountOptions options; + int64_t non_nulls = 0; Review comment: Make this `bool has_nulls = false` ########## File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc ########## @@ -754,6 +839,30 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { aggregate::CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared<ScalarAggregateFunction>( + "count_distinct", Arity::Unary(), &count_distinct_doc, &default_count_options); + + // Takes any input, outputs int64 scalar + aggregate::AddCountDistinctKernel<Int8Type>(int8(), func.get()); + aggregate::AddCountDistinctKernel<Int16Type>(int16(), func.get()); + aggregate::AddCountDistinctKernel<Int32Type>(int32(), func.get()); + aggregate::AddCountDistinctKernel<Date32Type>(date32(), func.get()); + aggregate::AddCountDistinctKernel<Int64Type>(int64(), func.get()); + aggregate::AddCountDistinctKernel<UInt8Type>(uint8(), func.get()); + aggregate::AddCountDistinctKernel<UInt16Type>(uint16(), func.get()); + aggregate::AddCountDistinctKernel<UInt32Type>(uint32(), func.get()); + aggregate::AddCountDistinctKernel<UInt64Type>(uint64(), func.get()); + aggregate::AddCountDistinctKernel<FloatType>(float32(), func.get()); + aggregate::AddCountDistinctKernel<DoubleType>(float64(), func.get()); + aggregate::AddCountDistinctKernel<Time32Type>(time32(TimeUnit::SECOND), func.get()); + aggregate::AddCountDistinctKernel<Time32Type>(time32(TimeUnit::MILLI), func.get()); + aggregate::AddCountDistinctKernel<Time64Type>(time64(TimeUnit::MICRO), func.get()); + aggregate::AddCountDistinctKernel<Time64Type>(time64(TimeUnit::NANO), func.get()); Review comment: You should take a look at how this is done for other kernels instead of enumerating types explicitly. For example, see `SumLikeInit`. ########## File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc ########## @@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> CountInit(KernelContext*, static_cast<const CountOptions&>(*args.options)); } +// ---------------------------------------------------------------------- +// Distinct Count implementation + +template <typename Type> +struct CountDistinctImpl : public ScalarAggregator { + using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType; + + explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options) + : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 0)) {} + + Status Consume(KernelContext*, const ExecBatch& batch) override { + if (batch[0].is_array()) { + const ArrayData& arr = *batch[0].array(); + auto visit_null = [&]() { + if (this->nulls > 0) return Status::OK(); + ++this->nulls; + return Status::OK(); + }; Review comment: Right. Basically you can have: ```c++ this->has_nulls = arr.GetNullCount() > 0; auto visit_null = []() {}; auto visit_value = ... ``` ########## File path: cpp/src/arrow/compute/kernels/aggregate_test.cc ########## @@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { } } +// +// Count Distinct +// + +class TestCountDistinctKernel : public ::testing::Test { + protected: + void SetUp() override { + only_valid = CountOptions(CountOptions::ONLY_VALID); + only_null = CountOptions(CountOptions::ONLY_NULL); + all = CountOptions(CountOptions::ALL); + } + + const Datum& expected(int64_t value) { + expected_values[value] = Datum(static_cast<int64_t>(value)); + return expected_values.at(value); Review comment: Indeed, this looks like gratuitous complication. ########## File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc ########## @@ -754,6 +839,30 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { aggregate::CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared<ScalarAggregateFunction>( + "count_distinct", Arity::Unary(), &count_distinct_doc, &default_count_options); + + // Takes any input, outputs int64 scalar + aggregate::AddCountDistinctKernel<Int8Type>(int8(), func.get()); + aggregate::AddCountDistinctKernel<Int16Type>(int16(), func.get()); + aggregate::AddCountDistinctKernel<Int32Type>(int32(), func.get()); + aggregate::AddCountDistinctKernel<Date32Type>(date32(), func.get()); + aggregate::AddCountDistinctKernel<Int64Type>(int64(), func.get()); Review comment: Or use `Type::PhysicalType` when instantiating the kernel class. ########## File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc ########## @@ -121,6 +122,84 @@ Result<std::unique_ptr<KernelState>> CountInit(KernelContext*, static_cast<const CountOptions&>(*args.options)); } +// ---------------------------------------------------------------------- +// Distinct Count implementation + +template <typename Type> +struct CountDistinctImpl : public ScalarAggregator { + using MemoTable = typename arrow::internal::HashTraits<Type>::MemoTableType; + + explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options) + : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 0)) {} + + Status Consume(KernelContext*, const ExecBatch& batch) override { + if (batch[0].is_array()) { + const ArrayData& arr = *batch[0].array(); + auto visit_null = [&]() { + if (this->nulls > 0) return Status::OK(); + ++this->nulls; + return Status::OK(); + }; + auto visit_value = [&](typename Type::c_type arg) { + int y; + RETURN_NOT_OK(memo_table_->GetOrInsert(arg, &y)); + return Status::OK(); + }; + RETURN_NOT_OK(VisitArrayDataInline<Type>(arr, visit_value, visit_null)); + this->non_nulls += this->memo_table_->size(); + } else { + const Scalar& input = *batch[0].scalar(); + this->nulls += !input.is_valid * batch.length; + this->non_nulls += input.is_valid * batch.length; + } + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + const auto& other_state = checked_cast<const CountDistinctImpl&>(src); + this->non_nulls += other_state.non_nulls; + this->nulls += other_state.nulls; + return Status::OK(); + } + + Status Finalize(KernelContext* ctx, Datum* out) override { + const auto& state = checked_cast<const CountDistinctImpl&>(*ctx->state()); + switch (state.options.mode) { + case CountOptions::ONLY_VALID: + *out = Datum(state.non_nulls); + break; + case CountOptions::ALL: + *out = Datum(state.non_nulls + state.nulls); + break; + case CountOptions::ONLY_NULL: + *out = Datum(state.nulls); + break; + default: + DCHECK(false) << "unreachable"; + } + return Status::OK(); + } + + CountOptions options; Review comment: Make this `const`? ########## File path: cpp/src/arrow/compute/kernels/aggregate_test.cc ########## @@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { } } +// +// Count Distinct +// + +class TestCountDistinctKernel : public ::testing::Test { + protected: + void SetUp() override { + only_valid = CountOptions(CountOptions::ONLY_VALID); + only_null = CountOptions(CountOptions::ONLY_NULL); + all = CountOptions(CountOptions::ALL); + } + + const Datum& expected(int64_t value) { + expected_values[value] = Datum(static_cast<int64_t>(value)); + return expected_values.at(value); + } + + CountOptions only_valid; + CountOptions only_null; + CountOptions all; + + private: + std::map<int64_t, Datum> expected_values; +}; + +TEST_F(TestCountDistinctKernel, NumericArrowTypesWithNulls) { + auto sample = "[1, 1, 2, 2, 5, 8, 9, 9, 9, 10, 6, 6]"; + auto sample_nulls = "[null, 8, null, null, 6, null, 8]"; + for (auto ty : NumericTypes()) { + auto input = ArrayFromJSON(ty, sample); + CheckScalar("count_distinct", {input}, expected(7), &only_valid); + CheckScalar("count_distinct", {input}, expected(0), &only_null); + CheckScalar("count_distinct", {input}, expected(7), &all); + auto input_nulls = ArrayFromJSON(ty, sample_nulls); + CheckScalar("count_distinct", {input_nulls}, expected(2), &only_valid); + CheckScalar("count_distinct", {input_nulls}, expected(1), &only_null); + CheckScalar("count_distinct", {input_nulls}, expected(3), &all); + } +} + +TEST_F(TestCountDistinctKernel, RandomValidsStdMap) { + UInt32Builder builder; + std::map<uint32_t, int64_t> hashmap; + auto visit_null = [&]() { return Status::OK(); }; + auto visit_value = [&](uint32_t arg) { + if (hashmap.count(arg) == 0) { + hashmap[arg] = 0; + RETURN_NOT_OK(builder.Append(arg)); + } + ++hashmap[arg]; + return Status::OK(); + }; + auto rand = random::RandomArrayGenerator(0x1205643); + auto arr = rand.Numeric<UInt32Type>(1024, 0, 100, 0.0)->data(); + auto r = VisitArrayDataInline<UInt32Type>(*arr, visit_value, visit_null); + auto input = builder.Finish().ValueOrDie(); + CheckScalar("count_distinct", {input}, expected(hashmap.size()), &all); Review comment: Just use `MakeScalar(static_cast<int64_t>(hashmap.size()))`? ########## File path: cpp/src/arrow/compute/kernels/aggregate_test.cc ########## @@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { } } +// +// Count Distinct +// + +class TestCountDistinctKernel : public ::testing::Test { + protected: + void SetUp() override { + only_valid = CountOptions(CountOptions::ONLY_VALID); + only_null = CountOptions(CountOptions::ONLY_NULL); + all = CountOptions(CountOptions::ALL); + } + + const Datum& expected(int64_t value) { + expected_values[value] = Datum(static_cast<int64_t>(value)); + return expected_values.at(value); + } + + CountOptions only_valid; + CountOptions only_null; + CountOptions all; + + private: + std::map<int64_t, Datum> expected_values; +}; + +TEST_F(TestCountDistinctKernel, NumericArrowTypesWithNulls) { + auto sample = "[1, 1, 2, 2, 5, 8, 9, 9, 9, 10, 6, 6]"; + auto sample_nulls = "[null, 8, null, null, 6, null, 8]"; + for (auto ty : NumericTypes()) { + auto input = ArrayFromJSON(ty, sample); + CheckScalar("count_distinct", {input}, expected(7), &only_valid); + CheckScalar("count_distinct", {input}, expected(0), &only_null); + CheckScalar("count_distinct", {input}, expected(7), &all); + auto input_nulls = ArrayFromJSON(ty, sample_nulls); + CheckScalar("count_distinct", {input_nulls}, expected(2), &only_valid); + CheckScalar("count_distinct", {input_nulls}, expected(1), &only_null); + CheckScalar("count_distinct", {input_nulls}, expected(3), &all); + } +} + +TEST_F(TestCountDistinctKernel, RandomValidsStdMap) { + UInt32Builder builder; + std::map<uint32_t, int64_t> hashmap; Review comment: You shouldn't call this `hashmap` if you're using a plain `map`... Also, the values are unused, just make it a `unordered_set<uint32_t>`. ########## File path: cpp/src/arrow/compute/kernels/aggregate_test.cc ########## @@ -873,6 +873,65 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { } } +// +// Count Distinct +// + +class TestCountDistinctKernel : public ::testing::Test { + protected: + void SetUp() override { + only_valid = CountOptions(CountOptions::ONLY_VALID); + only_null = CountOptions(CountOptions::ONLY_NULL); + all = CountOptions(CountOptions::ALL); + } + + const Datum& expected(int64_t value) { + expected_values[value] = Datum(static_cast<int64_t>(value)); + return expected_values.at(value); + } + + CountOptions only_valid; + CountOptions only_null; + CountOptions all; + + private: + std::map<int64_t, Datum> expected_values; +}; + +TEST_F(TestCountDistinctKernel, NumericArrowTypesWithNulls) { + auto sample = "[1, 1, 2, 2, 5, 8, 9, 9, 9, 10, 6, 6]"; + auto sample_nulls = "[null, 8, null, null, 6, null, 8]"; + for (auto ty : NumericTypes()) { + auto input = ArrayFromJSON(ty, sample); + CheckScalar("count_distinct", {input}, expected(7), &only_valid); + CheckScalar("count_distinct", {input}, expected(0), &only_null); + CheckScalar("count_distinct", {input}, expected(7), &all); + auto input_nulls = ArrayFromJSON(ty, sample_nulls); + CheckScalar("count_distinct", {input_nulls}, expected(2), &only_valid); + CheckScalar("count_distinct", {input_nulls}, expected(1), &only_null); + CheckScalar("count_distinct", {input_nulls}, expected(3), &all); + } +} + +TEST_F(TestCountDistinctKernel, RandomValidsStdMap) { + UInt32Builder builder; + std::map<uint32_t, int64_t> hashmap; + auto visit_null = [&]() { return Status::OK(); }; + auto visit_value = [&](uint32_t arg) { + if (hashmap.count(arg) == 0) { Review comment: This can be made simpler: ```c++ const bool inserted = hashmap.insert(arg).second; if (inserted) { return builder.Append(arg); } return Status::OK(); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org