HappenLee commented on code in PR #50181:
URL: https://github.com/apache/doris/pull/50181#discussion_r2057587756
##########
be/src/vec/aggregate_functions/aggregate_function_collect.cpp:
##########
@@ -17,125 +17,94 @@
#include "vec/aggregate_functions/aggregate_function_collect.h"
-#include <fmt/format.h>
-
-#include <boost/iterator/iterator_facade.hpp>
#include <type_traits>
+#include "common/exception.h"
+#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
-template <typename T, typename HasLimit, typename ShowNull>
+template <typename T, typename HasLimit>
AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const
DataTypes& argument_types,
const bool
result_is_nullable) {
- if (argument_types[0]->is_nullable()) {
- if constexpr (ShowNull::value) {
- return
creator_without_type::create_ignore_nullable<AggregateFunctionCollect<
- AggregateFunctionArrayAggData<T>, std::false_type,
std::true_type>>(
- argument_types, result_is_nullable);
- }
- }
-
- if constexpr (!std::is_same_v<T, void>) {
- if (distinct) {
- return creator_without_type::create<AggregateFunctionCollect<
- AggregateFunctionCollectSetData<T, HasLimit>, HasLimit,
std::false_type>>(
- argument_types, result_is_nullable);
+ if (distinct) {
+ if constexpr (std::is_same_v<T, void>) {
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "unexpected type for collect, please check the
input");
} else {
return creator_without_type::create<AggregateFunctionCollect<
- AggregateFunctionCollectListData<T, HasLimit>, HasLimit,
std::false_type>>(
- argument_types, result_is_nullable);
+ AggregateFunctionCollectSetData<T, HasLimit>,
HasLimit>>(argument_types,
+
result_is_nullable);
}
- } else if (!distinct) {
- // void type means support array/map/struct type for collect_list
- return creator_without_type::create<AggregateFunctionCollect<
- AggregateFunctionCollectListData<void, HasLimit>, HasLimit,
std::false_type>>(
+ } else {
+ return creator_without_type::create<
+ AggregateFunctionCollect<AggregateFunctionCollectListData<T,
HasLimit>, HasLimit>>(
argument_types, result_is_nullable);
}
- return nullptr;
}
-template <typename HasLimit, typename ShowNull>
+template <typename HasLimit>
AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- bool distinct = false;
- if (name == "collect_set") {
- distinct = true;
- }
+ bool distinct = name == "collect_set";
WhichDataType which(remove_nullable(argument_types[0]));
-#define DISPATCH(TYPE)
\
- if (which.idx == TypeIndex::TYPE)
\
- return do_create_agg_function_collect<TYPE, HasLimit,
ShowNull>(distinct, argument_types, \
-
result_is_nullable);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return do_create_agg_function_collect<TYPE, HasLimit>(distinct,
argument_types, \
+
result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
if (which.is_date_or_datetime()) {
- return do_create_agg_function_collect<Int64, HasLimit,
ShowNull>(distinct, argument_types,
-
result_is_nullable);
+ return do_create_agg_function_collect<Int64, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
} else if (which.is_date_v2()) {
- return do_create_agg_function_collect<UInt32, HasLimit,
ShowNull>(distinct, argument_types,
-
result_is_nullable);
+ return do_create_agg_function_collect<UInt32, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
} else if (which.is_date_time_v2()) {
- return do_create_agg_function_collect<UInt64, HasLimit,
ShowNull>(distinct, argument_types,
-
result_is_nullable);
+ return do_create_agg_function_collect<UInt64, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
} else if (which.is_ipv6()) {
- return do_create_agg_function_collect<IPv6, HasLimit,
ShowNull>(distinct, argument_types,
-
result_is_nullable);
+ return do_create_agg_function_collect<IPv6, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
} else if (which.is_ipv4()) {
- return do_create_agg_function_collect<IPv4, HasLimit,
ShowNull>(distinct, argument_types,
-
result_is_nullable);
+ return do_create_agg_function_collect<IPv4, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
} else if (which.is_string()) {
- return do_create_agg_function_collect<StringRef, HasLimit, ShowNull>(
- distinct, argument_types, result_is_nullable);
+ return do_create_agg_function_collect<StringRef, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
} else {
- // generic serialize which will not use specializations,
ShowNull::value always means array_agg
- if constexpr (ShowNull::value) {
- return do_create_agg_function_collect<void, HasLimit, ShowNull>(
- distinct, argument_types, result_is_nullable);
- } else {
- return do_create_agg_function_collect<void, HasLimit, ShowNull>(
- distinct, argument_types, result_is_nullable);
- }
+ // generic serialize which will not use specializations::value always
means array_agg
+ return do_create_agg_function_collect<void, HasLimit>(distinct,
argument_types,
+
result_is_nullable);
}
-
- LOG(WARNING) << fmt::format("unsupported input type {} for aggregate
function {}",
- argument_types[0]->get_name(), name);
- return nullptr;
}
AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable,
const
AggregateFunctionAttr& attr) {
if (argument_types.size() == 1) {
- if (name == "array_agg") {
- return create_aggregate_function_collect_impl<std::false_type,
std::true_type>(
- name, argument_types, result_is_nullable);
- } else {
- return create_aggregate_function_collect_impl<std::false_type,
std::false_type>(
- name, argument_types, result_is_nullable);
- }
+ return create_aggregate_function_collect_impl<std::false_type>(name,
argument_types,
+
result_is_nullable);
}
if (argument_types.size() == 2) {
- return create_aggregate_function_collect_impl<std::true_type,
std::false_type>(
- name, argument_types, result_is_nullable);
+ return create_aggregate_function_collect_impl<std::true_type>(name,
argument_types,
+
result_is_nullable);
}
- LOG(WARNING) << fmt::format("number of parameters for aggregate function
{}, should be 1 or 2",
- name);
- return nullptr;
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "unexpected type for collect, please check the input");
}
void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory&
factory) {
// notice: array_agg only differs from collect_list in that array_agg will
show null elements in array
Review Comment:
remove the comment
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]