This is an automated email from the ASF dual-hosted git repository.
lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 7076744de8a [opt](BinaryArithmetic)Optimize FunctionBinaryArithmetic
by distributing types during the open phase. (#50082)
7076744de8a is described below
commit 7076744de8af3d8e2a00a53c221529d333d7541e
Author: Mryange <[email protected]>
AuthorDate: Thu Apr 17 11:35:17 2025 +0800
[opt](BinaryArithmetic)Optimize FunctionBinaryArithmetic by distributing
types during the open phase. (#50082)
In the past, we determined the type during the exec phase.
However, there was an issue where the type was evaluated sequentially
each time, resulting in multiple evaluations for certain types that
appear later.
```C++
template <typename F>
static bool cast_type(const IDataType* type, F&& f) {
return cast_type_to_either<DataTypeUInt8, DataTypeInt8,
DataTypeInt16, DataTypeInt32,
DataTypeInt64, DataTypeInt128,
DataTypeFloat32, DataTypeFloat64,
DataTypeDecimal<Decimal32>,
DataTypeDecimal<Decimal64>,
DataTypeDecimal<Decimal128V2>,
DataTypeDecimal<Decimal128V3>,
DataTypeDecimal<Decimal256>>(type,
std::forward<F>(f));
}
```
---
be/src/vec/functions/function_binary_arithmetic.h | 149 +++++++++++++---------
1 file changed, 88 insertions(+), 61 deletions(-)
diff --git a/be/src/vec/functions/function_binary_arithmetic.h
b/be/src/vec/functions/function_binary_arithmetic.h
index 13efdf9ddbd..666ee6471f2 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -20,6 +20,8 @@
#pragma once
+#include <functional>
+#include <memory>
#include <type_traits>
#include "common/exception.h"
@@ -34,6 +36,7 @@
#include "vec/core/types.h"
#include "vec/core/wide_integer.h"
#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_factory.hpp"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/number_traits.h"
@@ -792,13 +795,13 @@ struct BinaryOperationTraits {
DataTypeFromFieldType<typename Op::ResultType>>>;
};
-template <typename LeftDataType, typename RightDataType, typename
ExpectedResultDataType,
+template <typename LeftDataType, typename RightDataType, typename
FEResultDataType,
template <typename, typename> class Operation, typename Name, bool
is_to_null_type,
bool check_overflow_for_decimal>
struct ConstOrVectorAdapter {
static constexpr bool result_is_decimal =
IsDataTypeDecimal<LeftDataType> ||
IsDataTypeDecimal<RightDataType>;
- using ResultDataType = ExpectedResultDataType;
+ using ResultDataType = FEResultDataType;
using ResultType = typename ResultDataType::FieldType;
using A = typename LeftDataType::FieldType;
using B = typename RightDataType::FieldType;
@@ -931,6 +934,13 @@ private:
}
};
+struct BinaryArithmeticState {
+ std::function<Status(FunctionContext*, Block&, const ColumnNumbers&,
uint32_t, size_t)> impl;
+ DataTypePtr left_type;
+ DataTypePtr right_type;
+ DataTypePtr result_type;
+};
+
template <template <typename, typename> class Operation, typename Name, bool
is_to_null_type>
class FunctionBinaryArithmetic : public IFunction {
using OpTraits = OperationTraits<Operation>;
@@ -1032,91 +1042,108 @@ public:
return type_res;
}
- Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
- uint32_t result, size_t input_rows_count) const
override {
- auto* left_generic = block.get_by_position(arguments[0]).type.get();
- auto* right_generic = block.get_by_position(arguments[1]).type.get();
- auto* result_generic = block.get_by_position(result).type.get();
- if (left_generic->is_nullable()) {
- left_generic =
- static_cast<const
DataTypeNullable*>(left_generic)->get_nested_type().get();
- }
- if (right_generic->is_nullable()) {
- right_generic =
- static_cast<const
DataTypeNullable*>(right_generic)->get_nested_type().get();
- }
- if (result_generic->is_nullable()) {
- result_generic =
- static_cast<const
DataTypeNullable*>(result_generic)->get_nested_type().get();
+ Status open(FunctionContext* context, FunctionContext::FunctionStateScope
scope) override {
+ if (scope == FunctionContext::THREAD_LOCAL) {
+ return Status::OK();
}
-
- bool check_overflow_for_decimal =
context->check_overflow_for_decimal();
- Status status;
+ std::shared_ptr<BinaryArithmeticState> state =
std::make_shared<BinaryArithmeticState>();
+ context->set_function_state(scope, state);
+
+ state->left_type =
+
DataTypeFactory::instance().create_data_type(*context->get_arg_type(0), false);
+ state->right_type =
+
DataTypeFactory::instance().create_data_type(*context->get_arg_type(1), false);
+ state->result_type =
+
DataTypeFactory::instance().create_data_type(context->get_return_type(), false);
+ const auto* left_generic = state->left_type.get();
+ const auto* right_generic = state->right_type.get();
+ const auto* result_generic = state->result_type.get();
+
+ const bool check_overflow_for_decimal =
context->check_overflow_for_decimal();
bool valid = cast_both_types(
left_generic, right_generic, result_generic,
[&](const auto& left, const auto& right, const auto& res) {
using LeftDataType = std::decay_t<decltype(left)>;
using RightDataType = std::decay_t<decltype(right)>;
- using ExpectedResultDataType = std::decay_t<decltype(res)>;
- using ResultDataType =
+ using FEResultDataType = std::decay_t<decltype(res)>;
+ using BEResultDataType =
typename BinaryOperationTraits<Operation,
LeftDataType,
RightDataType>::ResultDataType;
if constexpr (
- !std::is_same_v<ResultDataType, InvalidType> &&
- (IsDataTypeDecimal<ExpectedResultDataType> ==
+ (!std::is_same_v<BEResultDataType,
+ InvalidType> /* Cannot be
InvalidType */) &&
+ (IsDataTypeDecimal<FEResultDataType> ==
IsDataTypeDecimal<
-
ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> ==
-
(IsDataTypeDecimal<LeftDataType> ||
-
IsDataTypeDecimal<RightDataType>))) {
+ BEResultDataType> /* The type planned by
FE and the type planned by BE must both be Decimal or not */) &&
+ (IsDataTypeDecimal<FEResultDataType> ==
+ (IsDataTypeDecimal<LeftDataType> ||
+ IsDataTypeDecimal<
+ RightDataType>)/* Only when at least one
of left or right is Decimal, the return value can be Decimal */)) {
if (check_overflow_for_decimal) {
// !is_to_null_type: plus, minus, multiply,
// pow, bitxor, bitor, bitand
// if check_overflow and params are decimal types:
// for functions pow, bitxor, bitor, bitand,
return error
- if constexpr (IsDataTypeDecimal<ResultDataType> &&
!is_to_null_type &&
- !OpTraits::is_multiply &&
!OpTraits::is_plus_minus) {
- status =
Status::Error<ErrorCode::NOT_IMPLEMENTED_ERROR>(
- "cannot check overflow with decimal
for function {}", name);
- return false;
- }
- auto column_result = ConstOrVectorAdapter<
- LeftDataType, RightDataType,
-
std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
- ExpectedResultDataType,
ResultDataType>,
- Operation, Name, is_to_null_type,
-
true>::execute(block.get_by_position(arguments[0]).column,
-
block.get_by_position(arguments[1]).column, left,
- right,
- remove_nullable(
-
block.get_by_position(result).type));
- block.replace_by_position(result,
std::move(column_result));
+ static_assert(
+ !(IsDataTypeDecimal<BEResultDataType> &&
!is_to_null_type &&
+ !OpTraits::is_multiply &&
!OpTraits::is_plus_minus),
+ "cannot check overflow with decimal for
function");
+
+ state->impl = execute_with_type<LeftDataType,
RightDataType,
+ FEResultDataType,
true>;
} else {
- auto column_result = ConstOrVectorAdapter<
- LeftDataType, RightDataType,
-
std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
- ExpectedResultDataType,
ResultDataType>,
- Operation, Name, is_to_null_type,
-
false>::execute(block.get_by_position(arguments[0]).column,
-
block.get_by_position(arguments[1]).column,
- left, right,
- remove_nullable(
-
block.get_by_position(result).type));
- block.replace_by_position(result,
std::move(column_result));
+ state->impl = execute_with_type<LeftDataType,
RightDataType,
+ FEResultDataType,
false>;
}
+
return true;
}
return false;
});
if (!valid) {
- if (status.ok()) {
- return Status::RuntimeError("{}'s arguments do not match the
expected data types",
- get_name());
- }
- return status;
+ return Status::RuntimeError("{}'s arguments do not match the
expected data types",
+ get_name());
}
return Status::OK();
}
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ uint32_t result, size_t input_rows_count) const
override {
+ auto* state = reinterpret_cast<BinaryArithmeticState*>(
+ context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
+ if (!state || !state->impl) {
+ return Status::RuntimeError("function context for function '{}'
must have state;",
+ get_name());
+ }
+ return state->impl(context, block, arguments, result,
input_rows_count);
+ }
+
+ template <typename LeftDataType, typename RightDataType, typename
FEResultDataType,
+ bool check_overflow_for_decimal>
+ static Status execute_with_type(FunctionContext* context, Block& block,
+ const ColumnNumbers& arguments, uint32_t
result,
+ size_t input_rows_count) {
+ const auto& left_type =
+ assert_cast<const
LeftDataType&>(*block.get_by_position(arguments[0]).type);
+ const auto& right_type =
+ assert_cast<const
RightDataType&>(*block.get_by_position(arguments[1]).type);
+
+ using BEResultDataType = typename BinaryOperationTraits<Operation,
LeftDataType,
+
RightDataType>::ResultDataType;
+
+ using ExpectedResultDataType =
std::conditional_t<IsDataTypeDecimal<FEResultDataType>,
+ FEResultDataType,
BEResultDataType>;
+ auto column_result =
+ ConstOrVectorAdapter<LeftDataType, RightDataType,
ExpectedResultDataType, Operation,
+ Name, is_to_null_type,
check_overflow_for_decimal>::
+ execute(block.get_by_position(arguments[0]).column,
+ block.get_by_position(arguments[1]).column,
left_type, right_type,
+
remove_nullable(block.get_by_position(result).type));
+ block.replace_by_position(result, std::move(column_result));
+
+ return Status::OK();
+ }
};
} // namespace doris::vectorized
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]