This is an automated email from the ASF dual-hosted git repository.
kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new 62e2a74538 [feature](agg) Make 'map_agg' support array type as value
(#22945) (#22991)
62e2a74538 is described below
commit 62e2a745387b70606f982b0734c1a2b6fb0f40da
Author: Jerry Hu <[email protected]>
AuthorDate: Thu Aug 17 15:59:17 2023 +0800
[feature](agg) Make 'map_agg' support array type as value (#22945) (#22991)
---
.../aggregate_functions/aggregate_function_map.h | 243 +++++++++++----------
be/src/vec/columns/column_map.cpp | 6 +-
be/src/vec/exec/vaggregation_node.cpp | 12 +-
.../java/org/apache/doris/catalog/FunctionSet.java | 13 +-
.../data/query_p0/aggregate/map_agg.out | 3 +
.../suites/query_p0/aggregate/map_agg.groovy | 4 +
6 files changed, 150 insertions(+), 131 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.h
b/be/src/vec/aggregate_functions/aggregate_function_map.h
index 5901c6eb66..d04f85973b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_map.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_map.h
@@ -55,7 +55,7 @@ struct AggregateFunctionMapAggData {
_value_column->clear();
}
- void add(const StringRef& key, const StringRef& value) {
+ void add(const StringRef& key, const Field& value) {
DCHECK(key.data != nullptr);
if (UNLIKELY(_map.find(key) != _map.end())) {
return;
@@ -68,7 +68,41 @@ struct AggregateFunctionMapAggData {
_map.emplace(key_holder.key, _key_column->size());
_key_column->insert_data(key_holder.key.data, key_holder.key.size);
- _value_column->insert_data(value.data, value.size);
+ _value_column->insert(value);
+ }
+
+ void add(const Field& key_, const Field& value) {
+ DCHECK(!key_.is_null());
+ auto key_array = vectorized::get<Array>(key_);
+ auto value_array = vectorized::get<Array>(value);
+
+ const auto count = key_array.size();
+ DCHECK_EQ(count, value_array.size());
+
+ for (size_t i = 0; i != count; ++i) {
+ StringRef key;
+ if constexpr (std::is_same_v<K, String>) {
+ auto string = key_array[i].get<K>();
+ key = string;
+ } else {
+ auto& k = key_array[i].get<KeyType>();
+ key.data = reinterpret_cast<const char*>(&k);
+ key.size = sizeof(k);
+ }
+
+ if (UNLIKELY(_map.find(key) != _map.end())) {
+ return;
+ }
+
+ ArenaKeyHolder key_holder {key, _arena};
+ if (key.size > 0) {
+ key_holder_persist_key(key_holder);
+ }
+
+ _map.emplace(key_holder.key, _key_column->size());
+ _key_column->insert_data(key_holder.key.data, key_holder.key.size);
+ _value_column->insert(value_array[i]);
+ }
}
void merge(const AggregateFunctionMapAggData& other) {
@@ -98,65 +132,6 @@ struct AggregateFunctionMapAggData {
}
}
- static void serialize(BufferWritable& buf, const IColumn& key_column,
- const IColumn& value_column, const DataTypePtr&
key_type,
- const DataTypePtr& value_type) {
- size_t element_number = key_column.size();
- write_binary(element_number, buf);
-
- DCHECK(!key_column.is_nullable());
- DCHECK(!key_type->is_nullable());
- DCHECK(value_column.is_nullable());
- DCHECK(value_type->is_nullable());
-
- if (element_number > 0) {
- size_t serialized_size =
key_type->get_uncompressed_serialized_bytes(key_column, 0);
- serialized_size +=
value_type->get_uncompressed_serialized_bytes(value_column, 0);
-
- std::string serialized_buffer;
- serialized_buffer.resize(serialized_size);
- auto* serialized_data = serialized_buffer.data();
-
- serialized_data = key_type->serialize(key_column, serialized_data,
0);
- value_type->serialize(value_column, serialized_data, 0);
-
- write_binary(serialized_size, buf);
- buf.write(serialized_buffer.data(), serialized_buffer.size());
- }
- }
-
- void write(BufferWritable& buf) const {
- serialize(buf, *_key_column, *_value_column, _key_type, _value_type);
- }
-
- void read(BufferReadable& buf) {
- size_t element_number = 0;
- read_binary(element_number, buf);
-
- if (element_number > 0) {
- _map.reserve(element_number);
-
- size_t serialized_size;
- read_binary(serialized_size, buf);
- std::string serialized_buffer;
- serialized_buffer.resize(serialized_size);
-
- buf.read(serialized_buffer.data(), serialized_size);
- const auto* serialized_data = serialized_buffer.data();
- serialized_data = _key_type->deserialize(serialized_data,
_key_column.get(), 0);
- _value_type->deserialize(serialized_data, _value_column.get(), 0);
-
- DCHECK_EQ(element_number, _key_column->size());
- DCHECK_EQ(element_number, _value_column->size());
-
- for (size_t i = 0; i != element_number; ++i) {
- auto key =
static_cast<KeyColumnType&>(*_key_column).get_data_at(i);
- DCHECK(_map.find(key) == _map.cend());
- _map.emplace(key, i);
- }
- }
- }
-
void insert_result_into(IColumn& to) const {
auto& dst = assert_cast<ColumnMap&>(to);
size_t num_rows = _key_column->size();
@@ -211,14 +186,17 @@ public:
if (nullable_map[row_num]) {
return;
}
+ Field value;
+ columns[1]->get(row_num, value);
this->data(place).add(
assert_cast<const
KeyColumnType&>(nullable_col.get_nested_column())
.get_data_at(row_num),
- columns[1]->get_data_at(row_num));
+ value);
} else {
+ Field value;
+ columns[1]->get(row_num, value);
this->data(place).add(
- assert_cast<const
KeyColumnType&>(*columns[0]).get_data_at(row_num),
- columns[1]->get_data_at(row_num));
+ assert_cast<const
KeyColumnType&>(*columns[0]).get_data_at(row_num), value);
}
}
@@ -233,80 +211,107 @@ public:
this->data(place).merge(this->data(rhs));
}
- void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
- this->data(place).write(buf);
+ void serialize(ConstAggregateDataPtr /* __restrict place */,
+ BufferWritable& /* buf */) const override {
+ __builtin_unreachable();
+ }
+
+ void deserialize(AggregateDataPtr /* __restrict place */, BufferReadable&
/* buf */,
+ Arena*) const override {
+ __builtin_unreachable();
}
- template <bool key_nullable, bool value_nullable>
- void streaming_agg_serialize_to_column_impl(const size_t num_rows, const
IColumn& key_column,
- const IColumn& value_column,
- const NullMap& null_map,
- BufferWritable& writer) const {
- auto& key_col = assert_cast<const KeyColumnType&>(key_column);
- auto key_to_serialize = key_col.clone_empty();
- auto val_to_serialize = value_column.clone_empty();
- auto key_type = remove_nullable(argument_types[0]);
- auto val_type = make_nullable(argument_types[1]);
+ void streaming_agg_serialize_to_column(const IColumn** columns,
MutableColumnPtr& dst,
+ const size_t num_rows, Arena*
arena) const override {
+ auto& col = assert_cast<ColumnMap&>(*dst);
for (size_t i = 0; i != num_rows; ++i) {
- key_to_serialize->clear();
- val_to_serialize->clear();
- if constexpr (key_nullable) {
- if (!null_map[i]) {
- key_to_serialize->insert_range_from(key_col, i, 1);
- val_to_serialize->insert_range_from(value_column, i, 1);
- }
- } else {
- key_to_serialize->insert_range_from(key_col, i, 1);
- val_to_serialize->insert_range_from(value_column, i, 1);
+ Map map(2);
+ columns[0]->get(i, map[0]);
+ if (map[0].is_null()) {
+ continue;
}
+ columns[1]->get(i, map[1]);
+ col.insert(map);
+ }
+ }
- if constexpr (value_nullable) {
- Data::serialize(writer, *key_to_serialize, *val_to_serialize,
key_type, val_type);
- } else {
- auto nullable_value_col =
make_nullable(val_to_serialize->assume_mutable(), false);
- Data::serialize(writer, *key_to_serialize,
*nullable_value_col, key_type, val_type);
- val_to_serialize = value_column.clone_empty();
- }
- writer.commit();
+ void deserialize_from_column(AggregateDataPtr places, const IColumn&
column, Arena* arena,
+ size_t num_rows) const override {
+ auto& col = assert_cast<const ColumnMap&>(column);
+ auto* data = &(this->data(places));
+ for (size_t i = 0; i != num_rows; ++i) {
+ auto map = doris::vectorized::get<Map>(col[i]);
+ data->add(map[0], map[1]);
}
}
- void streaming_agg_serialize_to_column(const IColumn** columns,
MutableColumnPtr& dst,
- const size_t num_rows, Arena*
arena) const override {
- auto& col = assert_cast<ColumnString&>(*dst);
- col.reserve(num_rows);
- VectorBufferWriter writer(col);
+ void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
+ MutableColumnPtr& dst, const size_t num_rows)
const override {
+ for (size_t i = 0; i != num_rows; ++i) {
+ Data& data_ = this->data(places[i] + offset);
+ data_.insert_result_into(*dst);
+ }
+ }
- if (columns[0]->is_nullable()) {
- auto& nullable_col = assert_cast<const
ColumnNullable&>(*columns[0]);
- auto& null_map = nullable_col.get_null_map_data();
- if (columns[0]->is_nullable()) {
- this->streaming_agg_serialize_to_column_impl<true, true>(
- num_rows, nullable_col.get_nested_column(),
*columns[1], null_map, writer);
- } else {
- this->streaming_agg_serialize_to_column_impl<true, false>(
- num_rows, nullable_col.get_nested_column(),
*columns[1], null_map, writer);
- }
- } else {
- if (columns[0]->is_nullable()) {
- this->streaming_agg_serialize_to_column_impl<false,
true>(num_rows, *columns[0],
-
*columns[1], {}, writer);
- } else {
- this->streaming_agg_serialize_to_column_impl<false,
false>(num_rows, *columns[0],
-
*columns[1], {}, writer);
+ void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column,
+ Arena* arena) const override {
+ auto& col = assert_cast<const ColumnMap&>(column);
+ const size_t num_rows = column.size();
+ for (size_t i = 0; i != num_rows; ++i) {
+ auto map = doris::vectorized::get<Map>(col[i]);
+ this->data(place).add(map[0], map[1]);
+ }
+ }
+
+ void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict
place,
+ const IColumn& column, size_t
begin, size_t end,
+ Arena* arena) const override {
+ DCHECK(end <= column.size() && begin <= end)
+ << ", begin:" << begin << ", end:" << end << ",
column.size():" << column.size();
+ auto& col = assert_cast<const ColumnMap&>(column);
+ for (size_t i = begin; i <= end; ++i) {
+ auto map = doris::vectorized::get<Map>(col[i]);
+ this->data(place).add(map[0], map[1]);
+ }
+ }
+
+ void deserialize_and_merge_vec(const AggregateDataPtr* places, size_t
offset,
+ AggregateDataPtr rhs, const ColumnString*
column, Arena* arena,
+ const size_t num_rows) const override {
+ auto& col = assert_cast<const ColumnMap&>(*assert_cast<const
IColumn*>(column));
+ for (size_t i = 0; i != num_rows; ++i) {
+ auto map = doris::vectorized::get<Map>(col[i]);
+ this->data(places[i]).add(map[0], map[1]);
+ }
+ }
+
+ void deserialize_and_merge_vec_selected(const AggregateDataPtr* places,
size_t offset,
+ AggregateDataPtr rhs, const
ColumnString* column,
+ Arena* arena, const size_t
num_rows) const override {
+ auto& col = assert_cast<const ColumnMap&>(*assert_cast<const
IColumn*>(column));
+ for (size_t i = 0; i != num_rows; ++i) {
+ if (places[i]) {
+ auto map = doris::vectorized::get<Map>(col[i]);
+ this->data(places[i]).add(map[0], map[1]);
}
}
}
- void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
- Arena*) const override {
- this->data(place).read(buf);
+ void serialize_without_key_to_column(ConstAggregateDataPtr __restrict
place,
+ IColumn& to) const override {
+ this->data(place).insert_result_into(to);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
this->data(place).insert_result_into(to);
}
+ [[nodiscard]] MutableColumnPtr create_serialize_column() const override {
+ return get_return_type()->create_column();
+ }
+
+ [[nodiscard]] DataTypePtr get_serialized_type() const override { return
get_return_type(); }
+
protected:
using IAggregateFunction::argument_types;
};
diff --git a/be/src/vec/columns/column_map.cpp
b/be/src/vec/columns/column_map.cpp
index ac7c5da1a9..8126b3e4e6 100644
--- a/be/src/vec/columns/column_map.cpp
+++ b/be/src/vec/columns/column_map.cpp
@@ -98,8 +98,6 @@ MutableColumnPtr ColumnMap::clone_resized(size_t to_size)
const {
// to support field functions
Field ColumnMap::operator[](size_t n) const {
- // Map is FieldVector, now we keep key value in seperate , see in field.h
- Map m(2);
size_t start_offset = offset_at(n);
size_t element_size = size_at(n);
@@ -116,9 +114,7 @@ Field ColumnMap::operator[](size_t n) const {
v[i] = get_values()[start_offset + i];
}
- m.push_back(k);
- m.push_back(v);
- return m;
+ return Map {k, v};
}
// here to compare to below
diff --git a/be/src/vec/exec/vaggregation_node.cpp
b/be/src/vec/exec/vaggregation_node.cpp
index 93cc3d97e9..c483e02ffa 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -685,11 +685,13 @@ Status
AggregationNode::_get_without_key_result(RuntimeState* state, Block* bloc
}
}
- ColumnPtr ptr = std::move(columns[i]);
- // unless `count`, other aggregate function dispose empty set
should be null
- // so here check the children row return
- ptr = make_nullable(ptr, _children[0]->rows_returned() == 0);
- columns[i] = std::move(*ptr).mutate();
+ if (column_type->is_nullable() && !data_types[i]->is_nullable()) {
+ ColumnPtr ptr = std::move(columns[i]);
+ // unless `count`, other aggregate function dispose empty set
should be null
+ // so here check the children row return
+ ptr = make_nullable(ptr, _children[0]->rows_returned() == 0);
+ columns[i] = ptr->assume_mutable();
+ }
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 2391e1ec84..41a9de9b66 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1040,12 +1040,21 @@ public class FunctionSet<T> {
}
if (!Type.JSONB.equals(t)) {
- for (Type valueType : Type.getTrivialTypes()) {
- addBuiltin(AggregateFunction.createBuiltin(MAP_AGG,
Lists.newArrayList(t, valueType), new MapType(t, valueType),
+ for (Type valueType : Type.getMapSubTypes()) {
+ addBuiltin(AggregateFunction.createBuiltin(MAP_AGG,
Lists.newArrayList(t, valueType),
+ new MapType(t, valueType),
Type.VARCHAR,
"", "", "", "", "", null, "",
true, true, false, true));
}
+
+ for (Type v : Type.getArraySubTypes()) {
+ addBuiltin(AggregateFunction.createBuiltin(MAP_AGG,
Lists.newArrayList(t, new ArrayType(v)),
+ new MapType(t, new ArrayType(v)),
+ new MapType(t, new ArrayType(v)),
+ "", "", "", "", "", null, "",
+ true, true, false, true));
+ }
}
if (STDDEV_UPDATE_SYMBOL.containsKey(t)) {
diff --git a/regression-test/data/query_p0/aggregate/map_agg.out
b/regression-test/data/query_p0/aggregate/map_agg.out
index 0b8d5f3be0..62c8ecc101 100644
--- a/regression-test/data/query_p0/aggregate/map_agg.out
+++ b/regression-test/data/query_p0/aggregate/map_agg.out
@@ -20,3 +20,6 @@
4 V4_1 V4_2 V4_3
5 V5_1 V5_2 V5_3
+-- !sql3 --
+{"key":["ab", "efg", NULL]}
+
diff --git a/regression-test/suites/query_p0/aggregate/map_agg.groovy
b/regression-test/suites/query_p0/aggregate/map_agg.groovy
index e779e6061e..2337f2fcea 100644
--- a/regression-test/suites/query_p0/aggregate/map_agg.groovy
+++ b/regression-test/suites/query_p0/aggregate/map_agg.groovy
@@ -168,6 +168,10 @@ suite("map_agg") {
ORDER BY `id`;
"""
+ qt_sql3 """
+ select map_agg(k, v) from (select 'key' as k, array('ab', 'efg', null)
v) a;
+ """
+
sql "DROP TABLE `test_map_agg`"
sql "DROP TABLE `test_map_agg_nullable`"
sql "DROP TABLE `test_map_agg_numeric_key`"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]