bkietz commented on a change in pull request #9621:
URL: https://github.com/apache/arrow/pull/9621#discussion_r592456204



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -91,6 +95,377 @@ struct CountImpl : public ScalarAggregator {
   int64_t nulls = 0;
 };
 
+struct GroupedAggregator {
+  virtual ~GroupedAggregator() = default;
+
+  virtual void Consume(KernelContext*, const Datum& aggregand,
+                       const uint32_t* group_ids) = 0;
+
+  virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
+
+  virtual void Resize(KernelContext* ctx, int64_t new_num_groups) = 0;
+
+  virtual int64_t num_groups() const = 0;
+
+  void MaybeResize(KernelContext* ctx, int64_t length, const uint32_t* 
group_ids) {
+    if (length == 0) return;
+
+    // maybe a batch of group_ids should include the min/max group id
+    int64_t max_group = *std::max_element(group_ids, group_ids + length);
+    auto old_size = num_groups();
+
+    if (max_group >= old_size) {
+      auto new_size = BufferBuilder::GrowByFactor(old_size, max_group + 1);
+      Resize(ctx, new_size);
+    }
+  }
+
+  virtual std::shared_ptr<DataType> out_type() const = 0;
+};
+
+struct GroupedCountImpl : public GroupedAggregator {
+  static std::unique_ptr<GroupedCountImpl> Make(KernelContext* ctx,
+                                                const 
std::shared_ptr<DataType>&,
+                                                const FunctionOptions* 
options) {
+    auto out = ::arrow::internal::make_unique<GroupedCountImpl>();
+    out->options_ = checked_cast<const CountOptions&>(*options);
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->counts_));
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_size = num_groups();
+    KERNEL_RETURN_IF_ERROR(ctx, counts_->TypedResize<int64_t>(new_num_groups));
+    auto new_size = num_groups();
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+    for (auto i = old_size; i < new_size; ++i) {
+      raw_counts[i] = 0;
+    }
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+
+    const auto& input = aggregand.array();
+
+    if (options_.count_mode == CountOptions::COUNT_NULL) {
+      for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
+        auto g = group_ids[i];
+        raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), input_i);
+      }
+      return;
+    }
+
+    arrow::internal::VisitSetBitRunsVoid(
+        input->buffers[0], input->offset, input->length,
+        [&](int64_t begin, int64_t length) {
+          for (int64_t input_i = begin, i = begin - input->offset;
+               input_i < begin + length; ++input_i, ++i) {
+            auto g = group_ids[i];
+            raw_counts[g] += 1;
+          }
+        });
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    auto length = num_groups();
+    *out = std::make_shared<Int64Array>(length, std::move(counts_));
+  }
+
+  int64_t num_groups() const override { return counts_->size() / 
sizeof(int64_t); }
+
+  std::shared_ptr<DataType> out_type() const override { return int64(); }
+
+  CountOptions options_;
+  std::shared_ptr<ResizableBuffer> counts_;
+};
+
+struct GroupedSumImpl : public GroupedAggregator {
+  // NB: whether we are accumulating into double, int64_t, or uint64_t
+  // we always have 64 bits per group in the sums buffer.
+  static constexpr size_t kSumSize = sizeof(int64_t);
+
+  using ConsumeImpl = std::function<void(const std::shared_ptr<ArrayData>&,
+                                         const uint32_t*, Buffer*, Buffer*)>;
+
+  struct GetConsumeImpl {
+    template <typename T,
+              typename AccumulatorType = typename FindAccumulatorType<T>::Type>
+    Status Visit(const T&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, Buffer* sums, Buffer* 
counts) {
+        auto raw_input = reinterpret_cast<const typename 
TypeTraits<T>::CType*>(
+            input->buffers[1]->data());
+        auto raw_sums = reinterpret_cast<typename 
TypeTraits<AccumulatorType>::CType*>(
+            sums->mutable_data());
+        auto raw_counts = reinterpret_cast<int64_t*>(counts->mutable_data());
+
+        arrow::internal::VisitSetBitRunsVoid(
+            input->buffers[0], input->offset, input->length,
+            [&](int64_t begin, int64_t length) {
+              for (int64_t input_i = begin, i = begin - input->offset;
+                   input_i < begin + length; ++input_i, ++i) {
+                auto g = group_ids[i];
+                raw_sums[g] += raw_input[input_i];
+                raw_counts[g] += 1;
+              }
+            });
+      };
+      out_type = TypeTraits<AccumulatorType>::type_singleton();
+      return Status::OK();
+    }
+
+    Status Visit(const BooleanType&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, Buffer* sums, Buffer* 
counts) {
+        auto raw_input = input->buffers[1]->data();
+        auto raw_sums = reinterpret_cast<uint64_t*>(sums->mutable_data());
+        auto raw_counts = reinterpret_cast<int64_t*>(counts->mutable_data());
+
+        arrow::internal::VisitSetBitRunsVoid(
+            input->buffers[0], input->offset, input->length,
+            [&](int64_t begin, int64_t length) {
+              for (int64_t input_i = begin, i = begin - input->offset;
+                   input_i < begin + length; ++input_i) {
+                auto g = group_ids[i];
+                raw_sums[g] += BitUtil::GetBit(raw_input, input_i);
+                raw_counts[g] += 1;
+              }
+            });
+      };
+      out_type = boolean();
+      return Status::OK();
+    }
+
+    Status Visit(const HalfFloatType& type) {
+      return Status::NotImplemented("Summing data of type ", type);
+    }
+
+    Status Visit(const DataType& type) {
+      return Status::NotImplemented("Summing data of type ", type);
+    }
+
+    ConsumeImpl consume_impl;
+    std::shared_ptr<DataType> out_type;
+  };
+
+  static std::unique_ptr<GroupedSumImpl> Make(KernelContext* ctx,
+                                              const std::shared_ptr<DataType>& 
input_type,
+                                              const FunctionOptions* options) {
+    auto out = ::arrow::internal::make_unique<GroupedSumImpl>();
+
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->sums_));
+    if (ctx->HasError()) return nullptr;
+
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->counts_));
+    if (ctx->HasError()) return nullptr;
+
+    GetConsumeImpl get_consume_impl;
+    ctx->SetStatus(VisitTypeInline(*input_type, &get_consume_impl));
+
+    out->consume_impl_ = std::move(get_consume_impl.consume_impl);
+    out->out_type_ = std::move(get_consume_impl.out_type);
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_size = num_groups() * kSumSize;
+    KERNEL_RETURN_IF_ERROR(ctx, sums_->Resize(new_num_groups * kSumSize));
+    KERNEL_RETURN_IF_ERROR(ctx, counts_->Resize(new_num_groups * 
sizeof(int64_t)));
+    auto new_size = num_groups() * kSumSize;
+    std::memset(sums_->mutable_data() + old_size, 0, new_size - old_size);
+    std::memset(counts_->mutable_data() + old_size, 0, new_size - old_size);
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+    consume_impl_(aggregand.array(), group_ids, sums_.get(), counts_.get());
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    std::shared_ptr<Buffer> null_bitmap;
+    int64_t null_count = 0;
+
+    for (int64_t i = 0; i < num_groups(); ++i) {
+      if (reinterpret_cast<const int64_t*>(counts_->data())[i] > 0) continue;
+
+      if (null_bitmap == nullptr) {
+        KERNEL_ASSIGN_OR_RAISE(null_bitmap, ctx, 
ctx->AllocateBitmap(num_groups()));
+        BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups(), true);
+      }
+
+      null_count += 1;
+      BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+    }
+
+    *out = ArrayData::Make(std::move(out_type_), num_groups(),
+                           {std::move(null_bitmap), std::move(sums_)}, 
null_count);
+  }
+
+  int64_t num_groups() const override { return counts_->size() / 
sizeof(int64_t); }
+
+  std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+  std::shared_ptr<ResizableBuffer> sums_, counts_;
+  std::shared_ptr<DataType> out_type_;
+  ConsumeImpl consume_impl_;
+};
+
+struct GroupedMinMaxImpl : public GroupedAggregator {
+  using ConsumeImpl = std::function<void(const std::shared_ptr<ArrayData>&,
+                                         const uint32_t*, BufferVector*)>;
+
+  using ResizeImpl = std::function<Status(Buffer*, int64_t)>;
+
+  struct GetImpl {
+    template <typename T, typename CType = typename TypeTraits<T>::CType>
+    enable_if_number<T, Status> Visit(const T&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, BufferVector* buffers) {
+        auto raw_inputs = reinterpret_cast<const 
CType*>(input->buffers[1]->data());
+
+        auto raw_mins = 
reinterpret_cast<CType*>(buffers->at(0)->mutable_data());
+        auto raw_maxes = 
reinterpret_cast<CType*>(buffers->at(1)->mutable_data());
+
+        auto raw_has_nulls = buffers->at(2)->mutable_data();
+        auto raw_has_values = buffers->at(3)->mutable_data();
+
+        for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
+          auto g = group_ids[i];
+          bool is_valid = BitUtil::GetBit(input->buffers[0]->data(), input_i);
+          if (is_valid) {
+            raw_maxes[g] = std::max(raw_maxes[g], raw_inputs[input_i]);
+            raw_mins[g] = std::min(raw_mins[g], raw_inputs[input_i]);
+            BitUtil::SetBit(raw_has_values, g);
+          } else {
+            BitUtil::SetBit(raw_has_nulls, g);
+          }
+        }
+      };
+
+      for (auto pair :
+           {std::make_pair(&resize_min_impl, 
std::numeric_limits<CType>::max()),
+            std::make_pair(&resize_max_impl, 
std::numeric_limits<CType>::min())}) {
+        *pair.first = [pair](Buffer* vals, int64_t new_num_groups) {

Review comment:
       That's great, thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to