icexelloss commented on code in PR #34912:
URL: https://github.com/apache/arrow/pull/34912#discussion_r1173124573
##########
cpp/src/arrow/compute/kernels/aggregate_basic_internal.h:
##########
@@ -272,8 +273,120 @@ struct MeanKernelInit : public SumLikeInit<KernelClass> {
};
// ----------------------------------------------------------------------
-// MinMax implementation
+// Last implementation
+template <typename ArrowType, SimdLevel::type SimdLevel, typename Enable =
void>
+struct FirstLastState {};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct FirstLastState<ArrowType, SimdLevel,
enable_if_floating_point<ArrowType>> {
+ using ThisType = FirstLastState<ArrowType, SimdLevel>;
+ using T = typename ArrowType::c_type;
+ using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
+
+ ThisType& operator+=(const ThisType& rhs) {
+ this->has_nulls |= rhs.has_nulls;
+ this->first = this->first.has_value() ? this->first : rhs.first;
+ this->last = rhs.last.has_value() ? rhs.last : this->last;
+ return *this;
+ }
+
+ void MergeOne(T value) {
+ if (!this->first.has_value()) {
+ this->first = value;
+ }
+ this->last = value;
+ }
+
+ std::optional<T> first = std::nullopt;
+ std::optional<T> last = std::nullopt;
+ bool has_nulls = false;
+};
+
+template <typename ArrowType, SimdLevel::type SimdLevel>
+struct FirstLastImpl : public ScalarAggregator {
+ using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using ThisType = FirstLastImpl<ArrowType, SimdLevel>;
+ using StateType = FirstLastState<ArrowType, SimdLevel>;
+
+ FirstLastImpl(std::shared_ptr<DataType> out_type, ScalarAggregateOptions
options)
+ : out_type(std::move(out_type)), options(std::move(options)), count(0) {
+ this->options.min_count = std::max<uint32_t>(1, this->options.min_count);
+ }
+
+ Status Consume(KernelContext*, const ExecSpan& batch) override {
+ if (batch[0].is_array()) {
+ return ConsumeArray(batch[0].array);
+ }
+ return ConsumeScalar(*batch[0].scalar);
+ }
+
+ Status ConsumeScalar(const Scalar& scalar) {
+ return Status::NotImplemented("Consume scalar");
+ }
+
+ Status ConsumeArray(const ArraySpan& arr_span) {
+ StateType local;
+
+ ArrayType arr(arr_span.ToArrayData());
+ const auto null_count = arr.null_count();
+ local.has_nulls = null_count > 0;
+ this->count += arr.length() - null_count;
+
+ if (!local.has_nulls) {
+ for (int64_t i = 0; i < arr.length(); i++) {
+ local.MergeOne(arr.GetView(i));
+ }
Review Comment:
I updated this to be close to what you have. Can you take a look if that
looks fine to you?
##########
cpp/src/arrow/compute/kernels/hash_aggregate.cc:
##########
@@ -1251,6 +1251,210 @@ HashAggregateKernel
MakeApproximateMedianKernel(HashAggregateFunction* tdigest_f
return kernel;
}
+// ----------------------------------------------------------------------
+// FirstLast implementation
+
+template <typename CType>
+struct NullSentinel {
+ static constexpr CType value() { return std::numeric_limits<CType>::min(); }
+};
+
+template <>
+struct NullSentinel<float> {
+ static constexpr float value() { return
std::numeric_limits<float>::infinity(); }
+};
+
+template <>
+struct NullSentinel<double> {
+ static constexpr double value() { return
std::numeric_limits<double>::infinity(); }
+};
+
+template <typename Type, typename Enable = void>
+struct GroupedFirstLastImpl final : public GroupedAggregator {
+ using CType = typename TypeTraits<Type>::CType;
+ using GetSet = GroupedValueTraits<Type>;
+ using ArrType =
+ typename std::conditional<is_boolean_type<Type>::value, uint8_t,
CType>::type;
+
+ Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
+ options_ = *checked_cast<const ScalarAggregateOptions*>(args.options);
+
+ firsts_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ lasts_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+ has_values_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ has_nulls_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+ return Status::OK();
+ }
+
+ Status Resize(int64_t new_num_groups) override {
+ auto added_groups = new_num_groups - num_groups_;
+ num_groups_ = new_num_groups;
+ RETURN_NOT_OK(firsts_.Append(added_groups, NullSentinel<CType>::value()));
+ RETURN_NOT_OK(lasts_.Append(added_groups, NullSentinel<CType>::value()));
+ RETURN_NOT_OK(has_values_.Append(added_groups, false));
+ RETURN_NOT_OK(has_nulls_.Append(added_groups, false));
+ return Status::OK();
+ }
+
+ Status Consume(const ExecSpan& batch) override {
+ auto raw_firsts = firsts_.mutable_data();
+ auto raw_lasts = lasts_.mutable_data();
+ auto raw_has_values = has_values_.mutable_data();
+
+ VisitGroupedValues<Type>(
+ batch,
+ [&](uint32_t g, CType val) {
+ if (!bit_util::GetBit(raw_has_values, g)) {
+ GetSet::Set(raw_firsts, g, val);
+ bit_util::SetBit(raw_has_values, g);
+ }
+ GetSet::Set(raw_lasts, g, val);
+ DCHECK(bit_util::GetBit(has_values_.mutable_data(), g));
+ },
+ [&](uint32_t g) { bit_util::SetBit(has_nulls_.mutable_data(), g); });
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ // The merge is asymmetric. "first" from this state gets pick over "first"
from other
+ // state. "last" from other state gets pick over from this state. This is
so that when
+ // using with segmeneted aggregation, we still get the correct "first" and
"last"
+ // value for the entire segement.
+ auto other = checked_cast<GroupedFirstLastImpl*>(&raw_other);
+
+ auto raw_firsts = firsts_.mutable_data();
+ auto raw_lasts = lasts_.mutable_data();
+ auto raw_has_values = has_values_.mutable_data();
+ auto raw_has_nulls = has_nulls_.mutable_data();
+
+ auto other_raw_firsts = other->firsts_.mutable_data();
+ auto other_raw_lasts = other->lasts_.mutable_data();
+ auto other_raw_has_values = other->has_values_.mutable_data();
+ auto other_raw_has_nulls = other->has_nulls_.mutable_data();
+
+ auto g = group_id_mapping.GetValues<uint32_t>(1);
+
+ for (uint32_t other_g = 0; static_cast<int64_t>(other_g) <
group_id_mapping.length;
+ ++other_g, ++g) {
+ if (!bit_util::GetBit(raw_has_values, *g)) {
+ if (bit_util::GetBit(other_raw_has_values, other_g)) {
+ GetSet::Set(raw_firsts, *g, GetSet::Get(other_raw_firsts, other_g));
+ }
+ }
+
+ if (bit_util::GetBit(other_raw_has_values, other_g)) {
+ GetSet::Set(raw_lasts, *g, GetSet::Get(other_raw_lasts, other_g));
+ }
+
+ if (bit_util::GetBit(other_raw_has_values, other_g)) {
+ bit_util::SetBit(raw_has_values, *g);
+ }
+ if (bit_util::GetBit(other_raw_has_nulls, other_g)) {
+ bit_util::SetBit(raw_has_nulls, *g);
+ }
+ }
+ return Status::OK();
+ }
+
+ Result<Datum> Finalize() override {
+ ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_values_.Finish());
+
+ if (!options_.skip_nulls) {
+ return Status::NotImplemented("Don't support first/last with skip nulls
= False");
+ }
+
+ auto firsts = ArrayData::Make(type_, num_groups_, {null_bitmap, nullptr});
+ auto lasts = ArrayData::Make(type_, num_groups_, {std::move(null_bitmap),
nullptr});
+ ARROW_ASSIGN_OR_RAISE(firsts->buffers[1], firsts_.Finish());
+ ARROW_ASSIGN_OR_RAISE(lasts->buffers[1], lasts_.Finish());
+
+ return ArrayData::Make(out_type(), num_groups_, {nullptr},
+ {std::move(firsts), std::move(lasts)});
+ }
+
+ std::shared_ptr<DataType> out_type() const override {
+ return struct_({field("first", type_), field("last", type_)});
+ }
+
+ int64_t num_groups_;
+ TypedBufferBuilder<CType> firsts_, lasts_;
+ TypedBufferBuilder<bool> has_values_, has_nulls_;
+ std::shared_ptr<DataType> type_;
+ ScalarAggregateOptions options_;
+};
+
+template <typename T>
+Result<std::unique_ptr<KernelState>> FirstLastInit(KernelContext* ctx,
+ const KernelInitArgs& args)
{
+ ARROW_ASSIGN_OR_RAISE(auto impl,
HashAggregateInit<GroupedFirstLastImpl<T>>(ctx, args));
+ static_cast<GroupedFirstLastImpl<T>*>(impl.get())->type_ =
+ args.inputs[0].GetSharedPtr();
+ return impl;
+}
+
+template <FirstOrLast first_or_last>
+HashAggregateKernel MakeFirstOrLastKernel(HashAggregateFunction*
first_last_func) {
+ HashAggregateKernel kernel;
+ kernel.init = [first_last_func](
+ KernelContext* ctx,
+ const KernelInitArgs& args) ->
Result<std::unique_ptr<KernelState>> {
+ std::vector<TypeHolder> inputs = args.inputs;
+ ARROW_ASSIGN_OR_RAISE(auto kernel,
first_last_func->DispatchExact(args.inputs));
+ KernelInitArgs new_args{kernel, inputs, args.options};
+ return kernel->init(ctx, new_args);
+ };
+
+ kernel.signature =
+ KernelSignature::Make({InputType::Any(), Type::UINT32},
OutputType(FirstType));
+ kernel.resize = HashAggregateResize;
+ kernel.consume = HashAggregateConsume;
+ kernel.merge = HashAggregateMerge;
+ kernel.finalize = [](KernelContext* ctx, Datum* out) {
+ ARROW_ASSIGN_OR_RAISE(Datum temp,
+
checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
+ *out =
temp.array_as<StructArray>()->field(static_cast<uint8_t>(first_or_last));
+ return Status::OK();
+ };
+ return kernel;
+}
+
+struct GroupedFirstLastFactory {
+ template <typename T>
+ enable_if_physical_integer<T, Status> Visit(const T&) {
+ using PhysicalType = typename T::PhysicalType;
+ kernel = MakeKernel(std::move(argument_type), FirstLastInit<PhysicalType>);
+ return Status::OK();
+ }
+ Status Visit(const FloatType&) {
+ kernel = MakeKernel(std::move(argument_type), FirstLastInit<FloatType>);
+ return Status::OK();
+ }
+
+ Status Visit(const DoubleType&) {
+ kernel = MakeKernel(std::move(argument_type), FirstLastInit<DoubleType>);
+ return Status::OK();
+ }
+
+ Status Visit(const HalfFloatType& type) {
+ return Status::NotImplemented("Computing first/last of data of type ",
type);
+ }
+
+ Status Visit(const DataType& type) {
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]