bkietz commented on a change in pull request #9621:
URL: https://github.com/apache/arrow/pull/9621#discussion_r599009321



##########
File path: cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
##########
@@ -0,0 +1,728 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+
+#include <gtest/gtest.h>
+
+#include "arrow/array.h"
+#include "arrow/chunked_array.h"
+#include "arrow/compute/api_aggregate.h"
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/kernels/aggregate_internal.h"
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/compute/registry.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+#include "arrow/util/bitmap_reader.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/int_util_internal.h"
+
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::BitmapReader;
+using internal::checked_cast;
+using internal::checked_pointer_cast;
+
+namespace compute {
+namespace {
+
+struct ScalarVectorToArray {
+  template <typename T, typename AppendScalar,
+            typename BuilderType = typename TypeTraits<T>::BuilderType,
+            typename ScalarType = typename TypeTraits<T>::ScalarType>
+  Status UseBuilder(const AppendScalar& append) {
+    BuilderType builder(type_, default_memory_pool());
+    for (const auto& s : scalars_) {
+      if (s->is_valid) {
+        RETURN_NOT_OK(append(checked_cast<const ScalarType&>(*s), &builder));
+      } else {
+        RETURN_NOT_OK(builder.AppendNull());
+      }
+    }
+    return builder.FinishInternal(&data_);
+  }
+
+  struct AppendValue {
+    template <typename BuilderType, typename ScalarType>
+    Status operator()(const ScalarType& s, BuilderType* builder) const {
+      return builder->Append(s.value);
+    }
+  };
+
+  struct AppendBuffer {
+    template <typename BuilderType, typename ScalarType>
+    Status operator()(const ScalarType& s, BuilderType* builder) const {
+      const Buffer& buffer = *s.value;
+      return builder->Append(util::string_view{buffer});
+    }
+  };
+
+  template <typename T>
+  enable_if_primitive_ctype<T, Status> Visit(const T&) {
+    return UseBuilder<T>(AppendValue{});
+  }
+
+  template <typename T>
+  enable_if_has_string_view<T, Status> Visit(const T&) {
+    return UseBuilder<T>(AppendBuffer{});
+  }
+
+  Status Visit(const StructType& type) {
+    data_ = ArrayData::Make(type_, static_cast<int64_t>(scalars_.size()),
+                            {/*null_bitmap=*/nullptr});
+    ScalarVector field_scalars(scalars_.size());
+
+    for (int field_index = 0; field_index < type.num_fields(); ++field_index) {
+      for (size_t i = 0; i < scalars_.size(); ++i) {
+        field_scalars[i] =
+            checked_cast<StructScalar*>(scalars_[i].get())->value[field_index];
+      }
+
+      ARROW_ASSIGN_OR_RAISE(Datum field, 
ScalarVectorToArray{}.Convert(field_scalars));
+      data_->child_data.push_back(field.array());
+    }
+    return Status::OK();
+  }
+
+  Status Visit(const DataType& type) {
+    return Status::NotImplemented("ScalarVectorToArray for type ", type);
+  }
+
+  Result<Datum> Convert(const ScalarVector& scalars) && {
+    if (scalars.size() == 0) {
+      return Status::NotImplemented("ScalarVectorToArray with no scalars");
+    }
+    scalars_ = std::move(scalars);
+    type_ = scalars_[0]->type;
+    RETURN_NOT_OK(VisitTypeInline(*type_, this));
+    return Datum(std::move(data_));
+  }
+
+  std::shared_ptr<DataType> type_;
+  ScalarVector scalars_;
+  std::shared_ptr<ArrayData> data_;
+};
+
+Result<Datum> NaiveGroupBy(std::vector<Datum> arguments, std::vector<Datum> 
keys,
+                           const std::vector<internal::Aggregate>& aggregates) 
{
+  ARROW_ASSIGN_OR_RAISE(auto key_batch, ExecBatch::Make(std::move(keys)));
+
+  ARROW_ASSIGN_OR_RAISE(auto grouper,
+                        internal::Grouper::Make(key_batch.GetDescriptors()));
+
+  ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
+
+  ARROW_ASSIGN_OR_RAISE(
+      auto groupings, 
internal::Grouper::MakeGroupings(*id_batch.array_as<UInt32Array>(),
+                                                       grouper->num_groups()));
+
+  ArrayVector out_columns;
+
+  for (size_t i = 0; i < arguments.size(); ++i) {
+    // trim "hash_" prefix
+    auto scalar_agg_function = aggregates[i].function.substr(5);
+
+    ARROW_ASSIGN_OR_RAISE(
+        auto grouped_argument,
+        internal::Grouper::ApplyGroupings(*groupings, 
*arguments[i].make_array()));
+
+    ScalarVector aggregated_scalars;
+
+    for (int64_t i_group = 0; i_group < grouper->num_groups(); ++i_group) {
+      auto slice = grouped_argument->value_slice(i_group);
+      if (slice->length() == 0) continue;
+      ARROW_ASSIGN_OR_RAISE(
+          Datum d, CallFunction(scalar_agg_function, {slice}, 
aggregates[i].options));
+      aggregated_scalars.push_back(d.scalar());
+    }
+
+    ARROW_ASSIGN_OR_RAISE(Datum aggregated_column,
+                          ScalarVectorToArray{}.Convert(aggregated_scalars));
+    out_columns.push_back(aggregated_column.make_array());
+  }
+
+  ARROW_ASSIGN_OR_RAISE(auto uniques, grouper->GetUniques());
+  for (const Datum& key : uniques.values) {
+    out_columns.push_back(key.make_array());
+  }
+
+  std::vector<std::string> out_names(out_columns.size(), "");
+  return StructArray::Make(std::move(out_columns), std::move(out_names));
+}
+
+void ValidateGroupBy(const std::vector<internal::Aggregate>& aggregates,
+                     std::vector<Datum> arguments, std::vector<Datum> keys) {
+  ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys, 
aggregates));
+
+  ASSERT_OK_AND_ASSIGN(Datum actual, GroupBy(arguments, keys, aggregates));
+
+  AssertDatumsEqual(expected, actual, /*verbose=*/true);
+}
+
+}  // namespace
+
+TEST(Grouper, SupportedKeys) {
+  ASSERT_OK(internal::Grouper::Make({boolean()}));
+
+  ASSERT_OK(internal::Grouper::Make({int8(), uint16(), int32(), uint64()}));
+
+  ASSERT_OK(internal::Grouper::Make({dictionary(int64(), utf8())}));
+
+  ASSERT_OK(internal::Grouper::Make({float16(), float32(), float64()}));
+
+  ASSERT_OK(internal::Grouper::Make({utf8(), binary(), large_utf8(), 
large_binary()}));
+
+  ASSERT_OK(internal::Grouper::Make({fixed_size_binary(16), 
fixed_size_binary(32)}));
+
+  ASSERT_OK(internal::Grouper::Make({decimal128(32, 10), decimal256(76, 20)}));
+
+  ASSERT_OK(internal::Grouper::Make({date32(), date64()}));
+
+  for (auto unit : internal::AllTimeUnits()) {
+    ASSERT_OK(internal::Grouper::Make({timestamp(unit), duration(unit)}));
+  }
+
+  ASSERT_OK(internal::Grouper::Make({day_time_interval(), month_interval()}));
+
+  ASSERT_RAISES(NotImplemented, internal::Grouper::Make({struct_({field("", 
int64())})}));
+
+  ASSERT_RAISES(NotImplemented, internal::Grouper::Make({struct_({})}));
+
+  ASSERT_RAISES(NotImplemented, internal::Grouper::Make({list(int32())}));
+
+  ASSERT_RAISES(NotImplemented, 
internal::Grouper::Make({fixed_size_list(int32(), 5)}));
+
+  ASSERT_RAISES(NotImplemented,
+                internal::Grouper::Make({dense_union({field("", int32())})}));
+}
+
+struct TestGrouper {
+  explicit TestGrouper(std::vector<ValueDescr> descrs) : 
descrs_(std::move(descrs)) {
+    grouper_ = internal::Grouper::Make(descrs_).ValueOrDie();
+
+    FieldVector fields;
+    for (const auto& descr : descrs_) {
+      fields.push_back(field("", descr.type));
+    }
+    key_schema_ = schema(std::move(fields));
+  }
+
+  void ExpectConsume(const std::string& key_json, const std::string& expected) 
{
+    ExpectConsume(ExecBatch(*RecordBatchFromJSON(key_schema_, key_json)),
+                  ArrayFromJSON(uint32(), expected));
+  }
+
+  void ExpectConsume(const std::vector<Datum>& key_batch, Datum expected) {
+    ExpectConsume(*ExecBatch::Make(key_batch), expected);
+  }
+
+  void ExpectConsume(const ExecBatch& key_batch, Datum expected) {
+    Datum ids;
+    ConsumeAndValidate(key_batch, &ids);
+    AssertDatumsEqual(expected, ids, /*verbose=*/true);
+  }
+
+  void ConsumeAndValidate(const ExecBatch& key_batch, Datum* ids = nullptr) {
+    ASSERT_OK_AND_ASSIGN(Datum id_batch, grouper_->Consume(key_batch));
+
+    ValidateConsume(key_batch, id_batch);
+
+    if (ids) {
+      *ids = std::move(id_batch);
+    }
+  }
+
+  void ValidateConsume(const ExecBatch& key_batch, const Datum& id_batch) {
+    if (uniques_.length == -1) {
+      ASSERT_OK_AND_ASSIGN(uniques_, grouper_->GetUniques());
+    } else if (static_cast<int64_t>(grouper_->num_groups()) > uniques_.length) 
{
+      ASSERT_OK_AND_ASSIGN(ExecBatch new_uniques, grouper_->GetUniques());
+
+      // check that uniques_ are prefixes of new_uniques
+      for (int i = 0; i < uniques_.num_values(); ++i) {
+        auto prefix = new_uniques[i].array()->Slice(0, uniques_.length);
+        AssertDatumsEqual(uniques_[i], prefix, /*verbose=*/true);
+      }
+
+      uniques_ = std::move(new_uniques);
+    }
+
+    // check that the ids encode an equivalent key sequence
+    for (int i = 0; i < key_batch.num_values(); ++i) {
+      SCOPED_TRACE(std::to_string(i) + "th key array");
+      ASSERT_OK_AND_ASSIGN(auto expected, Take(uniques_[i], id_batch));
+      AssertDatumsEqual(expected, key_batch[i], /*verbose=*/true);
+    }
+  }
+
+  std::vector<ValueDescr> descrs_;
+  std::shared_ptr<Schema> key_schema_;
+  std::unique_ptr<internal::Grouper> grouper_;
+  ExecBatch uniques_ = ExecBatch({}, -1);
+};
+
+TEST(Grouper, BooleanKey) {
+  TestGrouper g({boolean()});
+
+  g.ExpectConsume("[[true], [true]]", "[0, 0]");
+
+  g.ExpectConsume("[[true], [true]]", "[0, 0]");
+
+  g.ExpectConsume("[[false], [null]]", "[1, 2]");
+
+  g.ExpectConsume("[[true], [false], [true], [false], [null], [false], 
[null]]",
+                  "[0, 1, 0, 1, 2, 1, 2]");
+}
+
+TEST(Grouper, NumericKey) {
+  for (auto ty : internal::NumericTypes()) {
+    SCOPED_TRACE("key type: " + ty->ToString());
+
+    TestGrouper g({ty});
+
+    g.ExpectConsume("[[3], [3]]", "[0, 0]");
+
+    g.ExpectConsume("[[3], [3]]", "[0, 0]");
+
+    g.ExpectConsume("[[27], [81]]", "[1, 2]");
+
+    g.ExpectConsume("[[3], [27], [3], [27], [null], [81], [27], [81]]",
+                    "[0, 1, 0, 1, 3, 2, 1, 2]");

Review comment:
       Because previously 81 was encountered; the grouper preserves its hash 
table between calls to Consume. After this call, the internal unique keys will 
be: `[[3], [27], [81], [null]]`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to