bkietz commented on a change in pull request #9621:
URL: https://github.com/apache/arrow/pull/9621#discussion_r592463457



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -91,6 +95,377 @@ struct CountImpl : public ScalarAggregator {
   int64_t nulls = 0;
 };
 
+struct GroupedAggregator {
+  virtual ~GroupedAggregator() = default;
+
+  virtual void Consume(KernelContext*, const Datum& aggregand,
+                       const uint32_t* group_ids) = 0;
+
+  virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
+
+  virtual void Resize(KernelContext* ctx, int64_t new_num_groups) = 0;
+
+  virtual int64_t num_groups() const = 0;
+
+  void MaybeResize(KernelContext* ctx, int64_t length, const uint32_t* 
group_ids) {
+    if (length == 0) return;
+
+    // maybe a batch of group_ids should include the min/max group id
+    int64_t max_group = *std::max_element(group_ids, group_ids + length);
+    auto old_size = num_groups();
+
+    if (max_group >= old_size) {
+      auto new_size = BufferBuilder::GrowByFactor(old_size, max_group + 1);
+      Resize(ctx, new_size);
+    }
+  }
+
+  virtual std::shared_ptr<DataType> out_type() const = 0;
+};
+
+struct GroupedCountImpl : public GroupedAggregator {
+  static std::unique_ptr<GroupedCountImpl> Make(KernelContext* ctx,
+                                                const 
std::shared_ptr<DataType>&,
+                                                const FunctionOptions* 
options) {
+    auto out = ::arrow::internal::make_unique<GroupedCountImpl>();
+    out->options_ = checked_cast<const CountOptions&>(*options);
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->counts_));
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_size = num_groups();
+    KERNEL_RETURN_IF_ERROR(ctx, counts_->TypedResize<int64_t>(new_num_groups));
+    auto new_size = num_groups();
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+    for (auto i = old_size; i < new_size; ++i) {
+      raw_counts[i] = 0;
+    }
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+
+    auto raw_counts = reinterpret_cast<int64_t*>(counts_->mutable_data());
+
+    const auto& input = aggregand.array();
+
+    if (options_.count_mode == CountOptions::COUNT_NULL) {
+      for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
+        auto g = group_ids[i];
+        raw_counts[g] += !BitUtil::GetBit(input->buffers[0]->data(), input_i);
+      }
+      return;
+    }
+
+    arrow::internal::VisitSetBitRunsVoid(
+        input->buffers[0], input->offset, input->length,
+        [&](int64_t begin, int64_t length) {
+          for (int64_t input_i = begin, i = begin - input->offset;
+               input_i < begin + length; ++input_i, ++i) {
+            auto g = group_ids[i];
+            raw_counts[g] += 1;
+          }
+        });
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    auto length = num_groups();
+    *out = std::make_shared<Int64Array>(length, std::move(counts_));
+  }
+
+  int64_t num_groups() const override { return counts_->size() / 
sizeof(int64_t); }
+
+  std::shared_ptr<DataType> out_type() const override { return int64(); }
+
+  CountOptions options_;
+  std::shared_ptr<ResizableBuffer> counts_;
+};
+
+struct GroupedSumImpl : public GroupedAggregator {
+  // NB: whether we are accumulating into double, int64_t, or uint64_t
+  // we always have 64 bits per group in the sums buffer.
+  static constexpr size_t kSumSize = sizeof(int64_t);
+
+  using ConsumeImpl = std::function<void(const std::shared_ptr<ArrayData>&,
+                                         const uint32_t*, Buffer*, Buffer*)>;
+
+  struct GetConsumeImpl {
+    template <typename T,
+              typename AccumulatorType = typename FindAccumulatorType<T>::Type>
+    Status Visit(const T&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, Buffer* sums, Buffer* 
counts) {
+        auto raw_input = reinterpret_cast<const typename 
TypeTraits<T>::CType*>(
+            input->buffers[1]->data());
+        auto raw_sums = reinterpret_cast<typename 
TypeTraits<AccumulatorType>::CType*>(
+            sums->mutable_data());
+        auto raw_counts = reinterpret_cast<int64_t*>(counts->mutable_data());
+
+        arrow::internal::VisitSetBitRunsVoid(
+            input->buffers[0], input->offset, input->length,
+            [&](int64_t begin, int64_t length) {
+              for (int64_t input_i = begin, i = begin - input->offset;
+                   input_i < begin + length; ++input_i, ++i) {
+                auto g = group_ids[i];
+                raw_sums[g] += raw_input[input_i];
+                raw_counts[g] += 1;
+              }
+            });
+      };
+      out_type = TypeTraits<AccumulatorType>::type_singleton();
+      return Status::OK();
+    }
+
+    Status Visit(const BooleanType&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, Buffer* sums, Buffer* 
counts) {
+        auto raw_input = input->buffers[1]->data();
+        auto raw_sums = reinterpret_cast<uint64_t*>(sums->mutable_data());
+        auto raw_counts = reinterpret_cast<int64_t*>(counts->mutable_data());
+
+        arrow::internal::VisitSetBitRunsVoid(
+            input->buffers[0], input->offset, input->length,
+            [&](int64_t begin, int64_t length) {
+              for (int64_t input_i = begin, i = begin - input->offset;
+                   input_i < begin + length; ++input_i) {
+                auto g = group_ids[i];
+                raw_sums[g] += BitUtil::GetBit(raw_input, input_i);
+                raw_counts[g] += 1;
+              }
+            });
+      };
+      out_type = boolean();
+      return Status::OK();
+    }
+
+    Status Visit(const HalfFloatType& type) {
+      return Status::NotImplemented("Summing data of type ", type);
+    }
+
+    Status Visit(const DataType& type) {
+      return Status::NotImplemented("Summing data of type ", type);
+    }
+
+    ConsumeImpl consume_impl;
+    std::shared_ptr<DataType> out_type;
+  };
+
+  static std::unique_ptr<GroupedSumImpl> Make(KernelContext* ctx,
+                                              const std::shared_ptr<DataType>& 
input_type,
+                                              const FunctionOptions* options) {
+    auto out = ::arrow::internal::make_unique<GroupedSumImpl>();
+
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->sums_));
+    if (ctx->HasError()) return nullptr;
+
+    ctx->SetStatus(ctx->Allocate(0).Value(&out->counts_));
+    if (ctx->HasError()) return nullptr;
+
+    GetConsumeImpl get_consume_impl;
+    ctx->SetStatus(VisitTypeInline(*input_type, &get_consume_impl));
+
+    out->consume_impl_ = std::move(get_consume_impl.consume_impl);
+    out->out_type_ = std::move(get_consume_impl.out_type);
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_size = num_groups() * kSumSize;
+    KERNEL_RETURN_IF_ERROR(ctx, sums_->Resize(new_num_groups * kSumSize));
+    KERNEL_RETURN_IF_ERROR(ctx, counts_->Resize(new_num_groups * 
sizeof(int64_t)));
+    auto new_size = num_groups() * kSumSize;
+    std::memset(sums_->mutable_data() + old_size, 0, new_size - old_size);
+    std::memset(counts_->mutable_data() + old_size, 0, new_size - old_size);
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+    consume_impl_(aggregand.array(), group_ids, sums_.get(), counts_.get());
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    std::shared_ptr<Buffer> null_bitmap;
+    int64_t null_count = 0;
+
+    for (int64_t i = 0; i < num_groups(); ++i) {
+      if (reinterpret_cast<const int64_t*>(counts_->data())[i] > 0) continue;
+
+      if (null_bitmap == nullptr) {
+        KERNEL_ASSIGN_OR_RAISE(null_bitmap, ctx, 
ctx->AllocateBitmap(num_groups()));
+        BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups(), true);
+      }
+
+      null_count += 1;
+      BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+    }
+
+    *out = ArrayData::Make(std::move(out_type_), num_groups(),
+                           {std::move(null_bitmap), std::move(sums_)}, 
null_count);
+  }
+
+  int64_t num_groups() const override { return counts_->size() / 
sizeof(int64_t); }
+
+  std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+  std::shared_ptr<ResizableBuffer> sums_, counts_;
+  std::shared_ptr<DataType> out_type_;
+  ConsumeImpl consume_impl_;
+};
+
+struct GroupedMinMaxImpl : public GroupedAggregator {
+  using ConsumeImpl = std::function<void(const std::shared_ptr<ArrayData>&,
+                                         const uint32_t*, BufferVector*)>;
+
+  using ResizeImpl = std::function<Status(Buffer*, int64_t)>;
+
+  struct GetImpl {
+    template <typename T, typename CType = typename TypeTraits<T>::CType>
+    enable_if_number<T, Status> Visit(const T&) {
+      consume_impl = [](const std::shared_ptr<ArrayData>& input,
+                        const uint32_t* group_ids, BufferVector* buffers) {
+        auto raw_inputs = reinterpret_cast<const 
CType*>(input->buffers[1]->data());
+
+        auto raw_mins = 
reinterpret_cast<CType*>(buffers->at(0)->mutable_data());
+        auto raw_maxes = 
reinterpret_cast<CType*>(buffers->at(1)->mutable_data());
+
+        auto raw_has_nulls = buffers->at(2)->mutable_data();
+        auto raw_has_values = buffers->at(3)->mutable_data();
+
+        for (int64_t i = 0, input_i = input->offset; i < input->length; ++i, 
++input_i) {
+          auto g = group_ids[i];
+          bool is_valid = BitUtil::GetBit(input->buffers[0]->data(), input_i);
+          if (is_valid) {
+            raw_maxes[g] = std::max(raw_maxes[g], raw_inputs[input_i]);
+            raw_mins[g] = std::min(raw_mins[g], raw_inputs[input_i]);
+            BitUtil::SetBit(raw_has_values, g);
+          } else {
+            BitUtil::SetBit(raw_has_nulls, g);
+          }
+        }
+      };
+
+      for (auto pair :
+           {std::make_pair(&resize_min_impl, 
std::numeric_limits<CType>::max()),
+            std::make_pair(&resize_max_impl, 
std::numeric_limits<CType>::min())}) {
+        *pair.first = [pair](Buffer* vals, int64_t new_num_groups) {
+          int64_t old_num_groups = vals->size() / sizeof(CType);
+
+          int64_t new_size = new_num_groups * sizeof(CType);
+          
RETURN_NOT_OK(checked_cast<ResizableBuffer*>(vals)->Resize(new_size));
+
+          auto raw_vals = reinterpret_cast<CType*>(vals->mutable_data());
+          for (int64_t i = old_num_groups; i != new_num_groups; ++i) {
+            raw_vals[i] = pair.second;
+          }
+          return Status::OK();
+        };
+      }
+
+      return Status::OK();
+    }
+
+    Status Visit(const BooleanType& type) {
+      return Status::NotImplemented("Grouped MinMax data of type ", type);
+    }
+
+    Status Visit(const HalfFloatType& type) {
+      return Status::NotImplemented("Grouped MinMax data of type ", type);
+    }
+
+    Status Visit(const DataType& type) {
+      return Status::NotImplemented("Grouped MinMax data of type ", type);
+    }
+
+    ConsumeImpl consume_impl;
+    ResizeImpl resize_min_impl, resize_max_impl;
+  };
+
+  static std::unique_ptr<GroupedMinMaxImpl> Make(
+      KernelContext* ctx, const std::shared_ptr<DataType>& input_type,
+      const FunctionOptions* options) {
+    auto out = ::arrow::internal::make_unique<GroupedMinMaxImpl>();
+    out->options_ = *checked_cast<const MinMaxOptions*>(options);
+    out->type_ = input_type;
+
+    out->buffers_.resize(4);
+    for (auto& buf : out->buffers_) {
+      ctx->SetStatus(ctx->Allocate(0).Value(&buf));
+      if (ctx->HasError()) return nullptr;
+    }
+
+    GetImpl get_impl;
+    ctx->SetStatus(VisitTypeInline(*input_type, &get_impl));
+
+    out->consume_impl_ = std::move(get_impl.consume_impl);
+    out->resize_min_impl_ = std::move(get_impl.resize_min_impl);
+    out->resize_max_impl_ = std::move(get_impl.resize_max_impl);
+    return out;
+  }
+
+  void Resize(KernelContext* ctx, int64_t new_num_groups) override {
+    auto old_num_groups = num_groups_;
+    num_groups_ = new_num_groups;
+
+    KERNEL_RETURN_IF_ERROR(ctx, resize_min_impl_(buffers_[0].get(), 
new_num_groups));
+    KERNEL_RETURN_IF_ERROR(ctx, resize_max_impl_(buffers_[1].get(), 
new_num_groups));
+
+    for (auto buffer : {buffers_[2].get(), buffers_[3].get()}) {
+      KERNEL_RETURN_IF_ERROR(ctx, 
checked_cast<ResizableBuffer*>(buffer)->Resize(
+                                      BitUtil::BytesForBits(new_num_groups)));
+      BitUtil::SetBitsTo(buffer->mutable_data(), old_num_groups, 
new_num_groups, false);
+    }
+  }
+
+  void Consume(KernelContext* ctx, const Datum& aggregand,
+               const uint32_t* group_ids) override {
+    MaybeResize(ctx, aggregand.length(), group_ids);
+    if (ctx->HasError()) return;
+    consume_impl_(aggregand.array(), group_ids, &buffers_);
+  }
+
+  void Finalize(KernelContext* ctx, Datum* out) override {
+    // aggregation for group is valid if there was at least one value in that 
group
+    std::shared_ptr<Buffer> null_bitmap = std::move(buffers_[3]);
+
+    if (options_.null_handling == MinMaxOptions::EMIT_NULL) {
+      // ... and there were no nulls in that group

Review comment:
       At the time of writing this I wasn't immediately confident we could use 
a buffer as both an input and output to UnalignedBitmapOp. After looking 
through BitmapWordWriter I think this would be safe




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to