github-actions[bot] commented on code in PR #64774: URL: https://github.com/apache/doris/pull/64774#discussion_r3468566674
########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/SumMap.java: ########## @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.NullType; +import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.base.Preconditions; + +import java.util.List; + +/** AggregateFunction 'sum_map'. */ +public class SumMap extends MapAggregateFunction { Review Comment: `sum_map` now has the same session-dependent decimal precision behavior as `sum`: `enable_decimal256` changes both the decimal argument cast and the return value precision. Persisted SQL objects only preserve creation-time session variables when the expression implements `NeedSessionVarGuard`, so a view/generated column created with `enable_decimal256=true` can later be re-analyzed under `false` and silently change the `MAP<...,DECIMAL256>` value type/cast back to DECIMAL128. Please mark this class, and `AvgMap` which reads the same session variable, with `NeedSessionVarGuard` and add persisted-session coverage. ########## be/src/exprs/aggregate/aggregate_function_map_combinator.cpp: ########## @@ -0,0 +1,439 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/aggregate/aggregate_function_map_combinator.h" + +#include "agent/be_exec_version_manager.h" +#include "core/call_on_type_index.h" +#include "core/column/column_const.h" +#include "core/column/column_decimal.h" +#include "core/column/column_map.h" +#include "core/column/column_nullable.h" +#include "core/column/column_string.h" +#include "core/column/column_vector.h" +#include "core/data_type/data_type_map.h" +#include "core/data_type/data_type_nullable.h" +#include "core/data_type/data_type_string.h" +#include "core/string_buffer.hpp" +#include "exprs/aggregate/aggregate_function_simple_factory.h" +#include "exprs/aggregate/factory_helpers.h" +#include "exprs/aggregate/helpers.h" +#include "exprs/function/function_helpers.h" + +namespace doris { +namespace { + +std::string nested_function_name(const std::string& name) { + if (name == "sum_map") { + return "sum"; + } + if (name == "avg_map") { + return "avg"; + } + if (name == "min_map") { + return "min"; + } + if (name == "max_map") { + return "max"; + } + if (name == "count_map") { + return "count"; + } + throw Exception(ErrorCode::INTERNAL_ERROR, "Unknown map aggregate function {}", name); +} + +const DataTypeMap* get_map_type(const DataTypePtr& type) { + return type ? check_and_get_data_type<DataTypeMap>(remove_nullable(type).get()) : nullptr; +} + +DataTypePtr result_value_type_or_argument_value_type(const DataTypePtr& result_type, + const DataTypeMap& argument_map_type) { + const auto* result_map_type = get_map_type(result_type); + if (result_map_type != nullptr) { + return result_map_type->get_value_type(); + } + return argument_map_type.get_value_type(); +} + +AggregateFunctionPtr create_nested_function(const std::string& name, const DataTypeMap& map_type, + const DataTypePtr& result_value_type, + const AggregateFunctionAttr& attr) { + DataTypes nested_argument_types {map_type.get_value_type()}; + DataTypePtr nested_result_type = remove_nullable(result_value_type); + const bool nested_result_is_nullable = result_value_type->is_nullable(); + + auto nested_function = AggregateFunctionSimpleFactory::instance().get( + nested_function_name(name), nested_argument_types, nested_result_type, + nested_result_is_nullable, BeExecVersionManager::get_newest_version(), attr); + if (nested_function == nullptr) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Can not create nested aggregate function for {}", name); + } + return nested_function; +} + +template <PrimitiveType KeyType> +struct AggregateFunctionMapCombinatorDataTyped { + using KeyColumnType = typename PrimitiveTypeTraits<KeyType>::ColumnType; + using Key = typename PrimitiveTypeTraits<KeyType>::CppType; + using MapKey = std::conditional_t<is_string_type(KeyType), StringRef, Key>; + using Map = flat_hash_map<MapKey, size_t>; Review Comment: This fast path uses `flat_hash_map<MapKey, size_t>` with the default `phmap::Hash<MapKey>`. Because `dispatch_switch_all()` below includes `TYPE_LARGEINT`, this template is instantiated with `MapKey = PrimitiveTypeTraits<TYPE_LARGEINT>::CppType`, i.e. raw `__int128_t`. Generic `phmap::Hash<T>` falls back to `std::hash<T>()`, but libstdc++ does not provide `std::hash<__int128_t>` and Doris only specializes hashes for decimal wrappers/`UInt128`, not the raw `Int128` alias. Please add an explicit hash for this key type or keep `TYPE_LARGEINT` out of this typed path until it is supported. ########## be/src/exprs/aggregate/aggregate_function_map_combinator.cpp: ########## @@ -0,0 +1,439 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "exprs/aggregate/aggregate_function_map_combinator.h" + +#include "agent/be_exec_version_manager.h" +#include "core/call_on_type_index.h" +#include "core/column/column_const.h" +#include "core/column/column_decimal.h" +#include "core/column/column_map.h" +#include "core/column/column_nullable.h" +#include "core/column/column_string.h" +#include "core/column/column_vector.h" +#include "core/data_type/data_type_map.h" +#include "core/data_type/data_type_nullable.h" +#include "core/data_type/data_type_string.h" +#include "core/string_buffer.hpp" +#include "exprs/aggregate/aggregate_function_simple_factory.h" +#include "exprs/aggregate/factory_helpers.h" +#include "exprs/aggregate/helpers.h" +#include "exprs/function/function_helpers.h" + +namespace doris { +namespace { + +std::string nested_function_name(const std::string& name) { + if (name == "sum_map") { + return "sum"; + } + if (name == "avg_map") { + return "avg"; + } + if (name == "min_map") { + return "min"; + } + if (name == "max_map") { + return "max"; + } + if (name == "count_map") { + return "count"; + } + throw Exception(ErrorCode::INTERNAL_ERROR, "Unknown map aggregate function {}", name); +} + +const DataTypeMap* get_map_type(const DataTypePtr& type) { + return type ? check_and_get_data_type<DataTypeMap>(remove_nullable(type).get()) : nullptr; +} + +DataTypePtr result_value_type_or_argument_value_type(const DataTypePtr& result_type, + const DataTypeMap& argument_map_type) { + const auto* result_map_type = get_map_type(result_type); + if (result_map_type != nullptr) { + return result_map_type->get_value_type(); + } + return argument_map_type.get_value_type(); +} + +AggregateFunctionPtr create_nested_function(const std::string& name, const DataTypeMap& map_type, + const DataTypePtr& result_value_type, + const AggregateFunctionAttr& attr) { + DataTypes nested_argument_types {map_type.get_value_type()}; + DataTypePtr nested_result_type = remove_nullable(result_value_type); + const bool nested_result_is_nullable = result_value_type->is_nullable(); + + auto nested_function = AggregateFunctionSimpleFactory::instance().get( + nested_function_name(name), nested_argument_types, nested_result_type, + nested_result_is_nullable, BeExecVersionManager::get_newest_version(), attr); + if (nested_function == nullptr) { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Can not create nested aggregate function for {}", name); + } + return nested_function; +} + +template <PrimitiveType KeyType> +struct AggregateFunctionMapCombinatorDataTyped { + using KeyColumnType = typename PrimitiveTypeTraits<KeyType>::ColumnType; + using Key = typename PrimitiveTypeTraits<KeyType>::CppType; + using MapKey = std::conditional_t<is_string_type(KeyType), StringRef, Key>; + using Map = flat_hash_map<MapKey, size_t>; + + AggregateFunctionMapCombinatorDataTyped() = default; + + explicit AggregateFunctionMapCombinatorDataTyped(DataTypePtr key_type_) + : key_type(std::move(key_type_)), + key_column(key_type->create_column()), + key_arena(std::make_unique<Arena>()) {} + + void clear() { + key_to_index.clear(); + places.clear(); + key_column->clear(); + has_null_key = false; + null_index = 0; + key_arena = std::make_unique<Arena>(); + } + + MapKey stable_key(MapKey key) { + if constexpr (is_string_type(KeyType)) { + // Aggregation states outlive input columns, so StringRef keys kept in the hash map + // must point to state-owned memory instead of ColumnString storage. + key.data = key_arena->insert(key.data, key.size); + } + return key; + } + + DataTypePtr key_type; + MutableColumnPtr key_column; + Map key_to_index; + std::vector<AggregateDataPtr> places; + bool has_null_key = false; + size_t null_index = 0; + std::unique_ptr<Arena> key_arena; +}; + +template <PrimitiveType KeyType> +class AggregateFunctionMapCombinatorTyped final + : public IAggregateFunctionDataHelper<AggregateFunctionMapCombinatorDataTyped<KeyType>, + AggregateFunctionMapCombinatorTyped<KeyType>>, + UnaryExpression, + NotNullableAggregateFunction { +public: + using Data = AggregateFunctionMapCombinatorDataTyped<KeyType>; + using Base = IAggregateFunctionDataHelper<Data, AggregateFunctionMapCombinatorTyped<KeyType>>; + using KeyColumnType = typename Data::KeyColumnType; + using MapKey = typename Data::MapKey; + + AggregateFunctionMapCombinatorTyped(std::string name_, AggregateFunctionPtr nested_function_, + const DataTypes& argument_types_) + : Base(argument_types_), + _name(std::move(name_)), + _nested_function(std::move(nested_function_)) { + const auto* map_type = + assert_cast<const DataTypeMap*>(remove_nullable(this->argument_types[0]).get()); + _key_type = make_nullable(map_type->get_key_type()); + _value_type = make_nullable(_nested_function->get_return_type()); + } + + void set_version(const int version_) override { + Base::set_version(version_); + _nested_function->set_version(version_); + } + + String get_name() const override { return _name; } + + DataTypePtr get_return_type() const override { + return std::make_shared<DataTypeMap>(_key_type, _value_type); + } + + void create(AggregateDataPtr __restrict place) const override { new (place) Data(_key_type); } + + void destroy(AggregateDataPtr __restrict place) const noexcept override { + auto& data_ = this->data(place); + destroy_nested_places(data_); + data_.~Data(); + } + + void reset(AggregateDataPtr place) const override { + auto& data_ = this->data(place); + destroy_nested_places(data_); + data_.clear(); + } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena& arena) const override { + const IColumn* column = columns[0]; + ssize_t actual_row = row_num; + if (const auto* const_column = check_and_get_column<ColumnConst>(*column)) { + column = &const_column->get_data_column(); + actual_row = 0; + } + + const auto& map_column = assert_cast<const ColumnMap&>(*column); + const auto& offsets = map_column.get_offsets(); + const size_t offset = actual_row == 0 ? 0 : offsets[actual_row - 1]; + const size_t size = offsets[actual_row] - offset; + const auto& key_column = map_column.get_keys(); + const auto& value_column = map_column.get_values(); + const IColumn* nested_columns[1] = {&value_column}; + + auto& data_ = this->data(place); + for (size_t i = 0; i != size; ++i) { + const size_t row = offset + i; + AggregateDataPtr nested_place = get_or_create_place(data_, key_column, row, arena); + _nested_function->add(nested_place, nested_columns, row, arena); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena& arena) const override { + auto& data_ = this->data(place); + const auto& rhs_data = this->data(rhs); + DCHECK_EQ(rhs_data.key_column->size(), rhs_data.places.size()); + + for (size_t i = 0; i != rhs_data.places.size(); ++i) { + AggregateDataPtr nested_place = + get_or_create_place(data_, *rhs_data.key_column, i, arena); + _nested_function->merge(nested_place, rhs_data.places[i], arena); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + const auto& data_ = this->data(place); + DCHECK_EQ(data_.key_column->size(), data_.places.size()); + + auto serialized_bytes = + _key_type->get_uncompressed_serialized_bytes(*data_.key_column, this->version); + buf.write_var_uint(serialized_bytes); + buf.resize(serialized_bytes); + auto* buf_ptr = _key_type->serialize(*data_.key_column, buf.data(), this->version); + DCHECK_EQ(buf_ptr, buf.data() + serialized_bytes); + buf.add_offset(serialized_bytes); + + for (auto* nested_place : data_.places) { + _nested_function->serialize(nested_place, buf); + } + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena& arena) const override { + auto& data_ = this->data(place); + DCHECK(data_.places.empty()); + + std::string serialized_buffer; + buf.read_binary(serialized_buffer); + const auto* ptr = + _key_type->deserialize(serialized_buffer.data(), &data_.key_column, this->version); + DCHECK_EQ(ptr - serialized_buffer.data(), serialized_buffer.size()); + + const size_t size = data_.key_column->size(); + data_.places.reserve(size); + for (size_t i = 0; i != size; ++i) { + AggregateDataPtr nested_place = create_nested_place(arena); + _nested_function->deserialize(nested_place, buf, arena); + data_.places.push_back(nested_place); + add_key_index(data_, *data_.key_column, i, i); + } + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + auto& map_column = assert_cast<ColumnMap&>(to); + auto& key_column = map_column.get_keys(); + auto& value_column = map_column.get_values(); + auto& offsets = map_column.get_offsets(); + + const auto& data_ = this->data(place); + DCHECK_EQ(data_.key_column->size(), data_.places.size()); + for (size_t i = 0; i != data_.places.size(); ++i) { + key_column.insert_from(*data_.key_column, i); + insert_nested_result(data_.places[i], value_column); + } + + offsets.push_back(value_column.size()); + } + +private: + AggregateDataPtr create_nested_place(Arena& arena) const { + auto* nested_place = arena.aligned_alloc(_nested_function->size_of_data(), + _nested_function->align_of_data()); + _nested_function->create(nested_place); + return nested_place; + } + + void destroy_nested_places(Data& data) const noexcept { + for (auto* nested_place : data.places) { + _nested_function->destroy(nested_place); + } + } + + static bool is_null_key(const IColumn& key_column, size_t row) { + if (const auto* nullable_column = check_and_get_column<ColumnNullable>(key_column)) { + return nullable_column->is_null_at(row); + } + return false; + } + + static const IColumn& nested_key_column(const IColumn& key_column) { + if (const auto* nullable_column = check_and_get_column<ColumnNullable>(key_column)) { + return nullable_column->get_nested_column(); + } + return key_column; + } + + static MapKey get_key(const IColumn& key_column, size_t row) { + const auto& typed_column = + assert_cast<const KeyColumnType&, TypeCheckOnRelease::DISABLE>(key_column); + if constexpr (is_string_type(KeyType)) { + return typed_column.get_data_at(row); + } else { + return typed_column.get_data()[row]; + } + } + + static void insert_key_from(Data& data_, const IColumn& key_column, size_t row) { + auto& nullable_column = + assert_cast<ColumnNullable&, TypeCheckOnRelease::DISABLE>(*data_.key_column); + if (is_null_key(key_column, row)) { + nullable_column.insert_default(); + return; + } + + nullable_column.get_nested_column().insert_from(nested_key_column(key_column), row); + nullable_column.get_null_map_data().push_back(0); + } + + void add_key_index(Data& data_, const IColumn& key_column, size_t row, size_t index) const { + if (is_null_key(key_column, row)) { + DCHECK(!data_.has_null_key); + data_.has_null_key = true; + data_.null_index = index; + return; + } + + auto key = get_key(nested_key_column(key_column), row); + key = data_.stable_key(key); + data_.key_to_index.emplace(key, index); + } + + AggregateDataPtr get_or_create_place(Data& data_, const IColumn& key_column, size_t row, + Arena& arena) const { + if (is_null_key(key_column, row)) { + if (data_.has_null_key) { + return data_.places[data_.null_index]; + } + + AggregateDataPtr nested_place = create_nested_place(arena); + data_.null_index = data_.places.size(); + data_.has_null_key = true; + insert_key_from(data_, key_column, row); + data_.places.push_back(nested_place); + return nested_place; + } + + auto key = get_key(nested_key_column(key_column), row); + AggregateDataPtr nested_place = nullptr; + bool inserted = false; + auto it = data_.key_to_index.lazy_emplace(key, [&](const auto& ctor) { + inserted = true; + nested_place = create_nested_place(arena); + ctor(data_.stable_key(key), data_.places.size()); + }); + if (!inserted) { + return data_.places[it->second]; + } + + insert_key_from(data_, key_column, row); + data_.places.push_back(nested_place); + return nested_place; + } + + void insert_nested_result(ConstAggregateDataPtr nested_place, IColumn& value_column) const { + if (_nested_function->get_return_type()->is_nullable()) { + _nested_function->insert_result_into(nested_place, value_column); + return; + } + + if (auto* nullable_column = check_and_get_column<ColumnNullable>(value_column)) { + _nested_function->insert_result_into(nested_place, + nullable_column->get_nested_column()); + nullable_column->get_null_map_data().push_back(0); + return; + } + + _nested_function->insert_result_into(nested_place, value_column); + } + + std::string _name; + AggregateFunctionPtr _nested_function; + DataTypePtr _key_type; + DataTypePtr _value_type; +}; + +template <PrimitiveType KeyType> +AggregateFunctionPtr create_aggregate_function_map_combinator_typed( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable, + const AggregateFunctionAttr& attr, AggregateFunctionPtr nested_function) { + return creator_without_type::create<AggregateFunctionMapCombinatorTyped<KeyType>>( + argument_types, result_is_nullable, attr, name, nested_function); +} + +AggregateFunctionPtr create_aggregate_function_map_combinator(const std::string& name, + const DataTypes& argument_types, + const DataTypePtr& result_type, + const bool result_is_nullable, + const AggregateFunctionAttr& attr) { + assert_arity_range(name, argument_types, 1, 1); + + const auto* map_type = get_map_type(argument_types[0]); + if (map_type == nullptr) { + LOG(WARNING) << fmt::format("unsupported input type {} for aggregate function {}", + argument_types[0]->get_name(), name); + return nullptr; + } + + DataTypePtr result_value_type = + result_value_type_or_argument_value_type(result_type, *map_type); + AggregateFunctionPtr nested_function = + create_nested_function(name, *map_type, result_value_type, attr); + + AggregateFunctionPtr typed_function; + auto call = [&](const auto& type) -> bool { + using DispatchType = std::decay_t<decltype(type)>; + typed_function = create_aggregate_function_map_combinator_typed<DispatchType::PType>( + name, argument_types, result_is_nullable, attr, nested_function); + return true; + }; + + const PrimitiveType key_type = remove_nullable(map_type->get_key_type())->get_primitive_type(); + DORIS_CHECK(dispatch_switch_all(key_type, call)) Review Comment: This `DORIS_CHECK` turns FE-accepted map key types into a BE abort. `dispatch_switch_all()` only covers scalar/date/IP/string families, but FE validation still allows ARRAY/MAP/STRUCT as MAP subtypes and `MapAggregateFunction.customSignature()` preserves the argument key type; existing regression coverage even creates a native `map<map<...>, map<...>>` column. A query such as `count_map` over one of those MAP columns can therefore reach this precondition and crash instead of returning a result or a clean analysis error. Please restore a generic key path or reject unsupported map aggregate key types during FE analysis. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
