This is an automated email from the ASF dual-hosted git repository.
gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new ee4196d9d23 [Improvement](agg) Improve count distinct distribute keys
(#33167)
ee4196d9d23 is described below
commit ee4196d9d23c252fc35a287c04adf9d705e7637f
Author: Gabriel <[email protected]>
AuthorDate: Fri Apr 26 18:31:11 2024 +0800
[Improvement](agg) Improve count distinct distribute keys (#33167)
---
.../aggregate_function_simple_factory.cpp | 2 +
.../aggregate_functions/aggregate_function_uniq.h | 2 +-
.../aggregate_function_uniq_distribute_key.cpp | 73 ++++++
.../aggregate_function_uniq_distribute_key.h | 253 +++++++++++++++++++++
4 files changed, 329 insertions(+), 1 deletion(-)
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 00597b212be..d95d0ce6ccb 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -40,6 +40,7 @@ void
register_aggregate_function_count(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_count_by_enum(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory);
+void
register_aggregate_function_uniq_distribute_key(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory&
factory);
void
register_aggregate_function_quantile_state(AggregateFunctionSimpleFactory&
factory);
@@ -80,6 +81,7 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_count(instance);
register_aggregate_function_count_by_enum(instance);
register_aggregate_function_uniq(instance);
+ register_aggregate_function_uniq_distribute_key(instance);
register_aggregate_function_bit(instance);
register_aggregate_function_bitmap(instance);
register_aggregate_function_group_array_intersect(instance);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
index 2e8855134eb..58abd3842c2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
@@ -75,7 +75,7 @@ struct AggregateFunctionUniqExactData {
Set set;
- static String get_name() { return "uniqExact"; }
+ static String get_name() { return "multi_distinct"; }
};
namespace detail {
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp
b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp
new file mode 100644
index 00000000000..3bf979483b5
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.cpp
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function_uniq_distribute_key.h"
+
+#include <string>
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/aggregate_functions/helpers.h"
+
+namespace doris::vectorized {
+
+template <template <typename> class Data>
+AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
+ if (argument_types.size() == 1) {
+ const IDataType& argument_type = *remove_nullable(argument_types[0]);
+ WhichDataType which(argument_type);
+
+ AggregateFunctionPtr res(
+
creator_with_numeric_type::create<AggregateFunctionUniqDistributeKey, Data>(
+ argument_types, result_is_nullable));
+ if (res) {
+ return res;
+ } else if (which.is_decimal32()) {
+ return creator_without_type::create<
+ AggregateFunctionUniqDistributeKey<Decimal32,
Data<Int32>>>(argument_types,
+
result_is_nullable);
+ } else if (which.is_decimal64()) {
+ return creator_without_type::create<
+ AggregateFunctionUniqDistributeKey<Decimal64,
Data<Int64>>>(argument_types,
+
result_is_nullable);
+ } else if (which.is_decimal128v3()) {
+ return creator_without_type::create<
+ AggregateFunctionUniqDistributeKey<Decimal128V3,
Data<Int128>>>(
+ argument_types, result_is_nullable);
+ } else if (which.is_decimal128v2() || which.is_decimal128v3()) {
+ return creator_without_type::create<
+ AggregateFunctionUniqDistributeKey<Decimal128V2,
Data<Int128>>>(
+ argument_types, result_is_nullable);
+ } else if (which.is_string_or_fixed_string()) {
+ return creator_without_type::create<
+ AggregateFunctionUniqDistributeKey<String,
Data<String>>>(argument_types,
+
result_is_nullable);
+ }
+ }
+
+ return nullptr;
+}
+
+void
register_aggregate_function_uniq_distribute_key(AggregateFunctionSimpleFactory&
factory) {
+ AggregateFunctionCreator creator =
+
create_aggregate_function_uniq<AggregateFunctionUniqDistributeKeyData>;
+ factory.register_function_both("multi_distinct_count_distribute_key",
creator);
+}
+
+} // namespace doris::vectorized
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.h
b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.h
new file mode 100644
index 00000000000..0fa66e34230
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq_distribute_key.h
@@ -0,0 +1,253 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionUniq.h
+// and modified by Doris
+
+#pragma once
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <boost/iterator/iterator_facade.hpp>
+#include <memory>
+#include <vector>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_uniq.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_fixed_length_object.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_fixed_length_object.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/var_int.h"
+
+namespace doris {
+namespace vectorized {
+class Arena;
+class BufferReadable;
+class BufferWritable;
+} // namespace vectorized
+} // namespace doris
+template <typename T>
+struct HashCRC32;
+namespace doris::vectorized {
+
+template <typename T>
+struct AggregateFunctionUniqDistributeKeyData {
+ static constexpr bool is_string_key = std::is_same_v<T, String>;
+ using Key = std::conditional_t<is_string_key, UInt128, T>;
+ using Hash = std::conditional_t<is_string_key, UInt128TrivialHash,
HashCRC32<Key>>;
+
+ using Set = flat_hash_set<Key, Hash>;
+
+ // TODO: replace SipHash with xxhash to speed up
+ static UInt128 ALWAYS_INLINE get_key(const StringRef& value) {
+ auto hash_value = XXH_INLINE_XXH128(value.data, value.size, 0);
+ return UInt128 {hash_value.high64, hash_value.low64};
+ }
+
+ Set set;
+ UInt64 count = 0;
+};
+
+template <typename T, typename Data>
+class AggregateFunctionUniqDistributeKey final
+ : public IAggregateFunctionDataHelper<Data,
AggregateFunctionUniqDistributeKey<T, Data>> {
+public:
+ using KeyType = std::conditional_t<std::is_same_v<T, String>, UInt128, T>;
+ AggregateFunctionUniqDistributeKey(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper<Data,
AggregateFunctionUniqDistributeKey<T, Data>>(
+ argument_types_) {}
+
+ String get_name() const override { return "multi_distinct_distribute_key";
}
+
+ DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
+ Arena*) const override {
+ detail::OneAdder<T, Data>::add(this->data(place), *columns[0],
row_num);
+ }
+
+ static ALWAYS_INLINE const KeyType* get_keys(std::vector<KeyType>&
keys_container,
+ const IColumn& column, size_t
batch_size) {
+ if constexpr (std::is_same_v<T, String>) {
+ keys_container.resize(batch_size);
+ for (size_t i = 0; i != batch_size; ++i) {
+ StringRef value = column.get_data_at(i);
+ keys_container[i] = Data::get_key(value);
+ }
+ return keys_container.data();
+ } else {
+ using ColumnType =
+ std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>,
ColumnVector<T>>;
+ return assert_cast<const ColumnType&>(column).get_data().data();
+ }
+ }
+
+ void add_batch(size_t batch_size, AggregateDataPtr* places, size_t
place_offset,
+ const IColumn** columns, Arena* arena, bool /*agg_many*/)
const override {
+ std::vector<KeyType> keys_container;
+ const KeyType* keys = get_keys(keys_container, *columns[0],
batch_size);
+
+ std::vector<typename Data::Set*> array_of_data_set(batch_size);
+
+ for (size_t i = 0; i != batch_size; ++i) {
+ array_of_data_set[i] = &(this->data(places[i] + place_offset).set);
+ }
+
+ for (size_t i = 0; i != batch_size; ++i) {
+ if (i + HASH_MAP_PREFETCH_DIST < batch_size) {
+ array_of_data_set[i + HASH_MAP_PREFETCH_DIST]->prefetch(
+ keys[i + HASH_MAP_PREFETCH_DIST]);
+ }
+
+ array_of_data_set[i]->insert(keys[i]);
+ }
+ }
+
+ void add_batch_single_place(size_t batch_size, AggregateDataPtr place,
const IColumn** columns,
+ Arena* arena) const override {
+ std::vector<KeyType> keys_container;
+ const KeyType* keys = get_keys(keys_container, *columns[0],
batch_size);
+ auto& set = this->data(place).set;
+
+ for (size_t i = 0; i != batch_size; ++i) {
+ if (i + HASH_MAP_PREFETCH_DIST < batch_size) {
+ set.prefetch(keys[i + HASH_MAP_PREFETCH_DIST]);
+ }
+ set.insert(keys[i]);
+ }
+ }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ this->data(place).count += this->data(rhs).count;
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ write_var_uint(this->data(place).set.size(), buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ read_var_uint(this->data(place).count, buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+
assert_cast<ColumnInt64&>(to).get_data().push_back(this->data(place).count);
+ }
+
+ void deserialize_from_column(AggregateDataPtr places, const IColumn&
column, Arena* arena,
+ size_t num_rows) const override {
+ auto data = reinterpret_cast<const UInt64*>(
+ assert_cast<const
ColumnFixedLengthObject&>(column).get_data().data());
+ for (size_t i = 0; i != num_rows; ++i) {
+ auto rhs_place = places + sizeof(Data) * i;
+ this->create(rhs_place);
+ (reinterpret_cast<Data*>(rhs_place))->count = data[i];
+ }
+ }
+
+ void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
+ MutableColumnPtr& dst, const size_t num_rows)
const override {
+ auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
+ CHECK(col.item_size() == sizeof(UInt64))
+ << "size is not equal: " << col.item_size() << " " <<
sizeof(UInt64);
+ col.resize(num_rows);
+ auto* data = reinterpret_cast<UInt64*>(col.get_data().data());
+ for (size_t i = 0; i != num_rows; ++i) {
+ data[i] = this->data(places[i] + offset).set.size();
+ }
+ }
+
+ void streaming_agg_serialize_to_column(const IColumn** columns,
MutableColumnPtr& dst,
+ const size_t num_rows, Arena*
arena) const override {
+ auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst);
+ CHECK(dst_col.item_size() == sizeof(UInt64))
+ << "size is not equal: " << dst_col.item_size() << " " <<
sizeof(UInt64);
+ dst_col.resize(num_rows);
+ auto* data = reinterpret_cast<UInt64*>(dst_col.get_data().data());
+ for (size_t i = 0; i != num_rows; ++i) {
+ data[i] = 1;
+ }
+ }
+
+ void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column,
+ Arena* arena) const override {
+ auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
+ const size_t num_rows = column.size();
+ auto* data = reinterpret_cast<const UInt64*>(col.get_data().data());
+ for (size_t i = 0; i != num_rows; ++i) {
+ AggregateFunctionUniqDistributeKey::data(place).count += data[i];
+ }
+ }
+
+ void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict
place,
+ const IColumn& column, size_t
begin, size_t end,
+ Arena* arena) const override {
+ CHECK(end <= column.size() && begin <= end)
+ << ", begin:" << begin << ", end:" << end << ",
column.size():" << column.size();
+ auto& col = assert_cast<const ColumnFixedLengthObject&>(column);
+ auto* data = reinterpret_cast<const UInt64*>(col.get_data().data());
+ for (size_t i = begin; i <= end; ++i) {
+ this->data(place).count += data[i];
+ }
+ }
+
+ void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t
offset,
+ AggregateDataPtr rhs, const ColumnString*
column, Arena* arena,
+ const size_t num_rows) const override {
+ this->deserialize_from_column(rhs, *column, arena, num_rows);
+ DEFER({ this->destroy_vec(rhs, num_rows); });
+ this->merge_vec(places, offset, rhs, arena, num_rows);
+ }
+
+ void deserialize_and_merge_vec_selected(const AggregateDataPtr* places,
size_t offset,
+ AggregateDataPtr rhs, const
ColumnString* column,
+ Arena* arena, const size_t
num_rows) const override {
+ this->deserialize_from_column(rhs, *column, arena, num_rows);
+ DEFER({ this->destroy_vec(rhs, num_rows); });
+ this->merge_vec_selected(places, offset, rhs, arena, num_rows);
+ }
+
+ void serialize_without_key_to_column(ConstAggregateDataPtr __restrict
place,
+ IColumn& to) const override {
+ auto& col = assert_cast<ColumnFixedLengthObject&>(to);
+ CHECK(col.item_size() == sizeof(UInt64))
+ << "size is not equal: " << col.item_size() << " " <<
sizeof(UInt64);
+ size_t old_size = col.size();
+ col.resize(old_size + 1);
+ *reinterpret_cast<UInt64*>(col.get_data().data() + old_size) =
+ AggregateFunctionUniqDistributeKey::data(place).set.size();
+ }
+
+ MutableColumnPtr create_serialize_column() const override {
+ return ColumnFixedLengthObject::create(sizeof(UInt64));
+ }
+
+ DataTypePtr get_serialized_type() const override {
+ return std::make_shared<DataTypeFixedLengthObject>();
+ }
+};
+
+} // namespace doris::vectorized
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]