This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 8194398028c [fix](round) Fix incorrect decimal scale inference in
round functions (#34471)
8194398028c is described below
commit 8194398028c5a6731858dd96b0036f23bc3a4800
Author: zhiqiang <[email protected]>
AuthorDate: Fri May 10 16:09:46 2024 +0800
[fix](round) Fix incorrect decimal scale inference in round functions
(#34471)
* FIX NEEDED
* FORMAT
* FORMAT
* FIX TEST
---
be/src/vec/functions/round.h | 114 ++++++++++++-------
.../functions/ComputePrecisionForRound.java | 7 +-
.../sql_functions/math_functions/test_round.out | 123 +++++++++++++++++++++
.../sql_functions/math_functions/test_round.groovy | 35 +++++-
4 files changed, 237 insertions(+), 42 deletions(-)
diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index 97a81f644ed..a17865914c4 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -21,13 +21,17 @@
#pragma once
#include <cstddef>
+#include <memory>
#include "common/exception.h"
#include "common/status.h"
#include "vec/columns/column_const.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
+#include "vec/core/column_with_type_and_name.h"
#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_nullable.h"
#include "vec/functions/function.h"
#if defined(__SSE4_1__) || defined(__aarch64__)
#include "util/sse_util.hpp"
@@ -430,7 +434,10 @@ struct Dispatcher {
FloatRoundingImpl<T, rounding_mode, scale_mode,
tie_breaking_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode,
tie_breaking_mode>>>;
- static ColumnPtr apply_vec_const(const IColumn* col_general, Int16
scale_arg) {
+ // scale_arg: scale for function computation
+ // result_scale: scale for result decimal, this scale is got from planner
+ static ColumnPtr apply_vec_const(const IColumn* col_general, const Int16
scale_arg,
+ [[maybe_unused]] Int16 result_scale) {
if constexpr (IsNumber<T>) {
const auto* const col =
check_and_get_column<ColumnVector<T>>(col_general);
auto col_res = ColumnVector<T>::create();
@@ -457,10 +464,7 @@ struct Dispatcher {
} else if constexpr (IsDecimalNumber<T>) {
const auto* const decimal_col =
check_and_get_column<ColumnDecimal<T>>(col_general);
const auto& vec_src = decimal_col->get_data();
-
- UInt32 result_scale =
- std::min(static_cast<UInt32>(std::max(scale_arg,
static_cast<Int16>(0))),
- decimal_col->get_scale());
+ const size_t input_rows_count = vec_src.size();
auto col_res = ColumnDecimal<T>::create(vec_src.size(),
result_scale);
auto& vec_res = col_res->get_data();
@@ -468,6 +472,27 @@ struct Dispatcher {
FunctionRoundingImpl<ScaleMode::Negative>::apply(
decimal_col->get_data(), decimal_col->get_scale(),
vec_res, scale_arg);
}
+ // We need to always make sure result decimal's scale is as
expected as its in plan
+ // So we need to append enough zero to result.
+
+ // Case 0: scale_arg <= -(integer part digits count)
+ // do nothing, because result is 0
+ // Case 1: scale_arg <= 0 && scale_arg > -(integer part digits
count)
+ // decimal parts has been erased, so add them back by
multiply 10^(result_scale)
+ // Case 2: scale_arg > 0 && scale_arg < result_scale
+ // decimal part now has scale_arg digits, so multiply
10^(result_scale - scal_arg)
+ // Case 3: scale_arg >= input_scale
+ // do nothing
+
+ if (scale_arg <= 0) {
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ vec_res[i].value *= int_exp10(result_scale);
+ }
+ } else if (scale_arg > 0 && scale_arg < result_scale) {
+ for (size_t i = 0; i < input_rows_count; ++i) {
+ vec_res[i].value *= int_exp10(result_scale - scale_arg);
+ }
+ }
return col_res;
} else {
@@ -477,7 +502,9 @@ struct Dispatcher {
}
}
- static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn*
col_scale) {
+ // result_scale: scale for result decimal, this scale is got from planner
+ static ColumnPtr apply_vec_vec(const IColumn* col_general, const IColumn*
col_scale,
+ [[maybe_unused]] Int16 result_scale) {
const auto& col_scale_i32 = assert_cast<const
ColumnInt32&>(*col_scale);
const size_t input_row_count = col_scale_i32.size();
for (size_t i = 0; i < input_row_count; ++i) {
@@ -515,10 +542,8 @@ struct Dispatcher {
return col_res;
} else if constexpr (IsDecimalNumber<T>) {
const auto* decimal_col = assert_cast<const
ColumnDecimal<T>*>(col_general);
-
- // ALWAYS use SAME scale with source Decimal column
const Int32 input_scale = decimal_col->get_scale();
- auto col_res = ColumnDecimal<T>::create(input_row_count,
input_scale);
+ auto col_res = ColumnDecimal<T>::create(input_row_count,
result_scale);
for (size_t i = 0; i < input_row_count; ++i) {
DecimalRoundingImpl<T, rounding_mode,
tie_breaking_mode>::apply(
@@ -534,15 +559,15 @@ struct Dispatcher {
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits
count)
// decimal parts has been erased, so add them back by
multiply 10^(scale_arg)
- // Case 2: scale_arg > 0 && scale_arg < decimal part digits
count
- // decimal part now has scale_arg digits, so multiply
10^(input_scale - scal_arg)
+ // Case 2: scale_arg > 0 && scale_arg < result_scale
+ // decimal part now has scale_arg digits, so multiply
10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
const Int32 scale_arg = col_scale_i32.get_data()[i];
if (scale_arg <= 0) {
- col_res->get_element(i).value *= int_exp10(input_scale);
- } else if (scale_arg > 0 && scale_arg < input_scale) {
- col_res->get_element(i).value *= int_exp10(input_scale -
scale_arg);
+ col_res->get_element(i).value *= int_exp10(result_scale);
+ } else if (scale_arg > 0 && scale_arg < result_scale) {
+ col_res->get_element(i).value *= int_exp10(result_scale -
scale_arg);
}
}
@@ -554,8 +579,9 @@ struct Dispatcher {
}
}
- static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
- const IColumn* col_scale) {
+ // result_scale: scale for result decimal, this scale is got from planner
+ static ColumnPtr apply_const_vec(const ColumnConst* const_col_general,
const IColumn* col_scale,
+ [[maybe_unused]] Int16 result_scale) {
const auto& col_scale_i32 = assert_cast<const
ColumnInt32&>(*col_scale);
const size_t input_rows_count = col_scale->size();
@@ -575,8 +601,7 @@ struct Dispatcher {
assert_cast<const
ColumnDecimal<T>&>(const_col_general->get_data_column());
const T& general_val = data_col_general.get_data()[0];
Int32 input_scale = data_col_general.get_scale();
-
- auto col_res = ColumnDecimal<T>::create(input_rows_count,
input_scale);
+ auto col_res = ColumnDecimal<T>::create(input_rows_count,
result_scale);
for (size_t i = 0; i < input_rows_count; ++i) {
DecimalRoundingImpl<T, rounding_mode,
tie_breaking_mode>::apply(
@@ -592,15 +617,15 @@ struct Dispatcher {
// do nothing, because result is 0
// Case 1: scale_arg <= 0 && scale_arg > -(integer part digits
count)
// decimal parts has been erased, so add them back by
multiply 10^(scale_arg)
- // Case 2: scale_arg > 0 && scale_arg < decimal part digits
count
- // decimal part now has scale_arg digits, so multiply
10^(input_scale - scal_arg)
+ // Case 2: scale_arg > 0 && scale_arg < result_scale
+ // decimal part now has scale_arg digits, so multiply
10^(result_scale - scal_arg)
// Case 3: scale_arg >= input_scale
// do nothing
const Int32 scale_arg = col_scale_i32.get_data()[i];
if (scale_arg <= 0) {
- col_res->get_element(i).value *= int_exp10(input_scale);
- } else if (scale_arg > 0 && scale_arg < input_scale) {
- col_res->get_element(i).value *= int_exp10(input_scale -
scale_arg);
+ col_res->get_element(i).value *= int_exp10(result_scale);
+ } else if (scale_arg > 0 && scale_arg < result_scale) {
+ col_res->get_element(i).value *= int_exp10(result_scale -
scale_arg);
}
}
@@ -679,26 +704,23 @@ public:
return Status::OK();
}
- /// SELECT number, truncate(123.345, 1) FROM number("numbers"="10")
- /// should NOT behave like two column arguments, so we can not use const
column default implementation
- bool use_default_implementation_for_constants() const override { return
false; }
+ bool use_default_implementation_for_constants() const override { return
true; }
- //// We moved and optimized the execute_impl logic of function_truncate.h
from PR#32746,
- //// as well as make it suitable for all functions.
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override
{
const ColumnWithTypeAndName& column_general =
block.get_by_position(arguments[0]);
+ ColumnWithTypeAndName& column_result = block.get_by_position(result);
+ const DataTypePtr result_type = block.get_by_position(result).type;
const bool is_col_general_const =
is_column_const(*column_general.column);
const auto* col_general = is_col_general_const
? assert_cast<const
ColumnConst&>(*column_general.column)
.get_data_column_ptr()
: column_general.column.get();
-
ColumnPtr res;
/// potential argument types:
/// if the SECOND argument is MISSING(would be considered as ZERO
const) or CONST, then we have the following type:
- /// 1. func(Column), func(ColumnConst), func(Column, ColumnConst),
func(ColumnConst, ColumnConst)
+ /// 1. func(Column), func(Column, ColumnConst)
/// otherwise, the SECOND arugment is COLUMN, we have another type:
/// 2. func(Column, Column), func(ColumnConst, Column)
@@ -706,6 +728,23 @@ public:
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;
+ // For decimal, we will always make sure result Decimal has
exactly same precision and scale with
+ // arguments from query plan.
+ Int16 result_scale = 0;
+ if constexpr (IsDataTypeDecimal<DataType>) {
+ if (column_result.type->get_type_id() == TypeIndex::Nullable) {
+ if (auto nullable_type = std::dynamic_pointer_cast<const
DataTypeNullable>(
+ column_result.type)) {
+ result_scale =
nullable_type->get_nested_type()->get_scale();
+ } else {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
+ "Illegal nullable column");
+ }
+ } else {
+ result_scale = column_result.type->get_scale();
+ }
+ }
+
if constexpr (IsDataTypeNumber<DataType> ||
IsDataTypeDecimal<DataType>) {
using FieldType = typename DataType::FieldType;
if (arguments.size() == 1 ||
@@ -718,23 +757,20 @@ public:
}
res = Dispatcher<FieldType, rounding_mode,
tie_breaking_mode>::apply_vec_const(
- col_general, scale_arg);
-
- if (is_col_general_const) {
- // Important, make sure the result column has the same
size as the input column
- res = ColumnConst::create(std::move(res),
input_rows_count);
- }
+ col_general, scale_arg, result_scale);
} else {
// the SECOND arugment is COLUMN
if (is_col_general_const) {
res = Dispatcher<FieldType, rounding_mode,
tie_breaking_mode>::
apply_const_vec(
&assert_cast<const
ColumnConst&>(*column_general.column),
-
block.get_by_position(arguments[1]).column.get());
+
block.get_by_position(arguments[1]).column.get(),
+ result_scale);
} else {
res = Dispatcher<FieldType, rounding_mode,
tie_breaking_mode>::
apply_vec_vec(col_general,
-
block.get_by_position(arguments[1]).column.get());
+
block.get_by_position(arguments[1]).column.get(),
+ result_scale);
}
}
return true;
@@ -758,7 +794,7 @@ public:
column_general.type->get_name(),
name);
}
- block.replace_by_position(result, std::move(res));
+ column_result.column = std::move(res);
return Status::OK();
}
};
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
index eedbfea6df9..b47804e23ff 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputePrecisionForRound.java
@@ -37,9 +37,12 @@ public interface ComputePrecisionForRound extends
ComputePrecision {
Expression floatLength = getArgument(1);
int scale;
- if (floatLength.isLiteral() || (floatLength instanceof Cast &&
floatLength.child(0).isLiteral()
+ // If scale arg is an integer literal, or it is a cast(Integer as
Integer)
+ // then we will try to use its value as result scale
+ // In any other cases, we will make sure result decimal has same
scale with input.
+ if ((floatLength.isLiteral() && floatLength.getDataType()
instanceof Int32OrLessType)
+ || (floatLength instanceof Cast &&
floatLength.child(0).isLiteral()
&& floatLength.child(0).getDataType() instanceof
Int32OrLessType)) {
- // Scale argument is a literal or cast from other literal
if (floatLength instanceof Cast) {
scale = ((IntegerLikeLiteral)
floatLength.child(0)).getIntValue();
} else {
diff --git
a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
index 1ebc9cf5b89..ccdd9551f80 100644
--- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
+++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
@@ -1,4 +1,115 @@
-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select --
+123.100
+
+-- !select --
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+
+-- !select --
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+
+-- !select --
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+
+-- !select --
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+
+-- !select --
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+123.200
+
+-- !select --
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+130.000
+
+-- !select --
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+123.100
+
+-- !select --
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+120.000
+
+-- !select --
+4434.41
+
+-- !select --
+0
+
+-- !select --
+false \N 4434
+
+-- !select --
+0
+
-- !select --
10
@@ -97,6 +208,18 @@
-- !select --
16.025 16.02500 16.02500
+-- !select_fix --
+16.025 16.02500 16.02500
+
+-- !select_fix --
+16.025 16.02500 16.02500
+
+-- !select_fix --
+16.025 16.02500 16.02500
+
+-- !select_fix --
+16.025 16.02500 16.02500
+
-- !nereids_round_arg1 --
10
diff --git
a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
index 1d8bbb9df49..da361e15938 100644
---
a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
+++
b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
@@ -15,7 +15,35 @@
// specific language governing permissions and limitations
// under the License.
- suite("test_round") {
+suite("test_round") {
+ sql "set enable_fold_constant_by_be=false;"
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+
+ qt_select "SELECT round(123.123, 1.123);"
+ qt_select """SELECT round(123.123, 1.123) FROM numbers("number"="10");"""
+ qt_select """SELECT round(123.123, -1.123) FROM numbers("number"="10");"""
+ qt_select """SELECT truncate(123.123, 1.123) FROM
numbers("number"="10");"""
+ qt_select """SELECT truncate(123.123, -1.123) FROM
numbers("number"="10");"""
+ qt_select """SELECT ceil(123.123, 1.123) FROM numbers("number"="10");"""
+ qt_select """SELECT ceil(123.123, -1.123) FROM numbers("number"="10");"""
+ qt_select """SELECT round_bankers(123.123, 1.123) FROM
numbers("number"="10");"""
+ qt_select """SELECT round_bankers(123.123, -1.123) FROM
numbers("number"="10");"""
+ sql """drop table if exists test_round_1; """
+ sql """
+ create table test_round_1(big_key bigint not NULL)
+ DISTRIBUTED BY HASH(big_key) BUCKETS 1 PROPERTIES
("replication_num" = "1");
+ """
+ qt_select """SELECT truncate(cast(round(8990.65 - 4556.2354, 2.4652) as
Decimal(9,4)), 2);"""
+ qt_select """SELECT cast(round(round(465.56,min(-5.987)),2) as DECIMAL)"""
+ qt_select """
+ SELECT truncate(100,2)<-2308.57 ,
cast(round(round(465.56,min(-5.987)),2) as DECIMAL) ,
cast(truncate(round(8990.65-4556.2354,2.4652),2)as DECIMAL) from test_round_1;
+ """
+
+ qt_select """
+ SELECT truncate(123456789.123456789, -9);
+ """
+
qt_select "SELECT round(10.12345)"
qt_select "SELECT round(10.12345, 2)"
qt_select "SELECT round_bankers(10.12345)"
@@ -62,6 +90,11 @@
qt_select """ SELECT truncate(col1, 7), truncate(col2, 7), truncate(col3,
7) FROM `${tableName}`; """
qt_select """ SELECT round_bankers(col1, 7), round_bankers(col2, 7),
round_bankers(col3, 7) FROM `${tableName}`; """
+ qt_select_fix """ SELECT round(col1, 6.234), round(col2, 6.234),
round(col3, 6.234) FROM `${tableName}`; """
+ qt_select_fix """ SELECT floor(col1, 6.234), floor(col2, 6.234),
floor(col3, 6.234) FROM `${tableName}`; """
+ qt_select_fix """ SELECT truncate(col1, 6.234), truncate(col2, 6.234),
truncate(col3, 6.234) FROM `${tableName}`; """
+ qt_select_fix """ SELECT round_bankers(col1, 6.234), round_bankers(col2,
6.234), round_bankers(col3, 6.234) FROM `${tableName}`; """
+
sql """ DROP TABLE IF EXISTS `${tableName}` """
sql "SET enable_nereids_planner=true"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]