lidavidm commented on a change in pull request #10792: URL: https://github.com/apache/arrow/pull/10792#discussion_r676687330
########## File path: cpp/src/arrow/compute/kernels/hash_aggregate.cc ########## @@ -1005,6 +1007,325 @@ struct GroupedSumFactory { InputType argument_type; }; +// ---------------------------------------------------------------------- +// Mean implementation + +template <typename Type> +struct GroupedMeanImpl : public GroupedSumImpl<Type> { + Result<Datum> Finalize() override { + using SumType = typename GroupedSumImpl<Type>::SumType; + std::shared_ptr<Buffer> null_bitmap; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> values, + AllocateBuffer(num_groups_ * sizeof(double), pool_)); + int64_t null_count = 0; + + const int64_t* counts = reinterpret_cast<const int64_t*>(counts_.data()); + const auto* sums = reinterpret_cast<const SumType*>(sums_.data()); + double* means = reinterpret_cast<double*>(values->mutable_data()); + for (int64_t i = 0; i < num_groups_; ++i) { + if (counts[i] > 0) { + means[i] = static_cast<double>(sums[i] / counts[i]); + continue; + } + means[i] = 0; + + if (null_bitmap == nullptr) { + ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_)); + BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true); + } + + null_count += 1; + BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); + } + + return ArrayData::Make(float64(), num_groups_, + {std::move(null_bitmap), std::move(values)}, null_count); + } + + std::shared_ptr<DataType> out_type() const override { return float64(); } + + using GroupedSumImpl<Type>::num_groups_; + using GroupedSumImpl<Type>::pool_; + using GroupedSumImpl<Type>::counts_; + using GroupedSumImpl<Type>::sums_; +}; + +struct GroupedMeanFactory { + template <typename T, typename AccType = typename FindAccumulatorType<T>::Type> + Status Visit(const T&) { + kernel = MakeKernel(std::move(argument_type), HashAggregateInit<GroupedMeanImpl<T>>); + return Status::OK(); + } + + Status Visit(const HalfFloatType& type) { + return Status::NotImplemented("Computing mean of type ", type); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("Computing mean of type ", type); + } + + static Result<HashAggregateKernel> Make(const std::shared_ptr<DataType>& type) { + GroupedMeanFactory factory; + factory.argument_type = InputType::Array(type); + RETURN_NOT_OK(VisitTypeInline(*type, &factory)); + return std::move(factory.kernel); + } + + HashAggregateKernel kernel; + InputType argument_type; +}; + +// Variance/Stdev implementation + +using arrow::internal::int128_t; + +template <typename Type> +struct GroupedVarStdImpl : public GroupedAggregator { + using CType = typename Type::c_type; + + Status Init(ExecContext* ctx, const FunctionOptions* options) override { + options_ = *checked_cast<const VarianceOptions*>(options); + ctx_ = ctx; + pool_ = ctx->memory_pool(); + counts_ = BufferBuilder(pool_); + means_ = BufferBuilder(pool_); + m2s_ = BufferBuilder(pool_); + return Status::OK(); + } + + Status Resize(int64_t new_num_groups) override { + auto added_groups = new_num_groups - num_groups_; + num_groups_ = new_num_groups; + RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0)); + RETURN_NOT_OK(means_.Append(added_groups * sizeof(double), 0)); + RETURN_NOT_OK(m2s_.Append(added_groups * sizeof(double), 0)); + return Status::OK(); + } + + Status Consume(const ExecBatch& batch) override { return ConsumeImpl(batch); } + + // float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm` + // (see aggregate_var_std.cc) + template <typename T = Type> + enable_if_t<is_floating_type<T>::value || (sizeof(CType) > 4), Status> ConsumeImpl( + const ExecBatch& batch) { + using SumType = + typename std::conditional<is_floating_type<T>::value, double, int128_t>::type; + + int64_t* counts = reinterpret_cast<int64_t*>(counts_.mutable_data()); + double* means = reinterpret_cast<double*>(means_.mutable_data()); + double* m2s = reinterpret_cast<double*>(m2s_.mutable_data()); + + // XXX this uses naive summation; we should switch to pairwise summation as was + // done for the scalar aggregate kernel in ARROW-11567 + std::vector<SumType> sums(num_groups_); + auto g = batch[1].array()->GetValues<uint32_t>(1); + VisitArrayDataInline<Type>( + *batch[0].array(), + [&](typename TypeTraits<Type>::CType value) { + sums[*g] += value; + counts[*g] += 1; + ++g; + }, + [&] { ++g; }); + + for (int64_t i = 0; i < num_groups_; i++) { + means[i] = static_cast<double>(sums[i]) / counts[i]; + } + + g = batch[1].array()->GetValues<uint32_t>(1); + VisitArrayDataInline<Type>( + *batch[0].array(), + [&](typename TypeTraits<Type>::CType value) { + const double v = static_cast<double>(value); + m2s[*g] += (v - means[*g]) * (v - means[*g]); + ++g; + }, + [&] { ++g; }); + + return Status::OK(); + } + + // int32/16/8: textbook one pass algorithm with integer arithmetic (see + // aggregate_var_std.cc) + template <typename T = Type> + enable_if_t<is_integer_type<T>::value && (sizeof(CType) <= 4), Status> ConsumeImpl( + const ExecBatch& batch) { + // max number of elements that sum will not overflow int64 (2Gi int32 elements) + // for uint32: 0 <= sum < 2^63 (int64 >= 0) + // for int32: -2^62 <= sum < 2^62 + constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8); + + const auto& array = *batch[0].array(); + const auto g = batch[1].array()->GetValues<uint32_t>(1); + + std::vector<int64_t> sum(num_groups_); + std::vector<int128_t> square_sum(num_groups_); + + ARROW_ASSIGN_OR_RAISE(auto mapping, + AllocateBuffer(num_groups_ * sizeof(uint32_t), pool_)); + for (uint32_t i = 0; static_cast<int64_t>(i) < num_groups_; i++) { + reinterpret_cast<uint32_t*>(mapping->mutable_data())[i] = i; + } + ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)}, + /*null_count=*/0); + + const CType* values = array.GetValues<CType>(1); + + for (int64_t start_index = 0; start_index < batch.length; start_index += max_length) { + // process in chunks that overflow will never happen + + // reset state + std::fill(sum.begin(), sum.end(), 0); + std::fill(square_sum.begin(), square_sum.end(), 0); + GroupedVarStdImpl<Type> state; + RETURN_NOT_OK(state.Init(ctx_, &options_)); + RETURN_NOT_OK(state.Resize(num_groups_)); + int64_t* other_counts = reinterpret_cast<int64_t*>(state.counts_.mutable_data()); + double* other_means = reinterpret_cast<double*>(state.means_.mutable_data()); + double* other_m2s = reinterpret_cast<double*>(state.m2s_.mutable_data()); + + arrow::internal::VisitSetBitRunsVoid( + array.buffers[0], array.offset + start_index, + std::min(max_length, batch.length - start_index), + [&](int64_t pos, int64_t len) { + for (int64_t i = 0; i < len; ++i) { + const int64_t index = start_index + pos + i; + const auto value = values[index]; + sum[g[index]] += value; + square_sum[g[index]] += static_cast<uint64_t>(value) * value; + other_counts[g[index]]++; + } + }); + + for (int64_t i = 0; i < num_groups_; i++) { + if (other_counts[i] == 0) continue; + + const double mean = static_cast<double>(sum[i]) / other_counts[i]; + // calculate m2 = square_sum - sum * sum / count + // decompose `sum * sum / count` into integers and fractions + const int128_t sum_square = static_cast<int128_t>(sum[i]) * sum[i]; + const int128_t integers = sum_square / other_counts[i]; + const double fractions = + static_cast<double>(sum_square % other_counts[i]) / other_counts[i]; + const double m2 = static_cast<double>(square_sum[i] - integers) - fractions; + + other_means[i] = mean; + other_m2s[i] = m2; + } + RETURN_NOT_OK(this->Merge(std::move(state), group_id_mapping)); + } + return Status::OK(); + } + + Status Merge(GroupedAggregator&& raw_other, + const ArrayData& group_id_mapping) override { + // Combine m2 from two chunks (see aggregate_var_std.cc) + auto other = checked_cast<GroupedVarStdImpl*>(&raw_other); + + auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data()); + auto means = reinterpret_cast<double*>(means_.mutable_data()); + auto m2s = reinterpret_cast<double*>(m2s_.mutable_data()); + + const auto* other_counts = reinterpret_cast<const int64_t*>(other->counts_.data()); + const auto* other_means = reinterpret_cast<const double*>(other->means_.data()); + const auto* other_m2s = reinterpret_cast<const double*>(other->m2s_.data()); + + auto g = group_id_mapping.GetValues<uint32_t>(1); + for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { + if (other_counts[other_g] == 0) continue; + const double mean = + (means[*g] * counts[*g] + other_means[other_g] * other_counts[other_g]) / + (counts[*g] + other_counts[other_g]); + m2s[*g] += other_m2s[other_g] + + counts[*g] * (means[*g] - mean) * (means[*g] - mean) + + other_counts[other_g] * (other_means[other_g] - mean) * + (other_means[other_g] - mean); + counts[*g] += other_counts[other_g]; + means[*g] = mean; + } Review comment: Another option that Ben's mentioned would be to treat scalar aggregation as a hash aggregation with one group, though then we should immediately tackle the pairwise summation issue. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org